ESMFold2-Fast / modeling_esmfold2_common.py
lhallee's picture
Upload folder using huggingface_hub
407875b verified
Raw
History Blame Contribute Delete
106 kB
# coding=utf-8
# Copyright 2026 Biohub. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
"""Shared building blocks for ESMFold2 HuggingFace model variants."""
from __future__ import annotations
import random
import importlib
from contextlib import contextmanager
from functools import partial
from typing import cast
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.checkpoint import checkpoint
try:
flash_attn_module = importlib.import_module("flash_attn")
flash_bert_padding = importlib.import_module("flash_attn.bert_padding")
flash_attn_func = flash_attn_module.flash_attn_func
flash_attn_varlen_func = flash_attn_module.flash_attn_varlen_func
index_first_axis = flash_bert_padding.index_first_axis
pad_input = flash_bert_padding.pad_input
FLASH_ATTN_AVAILABLE = True
except ImportError:
flash_attn_func = None # type: ignore[assignment]
flash_attn_varlen_func = None # type: ignore[assignment]
index_first_axis = None # type: ignore[assignment]
pad_input = None # type: ignore[assignment]
FLASH_ATTN_AVAILABLE = False
try:
cue_module = importlib.import_module("cuequivariance_torch")
cue_triangle = importlib.import_module("cuequivariance_torch.primitives.triangle")
_cue_attn_pair_bias = cue_module.attention_pair_bias
_cue_tri_mul = cue_triangle.triangle_multiplicative_update
CUE_AVAILABLE = True
except ImportError:
_cue_attn_pair_bias = None # type: ignore[assignment]
_cue_tri_mul = None # type: ignore[assignment]
CUE_AVAILABLE = False
# The Biohub release includes optional Triton kernels. FastPLMs keeps the
# reference path enabled by default so Hugging Face remote-code loading can stay
# flat and self-contained.
_fused_pair_bias = None
_fused_trimul_with_residual = None
_FusedLNLinearSwiGLU = None
_FusedDropoutResidual = None
TRITON_KERNELS_AVAILABLE = False
from .configuration_esmfold2 import ESMFold2Config
BACKEND_FUSED = "fused"
BACKEND_CUEQ = "cuequivariance"
_VALID_BACKENDS = (None, BACKEND_FUSED, BACKEND_CUEQ)
def _fused_active(module: nn.Module, tensor: Tensor) -> bool:
"""Common preconditions for the vendored fused Triton inference kernels."""
return (
TRITON_KERNELS_AVAILABLE
and getattr(module, "_kernel_backend", None) == BACKEND_FUSED
and not torch.is_grad_enabled()
and tensor.is_cuda
)
def _cueq_active(module: nn.Module) -> bool:
return CUE_AVAILABLE and getattr(module, "_kernel_backend", None) == BACKEND_CUEQ
class DropoutResidual(nn.Module):
"""``residual + dropout(delta)`` with row/col-shared dropout.
Same signature on both paths. ``use_fused_kernels=True`` + ``batch_dim=1``
routes through ``FusedDropoutResidual`` (single-pass over pair tensor,
in-place residual add). Falls back to unfused otherwise.
"""
def __init__(
self, r: float, batch_dim: int, use_fused_kernels: bool = False
) -> None:
super().__init__()
assert batch_dim in (1, 2), f"batch_dim must be 1 or 2, got {batch_dim}"
self._use_fused_kernels = (
use_fused_kernels and batch_dim == 1 and _FusedDropoutResidual is not None
)
self._batch_dim = batch_dim
self._r = r
if self._use_fused_kernels:
assert _FusedDropoutResidual is not None
self._impl: nn.Module = _FusedDropoutResidual(r)
else:
self._impl = nn.Dropout(r)
def forward(self, residual: Tensor, delta: Tensor) -> Tensor:
if self._use_fused_kernels:
return self._impl(residual, delta)
# Unfused: row/col-shared dropout via [1, ...] mask broadcast.
if self._r == 0.0 or not self.training:
return residual + delta
shape = list(delta.shape)
shape[self._batch_dim] = 1
mask = self._impl(delta.new_ones(shape))
return residual + delta * mask
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
CHAR_VOCAB_SIZE: int = 64
MAX_CHARS: int = 4
XYZ_DIMS: int = 3
MAX_ATOMIC_NUMBER: int = 128
# Input feature dim = 3 + 1 + 1 + 128 + 64*4 = 389
ATOM_FEATURE_DIM: int = (
XYZ_DIMS + 1 + 1 + MAX_ATOMIC_NUMBER + CHAR_VOCAB_SIZE * MAX_CHARS
)
NUM_RES_TYPES: int = 33
_EPS = 1e-5
# Default for the triangle / OPM / pair-transition L² ops. Caps peak memory
# so L≈2k folds on an 80 GB GPU (~76 GB peak at chunk=128 for L=1438;
# chunk=64 leaves headroom for the largest foldbench targets). Override via
# ``model.set_chunk_size(...)``; pass None to disable chunking (faster for
# short L but OOM-prone past ~600).
_DEFAULT_CHUNK_SIZE = 64
# ===========================================================================
# MSA inference-time diversity augmentations
# ===========================================================================
def maybe_subsample_msa(
msa: Tensor,
msa_attention_mask: Tensor | None,
has_deletion: Tensor | None,
deletion_value: Tensor | None,
*,
max_depth: int | None,
enabled: bool,
) -> tuple[Tensor, Tensor | None, Tensor | None, Tensor | None]:
if not enabled or max_depth is None:
return msa, msa_attention_mask, has_deletion, deletion_value
depth = msa.size(1)
if depth <= 1 or depth <= max_depth:
return msa, msa_attention_mask, has_deletion, deletion_value
indices = torch.zeros(max_depth, dtype=torch.long, device=msa.device)
indices[1:] = torch.randperm(depth - 1, device=msa.device)[: max_depth - 1] + 1
indices = indices.sort().values
msa = msa[:, indices]
if msa_attention_mask is not None:
msa_attention_mask = msa_attention_mask[:, indices]
if has_deletion is not None:
has_deletion = has_deletion[:, indices]
if deletion_value is not None:
deletion_value = deletion_value[:, indices]
return msa, msa_attention_mask, has_deletion, deletion_value
def maybe_apply_msa_column_masking(
msa_attention_mask: Tensor | None,
rate: float,
) -> Tensor | None:
if msa_attention_mask is None or rate <= 0.0 or msa_attention_mask.size(1) <= 1:
return msa_attention_mask
batch_size, _, length = msa_attention_mask.shape
col_keep = torch.rand(batch_size, length, device=msa_attention_mask.device) >= rate
col_keep = col_keep.unsqueeze(1).expand_as(msa_attention_mask).clone()
col_keep[:, 0, :] = True
return msa_attention_mask.bool() & col_keep
# ===========================================================================
# Atom-token utilities
# ===========================================================================
def gather_token_to_atom(token_features: Tensor, atom_to_token_idx: Tensor) -> Tensor:
"""Broadcast per-token features to per-atom features using gather.
Args:
token_features: [B, L, d]
atom_to_token_idx: [B, A] int64
Returns:
[B, A, d]
"""
idx = atom_to_token_idx.unsqueeze(-1).expand(-1, -1, token_features.size(-1))
return torch.gather(token_features, 1, idx)
def scatter_atom_to_token(
atom_features: Tensor,
atom_to_token_idx: Tensor,
n_tokens: int,
atom_mask: Tensor | None = None,
) -> Tensor:
"""Aggregate per-atom features to per-token features (mean).
Args:
atom_features: [B, A, d]
atom_to_token_idx: [B, A] int64
n_tokens: L
atom_mask: [B, A] bool
Returns:
[B, L, d]
"""
B, A, d = atom_features.shape
n_out = n_tokens
idx = atom_to_token_idx
if atom_mask is not None:
idx = torch.where(atom_mask, atom_to_token_idx, n_tokens)
n_out = n_tokens + 1
idx_expanded = idx.unsqueeze(-1).expand(B, A, d)
out = torch.zeros(
B, n_out, d, device=atom_features.device, dtype=atom_features.dtype
)
out.scatter_reduce_(
1, idx_expanded, atom_features, reduce="mean", include_self=False
)
return out[:, :n_tokens, :]
def gather_rep_atom_coords(coords: Tensor, rep_atom_idx: Tensor) -> Tensor:
"""Gather representative atom coordinates for each token.
Args:
coords: [B, A, 3]
rep_atom_idx: [B, L] int64
Returns:
[B, L, 3]
"""
idx = rep_atom_idx.unsqueeze(-1).expand(-1, -1, coords.size(-1))
return torch.gather(coords, 1, idx)
def _compute_intra_token_idx(atom_to_token: Tensor) -> Tensor:
"""Compute local atom index within each token (vectorised).
Atoms belonging to the same token are contiguous, so this computes a
running count that resets at each token boundary.
Args:
atom_to_token: [B, A] flat index mapping each atom to its token.
Returns:
[B, A] tensor with values in [0, max_atoms_per_token - 1].
"""
same_as_prev = F.pad(
atom_to_token[:, 1:] == atom_to_token[:, :-1], (1, 0), value=False
)
ones = torch.ones_like(atom_to_token)
cumsum = torch.cumsum(ones, dim=-1)
group_start = cumsum.masked_fill(same_as_prev, 0)
group_start = torch.cummax(group_start, dim=-1).values
return cumsum - group_start
def _categorical_mean(logits: Tensor, start: float, end: float) -> Tensor:
"""Expected value of a categorical distribution over evenly-spaced bins.
Equivalent to ``CategoricalMixture(logits, bins=logits.shape[-1], start, end).mean()``.
Args:
logits: [..., n_bins]
start: left boundary
end: right boundary
Returns:
[...] expected value
"""
n_bins = logits.shape[-1]
edges = torch.linspace(
start, end, n_bins + 1, device=logits.device, dtype=torch.float32
)
v_bins = (edges[:-1] + edges[1:]) / 2 # [n_bins]
return (logits.float().softmax(-1) @ v_bins.unsqueeze(1)).squeeze(-1)
# ===========================================================================
# TransitionLayer (used in DiffusionConditioning)
# ===========================================================================
class TransitionLayer(nn.Module):
"""SwiGLU transition: norm -> a_proj, b_proj -> silu(a)*b -> out_proj."""
def __init__(self, d_model: int, n: int, eps: float = 1e-5) -> None:
super().__init__()
hidden = n * d_model
self.norm = nn.LayerNorm(d_model, eps=eps)
self.a_proj = nn.Linear(d_model, hidden, bias=False)
self.b_proj = nn.Linear(d_model, hidden, bias=False)
self.out_proj = nn.Linear(hidden, d_model, bias=False)
def forward(self, x: Tensor) -> Tensor:
x = self.norm(x)
a = self.a_proj(x)
b = self.b_proj(x)
return self.out_proj(F.silu(a) * b)
# ===========================================================================
# AdaptiveLayerNorm (used in DiffusionTransformer)
# ===========================================================================
class AdaptiveLayerNorm(nn.Module):
"""Adaptive layer normalization (adaLN-Zero)."""
def __init__(self, d_model: int, d_cond: int, eps: float = 1e-5) -> None:
super().__init__()
self.d_model = d_model
self.d_cond = d_cond
self.eps = eps
self.s_scale = nn.Parameter(torch.ones(d_cond))
self.s_gate = nn.Linear(d_cond, d_model, bias=True)
self.s_shift = nn.Linear(d_cond, d_model, bias=False)
def forward(self, a: Tensor, s: Tensor) -> Tensor:
a_norm = F.layer_norm(a, (self.d_model,), None, None, self.eps)
s_norm = F.layer_norm(s, (self.d_cond,), self.s_scale, None, self.eps)
return torch.sigmoid(self.s_gate(s_norm)) * a_norm + self.s_shift(s_norm)
# ===========================================================================
# FourierEmbedding
# ===========================================================================
class FourierEmbedding(nn.Module):
"""Fourier embedding: cos(2*pi*(t*w + b))."""
w: Tensor
b: Tensor
def __init__(self, c: int) -> None:
super().__init__()
self.c = c
self.register_buffer("w", torch.randn(c))
self.register_buffer("b", torch.randn(c))
def forward(self, t_hat: Tensor) -> Tensor:
t = torch.as_tensor(t_hat, device=self.w.device, dtype=self.w.dtype).reshape(-1)
return torch.cos(
2.0 * torch.pi * (t[:, None] * self.w[None, :] + self.b[None, :])
)
# ===========================================================================
# SwiGLU / SwiGLUMLP
# ===========================================================================
def _compute_swiglu_hidden_size(d_model: int, expansion_ratio: int) -> int:
return expansion_ratio * d_model
class SwiGLU(nn.Module):
"""SwiGLU with packed w12 and output w3."""
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: int | None = None,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
self.hidden_features = hidden_features
def forward(self, x: Tensor) -> Tensor:
x12 = self.w12(x)
x1, x2 = x12.split(self.hidden_features, dim=-1)
hidden = F.silu(x1) * x2
return self.w3(hidden)
class SwiGLUMLP(SwiGLU):
"""SwiGLU MLP with packed weights, no bias."""
def __init__(
self, d_model: int, expansion_ratio: int = 4, bias: bool = False
) -> None:
hidden = _compute_swiglu_hidden_size(d_model, expansion_ratio)
super().__init__(
in_features=d_model, hidden_features=hidden, out_features=d_model, bias=bias
)
# ===========================================================================
# SWA Atom Attention components
# ===========================================================================
def _rotate_half(x: Tensor) -> Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_emb_3d(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor:
"""Apply RoPE with batch-dependent cos/sin.
Args:
x: [B, L, H, D]
cos: [B, L, D/2]
sin: [B, L, D/2]
"""
ro_dim = cos.shape[-1] * 2
cos = cos.unsqueeze(2).repeat(1, 1, 1, 2)
sin = sin.unsqueeze(2).repeat(1, 1, 1, 2)
return torch.cat(
[x[..., :ro_dim] * cos + _rotate_half(x[..., :ro_dim]) * sin, x[..., ro_dim:]],
dim=-1,
)
@torch.compiler.disable
def build_3d_rope(
ref_pos: Tensor,
ref_space_uid: Tensor,
head_dim: int,
n_spatial_per_axis: int = 4,
n_uid_pairs: int = 2,
spatial_base_freq: float = 10000.0,
uid_base_freq: float = 10.0,
) -> tuple[Tensor, Tensor]:
"""Build cos/sin for 3D RoPE + UID RoPE."""
device = ref_pos.device
B, N = ref_pos.shape[:2]
half_dim = head_dim // 2
n_spatial_total = 3 * n_spatial_per_axis
spatial_inv_freq = 1.0 / (
spatial_base_freq
** (
torch.arange(0, n_spatial_per_axis, dtype=torch.float32, device=device)
/ n_spatial_per_axis
)
)
uid_inv_freq = 1.0 / (
uid_base_freq
** (
torch.arange(0, n_uid_pairs, dtype=torch.float32, device=device)
/ n_uid_pairs
)
)
pos_f32 = ref_pos.float()
spatial_freqs = torch.einsum("bna,k->bnak", pos_f32, spatial_inv_freq)
spatial_freqs = spatial_freqs.reshape(B, N, n_spatial_total)
uid_f32 = ref_space_uid.float()
uid_freqs = torch.einsum("bn,k->bnk", uid_f32, uid_inv_freq)
n_active = n_spatial_total + n_uid_pairs
freqs = torch.cat([spatial_freqs, uid_freqs], dim=-1)
if n_active < half_dim:
padding = torch.zeros(
B, N, half_dim - n_active, device=device, dtype=torch.float32
)
freqs = torch.cat([freqs, padding], dim=-1)
cos = freqs.cos().to(torch.bfloat16)
sin = freqs.sin().to(torch.bfloat16)
return cos, sin
def qk_norm(x: Tensor) -> Tensor:
return F.rms_norm(x, (x.size(-1),)).to(x.dtype)
# ===========================================================================
# SwiGLUFFN (atom transformer blocks)
# ===========================================================================
class SwiGLUFFN(nn.Module):
"""SwiGLU FFN with rounded hidden size for hardware alignment."""
def __init__(self, d_model: int, expansion_ratio: int = 2) -> None:
super().__init__()
hidden_size = ((expansion_ratio * (d_model // 3) * 2) + 255) // 256 * 256
self.w_up = nn.Linear(d_model, 2 * hidden_size, bias=False)
self.w_down = nn.Linear(hidden_size, d_model, bias=False)
def forward(self, x: Tensor) -> Tensor:
x = x.to(self.w_up.weight.dtype)
x1, x2 = self.w_up(x).chunk(2, dim=-1)
return self.w_down(F.silu(x1) * x2)
# ===========================================================================
# SWA3DRoPEAttention
# ===========================================================================
class SWA3DRoPEAttention(nn.Module):
"""Sliding window attention with 3D RoPE. Has Wqkv, gate_proj, out_proj."""
def __init__(self, d_model: int, n_heads: int, half_window: int = 64) -> None:
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.scale = self.head_dim**-0.5
self.half_window = half_window
self.Wqkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.gate_proj = nn.Linear(d_model, d_model, bias=False)
def forward(self, x: Tensor, attention_params: tuple) -> Tensor:
B, N = x.shape[:2]
cos, sin = attention_params[0], attention_params[1]
x_input = x
qkv = self.Wqkv(x)
qkv = qkv.view(B, N, 3, self.n_heads, self.head_dim).permute(2, 0, 1, 3, 4)
q, k, v = qkv.unbind(0)
q, k = qk_norm(q), qk_norm(k)
q = apply_rotary_emb_3d(q, cos, sin)
k = apply_rotary_emb_3d(k, cos, sin)
input_dtype = q.dtype
if q.dtype not in (torch.float16, torch.bfloat16):
q, k, v = q.bfloat16(), k.bfloat16(), v.bfloat16()
if len(attention_params) > 2 and FLASH_ATTN_AVAILABLE:
indices, cu_seqlens, max_seqlen = (
attention_params[2],
attention_params[3],
attention_params[4],
)
q_unpad = index_first_axis( # type: ignore[misc]
q.reshape(-1, self.n_heads, self.head_dim), indices
)
k_unpad = index_first_axis( # type: ignore[misc]
k.reshape(-1, self.n_heads, self.head_dim), indices
)
v_unpad = index_first_axis( # type: ignore[misc]
v.reshape(-1, self.n_heads, self.head_dim), indices
)
out_unpad = flash_attn_varlen_func( # type: ignore[misc]
q_unpad,
k_unpad,
v_unpad,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
softmax_scale=self.scale,
window_size=(self.half_window, self.half_window),
)
out = pad_input(out_unpad, indices, B, N) # type: ignore[misc]
elif FLASH_ATTN_AVAILABLE:
out = flash_attn_func( # type: ignore[misc]
q,
k,
v,
softmax_scale=self.scale,
window_size=(self.half_window, self.half_window),
)
else:
# Fallback: standard attention (no SWA)
q_t = q.transpose(1, 2)
k_t = k.transpose(1, 2)
v_t = v.transpose(1, 2)
attn = torch.matmul(q_t, k_t.transpose(-2, -1)) * self.scale
attn = F.softmax(attn, dim=-1)
out = torch.matmul(attn, v_t).transpose(1, 2)
out = out.to(input_dtype).reshape(B, N, -1) # type: ignore[union-attr]
out = out * torch.sigmoid(self.gate_proj(x_input))
return self.out_proj(out)
# ===========================================================================
# SWAAtomBlock, SWAAtomTransformer
# ===========================================================================
def _rms_adaln_raw(x: Tensor, scale: Tensor, shift: Tensor) -> Tensor:
return F.rms_norm(x, (x.shape[-1],)) * (1 + scale) + shift
def _gated_residual_raw(x: Tensor, gate: Tensor, y: Tensor) -> Tensor:
return x + gate * y
class SWAAtomBlock(nn.Module):
"""adaLN-Zero + SWA attention + SwiGLU FFN.
Creates adaln_modulation = Sequential(SiLU(), Linear) -> keys like adaln_modulation.1.weight
"""
def __init__(
self,
d_atom: int,
n_heads: int,
half_window: int = 64,
expansion_ratio: int = 2,
use_compile_fusions: bool = False,
) -> None:
super().__init__()
self.attn_norm = nn.RMSNorm(d_atom, elementwise_affine=False)
self.ffn_norm = nn.RMSNorm(d_atom, elementwise_affine=False)
adaln_linear = nn.Linear(d_atom, 6 * d_atom, bias=False)
nn.init.zeros_(adaln_linear.weight)
self.adaln_modulation = nn.Sequential(nn.SiLU(), adaln_linear)
self.attn = SWA3DRoPEAttention(d_atom, n_heads, half_window=half_window)
self.ffn = SwiGLUFFN(d_atom, expansion_ratio)
self._rms_adaln = (
torch.compile(_rms_adaln_raw) if use_compile_fusions else _rms_adaln_raw
)
self._gated_residual = (
torch.compile(_gated_residual_raw)
if use_compile_fusions
else _gated_residual_raw
)
def forward(self, x: Tensor, c_l: Tensor, attention_params: tuple) -> Tensor:
mod = self.adaln_modulation(c_l)
if mod.dim() == 2:
mod = mod.unsqueeze(1)
shift_a, scale_a, gate_a, shift_f, scale_f, gate_f = mod.chunk(6, dim=-1)
attn_input = self._rms_adaln(x, scale_a, shift_a)
attn_out = self.attn(attn_input, attention_params)
x = self._gated_residual(x, gate_a, attn_out)
ffn_input = self._rms_adaln(x, scale_f, shift_f)
ffn_out = self.ffn(ffn_input)
x = self._gated_residual(x, gate_f, ffn_out)
return x
class SWAAtomTransformer(nn.Module):
"""Stack of SWAAtomBlocks."""
def __init__(
self,
d_atom: int = 128,
n_blocks: int = 3,
n_heads: int = 4,
swa_window_size: int = 128,
expansion_ratio: int = 2,
spatial_rope_base_frequency: float = 20.0,
n_spatial_rope_pairs_per_axis: int = 2,
n_uid_rope_pairs: int = 10,
uid_rope_base_frequency: float = 10000.0,
) -> None:
super().__init__()
self.swa_window_size = swa_window_size
self.head_dim = d_atom // n_heads
self.spatial_rope_base_frequency = spatial_rope_base_frequency
self.n_spatial_rope_pairs_per_axis = n_spatial_rope_pairs_per_axis
self.n_uid_rope_pairs = n_uid_rope_pairs
self.uid_rope_base_frequency = uid_rope_base_frequency
self.blocks = nn.ModuleList(
[
SWAAtomBlock(
d_atom=d_atom,
n_heads=n_heads,
half_window=swa_window_size // 2,
expansion_ratio=expansion_ratio,
)
for _ in range(n_blocks)
]
)
def _build_3d_rope(
self, ref_pos: Tensor, ref_space_uid: Tensor
) -> tuple[Tensor, Tensor]:
return build_3d_rope(
ref_pos=ref_pos,
ref_space_uid=ref_space_uid,
head_dim=self.head_dim,
n_spatial_per_axis=self.n_spatial_rope_pairs_per_axis,
n_uid_pairs=self.n_uid_rope_pairs,
spatial_base_freq=self.spatial_rope_base_frequency,
uid_base_freq=self.uid_rope_base_frequency,
)
def forward(
self,
q_l: Tensor,
c_l: Tensor,
attention_params: tuple,
return_intermediates: bool = False,
) -> Tensor | tuple[Tensor, list[Tensor]]:
intermediates: list[Tensor] = []
for block in self.blocks:
q_l = block(q_l, c_l, attention_params)
if return_intermediates:
intermediates.append(q_l)
if return_intermediates:
return q_l, intermediates
return q_l
# ===========================================================================
# ESMFold2AtomEncoder (for both inputs_embedder and diffusion_module)
# ===========================================================================
class ESMFold2AtomEncoder(nn.Module):
"""SWA atom encoder with atom_linear, atom_norm, atom_to_token_linear, [coords_linear], atom_transformer.
Args:
d_atom: atom hidden dim
d_token: token dim for atom_to_token aggregation
n_blocks, n_heads, swa_window_size, expansion_ratio: transformer params
structure_prediction: if True, creates coords_linear and uses full d_token
spatial_rope_base_frequency, n_spatial_rope_pairs_per_axis,
n_uid_rope_pairs, uid_rope_base_frequency: 3D RoPE config
"""
def __init__(
self,
d_atom: int = 128,
d_token: int = 768,
n_blocks: int = 3,
n_heads: int = 4,
swa_window_size: int = 128,
expansion_ratio: int = 2,
structure_prediction: bool = True,
spatial_rope_base_frequency: float = 20.0,
n_spatial_rope_pairs_per_axis: int = 2,
n_uid_rope_pairs: int = 10,
uid_rope_base_frequency: float = 10000.0,
) -> None:
super().__init__()
self.d_atom = d_atom
self.d_token = d_token
self.structure_prediction = structure_prediction
self.atom_linear = nn.Linear(ATOM_FEATURE_DIM, d_atom, bias=False)
self.atom_norm = nn.LayerNorm(d_atom)
if structure_prediction:
self.coords_linear = nn.Linear(6, d_atom, bias=False)
self.atom_transformer = SWAAtomTransformer(
d_atom=d_atom,
n_blocks=n_blocks,
n_heads=n_heads,
swa_window_size=swa_window_size,
expansion_ratio=expansion_ratio,
spatial_rope_base_frequency=spatial_rope_base_frequency,
n_spatial_rope_pairs_per_axis=n_spatial_rope_pairs_per_axis,
n_uid_rope_pairs=n_uid_rope_pairs,
uid_rope_base_frequency=uid_rope_base_frequency,
)
# Output aggregation: d_token for structure prediction, d_token//2 for inputs
out_dim = d_token if structure_prediction else d_token // 2
self.atom_to_token_linear = nn.Linear(d_atom, out_dim, bias=False)
def forward(
self,
ref_pos: Tensor,
atom_attention_mask: Tensor,
ref_space_uid: Tensor,
ref_charge: Tensor,
ref_element: Tensor,
ref_atom_name_chars: Tensor,
atom_to_token: Tensor,
r_l: Tensor | None = None,
pred_r1: Tensor | None = None,
s_i: Tensor | None = None,
z_ij: Tensor | None = None,
num_diffusion_samples: int = 1,
return_intermediates: bool = False,
inference_cache: dict | None = None,
) -> tuple[Tensor, Tensor, Tensor, tuple, list[Tensor]]:
"""Returns (a, q, c, attention_params, intermediates).
``inference_cache`` caches step-invariant tensors (c_base, 3D RoPE,
attention indices, n_tokens) across diffusion steps.
"""
B, N = ref_pos.shape[:2]
layer_cache = None
if inference_cache is not None:
layer_cache = inference_cache.setdefault("atomencoder", {})
if layer_cache is None or len(layer_cache) == 0:
atom_feats = torch.cat(
[
ref_pos,
ref_charge.unsqueeze(-1),
atom_attention_mask.unsqueeze(-1),
ref_element,
ref_atom_name_chars.reshape(B, N, MAX_CHARS * CHAR_VOCAB_SIZE),
],
dim=-1,
)
c_base = self.atom_norm(self.atom_linear(atom_feats))
cos, sin = self.atom_transformer._build_3d_rope(ref_pos, ref_space_uid)
cos = cos.repeat_interleave(num_diffusion_samples, 0)
sin = sin.repeat_interleave(num_diffusion_samples, 0)
mask_exp = atom_attention_mask.repeat_interleave(num_diffusion_samples, 0)
seqlens = mask_exp.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(mask_exp.flatten(), as_tuple=False).flatten()
max_seqlen = int(seqlens.max().item())
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
attention_params = (cos, sin, indices, cu_seqlens, max_seqlen)
n_tokens = int(atom_to_token.max().item()) + 1
if layer_cache is not None:
layer_cache["c_base"] = c_base
layer_cache["attention_params"] = attention_params
layer_cache["mask_exp"] = mask_exp
layer_cache["n_tokens"] = n_tokens
layer_cache["atom_to_token_exp"] = atom_to_token.repeat_interleave(
num_diffusion_samples, 0
)
else:
c_base = layer_cache["c_base"]
attention_params = layer_cache["attention_params"]
mask_exp = layer_cache["mask_exp"]
n_tokens = layer_cache["n_tokens"]
c = c_base
q = c
if self.structure_prediction and r_l is not None:
q = q.repeat_interleave(num_diffusion_samples, 0)
if pred_r1 is None:
pred_r1 = torch.zeros_like(r_l)
r_input = torch.cat([r_l, pred_r1], dim=-1)
r_to_q = self.coords_linear(r_input)
q = q + r_to_q
c = c.repeat_interleave(num_diffusion_samples, 0)
result = self.atom_transformer(
q_l=q,
c_l=c,
attention_params=attention_params,
return_intermediates=return_intermediates,
)
if return_intermediates:
q, intermediates = result
else:
q = result
intermediates = []
q_to_a = F.relu(self.atom_to_token_linear(q))
if layer_cache is not None and "atom_to_token_exp" in layer_cache:
atom_to_token_exp = layer_cache["atom_to_token_exp"]
else:
atom_to_token_exp = atom_to_token.repeat_interleave(
num_diffusion_samples, 0
)
a = scatter_atom_to_token(
q_to_a, atom_to_token_exp, n_tokens, atom_mask=mask_exp.bool()
)
return a, q, c, attention_params, intermediates
# ===========================================================================
# ESMFold2AtomDecoder
# ===========================================================================
class ESMFold2AtomDecoder(nn.Module):
"""SWA atom decoder with token_to_atom_linear, atom_transformer, norm, output_linear."""
def __init__(
self,
d_atom: int = 128,
d_token: int = 768,
n_blocks: int = 3,
n_heads: int = 4,
swa_window_size: int = 128,
expansion_ratio: int = 2,
spatial_rope_base_frequency: float = 20.0,
n_spatial_rope_pairs_per_axis: int = 2,
n_uid_rope_pairs: int = 10,
uid_rope_base_frequency: float = 10000.0,
) -> None:
super().__init__()
self.token_to_atom_linear = nn.Linear(d_token, d_atom, bias=False)
self.atom_transformer = SWAAtomTransformer(
d_atom=d_atom,
n_blocks=n_blocks,
n_heads=n_heads,
swa_window_size=swa_window_size,
expansion_ratio=expansion_ratio,
spatial_rope_base_frequency=spatial_rope_base_frequency,
n_spatial_rope_pairs_per_axis=n_spatial_rope_pairs_per_axis,
n_uid_rope_pairs=n_uid_rope_pairs,
uid_rope_base_frequency=uid_rope_base_frequency,
)
self.norm = nn.LayerNorm(d_atom)
self.output_linear = nn.Linear(d_atom, XYZ_DIMS, bias=False)
def forward(
self,
a_i: Tensor,
q_l: Tensor,
c_l: Tensor,
p_lm: tuple,
atom_to_token: Tensor,
atom_attention_mask: Tensor,
num_diffusion_samples: int = 1,
return_intermediates: bool = False,
) -> tuple[Tensor, list[Tensor]]:
"""Returns (r_update, intermediates)."""
atom_to_token_exp = atom_to_token.repeat_interleave(num_diffusion_samples, 0)
a_to_q = self.token_to_atom_linear(a_i)
a_to_q = gather_token_to_atom(a_to_q, atom_to_token_exp)
q_l = q_l + a_to_q
result = self.atom_transformer(
q_l=q_l,
c_l=c_l,
attention_params=p_lm,
return_intermediates=return_intermediates,
)
if return_intermediates:
q_l, intermediates = result
else:
q_l = result
intermediates = []
r_l = self.output_linear(self.norm(q_l))
return r_l, intermediates
# ===========================================================================
# AttentionPairBias (DiffusionTransformer attention block)
# ===========================================================================
class AttentionPairBias(nn.Module):
"""Gated multi-head attention with pair bias conditioning."""
def __init__(
self,
d_model: int,
d_pair: int,
num_heads: int,
d_cond: int | None = None,
use_conditioning: bool = True,
) -> None:
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.scale = self.head_dim**-0.5
d_cond = d_cond or d_model
if use_conditioning:
self.adaln = AdaptiveLayerNorm(d_model, d_cond, eps=1e-5)
self.out_gate = nn.Linear(d_cond, d_model, bias=True)
# adaln init: weight=0, bias=-2
nn.init.zeros_(self.out_gate.weight)
nn.init.constant_(self.out_gate.bias, -2.0)
else:
self.pre_norm = nn.LayerNorm(d_model, eps=1e-5)
self.q_proj = nn.Linear(d_model, d_model, bias=True)
self.kv_proj = nn.Linear(d_model, 2 * d_model, bias=False)
self.g_proj = nn.Linear(d_model, d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
if d_pair > 0:
self.pair_norm = nn.LayerNorm(d_pair, eps=1e-5)
self.pair_bias_proj = nn.Linear(d_pair, num_heads, bias=False)
self._kernel_backend: str | None = None
def set_kernel_backend(self, backend: str | None) -> None:
if backend not in _VALID_BACKENDS:
raise ValueError(
f"backend must be one of {_VALID_BACKENDS}, got {backend!r}"
)
self._kernel_backend = backend
def _is_zero_beta(self, beta: Tensor | float) -> bool:
if isinstance(beta, (int, float)):
return beta == 0.0
return bool((beta == 0).all())
def _can_use_fused_pair_bias(
self, z: Tensor, n_queries: int, beta: Tensor | float
) -> bool:
return (
_fused_active(self, z)
and z.dim() == 4
and self._is_zero_beta(beta)
and hasattr(self, "pair_bias_proj")
and hasattr(self, "pair_norm")
)
def _can_use_cueq_pair_bias(
self, z: Tensor, n_queries: int, beta: Tensor | float
) -> bool:
return (
_cueq_active(self)
and n_queries > 750
and z.dim() == 4
and self._is_zero_beta(beta)
and hasattr(self, "pair_bias_proj")
)
def forward(
self,
a: Tensor,
s: Tensor | None,
z: Tensor,
beta: Tensor | float = 0.0,
attention_mask: Tensor | None = None,
num_diffusion_samples: int = 1,
) -> Tensor:
bsz, n_queries, d_model = a.shape
if s is not None:
x = self.adaln(a, s)
else:
x = self.pre_norm(a)
n_keys = x.shape[1]
q = self.q_proj(x).view(bsz, n_queries, self.num_heads, self.head_dim)
kv = self.kv_proj(x)
k, v = kv.chunk(2, dim=-1)
k = k.view(bsz, n_keys, self.num_heads, self.head_dim)
v = v.view(bsz, n_keys, self.num_heads, self.head_dim)
# Expand z for num_diffusion_samples
if z.dim() == 4 and z.shape[0] != bsz and num_diffusion_samples > 1:
z = z.repeat_interleave(num_diffusion_samples, dim=0)
if (
attention_mask is not None
and attention_mask.shape[0] != bsz
and num_diffusion_samples > 1
):
attention_mask = attention_mask.repeat_interleave(
num_diffusion_samples, dim=0
)
if self._can_use_fused_pair_bias(z, n_queries, beta):
kernel_mask = (
attention_mask
if attention_mask is not None
else torch.ones(bsz, n_queries, device=a.device, dtype=torch.bool)
)
pair_norm_w = self.pair_norm.weight
pair_norm_b = (
self.pair_norm.bias
if self.pair_norm.bias is not None
else torch.zeros_like(pair_norm_w)
)
z_bf = z if z.dtype == torch.bfloat16 else z.to(torch.bfloat16)
bias = _fused_pair_bias( # type: ignore[misc]
z_bf,
kernel_mask,
self.pair_bias_proj.weight,
num_heads=self.num_heads,
pair_norm_w=pair_norm_w,
pair_norm_b=pair_norm_b,
) # (B, H, Q, K)
q_bhqd = q.transpose(1, 2)
k_bhqd = k.transpose(1, 2)
v_bhqd = v.transpose(1, 2)
attn_out = F.scaled_dot_product_attention(
q_bhqd, k_bhqd, v_bhqd, attn_mask=bias.to(q_bhqd.dtype)
)
g = torch.sigmoid(self.g_proj(x)).view(
bsz, n_queries, self.num_heads, self.head_dim
)
ctx = g * attn_out.transpose(1, 2)
out = self.out_proj(ctx.reshape(bsz, n_queries, d_model))
if s is not None:
out = torch.sigmoid(self.out_gate(s)) * out
return out
if self._can_use_cueq_pair_bias(z, n_queries, beta):
kernel_mask = (
attention_mask
if attention_mask is not None
else torch.ones(bsz, n_queries, device=a.device, dtype=torch.bool)
)
out, _ = _cue_attn_pair_bias( # type: ignore[misc]
s=x,
q=q.transpose(1, 2),
k=k.transpose(1, 2),
v=v.transpose(1, 2),
z=z,
mask=kernel_mask,
num_heads=self.num_heads,
w_proj_z=self.pair_bias_proj.weight,
w_proj_g=self.g_proj.weight,
w_proj_o=self.out_proj.weight,
w_ln_z=self.pair_norm.weight,
b_ln_z=self.pair_norm.bias,
return_z_proj=False,
is_cached_z_proj=False,
)
else:
# Standard attention with pair bias
g = torch.sigmoid(self.g_proj(x)).view(
bsz, n_queries, self.num_heads, self.head_dim
)
logits = (
torch.einsum("... i h d, ... j h d -> ... i j h", q, k) * self.scale
)
if z.dim() == 4:
pair_bias = self.pair_bias_proj(self.pair_norm(z))
else:
pair_bias = z.unsqueeze(-1)
logits = logits + pair_bias.to(dtype=logits.dtype)
if attention_mask is not None:
min_val = torch.finfo(logits.dtype).min
mask_bias = torch.where(
attention_mask.bool()[:, None, :, None], 0.0, min_val
)
logits = logits + mask_bias.to(dtype=logits.dtype)
attn = torch.softmax(logits, dim=-2).to(dtype=v.dtype)
ctx = torch.einsum("... i j h, ... j h d -> ... i h d", attn, v)
ctx = g * ctx
out = self.out_proj(ctx.reshape(bsz, n_queries, d_model))
if s is not None:
out = torch.sigmoid(self.out_gate(s)) * out
return out
# ===========================================================================
# ConditionedTransitionBlock
# ===========================================================================
class ConditionedTransitionBlock(nn.Module):
"""Conditioned SwiGLU transition with adaptive layer norm."""
def __init__(
self,
d_model: int,
d_cond: int | None = None,
transition_multiplier: int = 2,
use_conditioning: bool = True,
) -> None:
super().__init__()
d_cond = d_cond or d_model
hidden = transition_multiplier * d_model
if use_conditioning:
self.adaln = AdaptiveLayerNorm(d_model, d_cond, eps=1e-5)
self.output_gate = nn.Linear(d_cond, d_model, bias=True)
nn.init.zeros_(self.output_gate.weight)
nn.init.constant_(self.output_gate.bias, -2.0)
else:
self.pre_norm = nn.LayerNorm(d_model, eps=1e-5)
self.lin_swish = nn.Linear(d_model, 2 * hidden, bias=False)
self.lin_out = nn.Linear(hidden, d_model, bias=False)
def forward(self, a: Tensor, s: Tensor | None) -> Tensor:
if s is not None:
x = self.adaln(a, s)
else:
x = self.pre_norm(a)
swish_a, swish_b = self.lin_swish(x).chunk(2, dim=-1)
b = F.silu(swish_a) * swish_b
out = self.lin_out(b)
if s is not None:
out = torch.sigmoid(self.output_gate(s)) * out
return out
# ===========================================================================
# DiffusionTransformer (token transformer)
# ===========================================================================
class DiffusionTransformer(nn.Module):
"""Diffusion denoising transformer with attention pair bias."""
def __init__(
self,
d_model: int,
d_pair: int,
num_heads: int,
num_blocks: int,
d_cond: int | None = None,
transition_multiplier: int = 2,
use_conditioning: bool = True,
) -> None:
super().__init__()
d_cond = d_cond or d_model
self.attn_blocks = nn.ModuleList(
[
AttentionPairBias(
d_model=d_model,
d_pair=d_pair,
num_heads=num_heads,
d_cond=d_cond,
use_conditioning=use_conditioning,
)
for _ in range(num_blocks)
]
)
self.transition_blocks = nn.ModuleList(
[
ConditionedTransitionBlock(
d_model=d_model,
d_cond=d_cond,
transition_multiplier=transition_multiplier,
use_conditioning=use_conditioning,
)
for _ in range(num_blocks)
]
)
def set_kernel_backend(self, backend: str | None) -> None:
for attn in self.attn_blocks:
cast(AttentionPairBias, attn).set_kernel_backend(backend)
def forward(
self,
a: Tensor,
s: Tensor | None,
z: Tensor,
beta: Tensor | float = 0.0,
attention_mask: Tensor | None = None,
num_diffusion_samples: int = 1,
return_intermediates: bool = False,
) -> tuple[Tensor, list[Tensor]]:
intermediates: list[Tensor] = []
x = a
for attn, transition in zip(self.attn_blocks, self.transition_blocks):
x = x + attn(
x,
s,
z,
beta,
attention_mask=attention_mask,
num_diffusion_samples=num_diffusion_samples,
)
x = x + transition(x, s)
if return_intermediates:
intermediates.append(x)
return x, intermediates
# ===========================================================================
# DiffusionConditioning
# ===========================================================================
class DiffusionConditioning(nn.Module):
"""Conditions pair and single representations on noise timestep."""
def __init__(
self,
c_z: int = 256,
c_s: int = 768,
c_s_inputs: int = 451,
sigma_data: float = 16.0,
fourier_dim: int = 256,
transition_multiplier: int = 2,
layer_norm_eps: float = 1e-5,
) -> None:
super().__init__()
self.sigma_data = float(sigma_data)
self.c_z = c_z
self.c_s = c_s
self.c_s_inputs = c_s_inputs
self.z_input_norm = nn.LayerNorm(2 * c_z, eps=layer_norm_eps)
self.z_proj = nn.Linear(2 * c_z, c_z, bias=False)
self.z_transitions = nn.ModuleList(
[
TransitionLayer(c_z, n=transition_multiplier, eps=layer_norm_eps)
for _ in range(2)
]
)
self.s_input_norm = nn.LayerNorm(c_s_inputs, eps=layer_norm_eps)
self.s_proj = nn.Linear(c_s_inputs, c_s, bias=False)
self.fourier = FourierEmbedding(fourier_dim)
self.noise_norm = nn.LayerNorm(fourier_dim, eps=layer_norm_eps)
self.noise_proj = nn.Linear(fourier_dim, c_s, bias=False)
self.s_transitions = nn.ModuleList(
[
TransitionLayer(c_s, n=transition_multiplier, eps=layer_norm_eps)
for _ in range(2)
]
)
def forward(
self,
t_hat: Tensor,
s_inputs: Tensor,
s_trunk: Tensor | None,
z_trunk: Tensor,
relative_position_encoding: Tensor,
sigma_data: float | None = None,
num_diffusion_samples: int = 1,
inference_cache: dict[str, Tensor] | None = None,
) -> tuple[Tensor, Tensor]:
sigma = self.sigma_data if sigma_data is None else float(sigma_data)
base_batch = z_trunk.shape[0]
target_batch = base_batch * num_diffusion_samples
# z conditioning (cached across diffusion steps — independent of t_hat)
if inference_cache is not None and "z" in inference_cache:
z = inference_cache["z"]
else:
z_rel = relative_position_encoding.to(dtype=torch.float32)
z = torch.cat([z_trunk.to(dtype=torch.float32), z_rel], dim=-1)
z = self.z_proj(self.z_input_norm(z))
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
for block in self.z_transitions:
z = z + block(z)
if inference_cache is not None:
inference_cache["z"] = z
# s conditioning
s_inputs_eff = s_inputs
if s_inputs_eff.shape[0] != target_batch:
s_inputs_eff = s_inputs_eff.repeat_interleave(num_diffusion_samples, 0)
s = self.s_proj(self.s_input_norm(s_inputs_eff.to(dtype=torch.float32)))
# Noise embedding
t = torch.as_tensor(t_hat, dtype=torch.float32, device=s.device).reshape(-1)
if t.numel() == 1:
t = t.expand(target_batch)
elif t.shape[0] != target_batch:
t = t.repeat_interleave(num_diffusion_samples, 0)
t_noise = 0.25 * torch.log((t / sigma).clamp(min=1e-20))
n = self.fourier(t_noise)
n = self.noise_proj(self.noise_norm(n))
s = s + n.unsqueeze(1)
for block in self.s_transitions:
s = s + block(s)
return s, z
# ===========================================================================
# DiffusionModule
# ===========================================================================
class DiffusionModule(nn.Module):
"""Diffusion denoising module for structure prediction."""
def __init__(
self,
c_atom: int = 128,
c_token: int = 768,
c_z: int = 256,
c_s_inputs: int = 451,
sigma_data: float = 16.0,
fourier_dim: int = 256,
atom_num_blocks: int = 3,
atom_num_heads: int = 4,
token_num_blocks: int = 12,
token_num_heads: int = 16,
transition_multiplier: int = 2,
swa_window_size: int = 128,
spatial_rope_base_frequency: float = 20.0,
n_spatial_rope_pairs_per_axis: int = 2,
n_uid_rope_pairs: int = 10,
uid_rope_base_frequency: float = 10000.0,
) -> None:
super().__init__()
self.sigma_data = float(sigma_data)
self.conditioning = DiffusionConditioning(
c_z=c_z,
c_s=c_token, # conditioning s output is c_token
c_s_inputs=c_s_inputs,
sigma_data=sigma_data,
fourier_dim=fourier_dim,
transition_multiplier=transition_multiplier,
)
# Atom encoder (structure_prediction=True, with coords_linear)
self.atom_encoder = ESMFold2AtomEncoder(
d_atom=c_atom,
d_token=c_token,
n_blocks=atom_num_blocks,
n_heads=atom_num_heads,
swa_window_size=swa_window_size,
expansion_ratio=2,
structure_prediction=True,
spatial_rope_base_frequency=spatial_rope_base_frequency,
n_spatial_rope_pairs_per_axis=n_spatial_rope_pairs_per_axis,
n_uid_rope_pairs=n_uid_rope_pairs,
uid_rope_base_frequency=uid_rope_base_frequency,
)
# Atom decoder
self.atom_decoder = ESMFold2AtomDecoder(
d_atom=c_atom,
d_token=c_token,
n_blocks=atom_num_blocks,
n_heads=atom_num_heads,
swa_window_size=swa_window_size,
expansion_ratio=2,
spatial_rope_base_frequency=spatial_rope_base_frequency,
n_spatial_rope_pairs_per_axis=n_spatial_rope_pairs_per_axis,
n_uid_rope_pairs=n_uid_rope_pairs,
uid_rope_base_frequency=uid_rope_base_frequency,
)
self.s_to_token = nn.Linear(c_token, c_token, bias=False)
nn.init.zeros_(self.s_to_token.weight)
# Token transformer (DiffusionTransformer with pair bias)
self.token_transformer = DiffusionTransformer(
d_model=c_token,
d_pair=c_z,
num_heads=token_num_heads,
num_blocks=token_num_blocks,
d_cond=c_token,
transition_multiplier=transition_multiplier,
use_conditioning=True,
)
self.s_step_norm = nn.LayerNorm(c_token)
self.token_norm = nn.LayerNorm(c_token)
def set_kernel_backend(self, backend: str | None) -> None:
self.token_transformer.set_kernel_backend(backend)
def forward(
self,
x_noisy: Tensor,
t_hat: Tensor,
ref_pos: Tensor,
ref_charge: Tensor,
ref_mask: Tensor,
ref_element: Tensor,
ref_atom_name_chars: Tensor,
ref_space_uid: Tensor,
tok_idx: Tensor,
s_inputs: Tensor,
s_trunk: Tensor | None,
z_trunk: Tensor,
relative_position_encoding: Tensor,
asym_id: Tensor,
residue_index: Tensor,
entity_id: Tensor,
token_index: Tensor,
sym_id: Tensor,
sigma_data: float | None = None,
token_attention_mask: Tensor | None = None,
num_diffusion_samples: int = 1,
return_token_repr: bool = False,
return_atom_repr: bool = False,
inference_cache: dict[str, Tensor] | None = None,
) -> dict[str, Tensor | None]:
bsz = x_noisy.shape[0]
sigma = self.sigma_data if sigma_data is None else float(sigma_data)
t = torch.as_tensor(t_hat, dtype=torch.float32, device=x_noisy.device).reshape(
-1
)
if t.numel() == 1:
t = t.expand(bsz)
# Step 1: conditioning (pair z is cached across diffusion steps)
s, z = self.conditioning(
t_hat=t,
s_inputs=s_inputs,
s_trunk=s_trunk,
z_trunk=z_trunk,
relative_position_encoding=relative_position_encoding,
sigma_data=sigma,
num_diffusion_samples=num_diffusion_samples,
inference_cache=inference_cache,
)
# Step 2: normalize noisy coords
denom = torch.sqrt(t * t + sigma * sigma)
r_noisy = x_noisy / denom[:, None, None]
# Step 3: atom encoder
a, q_skip, c_skip, p_skip, enc_intermediates = self.atom_encoder(
ref_pos=ref_pos,
atom_attention_mask=ref_mask,
ref_space_uid=ref_space_uid,
ref_charge=ref_charge,
ref_element=ref_element,
ref_atom_name_chars=ref_atom_name_chars,
atom_to_token=tok_idx,
r_l=r_noisy,
s_i=s_trunk,
num_diffusion_samples=num_diffusion_samples,
return_intermediates=return_atom_repr,
inference_cache=inference_cache,
)
# Step 4: add conditioned s
a = a + self.s_to_token(self.s_step_norm(s))
# Step 5: token transformer
a, _ = self.token_transformer(
a,
s,
z,
beta=0.0,
attention_mask=token_attention_mask,
num_diffusion_samples=num_diffusion_samples,
)
# Step 6: token norm
a = self.token_norm(a)
# Step 7: atom decoder
r_update, dec_intermediates = self.atom_decoder(
a_i=a,
q_l=q_skip,
c_l=c_skip,
p_lm=p_skip,
atom_to_token=tok_idx,
atom_attention_mask=ref_mask,
num_diffusion_samples=num_diffusion_samples,
return_intermediates=return_atom_repr,
)
# Step 8: compute denoised output
sigma2 = sigma * sigma
t2 = t * t
out = (sigma2 / (sigma2 + t2))[:, None, None] * x_noisy
out = out + ((sigma * t) / torch.sqrt(sigma2 + t2))[:, None, None] * r_update
# Collect atom intermediates from encoder + decoder
atom_intermediates: Tensor | None = None
if return_atom_repr:
all_ints = enc_intermediates + dec_intermediates
if all_ints:
atom_intermediates = torch.stack(all_ints, dim=2)
return {
"x_denoised": out,
"token_repr": a if return_token_repr else None,
"atom_intermediates": atom_intermediates,
}
# ===========================================================================
# DiffusionStructureHead
# ===========================================================================
class DiffusionStructureHead(nn.Module):
"""Wrapper around DiffusionModule with diffusion sampling."""
def __init__(self, config: ESMFold2Config) -> None:
super().__init__()
dm = config.structure_head.diffusion_module
swa_cfg = config.inputs.atom_encoder
sh = config.structure_head
self.diffusion_module = DiffusionModule(
c_atom=dm.c_atom,
c_token=dm.c_token,
c_z=dm.c_z,
c_s_inputs=dm.c_s_inputs,
sigma_data=dm.sigma_data,
fourier_dim=dm.fourier_dim,
atom_num_blocks=dm.atom_num_blocks,
atom_num_heads=dm.atom_num_heads,
token_num_blocks=dm.token_num_blocks,
token_num_heads=dm.token_num_heads,
transition_multiplier=dm.transition_multiplier,
swa_window_size=swa_cfg.swa_window_size,
spatial_rope_base_frequency=swa_cfg.spatial_rope_base_frequency,
n_spatial_rope_pairs_per_axis=swa_cfg.n_spatial_rope_pairs_per_axis,
n_uid_rope_pairs=swa_cfg.n_uid_rope_pairs,
uid_rope_base_frequency=swa_cfg.uid_rope_base_frequency,
)
# Sampling hyperparameters
self.sigma_data = dm.sigma_data
self.gamma_0 = sh.gamma_0
self.gamma_min = sh.gamma_min
self.noise_scale = sh.noise_scale
self.step_scale = sh.step_scale
self.inference_s_max = sh.inference_s_max
self.inference_s_min = sh.inference_s_min
self.inference_p = sh.inference_p
self.inference_num_steps = sh.inference_num_steps
def set_kernel_backend(self, backend: str | None) -> None:
self.diffusion_module.set_kernel_backend(backend)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def inference_noise_schedule(
self, num_steps: int | None = None, device: torch.device | None = None
) -> Tensor:
"""Karras power-law noise schedule."""
steps = self.inference_num_steps if num_steps is None else int(num_steps)
if steps == 1:
return torch.tensor(
[self.inference_s_max * self.sigma_data, 0.0],
device=device,
dtype=torch.float32,
)
p = float(self.inference_p)
inv_p = 1.0 / p
k = torch.arange(steps, device=device, dtype=torch.float32)
base = self.inference_s_max**inv_p + (k / (steps - 1)) * (
self.inference_s_min**inv_p - self.inference_s_max**inv_p
)
schedule = self.sigma_data * base.pow(p)
return F.pad(schedule, (0, 1), value=0.0)
@staticmethod
def _random_rotations(n: int, dtype: torch.dtype, device: torch.device) -> Tensor:
q = torch.randn((n, 4), dtype=dtype, device=device)
scale = torch.sqrt((q * q).sum(dim=1))
signs = torch.where(q[:, 0] < 0, -scale, scale)
q = q / signs[:, None]
r, i, j, k = torch.unbind(q, dim=-1)
two_s = 2.0 / (q * q).sum(dim=-1)
return torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
dim=-1,
).reshape(n, 3, 3)
def _center_random_augmentation(
self, x: Tensor, atom_mask: Tensor, second_coords: Tensor | None = None
) -> tuple[Tensor, Tensor | None]:
"""Algorithm 19: center + random rotation + translation."""
bsz = x.shape[0]
mask = atom_mask.unsqueeze(-1) # [B, A, 1]
denom = mask.sum(dim=1, keepdim=True).clamp(min=1)
mean = (x * mask).sum(dim=1, keepdim=True) / denom
x = x - mean
if second_coords is not None:
second_coords = second_coords - mean
r = self._random_rotations(bsz, x.dtype, x.device)
x = torch.einsum("bmd,bds->bms", x, r)
if second_coords is not None:
second_coords = torch.einsum("bmd,bds->bms", second_coords, r)
t = torch.randn_like(x[:, 0:1, :])
x = x + t
if second_coords is not None:
second_coords = second_coords + t
return x, second_coords
@staticmethod
def _weighted_rigid_align(
x: Tensor, x_gt: Tensor, w: Tensor, mask: Tensor
) -> Tensor:
"""Kabsch alignment: align x to x_gt with weights w."""
w = (mask * w).unsqueeze(-1) # [B, N, 1]
denom = w.sum(dim=-2, keepdim=True).clamp(min=1e-8)
mu = (x * w).sum(dim=-2, keepdim=True) / denom
mu_gt = (x_gt * w).sum(dim=-2, keepdim=True) / denom
x_c = x - mu
xgt_c = x_gt - mu_gt
H = torch.einsum("bni,bnj->bij", w * xgt_c, x_c)
H32 = H.float()
U, _, Vh = torch.linalg.svd(H32, driver="gesvd" if H32.is_cuda else None)
det = torch.linalg.det(U @ Vh)
ones = torch.ones_like(det)
R = (U @ torch.diag_embed(torch.stack([ones, ones, det], dim=-1)) @ Vh).to(
H.dtype
)
return x_c @ R.transpose(-1, -2) + mu_gt
# ------------------------------------------------------------------
# Sampling
# ------------------------------------------------------------------
@torch.inference_mode()
def sample(
self,
z_trunk: Tensor,
s_inputs: Tensor,
s_trunk: Tensor | None,
relative_position_encoding: Tensor,
ref_pos: Tensor,
ref_charge: Tensor,
ref_mask: Tensor,
ref_element: Tensor,
ref_atom_name_chars: Tensor,
ref_space_uid: Tensor,
tok_idx: Tensor,
asym_id: Tensor,
residue_index: Tensor,
entity_id: Tensor,
token_index: Tensor,
sym_id: Tensor,
token_attention_mask: Tensor | None = None,
num_diffusion_samples: int = 1,
num_sampling_steps: int | None = None,
max_inference_sigma: float | None = 256.0,
noise_scale: float | None = None,
step_scale: float | None = None,
return_atom_repr: bool = False,
use_inference_cache: bool = True,
denoising_early_exit_rmsd: float | None = None,
) -> dict[str, Tensor | None]:
"""Diffusion sampling (Algorithm 18).
``num_sampling_steps`` is the number of denoising steps actually run.
When ``max_inference_sigma`` is set, the Karras schedule built with
``num_sampling_steps`` entries would lose its high-σ tail to the cap,
so we inflate the underlying schedule length here to land back at the
requested step count post-truncation.
"""
n_atoms = tok_idx.shape[1]
device = s_inputs.device
target_batch = s_inputs.shape[0] * num_diffusion_samples
inference_cache: dict[str, Tensor] | None = {} if use_inference_cache else None
steps = (
self.inference_num_steps
if num_sampling_steps is None
else int(num_sampling_steps)
)
schedule = self.inference_noise_schedule(steps, device)
if max_inference_sigma is not None:
schedule = schedule[schedule <= float(max_inference_sigma)]
schedule = F.pad(schedule, (1, 0), value=float(max_inference_sigma))
lam = self.noise_scale if noise_scale is None else float(noise_scale)
eta = self.step_scale if step_scale is None else float(step_scale)
x = schedule[0] * torch.randn(
target_batch, n_atoms, 3, device=device, dtype=torch.float32
)
atom_mask = ref_mask.repeat_interleave(num_diffusion_samples, 0).float()
gammas = torch.where(
schedule > self.gamma_min,
torch.full_like(schedule, self.gamma_0),
torch.zeros_like(schedule),
)
x_denoised_prev: Tensor | None = None
token_repr: Tensor | None = None
diff_atom_intermediates: Tensor | None = None
step_pairs = list(zip(schedule[:-1], schedule[1:], gammas[1:]))
num_steps = len(step_pairs)
for step_idx, (sigma_tm, sigma_t, gamma) in enumerate(step_pairs):
x, x_denoised_prev = self._center_random_augmentation(
x, atom_mask, second_coords=x_denoised_prev
)
sigma_tm_val = float(sigma_tm.item())
t_hat_val = sigma_tm_val * (1.0 + float(gamma.item()))
eps_std = lam * max(t_hat_val**2 - sigma_tm_val**2, 0.0) ** 0.5
x_noisy = x + eps_std * torch.randn_like(x)
is_last_step = step_idx == num_steps - 1
request_atom_repr = return_atom_repr and (
is_last_step or denoising_early_exit_rmsd is not None
)
dm_out = self.diffusion_module(
x_noisy=x_noisy,
t_hat=torch.full(
(target_batch,), t_hat_val, device=device, dtype=torch.float32
),
ref_pos=ref_pos,
ref_charge=ref_charge,
ref_mask=ref_mask,
ref_element=ref_element,
ref_atom_name_chars=ref_atom_name_chars,
ref_space_uid=ref_space_uid,
tok_idx=tok_idx,
s_inputs=s_inputs,
s_trunk=s_trunk,
z_trunk=z_trunk,
relative_position_encoding=relative_position_encoding,
asym_id=asym_id,
residue_index=residue_index,
entity_id=entity_id,
token_index=token_index,
sym_id=sym_id,
token_attention_mask=token_attention_mask,
num_diffusion_samples=num_diffusion_samples,
return_token_repr=True,
return_atom_repr=request_atom_repr,
inference_cache=inference_cache,
)
x_denoised = dm_out["x_denoised"]
token_repr = dm_out["token_repr"]
if request_atom_repr:
diff_atom_intermediates = dm_out.get("atom_intermediates")
# Reverse diffusion alignment (Kabsch)
with torch.autocast(device_type="cuda", enabled=False):
x_noisy = self._weighted_rigid_align(
x_noisy.float(), x_denoised.float(), atom_mask, atom_mask
)
x_noisy = x_noisy.to(dtype=x_denoised.dtype)
# ODE/SDE step
sigma_t_val = float(sigma_t.item())
denoised_over_sigma = (x_noisy - x_denoised) / t_hat_val
x = x_noisy + eta * (sigma_t_val - t_hat_val) * denoised_over_sigma
# Denoising early-exit: stop when consecutive predictions converge
if (
denoising_early_exit_rmsd is not None
and x_denoised_prev is not None
and step_idx >= 1
):
with torch.autocast(device_type="cuda", enabled=False):
aligned = self._weighted_rigid_align(
x_denoised_prev.float(),
x_denoised.float(),
atom_mask,
atom_mask,
)
diff = (x_denoised.float() - aligned) * atom_mask.unsqueeze(-1)
per_sample_rmsd = (
diff.pow(2).sum(dim=(-1, -2)) / atom_mask.sum(dim=-1).clamp(min=1)
).sqrt()
if per_sample_rmsd.max().item() < denoising_early_exit_rmsd:
x = x_denoised
x_denoised_prev = x_denoised
break
x_denoised_prev = x_denoised
result: dict[str, Tensor | None] = {
"sample_atom_coords": x,
"diff_token_repr": token_repr,
}
if return_atom_repr:
result["diff_atom_intermediates"] = diff_atom_intermediates
return result
# ===========================================================================
# RowAttentionPooling
# ===========================================================================
class RowAttentionPooling(nn.Module):
"""Row-wise attention pooling: attn_proj, out_proj."""
def __init__(self, d_pair: int, d_single: int) -> None:
super().__init__()
self.attn_proj = nn.Linear(d_pair, 1, bias=False)
self.out_proj = nn.Linear(d_pair, d_single, bias=False)
def forward(self, z: Tensor, mask: Tensor) -> Tensor:
scores = self.attn_proj(z).squeeze(-1)
mask_bias = torch.where(
mask[:, None, :].bool(),
torch.zeros_like(scores),
torch.full_like(scores, -1e9),
)
scores = scores + mask_bias
weights = F.softmax(scores, dim=-1)
pooled = torch.einsum("bnm,bnmd->bnd", weights, z)
return self.out_proj(pooled)
# ===========================================================================
# InputsEmbedder
# ===========================================================================
class InputsEmbedder(nn.Module):
"""Embeds input features including atom-level encoding via SWA attention."""
def __init__(self, config: ESMFold2Config) -> None:
super().__init__()
swa_cfg = config.inputs.atom_encoder
self.atom_attention_encoder = ESMFold2AtomEncoder(
d_atom=swa_cfg.d_atom,
d_token=swa_cfg.d_token,
n_blocks=swa_cfg.n_blocks,
n_heads=swa_cfg.n_heads,
swa_window_size=swa_cfg.swa_window_size,
expansion_ratio=swa_cfg.expansion_ratio,
structure_prediction=False, # no coords_linear
spatial_rope_base_frequency=swa_cfg.spatial_rope_base_frequency,
n_spatial_rope_pairs_per_axis=swa_cfg.n_spatial_rope_pairs_per_axis,
n_uid_rope_pairs=swa_cfg.n_uid_rope_pairs,
uid_rope_base_frequency=swa_cfg.uid_rope_base_frequency,
)
def forward(
self,
aatype: Tensor,
profile: Tensor,
deletion_mean: Tensor,
ref_pos: Tensor,
atom_attention_mask: Tensor,
ref_space_uid: Tensor,
ref_charge: Tensor,
ref_element: Tensor,
ref_atom_name_chars: Tensor,
atom_to_token: Tensor,
) -> Tensor:
"""Embed inputs into per-token features.
Returns:
[B, L, d_inputs] concatenation of atom encoding, aatype, profile,
and deletion_mean.
"""
a, _q, _c, _attn_params, _intermediates = self.atom_attention_encoder(
ref_pos=ref_pos,
atom_attention_mask=atom_attention_mask,
ref_space_uid=ref_space_uid,
ref_charge=ref_charge,
ref_element=ref_element,
ref_atom_name_chars=ref_atom_name_chars,
atom_to_token=atom_to_token,
)
return torch.cat([a, aatype, profile, deletion_mean.unsqueeze(-1)], dim=-1)
# ===========================================================================
# ResIdxAsymIdSymIdEntityIdEncoding (trunk relative position)
# ===========================================================================
class ResIdxAsymIdSymIdEntityIdEncoding(nn.Module):
"""embed.weight [d_pair, n_features] where n_features = 2*(2*r_bins+2) + 1 + (2*c_bins+2).
For default r_bins=32, c_bins=2: 2*66 + 1 + 6 = 139.
"""
def __init__(
self,
n_relative_residx_bins: int = 32,
n_relative_chain_bins: int = 2,
d_pair: int = 256,
) -> None:
super().__init__()
self.n_relative_residx_bins = n_relative_residx_bins
self.n_relative_chain_bins = n_relative_chain_bins
self.d_pair = d_pair
n_feats_residue = 2 * n_relative_residx_bins + 2
n_feats_token = 2 * n_relative_residx_bins + 2
n_feats_chain = 2 * n_relative_chain_bins + 2
n_feats_same_entity = 1
total_feats = (
n_feats_residue + n_feats_token + n_feats_chain + n_feats_same_entity
)
self.embed = nn.Linear(total_feats, d_pair, bias=False)
def forward(
self,
residue_index: Tensor,
asym_id: Tensor,
sym_id: Tensor,
entity_id: Tensor,
token_index: Tensor,
) -> Tensor:
bij_same_chain = asym_id.unsqueeze(2) == asym_id.unsqueeze(1)
bij_same_residue = residue_index.unsqueeze(2) == residue_index.unsqueeze(1)
bij_same_entity = entity_id.unsqueeze(2) == entity_id.unsqueeze(1)
dij_residue = residue_index.unsqueeze(2) - residue_index.unsqueeze(1)
dij_residue = torch.clip(
dij_residue + self.n_relative_residx_bins,
0,
2 * self.n_relative_residx_bins,
)
dij_residue = torch.where(
bij_same_chain, dij_residue, 2 * self.n_relative_residx_bins + 1
)
aij_rel_pos = F.one_hot(dij_residue, 2 * self.n_relative_residx_bins + 2)
dij_token = torch.clip(
token_index.unsqueeze(2)
- token_index.unsqueeze(1)
+ self.n_relative_residx_bins,
0,
2 * self.n_relative_residx_bins,
)
dij_token = torch.where(
bij_same_chain & bij_same_residue,
dij_token,
2 * self.n_relative_residx_bins + 1,
)
aij_rel_token = F.one_hot(dij_token, 2 * self.n_relative_residx_bins + 2)
dij_chain = torch.clip(
sym_id.unsqueeze(2) - sym_id.unsqueeze(1) + self.n_relative_chain_bins,
0,
2 * self.n_relative_chain_bins,
)
dij_chain = torch.where(
bij_same_chain, 2 * self.n_relative_chain_bins + 1, dij_chain
)
aij_rel_chain = F.one_hot(dij_chain, 2 * self.n_relative_chain_bins + 2)
feats = torch.cat(
[
aij_rel_pos.float(),
aij_rel_token.float(),
bij_same_entity.float().unsqueeze(-1),
aij_rel_chain.float(),
],
dim=-1,
)
return self.embed(feats)
# ===========================================================================
# SingleToPair (for LanguageModelShim)
# ===========================================================================
class SingleToPair(nn.Module):
"""downproject, output_mlp (Sequential of Linear, GELU, Linear)."""
def __init__(self, input_dim: int, downproject_dim: int, output_dim: int) -> None:
super().__init__()
self.downproject = nn.Linear(input_dim, downproject_dim)
self.output_mlp = nn.Sequential(
nn.Linear(2 * downproject_dim, output_dim),
nn.GELU(),
nn.Linear(output_dim, output_dim),
)
def forward(self, x: Tensor) -> Tensor:
x = self.downproject(x)
x = torch.cat(
[(x.unsqueeze(2) * x.unsqueeze(1)), (x.unsqueeze(2) - x.unsqueeze(1))],
dim=3,
)
return self.output_mlp(x)
# ===========================================================================
# LanguageModelShim
# ===========================================================================
class LanguageModelShim(nn.Module):
"""Shim holding the trainable projection weights for LM integration.
Contains:
- base_z_combine: nn.Parameter [num_layers+1]
- base_z_linear: Sequential(LayerNorm(d_model), Linear(d_model, d_z, bias=False))
- base_z_mlp: Sequential(SingleToPair(d_z, d_z, d_z), LayerNorm(d_z))
"""
def __init__(
self, d_z: int = 256, d_model: int = 2560, num_layers: int = 80
) -> None:
super().__init__()
self.base_z_mlp = nn.Sequential(SingleToPair(d_z, d_z, d_z), nn.LayerNorm(d_z))
self.base_z_linear = nn.Sequential(
nn.LayerNorm(d_model), nn.Linear(d_model, d_z, bias=False)
)
self.base_z_combine = nn.Parameter(torch.zeros(num_layers + 1))
def forward(self, hidden_states: Tensor, *, lm_dropout: float = 0.0) -> Tensor:
"""Project pre-computed ESMC hidden states to pair representation.
Args:
hidden_states: [B, L, num_layers+1, d_model] from ESMC 6B.
lm_dropout: Dropout probability applied to the pair
representation after ``base_z_mlp``.
Returns:
[B, L, L, d_pair] pair representation.
"""
lm_z = self.base_z_linear(hidden_states) # [B, L, 81, d_z]
weights = self.base_z_combine.softmax(0) # [81]
lm_z = (weights @ lm_z).squeeze(-2) # [B, L, d_z]
lm_z = self.base_z_mlp(lm_z) # [B, L, L, d_z]
if lm_dropout > 0:
lm_z = F.dropout(lm_z, p=lm_dropout, training=True)
return lm_z
# ===========================================================================
# Reproducibility helper (mirrors evolutionaryscale.utils.reproducibility)
# ===========================================================================
@contextmanager
def _seed_context(seed: int | None, *, cuda: bool = True):
"""Temporarily seed Python, NumPy, and PyTorch RNGs."""
if seed is None:
yield
return
py_state = random.getstate()
np_state = np.random.get_state()
torch_state = torch.get_rng_state()
cuda_states = (
torch.cuda.get_rng_state_all() if cuda and torch.cuda.is_available() else None
)
seed = int(seed) % (2**32)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if cuda and torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
try:
yield
finally:
random.setstate(py_state)
np.random.set_state(np_state)
torch.set_rng_state(torch_state)
if cuda_states is not None:
torch.cuda.set_rng_state_all(cuda_states)
# ===========================================================================
# ESMFold2ExperimentalModel — the top-level PreTrainedModel
# ===========================================================================
def compute_lm_hidden_states(
esmc: nn.Module,
input_ids: Tensor,
asym_id: Tensor,
residue_index: Tensor,
mol_type: Tensor,
token_mask: Tensor,
pad_to_multiple: int | None = None,
lm_mask_pct: float = 0.0,
mask_token_id: int = 32,
) -> Tensor:
"""Run ESMC with BOS/EOS wrapping, return hidden states [B, L, N, D] with N=81 layers.
Atom-tokenized modified residues (HYP, MSE, ACE, NH2, ...) span multiple
structure tokens but share a single ``(asym_id, residue_index)`` key —
collapse them to one LM token per residue before running the LM (the LM
was trained on per-residue inputs, not per-atom), then scatter the
hidden states back to the per-token layout.
"""
B, L = input_ids.shape
device = input_ids.device
protein_mask = (mol_type == 0) & token_mask
lm_input_list = []
lm_lengths = []
# Per-batch maps from (original protein-token index) to (LM input position).
expand_maps: list[Tensor] = []
for b in range(B):
mask_b = protein_mask[b]
ids_b = input_ids[b][mask_b]
asym_b = asym_id[b][mask_b]
res_b = residue_index[b][mask_b]
# Collapse: keep first token per (asym_id, residue_index) key, in
# input order. ``inverse`` maps each original protein-token to its
# collapsed residue index.
keys = torch.stack((asym_b, res_b), dim=1)
unique_keys, inverse = torch.unique(keys, dim=0, return_inverse=True)
n_unique = unique_keys.size(0)
token_positions = torch.arange(keys.size(0), device=device, dtype=torch.long)
first_pos = torch.full(
(n_unique,), keys.size(0), device=device, dtype=torch.long
)
first_pos.scatter_reduce_(
0, inverse, token_positions, reduce="amin", include_self=True
)
ordered = torch.argsort(first_pos)
first_pos_ordered = first_pos[ordered]
ids_collapsed = ids_b[first_pos_ordered]
asym_collapsed = asym_b[first_pos_ordered]
remap = torch.empty_like(ordered)
remap[ordered] = torch.arange(n_unique, device=device, dtype=torch.long)
inverse_ordered = remap[inverse]
chain_ids = asym_collapsed.unique(sorted=True)
# [BOS] chain1 [EOS BOS] chain2 ... [EOS]
parts: list[Tensor] = [torch.tensor([0], device=device, dtype=ids_b.dtype)]
# Per-chain LM positions accumulate; track them for the expand map.
per_token_lm_pos = torch.empty(n_unique, device=device, dtype=torch.long)
cursor = 1 # position 0 is the leading BOS
for i, cid in enumerate(chain_ids):
in_chain = (asym_collapsed == cid).nonzero(as_tuple=True)[0]
parts.append(ids_collapsed[in_chain])
per_token_lm_pos[in_chain] = torch.arange(
cursor, cursor + in_chain.shape[0], device=device, dtype=torch.long
)
cursor += in_chain.shape[0]
if i < len(chain_ids) - 1:
parts.append(torch.tensor([2, 0], device=device, dtype=ids_b.dtype))
cursor += 2 # EOS + BOS
parts.append(torch.tensor([2], device=device, dtype=ids_b.dtype))
lm_seq = torch.cat(parts)
lm_input_list.append(lm_seq)
lm_lengths.append(lm_seq.shape[0])
# Original protein-token position → LM input position.
prot_pos_b = mask_b.nonzero(as_tuple=True)[0]
expand_map = torch.full((L,), -1, device=device, dtype=torch.long)
expand_map[prot_pos_b] = per_token_lm_pos[inverse_ordered]
expand_maps.append(expand_map)
# Pad to longest LM input; round to ``pad_to_multiple`` when fp8 is on
# (TE fp8 kernels assert prod(shape[:-1]) % 8 == 0).
max_len = max(lm_lengths)
if pad_to_multiple is not None and pad_to_multiple > 1:
max_len = ((max_len + pad_to_multiple - 1) // pad_to_multiple) * pad_to_multiple
lm_input_ids = torch.full(
(B, max_len),
1,
device=device,
dtype=input_ids.dtype, # PAD=1
)
for b in range(B):
lm_input_ids[b, : lm_lengths[b]] = lm_input_list[b]
# sequence_id for chain-aware attention; PAD tokens get -1 (no attention).
sequence_id = (lm_input_ids == 0).cumsum(dim=1) - 1 # BOS=0
sequence_id = sequence_id.masked_fill(lm_input_ids == 1, -1) # PAD=1
if lm_mask_pct > 0.0:
special = (lm_input_ids == 0) | (lm_input_ids == 1) | (lm_input_ids == 2)
do_mask = (
torch.rand(lm_input_ids.shape, device=device) < lm_mask_pct
) & ~special
lm_input_ids = lm_input_ids.masked_fill(do_mask, mask_token_id)
with torch.inference_mode():
esmc_out = esmc(
input_ids=lm_input_ids, sequence_id=sequence_id, output_hidden_states=True
)
hs = esmc_out.hidden_states # [n_layers+1, B, max_len, D]
n_layers_plus_1, _, _, D = hs.shape
result = torch.zeros(B, L, n_layers_plus_1, D, device=device, dtype=hs.dtype)
for b in range(B):
mb = protein_mask[b]
em = expand_maps[b][mb] # [n_protein_tokens] LM positions
# hs[:, b, em, :] -> [n_layers+1, n_protein_tokens, D]
gathered = hs[:, b, em, :].permute(1, 0, 2)
result[b, mb.nonzero(as_tuple=True)[0]] = gathered
return result.detach()
# ===========================================================================
# TriangleMultiplicativeUpdate
# ===========================================================================
class TriangleMultiplicativeBlock(nn.Module):
"""Triangle multiplicative update block with gated signal routing."""
_FLOW_TO_EINSUM = {"outgoing": "bikd,bjkd->bijd", "incoming": "bkid,bkjd->bijd"}
_VALID_FLOWS = ("outgoing", "incoming")
def __init__(self, input_channels: int, latent_channels: int, flow: str) -> None:
super().__init__()
if flow not in self._FLOW_TO_EINSUM:
raise ValueError(
f"Invalid flow={flow!r}. Expected one of {self._VALID_FLOWS}."
)
self.input_channels = input_channels
self.latent_channels = latent_channels
self.flow = flow
self._einsum_equation = self._FLOW_TO_EINSUM[flow]
self.norm_start = nn.LayerNorm(self.input_channels, eps=_EPS)
self.norm_mix = nn.LayerNorm(self.latent_channels, eps=_EPS)
self.proj_bundle = nn.Linear(
self.input_channels, 4 * self.latent_channels, bias=False
)
self.proj_emit = nn.Linear(
self.latent_channels, self.input_channels, bias=False
)
self.proj_gate = nn.Linear(self.input_channels, self.input_channels, bias=False)
self._use_kernels: bool = False
# Default chunked for memory on long sequences; tests override with
# ``set_chunk_size(None)`` for the unchunked path under bit-exact bf16
# parity checks.
self._chunk_size: int | None = 64
def set_chunk_size(self, chunk_size: int | None) -> None:
self._chunk_size = chunk_size
def split_kernel_weights(self) -> tuple[Tensor, Tensor]:
return (
self.proj_bundle.weight[: 2 * self.latent_channels, :],
self.proj_bundle.weight[2 * self.latent_channels :, :],
)
def _kernel_flow_direction(self) -> str:
return self.flow
def _triangular_contract(self, left_stream: Tensor, right_stream: Tensor) -> Tensor:
return torch.einsum(self._einsum_equation, left_stream, right_stream)
def _triangular_contract_chunked(
self, left_stream: Tensor, right_stream: Tensor, chunk_size: int
) -> Tensor:
"""Compute the triangular einsum in chunks along the output i-dimension."""
L = left_stream.shape[1] if self.flow == "outgoing" else left_stream.shape[2]
chunks = []
for start in range(0, L, chunk_size):
end = min(start + chunk_size, L)
if self.flow == "outgoing":
chunk = torch.einsum(
self._einsum_equation, left_stream[:, start:end], right_stream
)
else:
chunk = torch.einsum(
self._einsum_equation, left_stream[:, :, start:end], right_stream
)
chunks.append(chunk)
return torch.cat(chunks, dim=1)
def forward(self, pair_grid: Tensor, visibility: Tensor | None = None) -> Tensor:
if visibility is None:
visibility = pair_grid.new_ones(pair_grid.shape[:-1])
if self._use_kernels:
p_in_weight, g_in_weight = self.split_kernel_weights()
try:
return _cue_tri_mul( # type: ignore[misc]
pair_grid,
direction=self._kernel_flow_direction(),
mask=visibility,
norm_in_weight=self.norm_start.weight,
norm_in_bias=self.norm_start.bias,
p_in_weight=p_in_weight,
g_in_weight=g_in_weight,
norm_out_weight=self.norm_mix.weight,
norm_out_bias=self.norm_mix.bias,
p_out_weight=self.proj_emit.weight,
g_out_weight=self.proj_gate.weight,
eps=_EPS,
)
except Exception as e:
import logging as _logging
_logging.getLogger(__name__).warning(
"cuequivariance triangle_multiplicative_update kernel failed "
"(flow=%s, shape=%s, dtype=%s); falling back to chunked einsum. "
"Error: %s",
self.flow,
tuple(pair_grid.shape),
pair_grid.dtype,
e,
)
normalized_grid = self.norm_start(pair_grid)
bundled = self.proj_bundle(normalized_grid)
signal, gate_logits = bundled.split(2 * self.latent_channels, dim=-1)
routed = signal * torch.sigmoid(gate_logits)
routed = routed * visibility.unsqueeze(-1)
left_stream, right_stream = routed.float().chunk(2, dim=-1)
if self._chunk_size is not None:
contracted = self._triangular_contract_chunked(
left_stream, right_stream, self._chunk_size
)
else:
contracted = self._triangular_contract(left_stream, right_stream)
mixed = self.proj_emit(self.norm_mix(contracted))
output_gate = torch.sigmoid(self.proj_gate(normalized_grid))
return mixed * output_gate
class TriangleMultiplicativeUpdate(nn.Module):
"""Thin wrapper exposing the triangular mixer with explicit orientation (v3)."""
def __init__(self, dim: int = 128, _outgoing: bool = True) -> None:
super().__init__()
flow = "outgoing" if _outgoing else "incoming"
self._engine = TriangleMultiplicativeBlock(
input_channels=dim, latent_channels=dim, flow=flow
)
def set_kernel_backend(self, backend: str | None) -> None:
# Engine uses cueq when backend=="cuequivariance"; the "fused" backend
# routes through the parent PairUpdateBlock's fused path (bypassing this).
self._engine._use_kernels = backend == BACKEND_CUEQ
if backend == BACKEND_CUEQ and not CUE_AVAILABLE:
raise RuntimeError(
"backend='cuequivariance' but cuequivariance_torch is not installed."
)
def set_chunk_size(self, chunk_size: int | None) -> None:
self._engine.set_chunk_size(chunk_size)
def forward(self, z: Tensor, mask: Tensor | None = None) -> Tensor:
return self._engine(z, visibility=mask)
# ===========================================================================
# FoldingTrunk: Transition, PairUpdateBlock, FoldingTrunk
# ===========================================================================
class Transition(nn.Module):
"""LN + SwiGLU FFN with addmm-fused residual; optional Triton LN+w12+SwiGLU kernel."""
def __init__(self, d_model: int, expansion_ratio: int = 4) -> None:
super().__init__()
self.norm = nn.LayerNorm(d_model)
self.ffn = SwiGLUMLP(d_model, expansion_ratio=expansion_ratio, bias=False)
# Default chunked; set_chunk_size(None) disables for bit-exact parity tests.
self._chunk_size: int | None = 64
self._fused_swiglu: nn.Module | None = None
self._kernel_backend: str | None = None
def set_chunk_size(self, chunk_size: int | None) -> None:
self._chunk_size = chunk_size
def set_kernel_backend(self, backend: str | None) -> None:
"""Install / uninstall FusedLNLinearSwiGLU (no cueq equivalent)."""
if backend not in _VALID_BACKENDS:
raise ValueError(
f"backend must be one of {_VALID_BACKENDS}, got {backend!r}"
)
self._kernel_backend = backend
if backend == BACKEND_FUSED and TRITON_KERNELS_AVAILABLE:
assert _FusedLNLinearSwiGLU is not None
d_model = self.norm.normalized_shape[0]
d_inner = self.ffn.hidden_features
has_ln_bias = self.norm.bias is not None
device = self.ffn.w12.weight.device
dtype = self.ffn.w12.weight.dtype
fused = _FusedLNLinearSwiGLU(
d_model=d_model,
d_inner=d_inner,
has_ln_bias=has_ln_bias,
device=device,
dtype=dtype,
)
with torch.no_grad():
fused.LN_W.copy_(self.norm.weight)
if has_ln_bias:
fused.LN_B.copy_(self.norm.bias) # type: ignore[union-attr]
# FusedLNLinearSwiGLU.W12 is (d_model, 2*d_inner); transpose nn.Linear once.
fused.W12.copy_(self.ffn.w12.weight.t().contiguous())
self._fused_swiglu = fused.eval().requires_grad_(False)
else:
self._fused_swiglu = None
def _can_use_fused_path(self, x: Tensor) -> bool:
return (
_fused_active(self, x)
and self._fused_swiglu is not None
and x.dtype == torch.bfloat16
)
def _swiglu_pre_w3(self, x_normed: Tensor) -> Tensor:
"""SwiGLU through silu(x1)*x2, before the final w3."""
ffn = self.ffn
x12 = ffn.w12(x_normed)
x1, x2 = x12.split(ffn.hidden_features, dim=-1)
return F.silu(x1) * x2
def _addmm_residual(self, x: Tensor, hidden: Tensor) -> Tensor:
"""x + w3(hidden) via single cuBLAS addmm — avoids transition-output allocation."""
ffn = self.ffn
x_shape = x.shape
out = torch.addmm(
x.contiguous().view(-1, x_shape[-1]),
hidden.view(-1, hidden.shape[-1]),
ffn.w3.weight.t(),
)
return out.view(x_shape)
def forward(self, x: Tensor) -> Tensor:
# Inference-only fast path (addmm-fused residual + pre-alloc out)
# — diverges bit-exactly from ``x + ffn(norm(x))`` so we only use
# it when grad is disabled (binder-design / bit-exact tests run
# with grad on and need the reference path).
if not torch.is_grad_enabled() and self._can_use_fused_path(x):
fused = self._fused_swiglu
assert fused is not None
pre_w3 = fused
if self._chunk_size is None or x.shape[1] <= self._chunk_size:
hidden = pre_w3(x)
return self._addmm_residual(x, hidden)
out = torch.empty_like(x)
for s in range(0, x.shape[1], self._chunk_size):
e = min(s + self._chunk_size, x.shape[1])
sl = x[:, s:e]
hidden = pre_w3(sl)
out[:, s:e] = self._addmm_residual(sl, hidden)
return out
# Reference path — bit-exact with main: x + ffn(norm(x)).
if self._chunk_size is None or x.shape[1] <= self._chunk_size:
return x + self.ffn(self.norm(x))
out_list: list[Tensor] = []
for s in range(0, x.shape[1], self._chunk_size):
e = min(s + self._chunk_size, x.shape[1])
sl = x[:, s:e]
out_list.append(sl + self.ffn(self.norm(sl)))
return torch.cat(out_list, dim=1)
class PairUpdateBlock(nn.Module):
"""tri_mul_out, tri_mul_in, pair_transition."""
def __init__(self, d_pair: int = 256, expansion_ratio: int = 4) -> None:
super().__init__()
self.tri_mul_out = TriangleMultiplicativeUpdate(dim=d_pair, _outgoing=True)
self.tri_mul_in = TriangleMultiplicativeUpdate(dim=d_pair, _outgoing=False)
self.pair_transition = Transition(d_pair, expansion_ratio=expansion_ratio)
self._kernel_backend: str | None = None
# Row-shared dropout-residual; r=0 for inference (HF model is inference-only).
# backend='fused' swaps in the FusedDropoutResidual Triton kernel.
self.row_drop = DropoutResidual(0.0, batch_dim=1, use_fused_kernels=False)
def set_kernel_backend(self, backend: str | None) -> None:
if backend not in _VALID_BACKENDS:
raise ValueError(
f"backend must be one of {_VALID_BACKENDS}, got {backend!r}"
)
self.tri_mul_out.set_kernel_backend(backend)
self.tri_mul_in.set_kernel_backend(backend)
self.pair_transition.set_kernel_backend(backend)
self._kernel_backend = backend
self.row_drop = DropoutResidual(
0.0, batch_dim=1, use_fused_kernels=(backend == BACKEND_FUSED)
)
def set_chunk_size(self, chunk_size: int | None) -> None:
self.tri_mul_out.set_chunk_size(chunk_size)
self.tri_mul_in.set_chunk_size(chunk_size)
self.pair_transition.set_chunk_size(chunk_size)
def _can_use_fused_trimul_with_residual(self, pair: Tensor) -> bool:
return _fused_active(self, pair) and pair.dtype == torch.bfloat16
def _fused_trimul_with_residual(
self, pair: Tensor, direction: str, pair_attention_mask: Tensor | None
) -> Tensor:
"""Fused TriMul+residual call; weights from the corresponding engine."""
tri = self.tri_mul_out if direction == "outgoing" else self.tri_mul_in
engine: TriangleMultiplicativeBlock = tri._engine # type: ignore[assignment]
p_in_weight, g_in_weight = engine.split_kernel_weights()
def _bf16(t: Tensor) -> Tensor:
return t if t.dtype == torch.bfloat16 else t.to(torch.bfloat16)
return _fused_trimul_with_residual( # type: ignore[misc]
pair,
direction,
residual=pair,
drop_mask=None, # inference: no dropout, matches internal's eval path
norm_in_weight=_bf16(engine.norm_start.weight),
norm_in_bias=_bf16(engine.norm_start.bias),
p_in_weight=_bf16(p_in_weight),
g_in_weight=_bf16(g_in_weight),
norm_out_weight=_bf16(engine.norm_mix.weight),
norm_out_bias=_bf16(engine.norm_mix.bias),
p_out_weight=_bf16(engine.proj_emit.weight),
g_out_weight=_bf16(engine.proj_gate.weight),
mask=pair_attention_mask,
eps=_EPS,
)
def forward(
self, pair: Tensor, pair_attention_mask: Tensor | None = None
) -> Tensor:
if self._can_use_fused_trimul_with_residual(pair):
pair = self._fused_trimul_with_residual(
pair, "outgoing", pair_attention_mask
)
pair = self._fused_trimul_with_residual(
pair, "incoming", pair_attention_mask
)
else:
pair = self.row_drop(pair, self.tri_mul_out(pair, mask=pair_attention_mask))
pair = self.row_drop(pair, self.tri_mul_in(pair, mask=pair_attention_mask))
pair = self.pair_transition(pair)
return pair
class FoldingTrunk(nn.Module):
"""ModuleList of PairUpdateBlocks."""
def __init__(
self, n_layers: int = 24, d_pair: int = 256, expansion_ratio: int = 4
) -> None:
super().__init__()
self.blocks = nn.ModuleList(
[
PairUpdateBlock(d_pair=d_pair, expansion_ratio=expansion_ratio)
for _ in range(n_layers)
]
)
def set_kernel_backend(self, backend: str | None) -> None:
for block in self.blocks:
cast(PairUpdateBlock, block).set_kernel_backend(backend)
def set_chunk_size(self, chunk_size: int | None) -> None:
for block in self.blocks:
cast(PairUpdateBlock, block).set_chunk_size(chunk_size)
def forward(
self, pair: Tensor, pair_attention_mask: Tensor | None = None
) -> Tensor:
# Cast pair → bf16 internally when the fused trimul backend is enabled
# (its bwd kernel requires bf16). Other backends keep the input dtype.
orig_dtype = pair.dtype
fused_on = (
len(self.blocks) > 0
and getattr(self.blocks[0], "_kernel_backend", None) == BACKEND_FUSED
)
if pair.is_cuda and fused_on and orig_dtype != torch.bfloat16:
pair = pair.to(torch.bfloat16)
for block in self.blocks:
fn = partial(block, pair_attention_mask=pair_attention_mask)
if torch.is_grad_enabled():
pair = checkpoint(fn, pair, use_reentrant=False) # pyright: ignore
else:
pair = fn(pair)
if pair.dtype != orig_dtype:
pair = pair.to(orig_dtype)
return pair
# ===========================================================================
# MSA Encoder
# ===========================================================================
class OuterProductMean(nn.Module):
"""Outer-product mean: maps an MSA representation into a pair update.
The order of the ``/ n_valid`` divide vs. the ``Wout`` projection is
selectable via ``divide_outer_before_proj`` because different ESMFold2
checkpoints were trained with different orderings:
* ``False`` (default): ``Wout(outer) / n_valid`` — the projection bias
is scaled by 1/n_valid alongside the outer product.
* ``True``: ``Wout(outer / n_valid)`` — the projection bias is added
unscaled, post-divide.
"""
def __init__(
self,
d_msa: int,
d_hidden: int,
d_pair: int,
divide_outer_before_proj: bool = False,
) -> None:
super().__init__()
self.d_hidden = d_hidden
self.divide_outer_before_proj = divide_outer_before_proj
self.norm = nn.LayerNorm(d_msa)
self.W = nn.Linear(d_msa, 2 * d_hidden, bias=False)
self.Wout = nn.Linear(d_hidden * d_hidden, d_pair, bias=True)
# Off for bit-exact bf16; ``set_chunk_size(64)`` for long sequences.
self._chunk_size: int | None = None
def set_chunk_size(self, chunk_size: int | None) -> None:
self._chunk_size = chunk_size
def forward(self, m: Tensor, msa_attention_mask: Tensor) -> Tensor:
m_norm = self.norm(m)
x = self.W(m_norm) * msa_attention_mask.unsqueeze(-1).to(m_norm.dtype)
a, b = x.chunk(2, dim=-1)
mask_f = msa_attention_mask.to(a.dtype)
n_valid = (mask_f @ mask_f.transpose(-1, -2)).unsqueeze(-1).clamp(min=1.0)
if self._chunk_size is None:
outer = torch.einsum("bimc,bjmd->bijcd", a, b).flatten(-2)
if self.divide_outer_before_proj:
return self.Wout(outer / n_valid)
return self.Wout(outer) / n_valid
# Chunk along the left (i) axis so the peak einsum intermediate is
# [B, chunk, L, c, d] instead of [B, L, L, c, d].
L = a.shape[1]
out_chunks: list[Tensor] = []
for s in range(0, L, self._chunk_size):
e = min(s + self._chunk_size, L)
outer_chunk = torch.einsum("bimc,bjmd->bijcd", a[:, s:e], b).flatten(-2)
if self.divide_outer_before_proj:
out_chunks.append(self.Wout(outer_chunk / n_valid[:, s:e]))
else:
out_chunks.append(self.Wout(outer_chunk) / n_valid[:, s:e])
return torch.cat(out_chunks, dim=1)
class MSAPairWeightedAveraging(nn.Module):
"""Pair-biased MSA row update (AF3 Supplement Algorithm 10)."""
def __init__(
self, d_msa: int, d_pair: int, n_heads: int = 8, head_width: int = 32
) -> None:
super().__init__()
self.n_heads = n_heads
self.head_width = head_width
self.norm_single = nn.LayerNorm(d_msa)
self.compute_bias = nn.Sequential(
nn.LayerNorm(d_pair), nn.Linear(d_pair, n_heads, bias=False)
)
self.Wv = nn.Linear(d_msa, n_heads * head_width, bias=False)
self.Wgate = nn.Linear(d_msa, n_heads * head_width, bias=False)
self.Wout = nn.Linear(n_heads * head_width, d_msa, bias=False)
def forward(
self, msa_repr: Tensor, pair_repr: Tensor, pair_attention_mask: Tensor
) -> Tensor:
"""
Args:
msa_repr: [B, L, M, d_msa]
pair_repr: [B, L, L, d_pair]
pair_attention_mask:[B, L, L]
Returns:
[B, L, M, d_msa]
"""
B, L, M, _ = msa_repr.shape
h, dh = self.n_heads, self.head_width
msa_normed = self.norm_single(msa_repr)
bias = self.compute_bias(pair_repr) # [B, L, L, n_heads]
bias.masked_fill_(~pair_attention_mask.unsqueeze(-1).bool(), -1e5)
attn = torch.softmax(bias, dim=-2) # softmax over j
v = self.Wv(msa_normed).reshape(B, L, M, h, dh)
gate = torch.sigmoid(self.Wgate(msa_normed)).reshape(B, L, M, h, dh)
output = torch.einsum("bijh,bjmhd,bimhd->bimhd", attn, v, gate)
return self.Wout(output.reshape(B, L, M, h * dh))