# coding=utf-8 # Copyright 2026 Biohub. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 """Shared building blocks for ESMFold2 HuggingFace model variants.""" from __future__ import annotations import random import importlib from contextlib import contextmanager from functools import partial from typing import cast import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch.utils.checkpoint import checkpoint try: flash_attn_module = importlib.import_module("flash_attn") flash_bert_padding = importlib.import_module("flash_attn.bert_padding") flash_attn_func = flash_attn_module.flash_attn_func flash_attn_varlen_func = flash_attn_module.flash_attn_varlen_func index_first_axis = flash_bert_padding.index_first_axis pad_input = flash_bert_padding.pad_input FLASH_ATTN_AVAILABLE = True except ImportError: flash_attn_func = None # type: ignore[assignment] flash_attn_varlen_func = None # type: ignore[assignment] index_first_axis = None # type: ignore[assignment] pad_input = None # type: ignore[assignment] FLASH_ATTN_AVAILABLE = False try: cue_module = importlib.import_module("cuequivariance_torch") cue_triangle = importlib.import_module("cuequivariance_torch.primitives.triangle") _cue_attn_pair_bias = cue_module.attention_pair_bias _cue_tri_mul = cue_triangle.triangle_multiplicative_update CUE_AVAILABLE = True except ImportError: _cue_attn_pair_bias = None # type: ignore[assignment] _cue_tri_mul = None # type: ignore[assignment] CUE_AVAILABLE = False # The Biohub release includes optional Triton kernels. FastPLMs keeps the # reference path enabled by default so Hugging Face remote-code loading can stay # flat and self-contained. _fused_pair_bias = None _fused_trimul_with_residual = None _FusedLNLinearSwiGLU = None _FusedDropoutResidual = None TRITON_KERNELS_AVAILABLE = False from .configuration_esmfold2 import ESMFold2Config BACKEND_FUSED = "fused" BACKEND_CUEQ = "cuequivariance" _VALID_BACKENDS = (None, BACKEND_FUSED, BACKEND_CUEQ) def _fused_active(module: nn.Module, tensor: Tensor) -> bool: """Common preconditions for the vendored fused Triton inference kernels.""" return ( TRITON_KERNELS_AVAILABLE and getattr(module, "_kernel_backend", None) == BACKEND_FUSED and not torch.is_grad_enabled() and tensor.is_cuda ) def _cueq_active(module: nn.Module) -> bool: return CUE_AVAILABLE and getattr(module, "_kernel_backend", None) == BACKEND_CUEQ class DropoutResidual(nn.Module): """``residual + dropout(delta)`` with row/col-shared dropout. Same signature on both paths. ``use_fused_kernels=True`` + ``batch_dim=1`` routes through ``FusedDropoutResidual`` (single-pass over pair tensor, in-place residual add). Falls back to unfused otherwise. """ def __init__( self, r: float, batch_dim: int, use_fused_kernels: bool = False ) -> None: super().__init__() assert batch_dim in (1, 2), f"batch_dim must be 1 or 2, got {batch_dim}" self._use_fused_kernels = ( use_fused_kernels and batch_dim == 1 and _FusedDropoutResidual is not None ) self._batch_dim = batch_dim self._r = r if self._use_fused_kernels: assert _FusedDropoutResidual is not None self._impl: nn.Module = _FusedDropoutResidual(r) else: self._impl = nn.Dropout(r) def forward(self, residual: Tensor, delta: Tensor) -> Tensor: if self._use_fused_kernels: return self._impl(residual, delta) # Unfused: row/col-shared dropout via [1, ...] mask broadcast. if self._r == 0.0 or not self.training: return residual + delta shape = list(delta.shape) shape[self._batch_dim] = 1 mask = self._impl(delta.new_ones(shape)) return residual + delta * mask # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- CHAR_VOCAB_SIZE: int = 64 MAX_CHARS: int = 4 XYZ_DIMS: int = 3 MAX_ATOMIC_NUMBER: int = 128 # Input feature dim = 3 + 1 + 1 + 128 + 64*4 = 389 ATOM_FEATURE_DIM: int = ( XYZ_DIMS + 1 + 1 + MAX_ATOMIC_NUMBER + CHAR_VOCAB_SIZE * MAX_CHARS ) NUM_RES_TYPES: int = 33 _EPS = 1e-5 # Default for the triangle / OPM / pair-transition L² ops. Caps peak memory # so L≈2k folds on an 80 GB GPU (~76 GB peak at chunk=128 for L=1438; # chunk=64 leaves headroom for the largest foldbench targets). Override via # ``model.set_chunk_size(...)``; pass None to disable chunking (faster for # short L but OOM-prone past ~600). _DEFAULT_CHUNK_SIZE = 64 # =========================================================================== # MSA inference-time diversity augmentations # =========================================================================== def maybe_subsample_msa( msa: Tensor, msa_attention_mask: Tensor | None, has_deletion: Tensor | None, deletion_value: Tensor | None, *, max_depth: int | None, enabled: bool, ) -> tuple[Tensor, Tensor | None, Tensor | None, Tensor | None]: if not enabled or max_depth is None: return msa, msa_attention_mask, has_deletion, deletion_value depth = msa.size(1) if depth <= 1 or depth <= max_depth: return msa, msa_attention_mask, has_deletion, deletion_value indices = torch.zeros(max_depth, dtype=torch.long, device=msa.device) indices[1:] = torch.randperm(depth - 1, device=msa.device)[: max_depth - 1] + 1 indices = indices.sort().values msa = msa[:, indices] if msa_attention_mask is not None: msa_attention_mask = msa_attention_mask[:, indices] if has_deletion is not None: has_deletion = has_deletion[:, indices] if deletion_value is not None: deletion_value = deletion_value[:, indices] return msa, msa_attention_mask, has_deletion, deletion_value def maybe_apply_msa_column_masking( msa_attention_mask: Tensor | None, rate: float, ) -> Tensor | None: if msa_attention_mask is None or rate <= 0.0 or msa_attention_mask.size(1) <= 1: return msa_attention_mask batch_size, _, length = msa_attention_mask.shape col_keep = torch.rand(batch_size, length, device=msa_attention_mask.device) >= rate col_keep = col_keep.unsqueeze(1).expand_as(msa_attention_mask).clone() col_keep[:, 0, :] = True return msa_attention_mask.bool() & col_keep # =========================================================================== # Atom-token utilities # =========================================================================== def gather_token_to_atom(token_features: Tensor, atom_to_token_idx: Tensor) -> Tensor: """Broadcast per-token features to per-atom features using gather. Args: token_features: [B, L, d] atom_to_token_idx: [B, A] int64 Returns: [B, A, d] """ idx = atom_to_token_idx.unsqueeze(-1).expand(-1, -1, token_features.size(-1)) return torch.gather(token_features, 1, idx) def scatter_atom_to_token( atom_features: Tensor, atom_to_token_idx: Tensor, n_tokens: int, atom_mask: Tensor | None = None, ) -> Tensor: """Aggregate per-atom features to per-token features (mean). Args: atom_features: [B, A, d] atom_to_token_idx: [B, A] int64 n_tokens: L atom_mask: [B, A] bool Returns: [B, L, d] """ B, A, d = atom_features.shape n_out = n_tokens idx = atom_to_token_idx if atom_mask is not None: idx = torch.where(atom_mask, atom_to_token_idx, n_tokens) n_out = n_tokens + 1 idx_expanded = idx.unsqueeze(-1).expand(B, A, d) out = torch.zeros( B, n_out, d, device=atom_features.device, dtype=atom_features.dtype ) out.scatter_reduce_( 1, idx_expanded, atom_features, reduce="mean", include_self=False ) return out[:, :n_tokens, :] def gather_rep_atom_coords(coords: Tensor, rep_atom_idx: Tensor) -> Tensor: """Gather representative atom coordinates for each token. Args: coords: [B, A, 3] rep_atom_idx: [B, L] int64 Returns: [B, L, 3] """ idx = rep_atom_idx.unsqueeze(-1).expand(-1, -1, coords.size(-1)) return torch.gather(coords, 1, idx) def _compute_intra_token_idx(atom_to_token: Tensor) -> Tensor: """Compute local atom index within each token (vectorised). Atoms belonging to the same token are contiguous, so this computes a running count that resets at each token boundary. Args: atom_to_token: [B, A] flat index mapping each atom to its token. Returns: [B, A] tensor with values in [0, max_atoms_per_token - 1]. """ same_as_prev = F.pad( atom_to_token[:, 1:] == atom_to_token[:, :-1], (1, 0), value=False ) ones = torch.ones_like(atom_to_token) cumsum = torch.cumsum(ones, dim=-1) group_start = cumsum.masked_fill(same_as_prev, 0) group_start = torch.cummax(group_start, dim=-1).values return cumsum - group_start def _categorical_mean(logits: Tensor, start: float, end: float) -> Tensor: """Expected value of a categorical distribution over evenly-spaced bins. Equivalent to ``CategoricalMixture(logits, bins=logits.shape[-1], start, end).mean()``. Args: logits: [..., n_bins] start: left boundary end: right boundary Returns: [...] expected value """ n_bins = logits.shape[-1] edges = torch.linspace( start, end, n_bins + 1, device=logits.device, dtype=torch.float32 ) v_bins = (edges[:-1] + edges[1:]) / 2 # [n_bins] return (logits.float().softmax(-1) @ v_bins.unsqueeze(1)).squeeze(-1) # =========================================================================== # TransitionLayer (used in DiffusionConditioning) # =========================================================================== class TransitionLayer(nn.Module): """SwiGLU transition: norm -> a_proj, b_proj -> silu(a)*b -> out_proj.""" def __init__(self, d_model: int, n: int, eps: float = 1e-5) -> None: super().__init__() hidden = n * d_model self.norm = nn.LayerNorm(d_model, eps=eps) self.a_proj = nn.Linear(d_model, hidden, bias=False) self.b_proj = nn.Linear(d_model, hidden, bias=False) self.out_proj = nn.Linear(hidden, d_model, bias=False) def forward(self, x: Tensor) -> Tensor: x = self.norm(x) a = self.a_proj(x) b = self.b_proj(x) return self.out_proj(F.silu(a) * b) # =========================================================================== # AdaptiveLayerNorm (used in DiffusionTransformer) # =========================================================================== class AdaptiveLayerNorm(nn.Module): """Adaptive layer normalization (adaLN-Zero).""" def __init__(self, d_model: int, d_cond: int, eps: float = 1e-5) -> None: super().__init__() self.d_model = d_model self.d_cond = d_cond self.eps = eps self.s_scale = nn.Parameter(torch.ones(d_cond)) self.s_gate = nn.Linear(d_cond, d_model, bias=True) self.s_shift = nn.Linear(d_cond, d_model, bias=False) def forward(self, a: Tensor, s: Tensor) -> Tensor: a_norm = F.layer_norm(a, (self.d_model,), None, None, self.eps) s_norm = F.layer_norm(s, (self.d_cond,), self.s_scale, None, self.eps) return torch.sigmoid(self.s_gate(s_norm)) * a_norm + self.s_shift(s_norm) # =========================================================================== # FourierEmbedding # =========================================================================== class FourierEmbedding(nn.Module): """Fourier embedding: cos(2*pi*(t*w + b)).""" w: Tensor b: Tensor def __init__(self, c: int) -> None: super().__init__() self.c = c self.register_buffer("w", torch.randn(c)) self.register_buffer("b", torch.randn(c)) def forward(self, t_hat: Tensor) -> Tensor: t = torch.as_tensor(t_hat, device=self.w.device, dtype=self.w.dtype).reshape(-1) return torch.cos( 2.0 * torch.pi * (t[:, None] * self.w[None, :] + self.b[None, :]) ) # =========================================================================== # SwiGLU / SwiGLUMLP # =========================================================================== def _compute_swiglu_hidden_size(d_model: int, expansion_ratio: int) -> int: return expansion_ratio * d_model class SwiGLU(nn.Module): """SwiGLU with packed w12 and output w3.""" def __init__( self, in_features: int, hidden_features: int, out_features: int | None = None, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) self.w3 = nn.Linear(hidden_features, out_features, bias=bias) self.hidden_features = hidden_features def forward(self, x: Tensor) -> Tensor: x12 = self.w12(x) x1, x2 = x12.split(self.hidden_features, dim=-1) hidden = F.silu(x1) * x2 return self.w3(hidden) class SwiGLUMLP(SwiGLU): """SwiGLU MLP with packed weights, no bias.""" def __init__( self, d_model: int, expansion_ratio: int = 4, bias: bool = False ) -> None: hidden = _compute_swiglu_hidden_size(d_model, expansion_ratio) super().__init__( in_features=d_model, hidden_features=hidden, out_features=d_model, bias=bias ) # =========================================================================== # SWA Atom Attention components # =========================================================================== def _rotate_half(x: Tensor) -> Tensor: x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_emb_3d(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: """Apply RoPE with batch-dependent cos/sin. Args: x: [B, L, H, D] cos: [B, L, D/2] sin: [B, L, D/2] """ ro_dim = cos.shape[-1] * 2 cos = cos.unsqueeze(2).repeat(1, 1, 1, 2) sin = sin.unsqueeze(2).repeat(1, 1, 1, 2) return torch.cat( [x[..., :ro_dim] * cos + _rotate_half(x[..., :ro_dim]) * sin, x[..., ro_dim:]], dim=-1, ) @torch.compiler.disable def build_3d_rope( ref_pos: Tensor, ref_space_uid: Tensor, head_dim: int, n_spatial_per_axis: int = 4, n_uid_pairs: int = 2, spatial_base_freq: float = 10000.0, uid_base_freq: float = 10.0, ) -> tuple[Tensor, Tensor]: """Build cos/sin for 3D RoPE + UID RoPE.""" device = ref_pos.device B, N = ref_pos.shape[:2] half_dim = head_dim // 2 n_spatial_total = 3 * n_spatial_per_axis spatial_inv_freq = 1.0 / ( spatial_base_freq ** ( torch.arange(0, n_spatial_per_axis, dtype=torch.float32, device=device) / n_spatial_per_axis ) ) uid_inv_freq = 1.0 / ( uid_base_freq ** ( torch.arange(0, n_uid_pairs, dtype=torch.float32, device=device) / n_uid_pairs ) ) pos_f32 = ref_pos.float() spatial_freqs = torch.einsum("bna,k->bnak", pos_f32, spatial_inv_freq) spatial_freqs = spatial_freqs.reshape(B, N, n_spatial_total) uid_f32 = ref_space_uid.float() uid_freqs = torch.einsum("bn,k->bnk", uid_f32, uid_inv_freq) n_active = n_spatial_total + n_uid_pairs freqs = torch.cat([spatial_freqs, uid_freqs], dim=-1) if n_active < half_dim: padding = torch.zeros( B, N, half_dim - n_active, device=device, dtype=torch.float32 ) freqs = torch.cat([freqs, padding], dim=-1) cos = freqs.cos().to(torch.bfloat16) sin = freqs.sin().to(torch.bfloat16) return cos, sin def qk_norm(x: Tensor) -> Tensor: return F.rms_norm(x, (x.size(-1),)).to(x.dtype) # =========================================================================== # SwiGLUFFN (atom transformer blocks) # =========================================================================== class SwiGLUFFN(nn.Module): """SwiGLU FFN with rounded hidden size for hardware alignment.""" def __init__(self, d_model: int, expansion_ratio: int = 2) -> None: super().__init__() hidden_size = ((expansion_ratio * (d_model // 3) * 2) + 255) // 256 * 256 self.w_up = nn.Linear(d_model, 2 * hidden_size, bias=False) self.w_down = nn.Linear(hidden_size, d_model, bias=False) def forward(self, x: Tensor) -> Tensor: x = x.to(self.w_up.weight.dtype) x1, x2 = self.w_up(x).chunk(2, dim=-1) return self.w_down(F.silu(x1) * x2) # =========================================================================== # SWA3DRoPEAttention # =========================================================================== class SWA3DRoPEAttention(nn.Module): """Sliding window attention with 3D RoPE. Has Wqkv, gate_proj, out_proj.""" def __init__(self, d_model: int, n_heads: int, half_window: int = 64) -> None: super().__init__() self.n_heads = n_heads self.head_dim = d_model // n_heads self.scale = self.head_dim**-0.5 self.half_window = half_window self.Wqkv = nn.Linear(d_model, 3 * d_model, bias=False) self.out_proj = nn.Linear(d_model, d_model, bias=False) self.gate_proj = nn.Linear(d_model, d_model, bias=False) def forward(self, x: Tensor, attention_params: tuple) -> Tensor: B, N = x.shape[:2] cos, sin = attention_params[0], attention_params[1] x_input = x qkv = self.Wqkv(x) qkv = qkv.view(B, N, 3, self.n_heads, self.head_dim).permute(2, 0, 1, 3, 4) q, k, v = qkv.unbind(0) q, k = qk_norm(q), qk_norm(k) q = apply_rotary_emb_3d(q, cos, sin) k = apply_rotary_emb_3d(k, cos, sin) input_dtype = q.dtype if q.dtype not in (torch.float16, torch.bfloat16): q, k, v = q.bfloat16(), k.bfloat16(), v.bfloat16() if len(attention_params) > 2 and FLASH_ATTN_AVAILABLE: indices, cu_seqlens, max_seqlen = ( attention_params[2], attention_params[3], attention_params[4], ) q_unpad = index_first_axis( # type: ignore[misc] q.reshape(-1, self.n_heads, self.head_dim), indices ) k_unpad = index_first_axis( # type: ignore[misc] k.reshape(-1, self.n_heads, self.head_dim), indices ) v_unpad = index_first_axis( # type: ignore[misc] v.reshape(-1, self.n_heads, self.head_dim), indices ) out_unpad = flash_attn_varlen_func( # type: ignore[misc] q_unpad, k_unpad, v_unpad, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, softmax_scale=self.scale, window_size=(self.half_window, self.half_window), ) out = pad_input(out_unpad, indices, B, N) # type: ignore[misc] elif FLASH_ATTN_AVAILABLE: out = flash_attn_func( # type: ignore[misc] q, k, v, softmax_scale=self.scale, window_size=(self.half_window, self.half_window), ) else: # Fallback: standard attention (no SWA) q_t = q.transpose(1, 2) k_t = k.transpose(1, 2) v_t = v.transpose(1, 2) attn = torch.matmul(q_t, k_t.transpose(-2, -1)) * self.scale attn = F.softmax(attn, dim=-1) out = torch.matmul(attn, v_t).transpose(1, 2) out = out.to(input_dtype).reshape(B, N, -1) # type: ignore[union-attr] out = out * torch.sigmoid(self.gate_proj(x_input)) return self.out_proj(out) # =========================================================================== # SWAAtomBlock, SWAAtomTransformer # =========================================================================== def _rms_adaln_raw(x: Tensor, scale: Tensor, shift: Tensor) -> Tensor: return F.rms_norm(x, (x.shape[-1],)) * (1 + scale) + shift def _gated_residual_raw(x: Tensor, gate: Tensor, y: Tensor) -> Tensor: return x + gate * y class SWAAtomBlock(nn.Module): """adaLN-Zero + SWA attention + SwiGLU FFN. Creates adaln_modulation = Sequential(SiLU(), Linear) -> keys like adaln_modulation.1.weight """ def __init__( self, d_atom: int, n_heads: int, half_window: int = 64, expansion_ratio: int = 2, use_compile_fusions: bool = False, ) -> None: super().__init__() self.attn_norm = nn.RMSNorm(d_atom, elementwise_affine=False) self.ffn_norm = nn.RMSNorm(d_atom, elementwise_affine=False) adaln_linear = nn.Linear(d_atom, 6 * d_atom, bias=False) nn.init.zeros_(adaln_linear.weight) self.adaln_modulation = nn.Sequential(nn.SiLU(), adaln_linear) self.attn = SWA3DRoPEAttention(d_atom, n_heads, half_window=half_window) self.ffn = SwiGLUFFN(d_atom, expansion_ratio) self._rms_adaln = ( torch.compile(_rms_adaln_raw) if use_compile_fusions else _rms_adaln_raw ) self._gated_residual = ( torch.compile(_gated_residual_raw) if use_compile_fusions else _gated_residual_raw ) def forward(self, x: Tensor, c_l: Tensor, attention_params: tuple) -> Tensor: mod = self.adaln_modulation(c_l) if mod.dim() == 2: mod = mod.unsqueeze(1) shift_a, scale_a, gate_a, shift_f, scale_f, gate_f = mod.chunk(6, dim=-1) attn_input = self._rms_adaln(x, scale_a, shift_a) attn_out = self.attn(attn_input, attention_params) x = self._gated_residual(x, gate_a, attn_out) ffn_input = self._rms_adaln(x, scale_f, shift_f) ffn_out = self.ffn(ffn_input) x = self._gated_residual(x, gate_f, ffn_out) return x class SWAAtomTransformer(nn.Module): """Stack of SWAAtomBlocks.""" def __init__( self, d_atom: int = 128, n_blocks: int = 3, n_heads: int = 4, swa_window_size: int = 128, expansion_ratio: int = 2, spatial_rope_base_frequency: float = 20.0, n_spatial_rope_pairs_per_axis: int = 2, n_uid_rope_pairs: int = 10, uid_rope_base_frequency: float = 10000.0, ) -> None: super().__init__() self.swa_window_size = swa_window_size self.head_dim = d_atom // n_heads self.spatial_rope_base_frequency = spatial_rope_base_frequency self.n_spatial_rope_pairs_per_axis = n_spatial_rope_pairs_per_axis self.n_uid_rope_pairs = n_uid_rope_pairs self.uid_rope_base_frequency = uid_rope_base_frequency self.blocks = nn.ModuleList( [ SWAAtomBlock( d_atom=d_atom, n_heads=n_heads, half_window=swa_window_size // 2, expansion_ratio=expansion_ratio, ) for _ in range(n_blocks) ] ) def _build_3d_rope( self, ref_pos: Tensor, ref_space_uid: Tensor ) -> tuple[Tensor, Tensor]: return build_3d_rope( ref_pos=ref_pos, ref_space_uid=ref_space_uid, head_dim=self.head_dim, n_spatial_per_axis=self.n_spatial_rope_pairs_per_axis, n_uid_pairs=self.n_uid_rope_pairs, spatial_base_freq=self.spatial_rope_base_frequency, uid_base_freq=self.uid_rope_base_frequency, ) def forward( self, q_l: Tensor, c_l: Tensor, attention_params: tuple, return_intermediates: bool = False, ) -> Tensor | tuple[Tensor, list[Tensor]]: intermediates: list[Tensor] = [] for block in self.blocks: q_l = block(q_l, c_l, attention_params) if return_intermediates: intermediates.append(q_l) if return_intermediates: return q_l, intermediates return q_l # =========================================================================== # ESMFold2AtomEncoder (for both inputs_embedder and diffusion_module) # =========================================================================== class ESMFold2AtomEncoder(nn.Module): """SWA atom encoder with atom_linear, atom_norm, atom_to_token_linear, [coords_linear], atom_transformer. Args: d_atom: atom hidden dim d_token: token dim for atom_to_token aggregation n_blocks, n_heads, swa_window_size, expansion_ratio: transformer params structure_prediction: if True, creates coords_linear and uses full d_token spatial_rope_base_frequency, n_spatial_rope_pairs_per_axis, n_uid_rope_pairs, uid_rope_base_frequency: 3D RoPE config """ def __init__( self, d_atom: int = 128, d_token: int = 768, n_blocks: int = 3, n_heads: int = 4, swa_window_size: int = 128, expansion_ratio: int = 2, structure_prediction: bool = True, spatial_rope_base_frequency: float = 20.0, n_spatial_rope_pairs_per_axis: int = 2, n_uid_rope_pairs: int = 10, uid_rope_base_frequency: float = 10000.0, ) -> None: super().__init__() self.d_atom = d_atom self.d_token = d_token self.structure_prediction = structure_prediction self.atom_linear = nn.Linear(ATOM_FEATURE_DIM, d_atom, bias=False) self.atom_norm = nn.LayerNorm(d_atom) if structure_prediction: self.coords_linear = nn.Linear(6, d_atom, bias=False) self.atom_transformer = SWAAtomTransformer( d_atom=d_atom, n_blocks=n_blocks, n_heads=n_heads, swa_window_size=swa_window_size, expansion_ratio=expansion_ratio, spatial_rope_base_frequency=spatial_rope_base_frequency, n_spatial_rope_pairs_per_axis=n_spatial_rope_pairs_per_axis, n_uid_rope_pairs=n_uid_rope_pairs, uid_rope_base_frequency=uid_rope_base_frequency, ) # Output aggregation: d_token for structure prediction, d_token//2 for inputs out_dim = d_token if structure_prediction else d_token // 2 self.atom_to_token_linear = nn.Linear(d_atom, out_dim, bias=False) def forward( self, ref_pos: Tensor, atom_attention_mask: Tensor, ref_space_uid: Tensor, ref_charge: Tensor, ref_element: Tensor, ref_atom_name_chars: Tensor, atom_to_token: Tensor, r_l: Tensor | None = None, pred_r1: Tensor | None = None, s_i: Tensor | None = None, z_ij: Tensor | None = None, num_diffusion_samples: int = 1, return_intermediates: bool = False, inference_cache: dict | None = None, ) -> tuple[Tensor, Tensor, Tensor, tuple, list[Tensor]]: """Returns (a, q, c, attention_params, intermediates). ``inference_cache`` caches step-invariant tensors (c_base, 3D RoPE, attention indices, n_tokens) across diffusion steps. """ B, N = ref_pos.shape[:2] layer_cache = None if inference_cache is not None: layer_cache = inference_cache.setdefault("atomencoder", {}) if layer_cache is None or len(layer_cache) == 0: atom_feats = torch.cat( [ ref_pos, ref_charge.unsqueeze(-1), atom_attention_mask.unsqueeze(-1), ref_element, ref_atom_name_chars.reshape(B, N, MAX_CHARS * CHAR_VOCAB_SIZE), ], dim=-1, ) c_base = self.atom_norm(self.atom_linear(atom_feats)) cos, sin = self.atom_transformer._build_3d_rope(ref_pos, ref_space_uid) cos = cos.repeat_interleave(num_diffusion_samples, 0) sin = sin.repeat_interleave(num_diffusion_samples, 0) mask_exp = atom_attention_mask.repeat_interleave(num_diffusion_samples, 0) seqlens = mask_exp.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(mask_exp.flatten(), as_tuple=False).flatten() max_seqlen = int(seqlens.max().item()) cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) attention_params = (cos, sin, indices, cu_seqlens, max_seqlen) n_tokens = int(atom_to_token.max().item()) + 1 if layer_cache is not None: layer_cache["c_base"] = c_base layer_cache["attention_params"] = attention_params layer_cache["mask_exp"] = mask_exp layer_cache["n_tokens"] = n_tokens layer_cache["atom_to_token_exp"] = atom_to_token.repeat_interleave( num_diffusion_samples, 0 ) else: c_base = layer_cache["c_base"] attention_params = layer_cache["attention_params"] mask_exp = layer_cache["mask_exp"] n_tokens = layer_cache["n_tokens"] c = c_base q = c if self.structure_prediction and r_l is not None: q = q.repeat_interleave(num_diffusion_samples, 0) if pred_r1 is None: pred_r1 = torch.zeros_like(r_l) r_input = torch.cat([r_l, pred_r1], dim=-1) r_to_q = self.coords_linear(r_input) q = q + r_to_q c = c.repeat_interleave(num_diffusion_samples, 0) result = self.atom_transformer( q_l=q, c_l=c, attention_params=attention_params, return_intermediates=return_intermediates, ) if return_intermediates: q, intermediates = result else: q = result intermediates = [] q_to_a = F.relu(self.atom_to_token_linear(q)) if layer_cache is not None and "atom_to_token_exp" in layer_cache: atom_to_token_exp = layer_cache["atom_to_token_exp"] else: atom_to_token_exp = atom_to_token.repeat_interleave( num_diffusion_samples, 0 ) a = scatter_atom_to_token( q_to_a, atom_to_token_exp, n_tokens, atom_mask=mask_exp.bool() ) return a, q, c, attention_params, intermediates # =========================================================================== # ESMFold2AtomDecoder # =========================================================================== class ESMFold2AtomDecoder(nn.Module): """SWA atom decoder with token_to_atom_linear, atom_transformer, norm, output_linear.""" def __init__( self, d_atom: int = 128, d_token: int = 768, n_blocks: int = 3, n_heads: int = 4, swa_window_size: int = 128, expansion_ratio: int = 2, spatial_rope_base_frequency: float = 20.0, n_spatial_rope_pairs_per_axis: int = 2, n_uid_rope_pairs: int = 10, uid_rope_base_frequency: float = 10000.0, ) -> None: super().__init__() self.token_to_atom_linear = nn.Linear(d_token, d_atom, bias=False) self.atom_transformer = SWAAtomTransformer( d_atom=d_atom, n_blocks=n_blocks, n_heads=n_heads, swa_window_size=swa_window_size, expansion_ratio=expansion_ratio, spatial_rope_base_frequency=spatial_rope_base_frequency, n_spatial_rope_pairs_per_axis=n_spatial_rope_pairs_per_axis, n_uid_rope_pairs=n_uid_rope_pairs, uid_rope_base_frequency=uid_rope_base_frequency, ) self.norm = nn.LayerNorm(d_atom) self.output_linear = nn.Linear(d_atom, XYZ_DIMS, bias=False) def forward( self, a_i: Tensor, q_l: Tensor, c_l: Tensor, p_lm: tuple, atom_to_token: Tensor, atom_attention_mask: Tensor, num_diffusion_samples: int = 1, return_intermediates: bool = False, ) -> tuple[Tensor, list[Tensor]]: """Returns (r_update, intermediates).""" atom_to_token_exp = atom_to_token.repeat_interleave(num_diffusion_samples, 0) a_to_q = self.token_to_atom_linear(a_i) a_to_q = gather_token_to_atom(a_to_q, atom_to_token_exp) q_l = q_l + a_to_q result = self.atom_transformer( q_l=q_l, c_l=c_l, attention_params=p_lm, return_intermediates=return_intermediates, ) if return_intermediates: q_l, intermediates = result else: q_l = result intermediates = [] r_l = self.output_linear(self.norm(q_l)) return r_l, intermediates # =========================================================================== # AttentionPairBias (DiffusionTransformer attention block) # =========================================================================== class AttentionPairBias(nn.Module): """Gated multi-head attention with pair bias conditioning.""" def __init__( self, d_model: int, d_pair: int, num_heads: int, d_cond: int | None = None, use_conditioning: bool = True, ) -> None: super().__init__() self.d_model = d_model self.num_heads = num_heads self.head_dim = d_model // num_heads self.scale = self.head_dim**-0.5 d_cond = d_cond or d_model if use_conditioning: self.adaln = AdaptiveLayerNorm(d_model, d_cond, eps=1e-5) self.out_gate = nn.Linear(d_cond, d_model, bias=True) # adaln init: weight=0, bias=-2 nn.init.zeros_(self.out_gate.weight) nn.init.constant_(self.out_gate.bias, -2.0) else: self.pre_norm = nn.LayerNorm(d_model, eps=1e-5) self.q_proj = nn.Linear(d_model, d_model, bias=True) self.kv_proj = nn.Linear(d_model, 2 * d_model, bias=False) self.g_proj = nn.Linear(d_model, d_model, bias=False) self.out_proj = nn.Linear(d_model, d_model, bias=False) if d_pair > 0: self.pair_norm = nn.LayerNorm(d_pair, eps=1e-5) self.pair_bias_proj = nn.Linear(d_pair, num_heads, bias=False) self._kernel_backend: str | None = None def set_kernel_backend(self, backend: str | None) -> None: if backend not in _VALID_BACKENDS: raise ValueError( f"backend must be one of {_VALID_BACKENDS}, got {backend!r}" ) self._kernel_backend = backend def _is_zero_beta(self, beta: Tensor | float) -> bool: if isinstance(beta, (int, float)): return beta == 0.0 return bool((beta == 0).all()) def _can_use_fused_pair_bias( self, z: Tensor, n_queries: int, beta: Tensor | float ) -> bool: return ( _fused_active(self, z) and z.dim() == 4 and self._is_zero_beta(beta) and hasattr(self, "pair_bias_proj") and hasattr(self, "pair_norm") ) def _can_use_cueq_pair_bias( self, z: Tensor, n_queries: int, beta: Tensor | float ) -> bool: return ( _cueq_active(self) and n_queries > 750 and z.dim() == 4 and self._is_zero_beta(beta) and hasattr(self, "pair_bias_proj") ) def forward( self, a: Tensor, s: Tensor | None, z: Tensor, beta: Tensor | float = 0.0, attention_mask: Tensor | None = None, num_diffusion_samples: int = 1, ) -> Tensor: bsz, n_queries, d_model = a.shape if s is not None: x = self.adaln(a, s) else: x = self.pre_norm(a) n_keys = x.shape[1] q = self.q_proj(x).view(bsz, n_queries, self.num_heads, self.head_dim) kv = self.kv_proj(x) k, v = kv.chunk(2, dim=-1) k = k.view(bsz, n_keys, self.num_heads, self.head_dim) v = v.view(bsz, n_keys, self.num_heads, self.head_dim) # Expand z for num_diffusion_samples if z.dim() == 4 and z.shape[0] != bsz and num_diffusion_samples > 1: z = z.repeat_interleave(num_diffusion_samples, dim=0) if ( attention_mask is not None and attention_mask.shape[0] != bsz and num_diffusion_samples > 1 ): attention_mask = attention_mask.repeat_interleave( num_diffusion_samples, dim=0 ) if self._can_use_fused_pair_bias(z, n_queries, beta): kernel_mask = ( attention_mask if attention_mask is not None else torch.ones(bsz, n_queries, device=a.device, dtype=torch.bool) ) pair_norm_w = self.pair_norm.weight pair_norm_b = ( self.pair_norm.bias if self.pair_norm.bias is not None else torch.zeros_like(pair_norm_w) ) z_bf = z if z.dtype == torch.bfloat16 else z.to(torch.bfloat16) bias = _fused_pair_bias( # type: ignore[misc] z_bf, kernel_mask, self.pair_bias_proj.weight, num_heads=self.num_heads, pair_norm_w=pair_norm_w, pair_norm_b=pair_norm_b, ) # (B, H, Q, K) q_bhqd = q.transpose(1, 2) k_bhqd = k.transpose(1, 2) v_bhqd = v.transpose(1, 2) attn_out = F.scaled_dot_product_attention( q_bhqd, k_bhqd, v_bhqd, attn_mask=bias.to(q_bhqd.dtype) ) g = torch.sigmoid(self.g_proj(x)).view( bsz, n_queries, self.num_heads, self.head_dim ) ctx = g * attn_out.transpose(1, 2) out = self.out_proj(ctx.reshape(bsz, n_queries, d_model)) if s is not None: out = torch.sigmoid(self.out_gate(s)) * out return out if self._can_use_cueq_pair_bias(z, n_queries, beta): kernel_mask = ( attention_mask if attention_mask is not None else torch.ones(bsz, n_queries, device=a.device, dtype=torch.bool) ) out, _ = _cue_attn_pair_bias( # type: ignore[misc] s=x, q=q.transpose(1, 2), k=k.transpose(1, 2), v=v.transpose(1, 2), z=z, mask=kernel_mask, num_heads=self.num_heads, w_proj_z=self.pair_bias_proj.weight, w_proj_g=self.g_proj.weight, w_proj_o=self.out_proj.weight, w_ln_z=self.pair_norm.weight, b_ln_z=self.pair_norm.bias, return_z_proj=False, is_cached_z_proj=False, ) else: # Standard attention with pair bias g = torch.sigmoid(self.g_proj(x)).view( bsz, n_queries, self.num_heads, self.head_dim ) logits = ( torch.einsum("... i h d, ... j h d -> ... i j h", q, k) * self.scale ) if z.dim() == 4: pair_bias = self.pair_bias_proj(self.pair_norm(z)) else: pair_bias = z.unsqueeze(-1) logits = logits + pair_bias.to(dtype=logits.dtype) if attention_mask is not None: min_val = torch.finfo(logits.dtype).min mask_bias = torch.where( attention_mask.bool()[:, None, :, None], 0.0, min_val ) logits = logits + mask_bias.to(dtype=logits.dtype) attn = torch.softmax(logits, dim=-2).to(dtype=v.dtype) ctx = torch.einsum("... i j h, ... j h d -> ... i h d", attn, v) ctx = g * ctx out = self.out_proj(ctx.reshape(bsz, n_queries, d_model)) if s is not None: out = torch.sigmoid(self.out_gate(s)) * out return out # =========================================================================== # ConditionedTransitionBlock # =========================================================================== class ConditionedTransitionBlock(nn.Module): """Conditioned SwiGLU transition with adaptive layer norm.""" def __init__( self, d_model: int, d_cond: int | None = None, transition_multiplier: int = 2, use_conditioning: bool = True, ) -> None: super().__init__() d_cond = d_cond or d_model hidden = transition_multiplier * d_model if use_conditioning: self.adaln = AdaptiveLayerNorm(d_model, d_cond, eps=1e-5) self.output_gate = nn.Linear(d_cond, d_model, bias=True) nn.init.zeros_(self.output_gate.weight) nn.init.constant_(self.output_gate.bias, -2.0) else: self.pre_norm = nn.LayerNorm(d_model, eps=1e-5) self.lin_swish = nn.Linear(d_model, 2 * hidden, bias=False) self.lin_out = nn.Linear(hidden, d_model, bias=False) def forward(self, a: Tensor, s: Tensor | None) -> Tensor: if s is not None: x = self.adaln(a, s) else: x = self.pre_norm(a) swish_a, swish_b = self.lin_swish(x).chunk(2, dim=-1) b = F.silu(swish_a) * swish_b out = self.lin_out(b) if s is not None: out = torch.sigmoid(self.output_gate(s)) * out return out # =========================================================================== # DiffusionTransformer (token transformer) # =========================================================================== class DiffusionTransformer(nn.Module): """Diffusion denoising transformer with attention pair bias.""" def __init__( self, d_model: int, d_pair: int, num_heads: int, num_blocks: int, d_cond: int | None = None, transition_multiplier: int = 2, use_conditioning: bool = True, ) -> None: super().__init__() d_cond = d_cond or d_model self.attn_blocks = nn.ModuleList( [ AttentionPairBias( d_model=d_model, d_pair=d_pair, num_heads=num_heads, d_cond=d_cond, use_conditioning=use_conditioning, ) for _ in range(num_blocks) ] ) self.transition_blocks = nn.ModuleList( [ ConditionedTransitionBlock( d_model=d_model, d_cond=d_cond, transition_multiplier=transition_multiplier, use_conditioning=use_conditioning, ) for _ in range(num_blocks) ] ) def set_kernel_backend(self, backend: str | None) -> None: for attn in self.attn_blocks: cast(AttentionPairBias, attn).set_kernel_backend(backend) def forward( self, a: Tensor, s: Tensor | None, z: Tensor, beta: Tensor | float = 0.0, attention_mask: Tensor | None = None, num_diffusion_samples: int = 1, return_intermediates: bool = False, ) -> tuple[Tensor, list[Tensor]]: intermediates: list[Tensor] = [] x = a for attn, transition in zip(self.attn_blocks, self.transition_blocks): x = x + attn( x, s, z, beta, attention_mask=attention_mask, num_diffusion_samples=num_diffusion_samples, ) x = x + transition(x, s) if return_intermediates: intermediates.append(x) return x, intermediates # =========================================================================== # DiffusionConditioning # =========================================================================== class DiffusionConditioning(nn.Module): """Conditions pair and single representations on noise timestep.""" def __init__( self, c_z: int = 256, c_s: int = 768, c_s_inputs: int = 451, sigma_data: float = 16.0, fourier_dim: int = 256, transition_multiplier: int = 2, layer_norm_eps: float = 1e-5, ) -> None: super().__init__() self.sigma_data = float(sigma_data) self.c_z = c_z self.c_s = c_s self.c_s_inputs = c_s_inputs self.z_input_norm = nn.LayerNorm(2 * c_z, eps=layer_norm_eps) self.z_proj = nn.Linear(2 * c_z, c_z, bias=False) self.z_transitions = nn.ModuleList( [ TransitionLayer(c_z, n=transition_multiplier, eps=layer_norm_eps) for _ in range(2) ] ) self.s_input_norm = nn.LayerNorm(c_s_inputs, eps=layer_norm_eps) self.s_proj = nn.Linear(c_s_inputs, c_s, bias=False) self.fourier = FourierEmbedding(fourier_dim) self.noise_norm = nn.LayerNorm(fourier_dim, eps=layer_norm_eps) self.noise_proj = nn.Linear(fourier_dim, c_s, bias=False) self.s_transitions = nn.ModuleList( [ TransitionLayer(c_s, n=transition_multiplier, eps=layer_norm_eps) for _ in range(2) ] ) def forward( self, t_hat: Tensor, s_inputs: Tensor, s_trunk: Tensor | None, z_trunk: Tensor, relative_position_encoding: Tensor, sigma_data: float | None = None, num_diffusion_samples: int = 1, inference_cache: dict[str, Tensor] | None = None, ) -> tuple[Tensor, Tensor]: sigma = self.sigma_data if sigma_data is None else float(sigma_data) base_batch = z_trunk.shape[0] target_batch = base_batch * num_diffusion_samples # z conditioning (cached across diffusion steps — independent of t_hat) if inference_cache is not None and "z" in inference_cache: z = inference_cache["z"] else: z_rel = relative_position_encoding.to(dtype=torch.float32) z = torch.cat([z_trunk.to(dtype=torch.float32), z_rel], dim=-1) z = self.z_proj(self.z_input_norm(z)) with torch.autocast(device_type="cuda", dtype=torch.bfloat16): for block in self.z_transitions: z = z + block(z) if inference_cache is not None: inference_cache["z"] = z # s conditioning s_inputs_eff = s_inputs if s_inputs_eff.shape[0] != target_batch: s_inputs_eff = s_inputs_eff.repeat_interleave(num_diffusion_samples, 0) s = self.s_proj(self.s_input_norm(s_inputs_eff.to(dtype=torch.float32))) # Noise embedding t = torch.as_tensor(t_hat, dtype=torch.float32, device=s.device).reshape(-1) if t.numel() == 1: t = t.expand(target_batch) elif t.shape[0] != target_batch: t = t.repeat_interleave(num_diffusion_samples, 0) t_noise = 0.25 * torch.log((t / sigma).clamp(min=1e-20)) n = self.fourier(t_noise) n = self.noise_proj(self.noise_norm(n)) s = s + n.unsqueeze(1) for block in self.s_transitions: s = s + block(s) return s, z # =========================================================================== # DiffusionModule # =========================================================================== class DiffusionModule(nn.Module): """Diffusion denoising module for structure prediction.""" def __init__( self, c_atom: int = 128, c_token: int = 768, c_z: int = 256, c_s_inputs: int = 451, sigma_data: float = 16.0, fourier_dim: int = 256, atom_num_blocks: int = 3, atom_num_heads: int = 4, token_num_blocks: int = 12, token_num_heads: int = 16, transition_multiplier: int = 2, swa_window_size: int = 128, spatial_rope_base_frequency: float = 20.0, n_spatial_rope_pairs_per_axis: int = 2, n_uid_rope_pairs: int = 10, uid_rope_base_frequency: float = 10000.0, ) -> None: super().__init__() self.sigma_data = float(sigma_data) self.conditioning = DiffusionConditioning( c_z=c_z, c_s=c_token, # conditioning s output is c_token c_s_inputs=c_s_inputs, sigma_data=sigma_data, fourier_dim=fourier_dim, transition_multiplier=transition_multiplier, ) # Atom encoder (structure_prediction=True, with coords_linear) self.atom_encoder = ESMFold2AtomEncoder( d_atom=c_atom, d_token=c_token, n_blocks=atom_num_blocks, n_heads=atom_num_heads, swa_window_size=swa_window_size, expansion_ratio=2, structure_prediction=True, spatial_rope_base_frequency=spatial_rope_base_frequency, n_spatial_rope_pairs_per_axis=n_spatial_rope_pairs_per_axis, n_uid_rope_pairs=n_uid_rope_pairs, uid_rope_base_frequency=uid_rope_base_frequency, ) # Atom decoder self.atom_decoder = ESMFold2AtomDecoder( d_atom=c_atom, d_token=c_token, n_blocks=atom_num_blocks, n_heads=atom_num_heads, swa_window_size=swa_window_size, expansion_ratio=2, spatial_rope_base_frequency=spatial_rope_base_frequency, n_spatial_rope_pairs_per_axis=n_spatial_rope_pairs_per_axis, n_uid_rope_pairs=n_uid_rope_pairs, uid_rope_base_frequency=uid_rope_base_frequency, ) self.s_to_token = nn.Linear(c_token, c_token, bias=False) nn.init.zeros_(self.s_to_token.weight) # Token transformer (DiffusionTransformer with pair bias) self.token_transformer = DiffusionTransformer( d_model=c_token, d_pair=c_z, num_heads=token_num_heads, num_blocks=token_num_blocks, d_cond=c_token, transition_multiplier=transition_multiplier, use_conditioning=True, ) self.s_step_norm = nn.LayerNorm(c_token) self.token_norm = nn.LayerNorm(c_token) def set_kernel_backend(self, backend: str | None) -> None: self.token_transformer.set_kernel_backend(backend) def forward( self, x_noisy: Tensor, t_hat: Tensor, ref_pos: Tensor, ref_charge: Tensor, ref_mask: Tensor, ref_element: Tensor, ref_atom_name_chars: Tensor, ref_space_uid: Tensor, tok_idx: Tensor, s_inputs: Tensor, s_trunk: Tensor | None, z_trunk: Tensor, relative_position_encoding: Tensor, asym_id: Tensor, residue_index: Tensor, entity_id: Tensor, token_index: Tensor, sym_id: Tensor, sigma_data: float | None = None, token_attention_mask: Tensor | None = None, num_diffusion_samples: int = 1, return_token_repr: bool = False, return_atom_repr: bool = False, inference_cache: dict[str, Tensor] | None = None, ) -> dict[str, Tensor | None]: bsz = x_noisy.shape[0] sigma = self.sigma_data if sigma_data is None else float(sigma_data) t = torch.as_tensor(t_hat, dtype=torch.float32, device=x_noisy.device).reshape( -1 ) if t.numel() == 1: t = t.expand(bsz) # Step 1: conditioning (pair z is cached across diffusion steps) s, z = self.conditioning( t_hat=t, s_inputs=s_inputs, s_trunk=s_trunk, z_trunk=z_trunk, relative_position_encoding=relative_position_encoding, sigma_data=sigma, num_diffusion_samples=num_diffusion_samples, inference_cache=inference_cache, ) # Step 2: normalize noisy coords denom = torch.sqrt(t * t + sigma * sigma) r_noisy = x_noisy / denom[:, None, None] # Step 3: atom encoder a, q_skip, c_skip, p_skip, enc_intermediates = self.atom_encoder( ref_pos=ref_pos, atom_attention_mask=ref_mask, ref_space_uid=ref_space_uid, ref_charge=ref_charge, ref_element=ref_element, ref_atom_name_chars=ref_atom_name_chars, atom_to_token=tok_idx, r_l=r_noisy, s_i=s_trunk, num_diffusion_samples=num_diffusion_samples, return_intermediates=return_atom_repr, inference_cache=inference_cache, ) # Step 4: add conditioned s a = a + self.s_to_token(self.s_step_norm(s)) # Step 5: token transformer a, _ = self.token_transformer( a, s, z, beta=0.0, attention_mask=token_attention_mask, num_diffusion_samples=num_diffusion_samples, ) # Step 6: token norm a = self.token_norm(a) # Step 7: atom decoder r_update, dec_intermediates = self.atom_decoder( a_i=a, q_l=q_skip, c_l=c_skip, p_lm=p_skip, atom_to_token=tok_idx, atom_attention_mask=ref_mask, num_diffusion_samples=num_diffusion_samples, return_intermediates=return_atom_repr, ) # Step 8: compute denoised output sigma2 = sigma * sigma t2 = t * t out = (sigma2 / (sigma2 + t2))[:, None, None] * x_noisy out = out + ((sigma * t) / torch.sqrt(sigma2 + t2))[:, None, None] * r_update # Collect atom intermediates from encoder + decoder atom_intermediates: Tensor | None = None if return_atom_repr: all_ints = enc_intermediates + dec_intermediates if all_ints: atom_intermediates = torch.stack(all_ints, dim=2) return { "x_denoised": out, "token_repr": a if return_token_repr else None, "atom_intermediates": atom_intermediates, } # =========================================================================== # DiffusionStructureHead # =========================================================================== class DiffusionStructureHead(nn.Module): """Wrapper around DiffusionModule with diffusion sampling.""" def __init__(self, config: ESMFold2Config) -> None: super().__init__() dm = config.structure_head.diffusion_module swa_cfg = config.inputs.atom_encoder sh = config.structure_head self.diffusion_module = DiffusionModule( c_atom=dm.c_atom, c_token=dm.c_token, c_z=dm.c_z, c_s_inputs=dm.c_s_inputs, sigma_data=dm.sigma_data, fourier_dim=dm.fourier_dim, atom_num_blocks=dm.atom_num_blocks, atom_num_heads=dm.atom_num_heads, token_num_blocks=dm.token_num_blocks, token_num_heads=dm.token_num_heads, transition_multiplier=dm.transition_multiplier, swa_window_size=swa_cfg.swa_window_size, spatial_rope_base_frequency=swa_cfg.spatial_rope_base_frequency, n_spatial_rope_pairs_per_axis=swa_cfg.n_spatial_rope_pairs_per_axis, n_uid_rope_pairs=swa_cfg.n_uid_rope_pairs, uid_rope_base_frequency=swa_cfg.uid_rope_base_frequency, ) # Sampling hyperparameters self.sigma_data = dm.sigma_data self.gamma_0 = sh.gamma_0 self.gamma_min = sh.gamma_min self.noise_scale = sh.noise_scale self.step_scale = sh.step_scale self.inference_s_max = sh.inference_s_max self.inference_s_min = sh.inference_s_min self.inference_p = sh.inference_p self.inference_num_steps = sh.inference_num_steps def set_kernel_backend(self, backend: str | None) -> None: self.diffusion_module.set_kernel_backend(backend) # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def inference_noise_schedule( self, num_steps: int | None = None, device: torch.device | None = None ) -> Tensor: """Karras power-law noise schedule.""" steps = self.inference_num_steps if num_steps is None else int(num_steps) if steps == 1: return torch.tensor( [self.inference_s_max * self.sigma_data, 0.0], device=device, dtype=torch.float32, ) p = float(self.inference_p) inv_p = 1.0 / p k = torch.arange(steps, device=device, dtype=torch.float32) base = self.inference_s_max**inv_p + (k / (steps - 1)) * ( self.inference_s_min**inv_p - self.inference_s_max**inv_p ) schedule = self.sigma_data * base.pow(p) return F.pad(schedule, (0, 1), value=0.0) @staticmethod def _random_rotations(n: int, dtype: torch.dtype, device: torch.device) -> Tensor: q = torch.randn((n, 4), dtype=dtype, device=device) scale = torch.sqrt((q * q).sum(dim=1)) signs = torch.where(q[:, 0] < 0, -scale, scale) q = q / signs[:, None] r, i, j, k = torch.unbind(q, dim=-1) two_s = 2.0 / (q * q).sum(dim=-1) return torch.stack( ( 1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r), two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r), two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j), ), dim=-1, ).reshape(n, 3, 3) def _center_random_augmentation( self, x: Tensor, atom_mask: Tensor, second_coords: Tensor | None = None ) -> tuple[Tensor, Tensor | None]: """Algorithm 19: center + random rotation + translation.""" bsz = x.shape[0] mask = atom_mask.unsqueeze(-1) # [B, A, 1] denom = mask.sum(dim=1, keepdim=True).clamp(min=1) mean = (x * mask).sum(dim=1, keepdim=True) / denom x = x - mean if second_coords is not None: second_coords = second_coords - mean r = self._random_rotations(bsz, x.dtype, x.device) x = torch.einsum("bmd,bds->bms", x, r) if second_coords is not None: second_coords = torch.einsum("bmd,bds->bms", second_coords, r) t = torch.randn_like(x[:, 0:1, :]) x = x + t if second_coords is not None: second_coords = second_coords + t return x, second_coords @staticmethod def _weighted_rigid_align( x: Tensor, x_gt: Tensor, w: Tensor, mask: Tensor ) -> Tensor: """Kabsch alignment: align x to x_gt with weights w.""" w = (mask * w).unsqueeze(-1) # [B, N, 1] denom = w.sum(dim=-2, keepdim=True).clamp(min=1e-8) mu = (x * w).sum(dim=-2, keepdim=True) / denom mu_gt = (x_gt * w).sum(dim=-2, keepdim=True) / denom x_c = x - mu xgt_c = x_gt - mu_gt H = torch.einsum("bni,bnj->bij", w * xgt_c, x_c) H32 = H.float() U, _, Vh = torch.linalg.svd(H32, driver="gesvd" if H32.is_cuda else None) det = torch.linalg.det(U @ Vh) ones = torch.ones_like(det) R = (U @ torch.diag_embed(torch.stack([ones, ones, det], dim=-1)) @ Vh).to( H.dtype ) return x_c @ R.transpose(-1, -2) + mu_gt # ------------------------------------------------------------------ # Sampling # ------------------------------------------------------------------ @torch.inference_mode() def sample( self, z_trunk: Tensor, s_inputs: Tensor, s_trunk: Tensor | None, relative_position_encoding: Tensor, ref_pos: Tensor, ref_charge: Tensor, ref_mask: Tensor, ref_element: Tensor, ref_atom_name_chars: Tensor, ref_space_uid: Tensor, tok_idx: Tensor, asym_id: Tensor, residue_index: Tensor, entity_id: Tensor, token_index: Tensor, sym_id: Tensor, token_attention_mask: Tensor | None = None, num_diffusion_samples: int = 1, num_sampling_steps: int | None = None, max_inference_sigma: float | None = 256.0, noise_scale: float | None = None, step_scale: float | None = None, return_atom_repr: bool = False, use_inference_cache: bool = True, denoising_early_exit_rmsd: float | None = None, ) -> dict[str, Tensor | None]: """Diffusion sampling (Algorithm 18). ``num_sampling_steps`` is the number of denoising steps actually run. When ``max_inference_sigma`` is set, the Karras schedule built with ``num_sampling_steps`` entries would lose its high-σ tail to the cap, so we inflate the underlying schedule length here to land back at the requested step count post-truncation. """ n_atoms = tok_idx.shape[1] device = s_inputs.device target_batch = s_inputs.shape[0] * num_diffusion_samples inference_cache: dict[str, Tensor] | None = {} if use_inference_cache else None steps = ( self.inference_num_steps if num_sampling_steps is None else int(num_sampling_steps) ) schedule = self.inference_noise_schedule(steps, device) if max_inference_sigma is not None: schedule = schedule[schedule <= float(max_inference_sigma)] schedule = F.pad(schedule, (1, 0), value=float(max_inference_sigma)) lam = self.noise_scale if noise_scale is None else float(noise_scale) eta = self.step_scale if step_scale is None else float(step_scale) x = schedule[0] * torch.randn( target_batch, n_atoms, 3, device=device, dtype=torch.float32 ) atom_mask = ref_mask.repeat_interleave(num_diffusion_samples, 0).float() gammas = torch.where( schedule > self.gamma_min, torch.full_like(schedule, self.gamma_0), torch.zeros_like(schedule), ) x_denoised_prev: Tensor | None = None token_repr: Tensor | None = None diff_atom_intermediates: Tensor | None = None step_pairs = list(zip(schedule[:-1], schedule[1:], gammas[1:])) num_steps = len(step_pairs) for step_idx, (sigma_tm, sigma_t, gamma) in enumerate(step_pairs): x, x_denoised_prev = self._center_random_augmentation( x, atom_mask, second_coords=x_denoised_prev ) sigma_tm_val = float(sigma_tm.item()) t_hat_val = sigma_tm_val * (1.0 + float(gamma.item())) eps_std = lam * max(t_hat_val**2 - sigma_tm_val**2, 0.0) ** 0.5 x_noisy = x + eps_std * torch.randn_like(x) is_last_step = step_idx == num_steps - 1 request_atom_repr = return_atom_repr and ( is_last_step or denoising_early_exit_rmsd is not None ) dm_out = self.diffusion_module( x_noisy=x_noisy, t_hat=torch.full( (target_batch,), t_hat_val, device=device, dtype=torch.float32 ), ref_pos=ref_pos, ref_charge=ref_charge, ref_mask=ref_mask, ref_element=ref_element, ref_atom_name_chars=ref_atom_name_chars, ref_space_uid=ref_space_uid, tok_idx=tok_idx, s_inputs=s_inputs, s_trunk=s_trunk, z_trunk=z_trunk, relative_position_encoding=relative_position_encoding, asym_id=asym_id, residue_index=residue_index, entity_id=entity_id, token_index=token_index, sym_id=sym_id, token_attention_mask=token_attention_mask, num_diffusion_samples=num_diffusion_samples, return_token_repr=True, return_atom_repr=request_atom_repr, inference_cache=inference_cache, ) x_denoised = dm_out["x_denoised"] token_repr = dm_out["token_repr"] if request_atom_repr: diff_atom_intermediates = dm_out.get("atom_intermediates") # Reverse diffusion alignment (Kabsch) with torch.autocast(device_type="cuda", enabled=False): x_noisy = self._weighted_rigid_align( x_noisy.float(), x_denoised.float(), atom_mask, atom_mask ) x_noisy = x_noisy.to(dtype=x_denoised.dtype) # ODE/SDE step sigma_t_val = float(sigma_t.item()) denoised_over_sigma = (x_noisy - x_denoised) / t_hat_val x = x_noisy + eta * (sigma_t_val - t_hat_val) * denoised_over_sigma # Denoising early-exit: stop when consecutive predictions converge if ( denoising_early_exit_rmsd is not None and x_denoised_prev is not None and step_idx >= 1 ): with torch.autocast(device_type="cuda", enabled=False): aligned = self._weighted_rigid_align( x_denoised_prev.float(), x_denoised.float(), atom_mask, atom_mask, ) diff = (x_denoised.float() - aligned) * atom_mask.unsqueeze(-1) per_sample_rmsd = ( diff.pow(2).sum(dim=(-1, -2)) / atom_mask.sum(dim=-1).clamp(min=1) ).sqrt() if per_sample_rmsd.max().item() < denoising_early_exit_rmsd: x = x_denoised x_denoised_prev = x_denoised break x_denoised_prev = x_denoised result: dict[str, Tensor | None] = { "sample_atom_coords": x, "diff_token_repr": token_repr, } if return_atom_repr: result["diff_atom_intermediates"] = diff_atom_intermediates return result # =========================================================================== # RowAttentionPooling # =========================================================================== class RowAttentionPooling(nn.Module): """Row-wise attention pooling: attn_proj, out_proj.""" def __init__(self, d_pair: int, d_single: int) -> None: super().__init__() self.attn_proj = nn.Linear(d_pair, 1, bias=False) self.out_proj = nn.Linear(d_pair, d_single, bias=False) def forward(self, z: Tensor, mask: Tensor) -> Tensor: scores = self.attn_proj(z).squeeze(-1) mask_bias = torch.where( mask[:, None, :].bool(), torch.zeros_like(scores), torch.full_like(scores, -1e9), ) scores = scores + mask_bias weights = F.softmax(scores, dim=-1) pooled = torch.einsum("bnm,bnmd->bnd", weights, z) return self.out_proj(pooled) # =========================================================================== # InputsEmbedder # =========================================================================== class InputsEmbedder(nn.Module): """Embeds input features including atom-level encoding via SWA attention.""" def __init__(self, config: ESMFold2Config) -> None: super().__init__() swa_cfg = config.inputs.atom_encoder self.atom_attention_encoder = ESMFold2AtomEncoder( d_atom=swa_cfg.d_atom, d_token=swa_cfg.d_token, n_blocks=swa_cfg.n_blocks, n_heads=swa_cfg.n_heads, swa_window_size=swa_cfg.swa_window_size, expansion_ratio=swa_cfg.expansion_ratio, structure_prediction=False, # no coords_linear spatial_rope_base_frequency=swa_cfg.spatial_rope_base_frequency, n_spatial_rope_pairs_per_axis=swa_cfg.n_spatial_rope_pairs_per_axis, n_uid_rope_pairs=swa_cfg.n_uid_rope_pairs, uid_rope_base_frequency=swa_cfg.uid_rope_base_frequency, ) def forward( self, aatype: Tensor, profile: Tensor, deletion_mean: Tensor, ref_pos: Tensor, atom_attention_mask: Tensor, ref_space_uid: Tensor, ref_charge: Tensor, ref_element: Tensor, ref_atom_name_chars: Tensor, atom_to_token: Tensor, ) -> Tensor: """Embed inputs into per-token features. Returns: [B, L, d_inputs] concatenation of atom encoding, aatype, profile, and deletion_mean. """ a, _q, _c, _attn_params, _intermediates = self.atom_attention_encoder( ref_pos=ref_pos, atom_attention_mask=atom_attention_mask, ref_space_uid=ref_space_uid, ref_charge=ref_charge, ref_element=ref_element, ref_atom_name_chars=ref_atom_name_chars, atom_to_token=atom_to_token, ) return torch.cat([a, aatype, profile, deletion_mean.unsqueeze(-1)], dim=-1) # =========================================================================== # ResIdxAsymIdSymIdEntityIdEncoding (trunk relative position) # =========================================================================== class ResIdxAsymIdSymIdEntityIdEncoding(nn.Module): """embed.weight [d_pair, n_features] where n_features = 2*(2*r_bins+2) + 1 + (2*c_bins+2). For default r_bins=32, c_bins=2: 2*66 + 1 + 6 = 139. """ def __init__( self, n_relative_residx_bins: int = 32, n_relative_chain_bins: int = 2, d_pair: int = 256, ) -> None: super().__init__() self.n_relative_residx_bins = n_relative_residx_bins self.n_relative_chain_bins = n_relative_chain_bins self.d_pair = d_pair n_feats_residue = 2 * n_relative_residx_bins + 2 n_feats_token = 2 * n_relative_residx_bins + 2 n_feats_chain = 2 * n_relative_chain_bins + 2 n_feats_same_entity = 1 total_feats = ( n_feats_residue + n_feats_token + n_feats_chain + n_feats_same_entity ) self.embed = nn.Linear(total_feats, d_pair, bias=False) def forward( self, residue_index: Tensor, asym_id: Tensor, sym_id: Tensor, entity_id: Tensor, token_index: Tensor, ) -> Tensor: bij_same_chain = asym_id.unsqueeze(2) == asym_id.unsqueeze(1) bij_same_residue = residue_index.unsqueeze(2) == residue_index.unsqueeze(1) bij_same_entity = entity_id.unsqueeze(2) == entity_id.unsqueeze(1) dij_residue = residue_index.unsqueeze(2) - residue_index.unsqueeze(1) dij_residue = torch.clip( dij_residue + self.n_relative_residx_bins, 0, 2 * self.n_relative_residx_bins, ) dij_residue = torch.where( bij_same_chain, dij_residue, 2 * self.n_relative_residx_bins + 1 ) aij_rel_pos = F.one_hot(dij_residue, 2 * self.n_relative_residx_bins + 2) dij_token = torch.clip( token_index.unsqueeze(2) - token_index.unsqueeze(1) + self.n_relative_residx_bins, 0, 2 * self.n_relative_residx_bins, ) dij_token = torch.where( bij_same_chain & bij_same_residue, dij_token, 2 * self.n_relative_residx_bins + 1, ) aij_rel_token = F.one_hot(dij_token, 2 * self.n_relative_residx_bins + 2) dij_chain = torch.clip( sym_id.unsqueeze(2) - sym_id.unsqueeze(1) + self.n_relative_chain_bins, 0, 2 * self.n_relative_chain_bins, ) dij_chain = torch.where( bij_same_chain, 2 * self.n_relative_chain_bins + 1, dij_chain ) aij_rel_chain = F.one_hot(dij_chain, 2 * self.n_relative_chain_bins + 2) feats = torch.cat( [ aij_rel_pos.float(), aij_rel_token.float(), bij_same_entity.float().unsqueeze(-1), aij_rel_chain.float(), ], dim=-1, ) return self.embed(feats) # =========================================================================== # SingleToPair (for LanguageModelShim) # =========================================================================== class SingleToPair(nn.Module): """downproject, output_mlp (Sequential of Linear, GELU, Linear).""" def __init__(self, input_dim: int, downproject_dim: int, output_dim: int) -> None: super().__init__() self.downproject = nn.Linear(input_dim, downproject_dim) self.output_mlp = nn.Sequential( nn.Linear(2 * downproject_dim, output_dim), nn.GELU(), nn.Linear(output_dim, output_dim), ) def forward(self, x: Tensor) -> Tensor: x = self.downproject(x) x = torch.cat( [(x.unsqueeze(2) * x.unsqueeze(1)), (x.unsqueeze(2) - x.unsqueeze(1))], dim=3, ) return self.output_mlp(x) # =========================================================================== # LanguageModelShim # =========================================================================== class LanguageModelShim(nn.Module): """Shim holding the trainable projection weights for LM integration. Contains: - base_z_combine: nn.Parameter [num_layers+1] - base_z_linear: Sequential(LayerNorm(d_model), Linear(d_model, d_z, bias=False)) - base_z_mlp: Sequential(SingleToPair(d_z, d_z, d_z), LayerNorm(d_z)) """ def __init__( self, d_z: int = 256, d_model: int = 2560, num_layers: int = 80 ) -> None: super().__init__() self.base_z_mlp = nn.Sequential(SingleToPair(d_z, d_z, d_z), nn.LayerNorm(d_z)) self.base_z_linear = nn.Sequential( nn.LayerNorm(d_model), nn.Linear(d_model, d_z, bias=False) ) self.base_z_combine = nn.Parameter(torch.zeros(num_layers + 1)) def forward(self, hidden_states: Tensor, *, lm_dropout: float = 0.0) -> Tensor: """Project pre-computed ESMC hidden states to pair representation. Args: hidden_states: [B, L, num_layers+1, d_model] from ESMC 6B. lm_dropout: Dropout probability applied to the pair representation after ``base_z_mlp``. Returns: [B, L, L, d_pair] pair representation. """ lm_z = self.base_z_linear(hidden_states) # [B, L, 81, d_z] weights = self.base_z_combine.softmax(0) # [81] lm_z = (weights @ lm_z).squeeze(-2) # [B, L, d_z] lm_z = self.base_z_mlp(lm_z) # [B, L, L, d_z] if lm_dropout > 0: lm_z = F.dropout(lm_z, p=lm_dropout, training=True) return lm_z # =========================================================================== # Reproducibility helper (mirrors evolutionaryscale.utils.reproducibility) # =========================================================================== @contextmanager def _seed_context(seed: int | None, *, cuda: bool = True): """Temporarily seed Python, NumPy, and PyTorch RNGs.""" if seed is None: yield return py_state = random.getstate() np_state = np.random.get_state() torch_state = torch.get_rng_state() cuda_states = ( torch.cuda.get_rng_state_all() if cuda and torch.cuda.is_available() else None ) seed = int(seed) % (2**32) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if cuda and torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) try: yield finally: random.setstate(py_state) np.random.set_state(np_state) torch.set_rng_state(torch_state) if cuda_states is not None: torch.cuda.set_rng_state_all(cuda_states) # =========================================================================== # ESMFold2ExperimentalModel — the top-level PreTrainedModel # =========================================================================== def compute_lm_hidden_states( esmc: nn.Module, input_ids: Tensor, asym_id: Tensor, residue_index: Tensor, mol_type: Tensor, token_mask: Tensor, pad_to_multiple: int | None = None, lm_mask_pct: float = 0.0, mask_token_id: int = 32, ) -> Tensor: """Run ESMC with BOS/EOS wrapping, return hidden states [B, L, N, D] with N=81 layers. Atom-tokenized modified residues (HYP, MSE, ACE, NH2, ...) span multiple structure tokens but share a single ``(asym_id, residue_index)`` key — collapse them to one LM token per residue before running the LM (the LM was trained on per-residue inputs, not per-atom), then scatter the hidden states back to the per-token layout. """ B, L = input_ids.shape device = input_ids.device protein_mask = (mol_type == 0) & token_mask lm_input_list = [] lm_lengths = [] # Per-batch maps from (original protein-token index) to (LM input position). expand_maps: list[Tensor] = [] for b in range(B): mask_b = protein_mask[b] ids_b = input_ids[b][mask_b] asym_b = asym_id[b][mask_b] res_b = residue_index[b][mask_b] # Collapse: keep first token per (asym_id, residue_index) key, in # input order. ``inverse`` maps each original protein-token to its # collapsed residue index. keys = torch.stack((asym_b, res_b), dim=1) unique_keys, inverse = torch.unique(keys, dim=0, return_inverse=True) n_unique = unique_keys.size(0) token_positions = torch.arange(keys.size(0), device=device, dtype=torch.long) first_pos = torch.full( (n_unique,), keys.size(0), device=device, dtype=torch.long ) first_pos.scatter_reduce_( 0, inverse, token_positions, reduce="amin", include_self=True ) ordered = torch.argsort(first_pos) first_pos_ordered = first_pos[ordered] ids_collapsed = ids_b[first_pos_ordered] asym_collapsed = asym_b[first_pos_ordered] remap = torch.empty_like(ordered) remap[ordered] = torch.arange(n_unique, device=device, dtype=torch.long) inverse_ordered = remap[inverse] chain_ids = asym_collapsed.unique(sorted=True) # [BOS] chain1 [EOS BOS] chain2 ... [EOS] parts: list[Tensor] = [torch.tensor([0], device=device, dtype=ids_b.dtype)] # Per-chain LM positions accumulate; track them for the expand map. per_token_lm_pos = torch.empty(n_unique, device=device, dtype=torch.long) cursor = 1 # position 0 is the leading BOS for i, cid in enumerate(chain_ids): in_chain = (asym_collapsed == cid).nonzero(as_tuple=True)[0] parts.append(ids_collapsed[in_chain]) per_token_lm_pos[in_chain] = torch.arange( cursor, cursor + in_chain.shape[0], device=device, dtype=torch.long ) cursor += in_chain.shape[0] if i < len(chain_ids) - 1: parts.append(torch.tensor([2, 0], device=device, dtype=ids_b.dtype)) cursor += 2 # EOS + BOS parts.append(torch.tensor([2], device=device, dtype=ids_b.dtype)) lm_seq = torch.cat(parts) lm_input_list.append(lm_seq) lm_lengths.append(lm_seq.shape[0]) # Original protein-token position → LM input position. prot_pos_b = mask_b.nonzero(as_tuple=True)[0] expand_map = torch.full((L,), -1, device=device, dtype=torch.long) expand_map[prot_pos_b] = per_token_lm_pos[inverse_ordered] expand_maps.append(expand_map) # Pad to longest LM input; round to ``pad_to_multiple`` when fp8 is on # (TE fp8 kernels assert prod(shape[:-1]) % 8 == 0). max_len = max(lm_lengths) if pad_to_multiple is not None and pad_to_multiple > 1: max_len = ((max_len + pad_to_multiple - 1) // pad_to_multiple) * pad_to_multiple lm_input_ids = torch.full( (B, max_len), 1, device=device, dtype=input_ids.dtype, # PAD=1 ) for b in range(B): lm_input_ids[b, : lm_lengths[b]] = lm_input_list[b] # sequence_id for chain-aware attention; PAD tokens get -1 (no attention). sequence_id = (lm_input_ids == 0).cumsum(dim=1) - 1 # BOS=0 sequence_id = sequence_id.masked_fill(lm_input_ids == 1, -1) # PAD=1 if lm_mask_pct > 0.0: special = (lm_input_ids == 0) | (lm_input_ids == 1) | (lm_input_ids == 2) do_mask = ( torch.rand(lm_input_ids.shape, device=device) < lm_mask_pct ) & ~special lm_input_ids = lm_input_ids.masked_fill(do_mask, mask_token_id) with torch.inference_mode(): esmc_out = esmc( input_ids=lm_input_ids, sequence_id=sequence_id, output_hidden_states=True ) hs = esmc_out.hidden_states # [n_layers+1, B, max_len, D] n_layers_plus_1, _, _, D = hs.shape result = torch.zeros(B, L, n_layers_plus_1, D, device=device, dtype=hs.dtype) for b in range(B): mb = protein_mask[b] em = expand_maps[b][mb] # [n_protein_tokens] LM positions # hs[:, b, em, :] -> [n_layers+1, n_protein_tokens, D] gathered = hs[:, b, em, :].permute(1, 0, 2) result[b, mb.nonzero(as_tuple=True)[0]] = gathered return result.detach() # =========================================================================== # TriangleMultiplicativeUpdate # =========================================================================== class TriangleMultiplicativeBlock(nn.Module): """Triangle multiplicative update block with gated signal routing.""" _FLOW_TO_EINSUM = {"outgoing": "bikd,bjkd->bijd", "incoming": "bkid,bkjd->bijd"} _VALID_FLOWS = ("outgoing", "incoming") def __init__(self, input_channels: int, latent_channels: int, flow: str) -> None: super().__init__() if flow not in self._FLOW_TO_EINSUM: raise ValueError( f"Invalid flow={flow!r}. Expected one of {self._VALID_FLOWS}." ) self.input_channels = input_channels self.latent_channels = latent_channels self.flow = flow self._einsum_equation = self._FLOW_TO_EINSUM[flow] self.norm_start = nn.LayerNorm(self.input_channels, eps=_EPS) self.norm_mix = nn.LayerNorm(self.latent_channels, eps=_EPS) self.proj_bundle = nn.Linear( self.input_channels, 4 * self.latent_channels, bias=False ) self.proj_emit = nn.Linear( self.latent_channels, self.input_channels, bias=False ) self.proj_gate = nn.Linear(self.input_channels, self.input_channels, bias=False) self._use_kernels: bool = False # Default chunked for memory on long sequences; tests override with # ``set_chunk_size(None)`` for the unchunked path under bit-exact bf16 # parity checks. self._chunk_size: int | None = 64 def set_chunk_size(self, chunk_size: int | None) -> None: self._chunk_size = chunk_size def split_kernel_weights(self) -> tuple[Tensor, Tensor]: return ( self.proj_bundle.weight[: 2 * self.latent_channels, :], self.proj_bundle.weight[2 * self.latent_channels :, :], ) def _kernel_flow_direction(self) -> str: return self.flow def _triangular_contract(self, left_stream: Tensor, right_stream: Tensor) -> Tensor: return torch.einsum(self._einsum_equation, left_stream, right_stream) def _triangular_contract_chunked( self, left_stream: Tensor, right_stream: Tensor, chunk_size: int ) -> Tensor: """Compute the triangular einsum in chunks along the output i-dimension.""" L = left_stream.shape[1] if self.flow == "outgoing" else left_stream.shape[2] chunks = [] for start in range(0, L, chunk_size): end = min(start + chunk_size, L) if self.flow == "outgoing": chunk = torch.einsum( self._einsum_equation, left_stream[:, start:end], right_stream ) else: chunk = torch.einsum( self._einsum_equation, left_stream[:, :, start:end], right_stream ) chunks.append(chunk) return torch.cat(chunks, dim=1) def forward(self, pair_grid: Tensor, visibility: Tensor | None = None) -> Tensor: if visibility is None: visibility = pair_grid.new_ones(pair_grid.shape[:-1]) if self._use_kernels: p_in_weight, g_in_weight = self.split_kernel_weights() try: return _cue_tri_mul( # type: ignore[misc] pair_grid, direction=self._kernel_flow_direction(), mask=visibility, norm_in_weight=self.norm_start.weight, norm_in_bias=self.norm_start.bias, p_in_weight=p_in_weight, g_in_weight=g_in_weight, norm_out_weight=self.norm_mix.weight, norm_out_bias=self.norm_mix.bias, p_out_weight=self.proj_emit.weight, g_out_weight=self.proj_gate.weight, eps=_EPS, ) except Exception as e: import logging as _logging _logging.getLogger(__name__).warning( "cuequivariance triangle_multiplicative_update kernel failed " "(flow=%s, shape=%s, dtype=%s); falling back to chunked einsum. " "Error: %s", self.flow, tuple(pair_grid.shape), pair_grid.dtype, e, ) normalized_grid = self.norm_start(pair_grid) bundled = self.proj_bundle(normalized_grid) signal, gate_logits = bundled.split(2 * self.latent_channels, dim=-1) routed = signal * torch.sigmoid(gate_logits) routed = routed * visibility.unsqueeze(-1) left_stream, right_stream = routed.float().chunk(2, dim=-1) if self._chunk_size is not None: contracted = self._triangular_contract_chunked( left_stream, right_stream, self._chunk_size ) else: contracted = self._triangular_contract(left_stream, right_stream) mixed = self.proj_emit(self.norm_mix(contracted)) output_gate = torch.sigmoid(self.proj_gate(normalized_grid)) return mixed * output_gate class TriangleMultiplicativeUpdate(nn.Module): """Thin wrapper exposing the triangular mixer with explicit orientation (v3).""" def __init__(self, dim: int = 128, _outgoing: bool = True) -> None: super().__init__() flow = "outgoing" if _outgoing else "incoming" self._engine = TriangleMultiplicativeBlock( input_channels=dim, latent_channels=dim, flow=flow ) def set_kernel_backend(self, backend: str | None) -> None: # Engine uses cueq when backend=="cuequivariance"; the "fused" backend # routes through the parent PairUpdateBlock's fused path (bypassing this). self._engine._use_kernels = backend == BACKEND_CUEQ if backend == BACKEND_CUEQ and not CUE_AVAILABLE: raise RuntimeError( "backend='cuequivariance' but cuequivariance_torch is not installed." ) def set_chunk_size(self, chunk_size: int | None) -> None: self._engine.set_chunk_size(chunk_size) def forward(self, z: Tensor, mask: Tensor | None = None) -> Tensor: return self._engine(z, visibility=mask) # =========================================================================== # FoldingTrunk: Transition, PairUpdateBlock, FoldingTrunk # =========================================================================== class Transition(nn.Module): """LN + SwiGLU FFN with addmm-fused residual; optional Triton LN+w12+SwiGLU kernel.""" def __init__(self, d_model: int, expansion_ratio: int = 4) -> None: super().__init__() self.norm = nn.LayerNorm(d_model) self.ffn = SwiGLUMLP(d_model, expansion_ratio=expansion_ratio, bias=False) # Default chunked; set_chunk_size(None) disables for bit-exact parity tests. self._chunk_size: int | None = 64 self._fused_swiglu: nn.Module | None = None self._kernel_backend: str | None = None def set_chunk_size(self, chunk_size: int | None) -> None: self._chunk_size = chunk_size def set_kernel_backend(self, backend: str | None) -> None: """Install / uninstall FusedLNLinearSwiGLU (no cueq equivalent).""" if backend not in _VALID_BACKENDS: raise ValueError( f"backend must be one of {_VALID_BACKENDS}, got {backend!r}" ) self._kernel_backend = backend if backend == BACKEND_FUSED and TRITON_KERNELS_AVAILABLE: assert _FusedLNLinearSwiGLU is not None d_model = self.norm.normalized_shape[0] d_inner = self.ffn.hidden_features has_ln_bias = self.norm.bias is not None device = self.ffn.w12.weight.device dtype = self.ffn.w12.weight.dtype fused = _FusedLNLinearSwiGLU( d_model=d_model, d_inner=d_inner, has_ln_bias=has_ln_bias, device=device, dtype=dtype, ) with torch.no_grad(): fused.LN_W.copy_(self.norm.weight) if has_ln_bias: fused.LN_B.copy_(self.norm.bias) # type: ignore[union-attr] # FusedLNLinearSwiGLU.W12 is (d_model, 2*d_inner); transpose nn.Linear once. fused.W12.copy_(self.ffn.w12.weight.t().contiguous()) self._fused_swiglu = fused.eval().requires_grad_(False) else: self._fused_swiglu = None def _can_use_fused_path(self, x: Tensor) -> bool: return ( _fused_active(self, x) and self._fused_swiglu is not None and x.dtype == torch.bfloat16 ) def _swiglu_pre_w3(self, x_normed: Tensor) -> Tensor: """SwiGLU through silu(x1)*x2, before the final w3.""" ffn = self.ffn x12 = ffn.w12(x_normed) x1, x2 = x12.split(ffn.hidden_features, dim=-1) return F.silu(x1) * x2 def _addmm_residual(self, x: Tensor, hidden: Tensor) -> Tensor: """x + w3(hidden) via single cuBLAS addmm — avoids transition-output allocation.""" ffn = self.ffn x_shape = x.shape out = torch.addmm( x.contiguous().view(-1, x_shape[-1]), hidden.view(-1, hidden.shape[-1]), ffn.w3.weight.t(), ) return out.view(x_shape) def forward(self, x: Tensor) -> Tensor: # Inference-only fast path (addmm-fused residual + pre-alloc out) # — diverges bit-exactly from ``x + ffn(norm(x))`` so we only use # it when grad is disabled (binder-design / bit-exact tests run # with grad on and need the reference path). if not torch.is_grad_enabled() and self._can_use_fused_path(x): fused = self._fused_swiglu assert fused is not None pre_w3 = fused if self._chunk_size is None or x.shape[1] <= self._chunk_size: hidden = pre_w3(x) return self._addmm_residual(x, hidden) out = torch.empty_like(x) for s in range(0, x.shape[1], self._chunk_size): e = min(s + self._chunk_size, x.shape[1]) sl = x[:, s:e] hidden = pre_w3(sl) out[:, s:e] = self._addmm_residual(sl, hidden) return out # Reference path — bit-exact with main: x + ffn(norm(x)). if self._chunk_size is None or x.shape[1] <= self._chunk_size: return x + self.ffn(self.norm(x)) out_list: list[Tensor] = [] for s in range(0, x.shape[1], self._chunk_size): e = min(s + self._chunk_size, x.shape[1]) sl = x[:, s:e] out_list.append(sl + self.ffn(self.norm(sl))) return torch.cat(out_list, dim=1) class PairUpdateBlock(nn.Module): """tri_mul_out, tri_mul_in, pair_transition.""" def __init__(self, d_pair: int = 256, expansion_ratio: int = 4) -> None: super().__init__() self.tri_mul_out = TriangleMultiplicativeUpdate(dim=d_pair, _outgoing=True) self.tri_mul_in = TriangleMultiplicativeUpdate(dim=d_pair, _outgoing=False) self.pair_transition = Transition(d_pair, expansion_ratio=expansion_ratio) self._kernel_backend: str | None = None # Row-shared dropout-residual; r=0 for inference (HF model is inference-only). # backend='fused' swaps in the FusedDropoutResidual Triton kernel. self.row_drop = DropoutResidual(0.0, batch_dim=1, use_fused_kernels=False) def set_kernel_backend(self, backend: str | None) -> None: if backend not in _VALID_BACKENDS: raise ValueError( f"backend must be one of {_VALID_BACKENDS}, got {backend!r}" ) self.tri_mul_out.set_kernel_backend(backend) self.tri_mul_in.set_kernel_backend(backend) self.pair_transition.set_kernel_backend(backend) self._kernel_backend = backend self.row_drop = DropoutResidual( 0.0, batch_dim=1, use_fused_kernels=(backend == BACKEND_FUSED) ) def set_chunk_size(self, chunk_size: int | None) -> None: self.tri_mul_out.set_chunk_size(chunk_size) self.tri_mul_in.set_chunk_size(chunk_size) self.pair_transition.set_chunk_size(chunk_size) def _can_use_fused_trimul_with_residual(self, pair: Tensor) -> bool: return _fused_active(self, pair) and pair.dtype == torch.bfloat16 def _fused_trimul_with_residual( self, pair: Tensor, direction: str, pair_attention_mask: Tensor | None ) -> Tensor: """Fused TriMul+residual call; weights from the corresponding engine.""" tri = self.tri_mul_out if direction == "outgoing" else self.tri_mul_in engine: TriangleMultiplicativeBlock = tri._engine # type: ignore[assignment] p_in_weight, g_in_weight = engine.split_kernel_weights() def _bf16(t: Tensor) -> Tensor: return t if t.dtype == torch.bfloat16 else t.to(torch.bfloat16) return _fused_trimul_with_residual( # type: ignore[misc] pair, direction, residual=pair, drop_mask=None, # inference: no dropout, matches internal's eval path norm_in_weight=_bf16(engine.norm_start.weight), norm_in_bias=_bf16(engine.norm_start.bias), p_in_weight=_bf16(p_in_weight), g_in_weight=_bf16(g_in_weight), norm_out_weight=_bf16(engine.norm_mix.weight), norm_out_bias=_bf16(engine.norm_mix.bias), p_out_weight=_bf16(engine.proj_emit.weight), g_out_weight=_bf16(engine.proj_gate.weight), mask=pair_attention_mask, eps=_EPS, ) def forward( self, pair: Tensor, pair_attention_mask: Tensor | None = None ) -> Tensor: if self._can_use_fused_trimul_with_residual(pair): pair = self._fused_trimul_with_residual( pair, "outgoing", pair_attention_mask ) pair = self._fused_trimul_with_residual( pair, "incoming", pair_attention_mask ) else: pair = self.row_drop(pair, self.tri_mul_out(pair, mask=pair_attention_mask)) pair = self.row_drop(pair, self.tri_mul_in(pair, mask=pair_attention_mask)) pair = self.pair_transition(pair) return pair class FoldingTrunk(nn.Module): """ModuleList of PairUpdateBlocks.""" def __init__( self, n_layers: int = 24, d_pair: int = 256, expansion_ratio: int = 4 ) -> None: super().__init__() self.blocks = nn.ModuleList( [ PairUpdateBlock(d_pair=d_pair, expansion_ratio=expansion_ratio) for _ in range(n_layers) ] ) def set_kernel_backend(self, backend: str | None) -> None: for block in self.blocks: cast(PairUpdateBlock, block).set_kernel_backend(backend) def set_chunk_size(self, chunk_size: int | None) -> None: for block in self.blocks: cast(PairUpdateBlock, block).set_chunk_size(chunk_size) def forward( self, pair: Tensor, pair_attention_mask: Tensor | None = None ) -> Tensor: # Cast pair → bf16 internally when the fused trimul backend is enabled # (its bwd kernel requires bf16). Other backends keep the input dtype. orig_dtype = pair.dtype fused_on = ( len(self.blocks) > 0 and getattr(self.blocks[0], "_kernel_backend", None) == BACKEND_FUSED ) if pair.is_cuda and fused_on and orig_dtype != torch.bfloat16: pair = pair.to(torch.bfloat16) for block in self.blocks: fn = partial(block, pair_attention_mask=pair_attention_mask) if torch.is_grad_enabled(): pair = checkpoint(fn, pair, use_reentrant=False) # pyright: ignore else: pair = fn(pair) if pair.dtype != orig_dtype: pair = pair.to(orig_dtype) return pair # =========================================================================== # MSA Encoder # =========================================================================== class OuterProductMean(nn.Module): """Outer-product mean: maps an MSA representation into a pair update. The order of the ``/ n_valid`` divide vs. the ``Wout`` projection is selectable via ``divide_outer_before_proj`` because different ESMFold2 checkpoints were trained with different orderings: * ``False`` (default): ``Wout(outer) / n_valid`` — the projection bias is scaled by 1/n_valid alongside the outer product. * ``True``: ``Wout(outer / n_valid)`` — the projection bias is added unscaled, post-divide. """ def __init__( self, d_msa: int, d_hidden: int, d_pair: int, divide_outer_before_proj: bool = False, ) -> None: super().__init__() self.d_hidden = d_hidden self.divide_outer_before_proj = divide_outer_before_proj self.norm = nn.LayerNorm(d_msa) self.W = nn.Linear(d_msa, 2 * d_hidden, bias=False) self.Wout = nn.Linear(d_hidden * d_hidden, d_pair, bias=True) # Off for bit-exact bf16; ``set_chunk_size(64)`` for long sequences. self._chunk_size: int | None = None def set_chunk_size(self, chunk_size: int | None) -> None: self._chunk_size = chunk_size def forward(self, m: Tensor, msa_attention_mask: Tensor) -> Tensor: m_norm = self.norm(m) x = self.W(m_norm) * msa_attention_mask.unsqueeze(-1).to(m_norm.dtype) a, b = x.chunk(2, dim=-1) mask_f = msa_attention_mask.to(a.dtype) n_valid = (mask_f @ mask_f.transpose(-1, -2)).unsqueeze(-1).clamp(min=1.0) if self._chunk_size is None: outer = torch.einsum("bimc,bjmd->bijcd", a, b).flatten(-2) if self.divide_outer_before_proj: return self.Wout(outer / n_valid) return self.Wout(outer) / n_valid # Chunk along the left (i) axis so the peak einsum intermediate is # [B, chunk, L, c, d] instead of [B, L, L, c, d]. L = a.shape[1] out_chunks: list[Tensor] = [] for s in range(0, L, self._chunk_size): e = min(s + self._chunk_size, L) outer_chunk = torch.einsum("bimc,bjmd->bijcd", a[:, s:e], b).flatten(-2) if self.divide_outer_before_proj: out_chunks.append(self.Wout(outer_chunk / n_valid[:, s:e])) else: out_chunks.append(self.Wout(outer_chunk) / n_valid[:, s:e]) return torch.cat(out_chunks, dim=1) class MSAPairWeightedAveraging(nn.Module): """Pair-biased MSA row update (AF3 Supplement Algorithm 10).""" def __init__( self, d_msa: int, d_pair: int, n_heads: int = 8, head_width: int = 32 ) -> None: super().__init__() self.n_heads = n_heads self.head_width = head_width self.norm_single = nn.LayerNorm(d_msa) self.compute_bias = nn.Sequential( nn.LayerNorm(d_pair), nn.Linear(d_pair, n_heads, bias=False) ) self.Wv = nn.Linear(d_msa, n_heads * head_width, bias=False) self.Wgate = nn.Linear(d_msa, n_heads * head_width, bias=False) self.Wout = nn.Linear(n_heads * head_width, d_msa, bias=False) def forward( self, msa_repr: Tensor, pair_repr: Tensor, pair_attention_mask: Tensor ) -> Tensor: """ Args: msa_repr: [B, L, M, d_msa] pair_repr: [B, L, L, d_pair] pair_attention_mask:[B, L, L] Returns: [B, L, M, d_msa] """ B, L, M, _ = msa_repr.shape h, dh = self.n_heads, self.head_width msa_normed = self.norm_single(msa_repr) bias = self.compute_bias(pair_repr) # [B, L, L, n_heads] bias.masked_fill_(~pair_attention_mask.unsqueeze(-1).bool(), -1e5) attn = torch.softmax(bias, dim=-2) # softmax over j v = self.Wv(msa_normed).reshape(B, L, M, h, dh) gate = torch.sigmoid(self.Wgate(msa_normed)).reshape(B, L, M, h, dh) output = torch.einsum("bijh,bjmhd,bimhd->bimhd", attn, v, gate) return self.Wout(output.reshape(B, L, M, h * dh))