Feature Extraction
Transformers
Safetensors
esmfold2
biology
protein-structure
multimodal-protein-model
custom_code
Instructions to use Synthyra/ESMFold2-Fast with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Synthyra/ESMFold2-Fast with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Synthyra/ESMFold2-Fast", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Synthyra/ESMFold2-Fast", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # coding=utf-8 | |
| # Copyright 2026 Biohub. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| """Shared building blocks for ESMFold2 HuggingFace model variants.""" | |
| from __future__ import annotations | |
| import random | |
| import importlib | |
| from contextlib import contextmanager | |
| from functools import partial | |
| from typing import cast | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from torch.utils.checkpoint import checkpoint | |
| try: | |
| flash_attn_module = importlib.import_module("flash_attn") | |
| flash_bert_padding = importlib.import_module("flash_attn.bert_padding") | |
| flash_attn_func = flash_attn_module.flash_attn_func | |
| flash_attn_varlen_func = flash_attn_module.flash_attn_varlen_func | |
| index_first_axis = flash_bert_padding.index_first_axis | |
| pad_input = flash_bert_padding.pad_input | |
| FLASH_ATTN_AVAILABLE = True | |
| except ImportError: | |
| flash_attn_func = None # type: ignore[assignment] | |
| flash_attn_varlen_func = None # type: ignore[assignment] | |
| index_first_axis = None # type: ignore[assignment] | |
| pad_input = None # type: ignore[assignment] | |
| FLASH_ATTN_AVAILABLE = False | |
| try: | |
| cue_module = importlib.import_module("cuequivariance_torch") | |
| cue_triangle = importlib.import_module("cuequivariance_torch.primitives.triangle") | |
| _cue_attn_pair_bias = cue_module.attention_pair_bias | |
| _cue_tri_mul = cue_triangle.triangle_multiplicative_update | |
| CUE_AVAILABLE = True | |
| except ImportError: | |
| _cue_attn_pair_bias = None # type: ignore[assignment] | |
| _cue_tri_mul = None # type: ignore[assignment] | |
| CUE_AVAILABLE = False | |
| # The Biohub release includes optional Triton kernels. FastPLMs keeps the | |
| # reference path enabled by default so Hugging Face remote-code loading can stay | |
| # flat and self-contained. | |
| _fused_pair_bias = None | |
| _fused_trimul_with_residual = None | |
| _FusedLNLinearSwiGLU = None | |
| _FusedDropoutResidual = None | |
| TRITON_KERNELS_AVAILABLE = False | |
| from .configuration_esmfold2 import ESMFold2Config | |
| BACKEND_FUSED = "fused" | |
| BACKEND_CUEQ = "cuequivariance" | |
| _VALID_BACKENDS = (None, BACKEND_FUSED, BACKEND_CUEQ) | |
| def _fused_active(module: nn.Module, tensor: Tensor) -> bool: | |
| """Common preconditions for the vendored fused Triton inference kernels.""" | |
| return ( | |
| TRITON_KERNELS_AVAILABLE | |
| and getattr(module, "_kernel_backend", None) == BACKEND_FUSED | |
| and not torch.is_grad_enabled() | |
| and tensor.is_cuda | |
| ) | |
| def _cueq_active(module: nn.Module) -> bool: | |
| return CUE_AVAILABLE and getattr(module, "_kernel_backend", None) == BACKEND_CUEQ | |
| class DropoutResidual(nn.Module): | |
| """``residual + dropout(delta)`` with row/col-shared dropout. | |
| Same signature on both paths. ``use_fused_kernels=True`` + ``batch_dim=1`` | |
| routes through ``FusedDropoutResidual`` (single-pass over pair tensor, | |
| in-place residual add). Falls back to unfused otherwise. | |
| """ | |
| def __init__( | |
| self, r: float, batch_dim: int, use_fused_kernels: bool = False | |
| ) -> None: | |
| super().__init__() | |
| assert batch_dim in (1, 2), f"batch_dim must be 1 or 2, got {batch_dim}" | |
| self._use_fused_kernels = ( | |
| use_fused_kernels and batch_dim == 1 and _FusedDropoutResidual is not None | |
| ) | |
| self._batch_dim = batch_dim | |
| self._r = r | |
| if self._use_fused_kernels: | |
| assert _FusedDropoutResidual is not None | |
| self._impl: nn.Module = _FusedDropoutResidual(r) | |
| else: | |
| self._impl = nn.Dropout(r) | |
| def forward(self, residual: Tensor, delta: Tensor) -> Tensor: | |
| if self._use_fused_kernels: | |
| return self._impl(residual, delta) | |
| # Unfused: row/col-shared dropout via [1, ...] mask broadcast. | |
| if self._r == 0.0 or not self.training: | |
| return residual + delta | |
| shape = list(delta.shape) | |
| shape[self._batch_dim] = 1 | |
| mask = self._impl(delta.new_ones(shape)) | |
| return residual + delta * mask | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| CHAR_VOCAB_SIZE: int = 64 | |
| MAX_CHARS: int = 4 | |
| XYZ_DIMS: int = 3 | |
| MAX_ATOMIC_NUMBER: int = 128 | |
| # Input feature dim = 3 + 1 + 1 + 128 + 64*4 = 389 | |
| ATOM_FEATURE_DIM: int = ( | |
| XYZ_DIMS + 1 + 1 + MAX_ATOMIC_NUMBER + CHAR_VOCAB_SIZE * MAX_CHARS | |
| ) | |
| NUM_RES_TYPES: int = 33 | |
| _EPS = 1e-5 | |
| # Default for the triangle / OPM / pair-transition L² ops. Caps peak memory | |
| # so L≈2k folds on an 80 GB GPU (~76 GB peak at chunk=128 for L=1438; | |
| # chunk=64 leaves headroom for the largest foldbench targets). Override via | |
| # ``model.set_chunk_size(...)``; pass None to disable chunking (faster for | |
| # short L but OOM-prone past ~600). | |
| _DEFAULT_CHUNK_SIZE = 64 | |
| # =========================================================================== | |
| # MSA inference-time diversity augmentations | |
| # =========================================================================== | |
| def maybe_subsample_msa( | |
| msa: Tensor, | |
| msa_attention_mask: Tensor | None, | |
| has_deletion: Tensor | None, | |
| deletion_value: Tensor | None, | |
| *, | |
| max_depth: int | None, | |
| enabled: bool, | |
| ) -> tuple[Tensor, Tensor | None, Tensor | None, Tensor | None]: | |
| if not enabled or max_depth is None: | |
| return msa, msa_attention_mask, has_deletion, deletion_value | |
| depth = msa.size(1) | |
| if depth <= 1 or depth <= max_depth: | |
| return msa, msa_attention_mask, has_deletion, deletion_value | |
| indices = torch.zeros(max_depth, dtype=torch.long, device=msa.device) | |
| indices[1:] = torch.randperm(depth - 1, device=msa.device)[: max_depth - 1] + 1 | |
| indices = indices.sort().values | |
| msa = msa[:, indices] | |
| if msa_attention_mask is not None: | |
| msa_attention_mask = msa_attention_mask[:, indices] | |
| if has_deletion is not None: | |
| has_deletion = has_deletion[:, indices] | |
| if deletion_value is not None: | |
| deletion_value = deletion_value[:, indices] | |
| return msa, msa_attention_mask, has_deletion, deletion_value | |
| def maybe_apply_msa_column_masking( | |
| msa_attention_mask: Tensor | None, | |
| rate: float, | |
| ) -> Tensor | None: | |
| if msa_attention_mask is None or rate <= 0.0 or msa_attention_mask.size(1) <= 1: | |
| return msa_attention_mask | |
| batch_size, _, length = msa_attention_mask.shape | |
| col_keep = torch.rand(batch_size, length, device=msa_attention_mask.device) >= rate | |
| col_keep = col_keep.unsqueeze(1).expand_as(msa_attention_mask).clone() | |
| col_keep[:, 0, :] = True | |
| return msa_attention_mask.bool() & col_keep | |
| # =========================================================================== | |
| # Atom-token utilities | |
| # =========================================================================== | |
| def gather_token_to_atom(token_features: Tensor, atom_to_token_idx: Tensor) -> Tensor: | |
| """Broadcast per-token features to per-atom features using gather. | |
| Args: | |
| token_features: [B, L, d] | |
| atom_to_token_idx: [B, A] int64 | |
| Returns: | |
| [B, A, d] | |
| """ | |
| idx = atom_to_token_idx.unsqueeze(-1).expand(-1, -1, token_features.size(-1)) | |
| return torch.gather(token_features, 1, idx) | |
| def scatter_atom_to_token( | |
| atom_features: Tensor, | |
| atom_to_token_idx: Tensor, | |
| n_tokens: int, | |
| atom_mask: Tensor | None = None, | |
| ) -> Tensor: | |
| """Aggregate per-atom features to per-token features (mean). | |
| Args: | |
| atom_features: [B, A, d] | |
| atom_to_token_idx: [B, A] int64 | |
| n_tokens: L | |
| atom_mask: [B, A] bool | |
| Returns: | |
| [B, L, d] | |
| """ | |
| B, A, d = atom_features.shape | |
| n_out = n_tokens | |
| idx = atom_to_token_idx | |
| if atom_mask is not None: | |
| idx = torch.where(atom_mask, atom_to_token_idx, n_tokens) | |
| n_out = n_tokens + 1 | |
| idx_expanded = idx.unsqueeze(-1).expand(B, A, d) | |
| out = torch.zeros( | |
| B, n_out, d, device=atom_features.device, dtype=atom_features.dtype | |
| ) | |
| out.scatter_reduce_( | |
| 1, idx_expanded, atom_features, reduce="mean", include_self=False | |
| ) | |
| return out[:, :n_tokens, :] | |
| def gather_rep_atom_coords(coords: Tensor, rep_atom_idx: Tensor) -> Tensor: | |
| """Gather representative atom coordinates for each token. | |
| Args: | |
| coords: [B, A, 3] | |
| rep_atom_idx: [B, L] int64 | |
| Returns: | |
| [B, L, 3] | |
| """ | |
| idx = rep_atom_idx.unsqueeze(-1).expand(-1, -1, coords.size(-1)) | |
| return torch.gather(coords, 1, idx) | |
| def _compute_intra_token_idx(atom_to_token: Tensor) -> Tensor: | |
| """Compute local atom index within each token (vectorised). | |
| Atoms belonging to the same token are contiguous, so this computes a | |
| running count that resets at each token boundary. | |
| Args: | |
| atom_to_token: [B, A] flat index mapping each atom to its token. | |
| Returns: | |
| [B, A] tensor with values in [0, max_atoms_per_token - 1]. | |
| """ | |
| same_as_prev = F.pad( | |
| atom_to_token[:, 1:] == atom_to_token[:, :-1], (1, 0), value=False | |
| ) | |
| ones = torch.ones_like(atom_to_token) | |
| cumsum = torch.cumsum(ones, dim=-1) | |
| group_start = cumsum.masked_fill(same_as_prev, 0) | |
| group_start = torch.cummax(group_start, dim=-1).values | |
| return cumsum - group_start | |
| def _categorical_mean(logits: Tensor, start: float, end: float) -> Tensor: | |
| """Expected value of a categorical distribution over evenly-spaced bins. | |
| Equivalent to ``CategoricalMixture(logits, bins=logits.shape[-1], start, end).mean()``. | |
| Args: | |
| logits: [..., n_bins] | |
| start: left boundary | |
| end: right boundary | |
| Returns: | |
| [...] expected value | |
| """ | |
| n_bins = logits.shape[-1] | |
| edges = torch.linspace( | |
| start, end, n_bins + 1, device=logits.device, dtype=torch.float32 | |
| ) | |
| v_bins = (edges[:-1] + edges[1:]) / 2 # [n_bins] | |
| return (logits.float().softmax(-1) @ v_bins.unsqueeze(1)).squeeze(-1) | |
| # =========================================================================== | |
| # TransitionLayer (used in DiffusionConditioning) | |
| # =========================================================================== | |
| class TransitionLayer(nn.Module): | |
| """SwiGLU transition: norm -> a_proj, b_proj -> silu(a)*b -> out_proj.""" | |
| def __init__(self, d_model: int, n: int, eps: float = 1e-5) -> None: | |
| super().__init__() | |
| hidden = n * d_model | |
| self.norm = nn.LayerNorm(d_model, eps=eps) | |
| self.a_proj = nn.Linear(d_model, hidden, bias=False) | |
| self.b_proj = nn.Linear(d_model, hidden, bias=False) | |
| self.out_proj = nn.Linear(hidden, d_model, bias=False) | |
| def forward(self, x: Tensor) -> Tensor: | |
| x = self.norm(x) | |
| a = self.a_proj(x) | |
| b = self.b_proj(x) | |
| return self.out_proj(F.silu(a) * b) | |
| # =========================================================================== | |
| # AdaptiveLayerNorm (used in DiffusionTransformer) | |
| # =========================================================================== | |
| class AdaptiveLayerNorm(nn.Module): | |
| """Adaptive layer normalization (adaLN-Zero).""" | |
| def __init__(self, d_model: int, d_cond: int, eps: float = 1e-5) -> None: | |
| super().__init__() | |
| self.d_model = d_model | |
| self.d_cond = d_cond | |
| self.eps = eps | |
| self.s_scale = nn.Parameter(torch.ones(d_cond)) | |
| self.s_gate = nn.Linear(d_cond, d_model, bias=True) | |
| self.s_shift = nn.Linear(d_cond, d_model, bias=False) | |
| def forward(self, a: Tensor, s: Tensor) -> Tensor: | |
| a_norm = F.layer_norm(a, (self.d_model,), None, None, self.eps) | |
| s_norm = F.layer_norm(s, (self.d_cond,), self.s_scale, None, self.eps) | |
| return torch.sigmoid(self.s_gate(s_norm)) * a_norm + self.s_shift(s_norm) | |
| # =========================================================================== | |
| # FourierEmbedding | |
| # =========================================================================== | |
| class FourierEmbedding(nn.Module): | |
| """Fourier embedding: cos(2*pi*(t*w + b)).""" | |
| w: Tensor | |
| b: Tensor | |
| def __init__(self, c: int) -> None: | |
| super().__init__() | |
| self.c = c | |
| self.register_buffer("w", torch.randn(c)) | |
| self.register_buffer("b", torch.randn(c)) | |
| def forward(self, t_hat: Tensor) -> Tensor: | |
| t = torch.as_tensor(t_hat, device=self.w.device, dtype=self.w.dtype).reshape(-1) | |
| return torch.cos( | |
| 2.0 * torch.pi * (t[:, None] * self.w[None, :] + self.b[None, :]) | |
| ) | |
| # =========================================================================== | |
| # SwiGLU / SwiGLUMLP | |
| # =========================================================================== | |
| def _compute_swiglu_hidden_size(d_model: int, expansion_ratio: int) -> int: | |
| return expansion_ratio * d_model | |
| class SwiGLU(nn.Module): | |
| """SwiGLU with packed w12 and output w3.""" | |
| def __init__( | |
| self, | |
| in_features: int, | |
| hidden_features: int, | |
| out_features: int | None = None, | |
| bias: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| out_features = out_features or in_features | |
| self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) | |
| self.w3 = nn.Linear(hidden_features, out_features, bias=bias) | |
| self.hidden_features = hidden_features | |
| def forward(self, x: Tensor) -> Tensor: | |
| x12 = self.w12(x) | |
| x1, x2 = x12.split(self.hidden_features, dim=-1) | |
| hidden = F.silu(x1) * x2 | |
| return self.w3(hidden) | |
| class SwiGLUMLP(SwiGLU): | |
| """SwiGLU MLP with packed weights, no bias.""" | |
| def __init__( | |
| self, d_model: int, expansion_ratio: int = 4, bias: bool = False | |
| ) -> None: | |
| hidden = _compute_swiglu_hidden_size(d_model, expansion_ratio) | |
| super().__init__( | |
| in_features=d_model, hidden_features=hidden, out_features=d_model, bias=bias | |
| ) | |
| # =========================================================================== | |
| # SWA Atom Attention components | |
| # =========================================================================== | |
| def _rotate_half(x: Tensor) -> Tensor: | |
| x1, x2 = x.chunk(2, dim=-1) | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_emb_3d(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: | |
| """Apply RoPE with batch-dependent cos/sin. | |
| Args: | |
| x: [B, L, H, D] | |
| cos: [B, L, D/2] | |
| sin: [B, L, D/2] | |
| """ | |
| ro_dim = cos.shape[-1] * 2 | |
| cos = cos.unsqueeze(2).repeat(1, 1, 1, 2) | |
| sin = sin.unsqueeze(2).repeat(1, 1, 1, 2) | |
| return torch.cat( | |
| [x[..., :ro_dim] * cos + _rotate_half(x[..., :ro_dim]) * sin, x[..., ro_dim:]], | |
| dim=-1, | |
| ) | |
| def build_3d_rope( | |
| ref_pos: Tensor, | |
| ref_space_uid: Tensor, | |
| head_dim: int, | |
| n_spatial_per_axis: int = 4, | |
| n_uid_pairs: int = 2, | |
| spatial_base_freq: float = 10000.0, | |
| uid_base_freq: float = 10.0, | |
| ) -> tuple[Tensor, Tensor]: | |
| """Build cos/sin for 3D RoPE + UID RoPE.""" | |
| device = ref_pos.device | |
| B, N = ref_pos.shape[:2] | |
| half_dim = head_dim // 2 | |
| n_spatial_total = 3 * n_spatial_per_axis | |
| spatial_inv_freq = 1.0 / ( | |
| spatial_base_freq | |
| ** ( | |
| torch.arange(0, n_spatial_per_axis, dtype=torch.float32, device=device) | |
| / n_spatial_per_axis | |
| ) | |
| ) | |
| uid_inv_freq = 1.0 / ( | |
| uid_base_freq | |
| ** ( | |
| torch.arange(0, n_uid_pairs, dtype=torch.float32, device=device) | |
| / n_uid_pairs | |
| ) | |
| ) | |
| pos_f32 = ref_pos.float() | |
| spatial_freqs = torch.einsum("bna,k->bnak", pos_f32, spatial_inv_freq) | |
| spatial_freqs = spatial_freqs.reshape(B, N, n_spatial_total) | |
| uid_f32 = ref_space_uid.float() | |
| uid_freqs = torch.einsum("bn,k->bnk", uid_f32, uid_inv_freq) | |
| n_active = n_spatial_total + n_uid_pairs | |
| freqs = torch.cat([spatial_freqs, uid_freqs], dim=-1) | |
| if n_active < half_dim: | |
| padding = torch.zeros( | |
| B, N, half_dim - n_active, device=device, dtype=torch.float32 | |
| ) | |
| freqs = torch.cat([freqs, padding], dim=-1) | |
| cos = freqs.cos().to(torch.bfloat16) | |
| sin = freqs.sin().to(torch.bfloat16) | |
| return cos, sin | |
| def qk_norm(x: Tensor) -> Tensor: | |
| return F.rms_norm(x, (x.size(-1),)).to(x.dtype) | |
| # =========================================================================== | |
| # SwiGLUFFN (atom transformer blocks) | |
| # =========================================================================== | |
| class SwiGLUFFN(nn.Module): | |
| """SwiGLU FFN with rounded hidden size for hardware alignment.""" | |
| def __init__(self, d_model: int, expansion_ratio: int = 2) -> None: | |
| super().__init__() | |
| hidden_size = ((expansion_ratio * (d_model // 3) * 2) + 255) // 256 * 256 | |
| self.w_up = nn.Linear(d_model, 2 * hidden_size, bias=False) | |
| self.w_down = nn.Linear(hidden_size, d_model, bias=False) | |
| def forward(self, x: Tensor) -> Tensor: | |
| x = x.to(self.w_up.weight.dtype) | |
| x1, x2 = self.w_up(x).chunk(2, dim=-1) | |
| return self.w_down(F.silu(x1) * x2) | |
| # =========================================================================== | |
| # SWA3DRoPEAttention | |
| # =========================================================================== | |
| class SWA3DRoPEAttention(nn.Module): | |
| """Sliding window attention with 3D RoPE. Has Wqkv, gate_proj, out_proj.""" | |
| def __init__(self, d_model: int, n_heads: int, half_window: int = 64) -> None: | |
| super().__init__() | |
| self.n_heads = n_heads | |
| self.head_dim = d_model // n_heads | |
| self.scale = self.head_dim**-0.5 | |
| self.half_window = half_window | |
| self.Wqkv = nn.Linear(d_model, 3 * d_model, bias=False) | |
| self.out_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.gate_proj = nn.Linear(d_model, d_model, bias=False) | |
| def forward(self, x: Tensor, attention_params: tuple) -> Tensor: | |
| B, N = x.shape[:2] | |
| cos, sin = attention_params[0], attention_params[1] | |
| x_input = x | |
| qkv = self.Wqkv(x) | |
| qkv = qkv.view(B, N, 3, self.n_heads, self.head_dim).permute(2, 0, 1, 3, 4) | |
| q, k, v = qkv.unbind(0) | |
| q, k = qk_norm(q), qk_norm(k) | |
| q = apply_rotary_emb_3d(q, cos, sin) | |
| k = apply_rotary_emb_3d(k, cos, sin) | |
| input_dtype = q.dtype | |
| if q.dtype not in (torch.float16, torch.bfloat16): | |
| q, k, v = q.bfloat16(), k.bfloat16(), v.bfloat16() | |
| if len(attention_params) > 2 and FLASH_ATTN_AVAILABLE: | |
| indices, cu_seqlens, max_seqlen = ( | |
| attention_params[2], | |
| attention_params[3], | |
| attention_params[4], | |
| ) | |
| q_unpad = index_first_axis( # type: ignore[misc] | |
| q.reshape(-1, self.n_heads, self.head_dim), indices | |
| ) | |
| k_unpad = index_first_axis( # type: ignore[misc] | |
| k.reshape(-1, self.n_heads, self.head_dim), indices | |
| ) | |
| v_unpad = index_first_axis( # type: ignore[misc] | |
| v.reshape(-1, self.n_heads, self.head_dim), indices | |
| ) | |
| out_unpad = flash_attn_varlen_func( # type: ignore[misc] | |
| q_unpad, | |
| k_unpad, | |
| v_unpad, | |
| cu_seqlens, | |
| cu_seqlens, | |
| max_seqlen, | |
| max_seqlen, | |
| softmax_scale=self.scale, | |
| window_size=(self.half_window, self.half_window), | |
| ) | |
| out = pad_input(out_unpad, indices, B, N) # type: ignore[misc] | |
| elif FLASH_ATTN_AVAILABLE: | |
| out = flash_attn_func( # type: ignore[misc] | |
| q, | |
| k, | |
| v, | |
| softmax_scale=self.scale, | |
| window_size=(self.half_window, self.half_window), | |
| ) | |
| else: | |
| # Fallback: standard attention (no SWA) | |
| q_t = q.transpose(1, 2) | |
| k_t = k.transpose(1, 2) | |
| v_t = v.transpose(1, 2) | |
| attn = torch.matmul(q_t, k_t.transpose(-2, -1)) * self.scale | |
| attn = F.softmax(attn, dim=-1) | |
| out = torch.matmul(attn, v_t).transpose(1, 2) | |
| out = out.to(input_dtype).reshape(B, N, -1) # type: ignore[union-attr] | |
| out = out * torch.sigmoid(self.gate_proj(x_input)) | |
| return self.out_proj(out) | |
| # =========================================================================== | |
| # SWAAtomBlock, SWAAtomTransformer | |
| # =========================================================================== | |
| def _rms_adaln_raw(x: Tensor, scale: Tensor, shift: Tensor) -> Tensor: | |
| return F.rms_norm(x, (x.shape[-1],)) * (1 + scale) + shift | |
| def _gated_residual_raw(x: Tensor, gate: Tensor, y: Tensor) -> Tensor: | |
| return x + gate * y | |
| class SWAAtomBlock(nn.Module): | |
| """adaLN-Zero + SWA attention + SwiGLU FFN. | |
| Creates adaln_modulation = Sequential(SiLU(), Linear) -> keys like adaln_modulation.1.weight | |
| """ | |
| def __init__( | |
| self, | |
| d_atom: int, | |
| n_heads: int, | |
| half_window: int = 64, | |
| expansion_ratio: int = 2, | |
| use_compile_fusions: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| self.attn_norm = nn.RMSNorm(d_atom, elementwise_affine=False) | |
| self.ffn_norm = nn.RMSNorm(d_atom, elementwise_affine=False) | |
| adaln_linear = nn.Linear(d_atom, 6 * d_atom, bias=False) | |
| nn.init.zeros_(adaln_linear.weight) | |
| self.adaln_modulation = nn.Sequential(nn.SiLU(), adaln_linear) | |
| self.attn = SWA3DRoPEAttention(d_atom, n_heads, half_window=half_window) | |
| self.ffn = SwiGLUFFN(d_atom, expansion_ratio) | |
| self._rms_adaln = ( | |
| torch.compile(_rms_adaln_raw) if use_compile_fusions else _rms_adaln_raw | |
| ) | |
| self._gated_residual = ( | |
| torch.compile(_gated_residual_raw) | |
| if use_compile_fusions | |
| else _gated_residual_raw | |
| ) | |
| def forward(self, x: Tensor, c_l: Tensor, attention_params: tuple) -> Tensor: | |
| mod = self.adaln_modulation(c_l) | |
| if mod.dim() == 2: | |
| mod = mod.unsqueeze(1) | |
| shift_a, scale_a, gate_a, shift_f, scale_f, gate_f = mod.chunk(6, dim=-1) | |
| attn_input = self._rms_adaln(x, scale_a, shift_a) | |
| attn_out = self.attn(attn_input, attention_params) | |
| x = self._gated_residual(x, gate_a, attn_out) | |
| ffn_input = self._rms_adaln(x, scale_f, shift_f) | |
| ffn_out = self.ffn(ffn_input) | |
| x = self._gated_residual(x, gate_f, ffn_out) | |
| return x | |
| class SWAAtomTransformer(nn.Module): | |
| """Stack of SWAAtomBlocks.""" | |
| def __init__( | |
| self, | |
| d_atom: int = 128, | |
| n_blocks: int = 3, | |
| n_heads: int = 4, | |
| swa_window_size: int = 128, | |
| expansion_ratio: int = 2, | |
| spatial_rope_base_frequency: float = 20.0, | |
| n_spatial_rope_pairs_per_axis: int = 2, | |
| n_uid_rope_pairs: int = 10, | |
| uid_rope_base_frequency: float = 10000.0, | |
| ) -> None: | |
| super().__init__() | |
| self.swa_window_size = swa_window_size | |
| self.head_dim = d_atom // n_heads | |
| self.spatial_rope_base_frequency = spatial_rope_base_frequency | |
| self.n_spatial_rope_pairs_per_axis = n_spatial_rope_pairs_per_axis | |
| self.n_uid_rope_pairs = n_uid_rope_pairs | |
| self.uid_rope_base_frequency = uid_rope_base_frequency | |
| self.blocks = nn.ModuleList( | |
| [ | |
| SWAAtomBlock( | |
| d_atom=d_atom, | |
| n_heads=n_heads, | |
| half_window=swa_window_size // 2, | |
| expansion_ratio=expansion_ratio, | |
| ) | |
| for _ in range(n_blocks) | |
| ] | |
| ) | |
| def _build_3d_rope( | |
| self, ref_pos: Tensor, ref_space_uid: Tensor | |
| ) -> tuple[Tensor, Tensor]: | |
| return build_3d_rope( | |
| ref_pos=ref_pos, | |
| ref_space_uid=ref_space_uid, | |
| head_dim=self.head_dim, | |
| n_spatial_per_axis=self.n_spatial_rope_pairs_per_axis, | |
| n_uid_pairs=self.n_uid_rope_pairs, | |
| spatial_base_freq=self.spatial_rope_base_frequency, | |
| uid_base_freq=self.uid_rope_base_frequency, | |
| ) | |
| def forward( | |
| self, | |
| q_l: Tensor, | |
| c_l: Tensor, | |
| attention_params: tuple, | |
| return_intermediates: bool = False, | |
| ) -> Tensor | tuple[Tensor, list[Tensor]]: | |
| intermediates: list[Tensor] = [] | |
| for block in self.blocks: | |
| q_l = block(q_l, c_l, attention_params) | |
| if return_intermediates: | |
| intermediates.append(q_l) | |
| if return_intermediates: | |
| return q_l, intermediates | |
| return q_l | |
| # =========================================================================== | |
| # ESMFold2AtomEncoder (for both inputs_embedder and diffusion_module) | |
| # =========================================================================== | |
| class ESMFold2AtomEncoder(nn.Module): | |
| """SWA atom encoder with atom_linear, atom_norm, atom_to_token_linear, [coords_linear], atom_transformer. | |
| Args: | |
| d_atom: atom hidden dim | |
| d_token: token dim for atom_to_token aggregation | |
| n_blocks, n_heads, swa_window_size, expansion_ratio: transformer params | |
| structure_prediction: if True, creates coords_linear and uses full d_token | |
| spatial_rope_base_frequency, n_spatial_rope_pairs_per_axis, | |
| n_uid_rope_pairs, uid_rope_base_frequency: 3D RoPE config | |
| """ | |
| def __init__( | |
| self, | |
| d_atom: int = 128, | |
| d_token: int = 768, | |
| n_blocks: int = 3, | |
| n_heads: int = 4, | |
| swa_window_size: int = 128, | |
| expansion_ratio: int = 2, | |
| structure_prediction: bool = True, | |
| spatial_rope_base_frequency: float = 20.0, | |
| n_spatial_rope_pairs_per_axis: int = 2, | |
| n_uid_rope_pairs: int = 10, | |
| uid_rope_base_frequency: float = 10000.0, | |
| ) -> None: | |
| super().__init__() | |
| self.d_atom = d_atom | |
| self.d_token = d_token | |
| self.structure_prediction = structure_prediction | |
| self.atom_linear = nn.Linear(ATOM_FEATURE_DIM, d_atom, bias=False) | |
| self.atom_norm = nn.LayerNorm(d_atom) | |
| if structure_prediction: | |
| self.coords_linear = nn.Linear(6, d_atom, bias=False) | |
| self.atom_transformer = SWAAtomTransformer( | |
| d_atom=d_atom, | |
| n_blocks=n_blocks, | |
| n_heads=n_heads, | |
| swa_window_size=swa_window_size, | |
| expansion_ratio=expansion_ratio, | |
| spatial_rope_base_frequency=spatial_rope_base_frequency, | |
| n_spatial_rope_pairs_per_axis=n_spatial_rope_pairs_per_axis, | |
| n_uid_rope_pairs=n_uid_rope_pairs, | |
| uid_rope_base_frequency=uid_rope_base_frequency, | |
| ) | |
| # Output aggregation: d_token for structure prediction, d_token//2 for inputs | |
| out_dim = d_token if structure_prediction else d_token // 2 | |
| self.atom_to_token_linear = nn.Linear(d_atom, out_dim, bias=False) | |
| def forward( | |
| self, | |
| ref_pos: Tensor, | |
| atom_attention_mask: Tensor, | |
| ref_space_uid: Tensor, | |
| ref_charge: Tensor, | |
| ref_element: Tensor, | |
| ref_atom_name_chars: Tensor, | |
| atom_to_token: Tensor, | |
| r_l: Tensor | None = None, | |
| pred_r1: Tensor | None = None, | |
| s_i: Tensor | None = None, | |
| z_ij: Tensor | None = None, | |
| num_diffusion_samples: int = 1, | |
| return_intermediates: bool = False, | |
| inference_cache: dict | None = None, | |
| ) -> tuple[Tensor, Tensor, Tensor, tuple, list[Tensor]]: | |
| """Returns (a, q, c, attention_params, intermediates). | |
| ``inference_cache`` caches step-invariant tensors (c_base, 3D RoPE, | |
| attention indices, n_tokens) across diffusion steps. | |
| """ | |
| B, N = ref_pos.shape[:2] | |
| layer_cache = None | |
| if inference_cache is not None: | |
| layer_cache = inference_cache.setdefault("atomencoder", {}) | |
| if layer_cache is None or len(layer_cache) == 0: | |
| atom_feats = torch.cat( | |
| [ | |
| ref_pos, | |
| ref_charge.unsqueeze(-1), | |
| atom_attention_mask.unsqueeze(-1), | |
| ref_element, | |
| ref_atom_name_chars.reshape(B, N, MAX_CHARS * CHAR_VOCAB_SIZE), | |
| ], | |
| dim=-1, | |
| ) | |
| c_base = self.atom_norm(self.atom_linear(atom_feats)) | |
| cos, sin = self.atom_transformer._build_3d_rope(ref_pos, ref_space_uid) | |
| cos = cos.repeat_interleave(num_diffusion_samples, 0) | |
| sin = sin.repeat_interleave(num_diffusion_samples, 0) | |
| mask_exp = atom_attention_mask.repeat_interleave(num_diffusion_samples, 0) | |
| seqlens = mask_exp.sum(dim=-1, dtype=torch.int32) | |
| indices = torch.nonzero(mask_exp.flatten(), as_tuple=False).flatten() | |
| max_seqlen = int(seqlens.max().item()) | |
| cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) | |
| attention_params = (cos, sin, indices, cu_seqlens, max_seqlen) | |
| n_tokens = int(atom_to_token.max().item()) + 1 | |
| if layer_cache is not None: | |
| layer_cache["c_base"] = c_base | |
| layer_cache["attention_params"] = attention_params | |
| layer_cache["mask_exp"] = mask_exp | |
| layer_cache["n_tokens"] = n_tokens | |
| layer_cache["atom_to_token_exp"] = atom_to_token.repeat_interleave( | |
| num_diffusion_samples, 0 | |
| ) | |
| else: | |
| c_base = layer_cache["c_base"] | |
| attention_params = layer_cache["attention_params"] | |
| mask_exp = layer_cache["mask_exp"] | |
| n_tokens = layer_cache["n_tokens"] | |
| c = c_base | |
| q = c | |
| if self.structure_prediction and r_l is not None: | |
| q = q.repeat_interleave(num_diffusion_samples, 0) | |
| if pred_r1 is None: | |
| pred_r1 = torch.zeros_like(r_l) | |
| r_input = torch.cat([r_l, pred_r1], dim=-1) | |
| r_to_q = self.coords_linear(r_input) | |
| q = q + r_to_q | |
| c = c.repeat_interleave(num_diffusion_samples, 0) | |
| result = self.atom_transformer( | |
| q_l=q, | |
| c_l=c, | |
| attention_params=attention_params, | |
| return_intermediates=return_intermediates, | |
| ) | |
| if return_intermediates: | |
| q, intermediates = result | |
| else: | |
| q = result | |
| intermediates = [] | |
| q_to_a = F.relu(self.atom_to_token_linear(q)) | |
| if layer_cache is not None and "atom_to_token_exp" in layer_cache: | |
| atom_to_token_exp = layer_cache["atom_to_token_exp"] | |
| else: | |
| atom_to_token_exp = atom_to_token.repeat_interleave( | |
| num_diffusion_samples, 0 | |
| ) | |
| a = scatter_atom_to_token( | |
| q_to_a, atom_to_token_exp, n_tokens, atom_mask=mask_exp.bool() | |
| ) | |
| return a, q, c, attention_params, intermediates | |
| # =========================================================================== | |
| # ESMFold2AtomDecoder | |
| # =========================================================================== | |
| class ESMFold2AtomDecoder(nn.Module): | |
| """SWA atom decoder with token_to_atom_linear, atom_transformer, norm, output_linear.""" | |
| def __init__( | |
| self, | |
| d_atom: int = 128, | |
| d_token: int = 768, | |
| n_blocks: int = 3, | |
| n_heads: int = 4, | |
| swa_window_size: int = 128, | |
| expansion_ratio: int = 2, | |
| spatial_rope_base_frequency: float = 20.0, | |
| n_spatial_rope_pairs_per_axis: int = 2, | |
| n_uid_rope_pairs: int = 10, | |
| uid_rope_base_frequency: float = 10000.0, | |
| ) -> None: | |
| super().__init__() | |
| self.token_to_atom_linear = nn.Linear(d_token, d_atom, bias=False) | |
| self.atom_transformer = SWAAtomTransformer( | |
| d_atom=d_atom, | |
| n_blocks=n_blocks, | |
| n_heads=n_heads, | |
| swa_window_size=swa_window_size, | |
| expansion_ratio=expansion_ratio, | |
| spatial_rope_base_frequency=spatial_rope_base_frequency, | |
| n_spatial_rope_pairs_per_axis=n_spatial_rope_pairs_per_axis, | |
| n_uid_rope_pairs=n_uid_rope_pairs, | |
| uid_rope_base_frequency=uid_rope_base_frequency, | |
| ) | |
| self.norm = nn.LayerNorm(d_atom) | |
| self.output_linear = nn.Linear(d_atom, XYZ_DIMS, bias=False) | |
| def forward( | |
| self, | |
| a_i: Tensor, | |
| q_l: Tensor, | |
| c_l: Tensor, | |
| p_lm: tuple, | |
| atom_to_token: Tensor, | |
| atom_attention_mask: Tensor, | |
| num_diffusion_samples: int = 1, | |
| return_intermediates: bool = False, | |
| ) -> tuple[Tensor, list[Tensor]]: | |
| """Returns (r_update, intermediates).""" | |
| atom_to_token_exp = atom_to_token.repeat_interleave(num_diffusion_samples, 0) | |
| a_to_q = self.token_to_atom_linear(a_i) | |
| a_to_q = gather_token_to_atom(a_to_q, atom_to_token_exp) | |
| q_l = q_l + a_to_q | |
| result = self.atom_transformer( | |
| q_l=q_l, | |
| c_l=c_l, | |
| attention_params=p_lm, | |
| return_intermediates=return_intermediates, | |
| ) | |
| if return_intermediates: | |
| q_l, intermediates = result | |
| else: | |
| q_l = result | |
| intermediates = [] | |
| r_l = self.output_linear(self.norm(q_l)) | |
| return r_l, intermediates | |
| # =========================================================================== | |
| # AttentionPairBias (DiffusionTransformer attention block) | |
| # =========================================================================== | |
| class AttentionPairBias(nn.Module): | |
| """Gated multi-head attention with pair bias conditioning.""" | |
| def __init__( | |
| self, | |
| d_model: int, | |
| d_pair: int, | |
| num_heads: int, | |
| d_cond: int | None = None, | |
| use_conditioning: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| self.d_model = d_model | |
| self.num_heads = num_heads | |
| self.head_dim = d_model // num_heads | |
| self.scale = self.head_dim**-0.5 | |
| d_cond = d_cond or d_model | |
| if use_conditioning: | |
| self.adaln = AdaptiveLayerNorm(d_model, d_cond, eps=1e-5) | |
| self.out_gate = nn.Linear(d_cond, d_model, bias=True) | |
| # adaln init: weight=0, bias=-2 | |
| nn.init.zeros_(self.out_gate.weight) | |
| nn.init.constant_(self.out_gate.bias, -2.0) | |
| else: | |
| self.pre_norm = nn.LayerNorm(d_model, eps=1e-5) | |
| self.q_proj = nn.Linear(d_model, d_model, bias=True) | |
| self.kv_proj = nn.Linear(d_model, 2 * d_model, bias=False) | |
| self.g_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.out_proj = nn.Linear(d_model, d_model, bias=False) | |
| if d_pair > 0: | |
| self.pair_norm = nn.LayerNorm(d_pair, eps=1e-5) | |
| self.pair_bias_proj = nn.Linear(d_pair, num_heads, bias=False) | |
| self._kernel_backend: str | None = None | |
| def set_kernel_backend(self, backend: str | None) -> None: | |
| if backend not in _VALID_BACKENDS: | |
| raise ValueError( | |
| f"backend must be one of {_VALID_BACKENDS}, got {backend!r}" | |
| ) | |
| self._kernel_backend = backend | |
| def _is_zero_beta(self, beta: Tensor | float) -> bool: | |
| if isinstance(beta, (int, float)): | |
| return beta == 0.0 | |
| return bool((beta == 0).all()) | |
| def _can_use_fused_pair_bias( | |
| self, z: Tensor, n_queries: int, beta: Tensor | float | |
| ) -> bool: | |
| return ( | |
| _fused_active(self, z) | |
| and z.dim() == 4 | |
| and self._is_zero_beta(beta) | |
| and hasattr(self, "pair_bias_proj") | |
| and hasattr(self, "pair_norm") | |
| ) | |
| def _can_use_cueq_pair_bias( | |
| self, z: Tensor, n_queries: int, beta: Tensor | float | |
| ) -> bool: | |
| return ( | |
| _cueq_active(self) | |
| and n_queries > 750 | |
| and z.dim() == 4 | |
| and self._is_zero_beta(beta) | |
| and hasattr(self, "pair_bias_proj") | |
| ) | |
| def forward( | |
| self, | |
| a: Tensor, | |
| s: Tensor | None, | |
| z: Tensor, | |
| beta: Tensor | float = 0.0, | |
| attention_mask: Tensor | None = None, | |
| num_diffusion_samples: int = 1, | |
| ) -> Tensor: | |
| bsz, n_queries, d_model = a.shape | |
| if s is not None: | |
| x = self.adaln(a, s) | |
| else: | |
| x = self.pre_norm(a) | |
| n_keys = x.shape[1] | |
| q = self.q_proj(x).view(bsz, n_queries, self.num_heads, self.head_dim) | |
| kv = self.kv_proj(x) | |
| k, v = kv.chunk(2, dim=-1) | |
| k = k.view(bsz, n_keys, self.num_heads, self.head_dim) | |
| v = v.view(bsz, n_keys, self.num_heads, self.head_dim) | |
| # Expand z for num_diffusion_samples | |
| if z.dim() == 4 and z.shape[0] != bsz and num_diffusion_samples > 1: | |
| z = z.repeat_interleave(num_diffusion_samples, dim=0) | |
| if ( | |
| attention_mask is not None | |
| and attention_mask.shape[0] != bsz | |
| and num_diffusion_samples > 1 | |
| ): | |
| attention_mask = attention_mask.repeat_interleave( | |
| num_diffusion_samples, dim=0 | |
| ) | |
| if self._can_use_fused_pair_bias(z, n_queries, beta): | |
| kernel_mask = ( | |
| attention_mask | |
| if attention_mask is not None | |
| else torch.ones(bsz, n_queries, device=a.device, dtype=torch.bool) | |
| ) | |
| pair_norm_w = self.pair_norm.weight | |
| pair_norm_b = ( | |
| self.pair_norm.bias | |
| if self.pair_norm.bias is not None | |
| else torch.zeros_like(pair_norm_w) | |
| ) | |
| z_bf = z if z.dtype == torch.bfloat16 else z.to(torch.bfloat16) | |
| bias = _fused_pair_bias( # type: ignore[misc] | |
| z_bf, | |
| kernel_mask, | |
| self.pair_bias_proj.weight, | |
| num_heads=self.num_heads, | |
| pair_norm_w=pair_norm_w, | |
| pair_norm_b=pair_norm_b, | |
| ) # (B, H, Q, K) | |
| q_bhqd = q.transpose(1, 2) | |
| k_bhqd = k.transpose(1, 2) | |
| v_bhqd = v.transpose(1, 2) | |
| attn_out = F.scaled_dot_product_attention( | |
| q_bhqd, k_bhqd, v_bhqd, attn_mask=bias.to(q_bhqd.dtype) | |
| ) | |
| g = torch.sigmoid(self.g_proj(x)).view( | |
| bsz, n_queries, self.num_heads, self.head_dim | |
| ) | |
| ctx = g * attn_out.transpose(1, 2) | |
| out = self.out_proj(ctx.reshape(bsz, n_queries, d_model)) | |
| if s is not None: | |
| out = torch.sigmoid(self.out_gate(s)) * out | |
| return out | |
| if self._can_use_cueq_pair_bias(z, n_queries, beta): | |
| kernel_mask = ( | |
| attention_mask | |
| if attention_mask is not None | |
| else torch.ones(bsz, n_queries, device=a.device, dtype=torch.bool) | |
| ) | |
| out, _ = _cue_attn_pair_bias( # type: ignore[misc] | |
| s=x, | |
| q=q.transpose(1, 2), | |
| k=k.transpose(1, 2), | |
| v=v.transpose(1, 2), | |
| z=z, | |
| mask=kernel_mask, | |
| num_heads=self.num_heads, | |
| w_proj_z=self.pair_bias_proj.weight, | |
| w_proj_g=self.g_proj.weight, | |
| w_proj_o=self.out_proj.weight, | |
| w_ln_z=self.pair_norm.weight, | |
| b_ln_z=self.pair_norm.bias, | |
| return_z_proj=False, | |
| is_cached_z_proj=False, | |
| ) | |
| else: | |
| # Standard attention with pair bias | |
| g = torch.sigmoid(self.g_proj(x)).view( | |
| bsz, n_queries, self.num_heads, self.head_dim | |
| ) | |
| logits = ( | |
| torch.einsum("... i h d, ... j h d -> ... i j h", q, k) * self.scale | |
| ) | |
| if z.dim() == 4: | |
| pair_bias = self.pair_bias_proj(self.pair_norm(z)) | |
| else: | |
| pair_bias = z.unsqueeze(-1) | |
| logits = logits + pair_bias.to(dtype=logits.dtype) | |
| if attention_mask is not None: | |
| min_val = torch.finfo(logits.dtype).min | |
| mask_bias = torch.where( | |
| attention_mask.bool()[:, None, :, None], 0.0, min_val | |
| ) | |
| logits = logits + mask_bias.to(dtype=logits.dtype) | |
| attn = torch.softmax(logits, dim=-2).to(dtype=v.dtype) | |
| ctx = torch.einsum("... i j h, ... j h d -> ... i h d", attn, v) | |
| ctx = g * ctx | |
| out = self.out_proj(ctx.reshape(bsz, n_queries, d_model)) | |
| if s is not None: | |
| out = torch.sigmoid(self.out_gate(s)) * out | |
| return out | |
| # =========================================================================== | |
| # ConditionedTransitionBlock | |
| # =========================================================================== | |
| class ConditionedTransitionBlock(nn.Module): | |
| """Conditioned SwiGLU transition with adaptive layer norm.""" | |
| def __init__( | |
| self, | |
| d_model: int, | |
| d_cond: int | None = None, | |
| transition_multiplier: int = 2, | |
| use_conditioning: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| d_cond = d_cond or d_model | |
| hidden = transition_multiplier * d_model | |
| if use_conditioning: | |
| self.adaln = AdaptiveLayerNorm(d_model, d_cond, eps=1e-5) | |
| self.output_gate = nn.Linear(d_cond, d_model, bias=True) | |
| nn.init.zeros_(self.output_gate.weight) | |
| nn.init.constant_(self.output_gate.bias, -2.0) | |
| else: | |
| self.pre_norm = nn.LayerNorm(d_model, eps=1e-5) | |
| self.lin_swish = nn.Linear(d_model, 2 * hidden, bias=False) | |
| self.lin_out = nn.Linear(hidden, d_model, bias=False) | |
| def forward(self, a: Tensor, s: Tensor | None) -> Tensor: | |
| if s is not None: | |
| x = self.adaln(a, s) | |
| else: | |
| x = self.pre_norm(a) | |
| swish_a, swish_b = self.lin_swish(x).chunk(2, dim=-1) | |
| b = F.silu(swish_a) * swish_b | |
| out = self.lin_out(b) | |
| if s is not None: | |
| out = torch.sigmoid(self.output_gate(s)) * out | |
| return out | |
| # =========================================================================== | |
| # DiffusionTransformer (token transformer) | |
| # =========================================================================== | |
| class DiffusionTransformer(nn.Module): | |
| """Diffusion denoising transformer with attention pair bias.""" | |
| def __init__( | |
| self, | |
| d_model: int, | |
| d_pair: int, | |
| num_heads: int, | |
| num_blocks: int, | |
| d_cond: int | None = None, | |
| transition_multiplier: int = 2, | |
| use_conditioning: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| d_cond = d_cond or d_model | |
| self.attn_blocks = nn.ModuleList( | |
| [ | |
| AttentionPairBias( | |
| d_model=d_model, | |
| d_pair=d_pair, | |
| num_heads=num_heads, | |
| d_cond=d_cond, | |
| use_conditioning=use_conditioning, | |
| ) | |
| for _ in range(num_blocks) | |
| ] | |
| ) | |
| self.transition_blocks = nn.ModuleList( | |
| [ | |
| ConditionedTransitionBlock( | |
| d_model=d_model, | |
| d_cond=d_cond, | |
| transition_multiplier=transition_multiplier, | |
| use_conditioning=use_conditioning, | |
| ) | |
| for _ in range(num_blocks) | |
| ] | |
| ) | |
| def set_kernel_backend(self, backend: str | None) -> None: | |
| for attn in self.attn_blocks: | |
| cast(AttentionPairBias, attn).set_kernel_backend(backend) | |
| def forward( | |
| self, | |
| a: Tensor, | |
| s: Tensor | None, | |
| z: Tensor, | |
| beta: Tensor | float = 0.0, | |
| attention_mask: Tensor | None = None, | |
| num_diffusion_samples: int = 1, | |
| return_intermediates: bool = False, | |
| ) -> tuple[Tensor, list[Tensor]]: | |
| intermediates: list[Tensor] = [] | |
| x = a | |
| for attn, transition in zip(self.attn_blocks, self.transition_blocks): | |
| x = x + attn( | |
| x, | |
| s, | |
| z, | |
| beta, | |
| attention_mask=attention_mask, | |
| num_diffusion_samples=num_diffusion_samples, | |
| ) | |
| x = x + transition(x, s) | |
| if return_intermediates: | |
| intermediates.append(x) | |
| return x, intermediates | |
| # =========================================================================== | |
| # DiffusionConditioning | |
| # =========================================================================== | |
| class DiffusionConditioning(nn.Module): | |
| """Conditions pair and single representations on noise timestep.""" | |
| def __init__( | |
| self, | |
| c_z: int = 256, | |
| c_s: int = 768, | |
| c_s_inputs: int = 451, | |
| sigma_data: float = 16.0, | |
| fourier_dim: int = 256, | |
| transition_multiplier: int = 2, | |
| layer_norm_eps: float = 1e-5, | |
| ) -> None: | |
| super().__init__() | |
| self.sigma_data = float(sigma_data) | |
| self.c_z = c_z | |
| self.c_s = c_s | |
| self.c_s_inputs = c_s_inputs | |
| self.z_input_norm = nn.LayerNorm(2 * c_z, eps=layer_norm_eps) | |
| self.z_proj = nn.Linear(2 * c_z, c_z, bias=False) | |
| self.z_transitions = nn.ModuleList( | |
| [ | |
| TransitionLayer(c_z, n=transition_multiplier, eps=layer_norm_eps) | |
| for _ in range(2) | |
| ] | |
| ) | |
| self.s_input_norm = nn.LayerNorm(c_s_inputs, eps=layer_norm_eps) | |
| self.s_proj = nn.Linear(c_s_inputs, c_s, bias=False) | |
| self.fourier = FourierEmbedding(fourier_dim) | |
| self.noise_norm = nn.LayerNorm(fourier_dim, eps=layer_norm_eps) | |
| self.noise_proj = nn.Linear(fourier_dim, c_s, bias=False) | |
| self.s_transitions = nn.ModuleList( | |
| [ | |
| TransitionLayer(c_s, n=transition_multiplier, eps=layer_norm_eps) | |
| for _ in range(2) | |
| ] | |
| ) | |
| def forward( | |
| self, | |
| t_hat: Tensor, | |
| s_inputs: Tensor, | |
| s_trunk: Tensor | None, | |
| z_trunk: Tensor, | |
| relative_position_encoding: Tensor, | |
| sigma_data: float | None = None, | |
| num_diffusion_samples: int = 1, | |
| inference_cache: dict[str, Tensor] | None = None, | |
| ) -> tuple[Tensor, Tensor]: | |
| sigma = self.sigma_data if sigma_data is None else float(sigma_data) | |
| base_batch = z_trunk.shape[0] | |
| target_batch = base_batch * num_diffusion_samples | |
| # z conditioning (cached across diffusion steps — independent of t_hat) | |
| if inference_cache is not None and "z" in inference_cache: | |
| z = inference_cache["z"] | |
| else: | |
| z_rel = relative_position_encoding.to(dtype=torch.float32) | |
| z = torch.cat([z_trunk.to(dtype=torch.float32), z_rel], dim=-1) | |
| z = self.z_proj(self.z_input_norm(z)) | |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| for block in self.z_transitions: | |
| z = z + block(z) | |
| if inference_cache is not None: | |
| inference_cache["z"] = z | |
| # s conditioning | |
| s_inputs_eff = s_inputs | |
| if s_inputs_eff.shape[0] != target_batch: | |
| s_inputs_eff = s_inputs_eff.repeat_interleave(num_diffusion_samples, 0) | |
| s = self.s_proj(self.s_input_norm(s_inputs_eff.to(dtype=torch.float32))) | |
| # Noise embedding | |
| t = torch.as_tensor(t_hat, dtype=torch.float32, device=s.device).reshape(-1) | |
| if t.numel() == 1: | |
| t = t.expand(target_batch) | |
| elif t.shape[0] != target_batch: | |
| t = t.repeat_interleave(num_diffusion_samples, 0) | |
| t_noise = 0.25 * torch.log((t / sigma).clamp(min=1e-20)) | |
| n = self.fourier(t_noise) | |
| n = self.noise_proj(self.noise_norm(n)) | |
| s = s + n.unsqueeze(1) | |
| for block in self.s_transitions: | |
| s = s + block(s) | |
| return s, z | |
| # =========================================================================== | |
| # DiffusionModule | |
| # =========================================================================== | |
| class DiffusionModule(nn.Module): | |
| """Diffusion denoising module for structure prediction.""" | |
| def __init__( | |
| self, | |
| c_atom: int = 128, | |
| c_token: int = 768, | |
| c_z: int = 256, | |
| c_s_inputs: int = 451, | |
| sigma_data: float = 16.0, | |
| fourier_dim: int = 256, | |
| atom_num_blocks: int = 3, | |
| atom_num_heads: int = 4, | |
| token_num_blocks: int = 12, | |
| token_num_heads: int = 16, | |
| transition_multiplier: int = 2, | |
| swa_window_size: int = 128, | |
| spatial_rope_base_frequency: float = 20.0, | |
| n_spatial_rope_pairs_per_axis: int = 2, | |
| n_uid_rope_pairs: int = 10, | |
| uid_rope_base_frequency: float = 10000.0, | |
| ) -> None: | |
| super().__init__() | |
| self.sigma_data = float(sigma_data) | |
| self.conditioning = DiffusionConditioning( | |
| c_z=c_z, | |
| c_s=c_token, # conditioning s output is c_token | |
| c_s_inputs=c_s_inputs, | |
| sigma_data=sigma_data, | |
| fourier_dim=fourier_dim, | |
| transition_multiplier=transition_multiplier, | |
| ) | |
| # Atom encoder (structure_prediction=True, with coords_linear) | |
| self.atom_encoder = ESMFold2AtomEncoder( | |
| d_atom=c_atom, | |
| d_token=c_token, | |
| n_blocks=atom_num_blocks, | |
| n_heads=atom_num_heads, | |
| swa_window_size=swa_window_size, | |
| expansion_ratio=2, | |
| structure_prediction=True, | |
| spatial_rope_base_frequency=spatial_rope_base_frequency, | |
| n_spatial_rope_pairs_per_axis=n_spatial_rope_pairs_per_axis, | |
| n_uid_rope_pairs=n_uid_rope_pairs, | |
| uid_rope_base_frequency=uid_rope_base_frequency, | |
| ) | |
| # Atom decoder | |
| self.atom_decoder = ESMFold2AtomDecoder( | |
| d_atom=c_atom, | |
| d_token=c_token, | |
| n_blocks=atom_num_blocks, | |
| n_heads=atom_num_heads, | |
| swa_window_size=swa_window_size, | |
| expansion_ratio=2, | |
| spatial_rope_base_frequency=spatial_rope_base_frequency, | |
| n_spatial_rope_pairs_per_axis=n_spatial_rope_pairs_per_axis, | |
| n_uid_rope_pairs=n_uid_rope_pairs, | |
| uid_rope_base_frequency=uid_rope_base_frequency, | |
| ) | |
| self.s_to_token = nn.Linear(c_token, c_token, bias=False) | |
| nn.init.zeros_(self.s_to_token.weight) | |
| # Token transformer (DiffusionTransformer with pair bias) | |
| self.token_transformer = DiffusionTransformer( | |
| d_model=c_token, | |
| d_pair=c_z, | |
| num_heads=token_num_heads, | |
| num_blocks=token_num_blocks, | |
| d_cond=c_token, | |
| transition_multiplier=transition_multiplier, | |
| use_conditioning=True, | |
| ) | |
| self.s_step_norm = nn.LayerNorm(c_token) | |
| self.token_norm = nn.LayerNorm(c_token) | |
| def set_kernel_backend(self, backend: str | None) -> None: | |
| self.token_transformer.set_kernel_backend(backend) | |
| def forward( | |
| self, | |
| x_noisy: Tensor, | |
| t_hat: Tensor, | |
| ref_pos: Tensor, | |
| ref_charge: Tensor, | |
| ref_mask: Tensor, | |
| ref_element: Tensor, | |
| ref_atom_name_chars: Tensor, | |
| ref_space_uid: Tensor, | |
| tok_idx: Tensor, | |
| s_inputs: Tensor, | |
| s_trunk: Tensor | None, | |
| z_trunk: Tensor, | |
| relative_position_encoding: Tensor, | |
| asym_id: Tensor, | |
| residue_index: Tensor, | |
| entity_id: Tensor, | |
| token_index: Tensor, | |
| sym_id: Tensor, | |
| sigma_data: float | None = None, | |
| token_attention_mask: Tensor | None = None, | |
| num_diffusion_samples: int = 1, | |
| return_token_repr: bool = False, | |
| return_atom_repr: bool = False, | |
| inference_cache: dict[str, Tensor] | None = None, | |
| ) -> dict[str, Tensor | None]: | |
| bsz = x_noisy.shape[0] | |
| sigma = self.sigma_data if sigma_data is None else float(sigma_data) | |
| t = torch.as_tensor(t_hat, dtype=torch.float32, device=x_noisy.device).reshape( | |
| -1 | |
| ) | |
| if t.numel() == 1: | |
| t = t.expand(bsz) | |
| # Step 1: conditioning (pair z is cached across diffusion steps) | |
| s, z = self.conditioning( | |
| t_hat=t, | |
| s_inputs=s_inputs, | |
| s_trunk=s_trunk, | |
| z_trunk=z_trunk, | |
| relative_position_encoding=relative_position_encoding, | |
| sigma_data=sigma, | |
| num_diffusion_samples=num_diffusion_samples, | |
| inference_cache=inference_cache, | |
| ) | |
| # Step 2: normalize noisy coords | |
| denom = torch.sqrt(t * t + sigma * sigma) | |
| r_noisy = x_noisy / denom[:, None, None] | |
| # Step 3: atom encoder | |
| a, q_skip, c_skip, p_skip, enc_intermediates = self.atom_encoder( | |
| ref_pos=ref_pos, | |
| atom_attention_mask=ref_mask, | |
| ref_space_uid=ref_space_uid, | |
| ref_charge=ref_charge, | |
| ref_element=ref_element, | |
| ref_atom_name_chars=ref_atom_name_chars, | |
| atom_to_token=tok_idx, | |
| r_l=r_noisy, | |
| s_i=s_trunk, | |
| num_diffusion_samples=num_diffusion_samples, | |
| return_intermediates=return_atom_repr, | |
| inference_cache=inference_cache, | |
| ) | |
| # Step 4: add conditioned s | |
| a = a + self.s_to_token(self.s_step_norm(s)) | |
| # Step 5: token transformer | |
| a, _ = self.token_transformer( | |
| a, | |
| s, | |
| z, | |
| beta=0.0, | |
| attention_mask=token_attention_mask, | |
| num_diffusion_samples=num_diffusion_samples, | |
| ) | |
| # Step 6: token norm | |
| a = self.token_norm(a) | |
| # Step 7: atom decoder | |
| r_update, dec_intermediates = self.atom_decoder( | |
| a_i=a, | |
| q_l=q_skip, | |
| c_l=c_skip, | |
| p_lm=p_skip, | |
| atom_to_token=tok_idx, | |
| atom_attention_mask=ref_mask, | |
| num_diffusion_samples=num_diffusion_samples, | |
| return_intermediates=return_atom_repr, | |
| ) | |
| # Step 8: compute denoised output | |
| sigma2 = sigma * sigma | |
| t2 = t * t | |
| out = (sigma2 / (sigma2 + t2))[:, None, None] * x_noisy | |
| out = out + ((sigma * t) / torch.sqrt(sigma2 + t2))[:, None, None] * r_update | |
| # Collect atom intermediates from encoder + decoder | |
| atom_intermediates: Tensor | None = None | |
| if return_atom_repr: | |
| all_ints = enc_intermediates + dec_intermediates | |
| if all_ints: | |
| atom_intermediates = torch.stack(all_ints, dim=2) | |
| return { | |
| "x_denoised": out, | |
| "token_repr": a if return_token_repr else None, | |
| "atom_intermediates": atom_intermediates, | |
| } | |
| # =========================================================================== | |
| # DiffusionStructureHead | |
| # =========================================================================== | |
| class DiffusionStructureHead(nn.Module): | |
| """Wrapper around DiffusionModule with diffusion sampling.""" | |
| def __init__(self, config: ESMFold2Config) -> None: | |
| super().__init__() | |
| dm = config.structure_head.diffusion_module | |
| swa_cfg = config.inputs.atom_encoder | |
| sh = config.structure_head | |
| self.diffusion_module = DiffusionModule( | |
| c_atom=dm.c_atom, | |
| c_token=dm.c_token, | |
| c_z=dm.c_z, | |
| c_s_inputs=dm.c_s_inputs, | |
| sigma_data=dm.sigma_data, | |
| fourier_dim=dm.fourier_dim, | |
| atom_num_blocks=dm.atom_num_blocks, | |
| atom_num_heads=dm.atom_num_heads, | |
| token_num_blocks=dm.token_num_blocks, | |
| token_num_heads=dm.token_num_heads, | |
| transition_multiplier=dm.transition_multiplier, | |
| swa_window_size=swa_cfg.swa_window_size, | |
| spatial_rope_base_frequency=swa_cfg.spatial_rope_base_frequency, | |
| n_spatial_rope_pairs_per_axis=swa_cfg.n_spatial_rope_pairs_per_axis, | |
| n_uid_rope_pairs=swa_cfg.n_uid_rope_pairs, | |
| uid_rope_base_frequency=swa_cfg.uid_rope_base_frequency, | |
| ) | |
| # Sampling hyperparameters | |
| self.sigma_data = dm.sigma_data | |
| self.gamma_0 = sh.gamma_0 | |
| self.gamma_min = sh.gamma_min | |
| self.noise_scale = sh.noise_scale | |
| self.step_scale = sh.step_scale | |
| self.inference_s_max = sh.inference_s_max | |
| self.inference_s_min = sh.inference_s_min | |
| self.inference_p = sh.inference_p | |
| self.inference_num_steps = sh.inference_num_steps | |
| def set_kernel_backend(self, backend: str | None) -> None: | |
| self.diffusion_module.set_kernel_backend(backend) | |
| # ------------------------------------------------------------------ | |
| # Helpers | |
| # ------------------------------------------------------------------ | |
| def inference_noise_schedule( | |
| self, num_steps: int | None = None, device: torch.device | None = None | |
| ) -> Tensor: | |
| """Karras power-law noise schedule.""" | |
| steps = self.inference_num_steps if num_steps is None else int(num_steps) | |
| if steps == 1: | |
| return torch.tensor( | |
| [self.inference_s_max * self.sigma_data, 0.0], | |
| device=device, | |
| dtype=torch.float32, | |
| ) | |
| p = float(self.inference_p) | |
| inv_p = 1.0 / p | |
| k = torch.arange(steps, device=device, dtype=torch.float32) | |
| base = self.inference_s_max**inv_p + (k / (steps - 1)) * ( | |
| self.inference_s_min**inv_p - self.inference_s_max**inv_p | |
| ) | |
| schedule = self.sigma_data * base.pow(p) | |
| return F.pad(schedule, (0, 1), value=0.0) | |
| def _random_rotations(n: int, dtype: torch.dtype, device: torch.device) -> Tensor: | |
| q = torch.randn((n, 4), dtype=dtype, device=device) | |
| scale = torch.sqrt((q * q).sum(dim=1)) | |
| signs = torch.where(q[:, 0] < 0, -scale, scale) | |
| q = q / signs[:, None] | |
| r, i, j, k = torch.unbind(q, dim=-1) | |
| two_s = 2.0 / (q * q).sum(dim=-1) | |
| return torch.stack( | |
| ( | |
| 1 - two_s * (j * j + k * k), | |
| two_s * (i * j - k * r), | |
| two_s * (i * k + j * r), | |
| two_s * (i * j + k * r), | |
| 1 - two_s * (i * i + k * k), | |
| two_s * (j * k - i * r), | |
| two_s * (i * k - j * r), | |
| two_s * (j * k + i * r), | |
| 1 - two_s * (i * i + j * j), | |
| ), | |
| dim=-1, | |
| ).reshape(n, 3, 3) | |
| def _center_random_augmentation( | |
| self, x: Tensor, atom_mask: Tensor, second_coords: Tensor | None = None | |
| ) -> tuple[Tensor, Tensor | None]: | |
| """Algorithm 19: center + random rotation + translation.""" | |
| bsz = x.shape[0] | |
| mask = atom_mask.unsqueeze(-1) # [B, A, 1] | |
| denom = mask.sum(dim=1, keepdim=True).clamp(min=1) | |
| mean = (x * mask).sum(dim=1, keepdim=True) / denom | |
| x = x - mean | |
| if second_coords is not None: | |
| second_coords = second_coords - mean | |
| r = self._random_rotations(bsz, x.dtype, x.device) | |
| x = torch.einsum("bmd,bds->bms", x, r) | |
| if second_coords is not None: | |
| second_coords = torch.einsum("bmd,bds->bms", second_coords, r) | |
| t = torch.randn_like(x[:, 0:1, :]) | |
| x = x + t | |
| if second_coords is not None: | |
| second_coords = second_coords + t | |
| return x, second_coords | |
| def _weighted_rigid_align( | |
| x: Tensor, x_gt: Tensor, w: Tensor, mask: Tensor | |
| ) -> Tensor: | |
| """Kabsch alignment: align x to x_gt with weights w.""" | |
| w = (mask * w).unsqueeze(-1) # [B, N, 1] | |
| denom = w.sum(dim=-2, keepdim=True).clamp(min=1e-8) | |
| mu = (x * w).sum(dim=-2, keepdim=True) / denom | |
| mu_gt = (x_gt * w).sum(dim=-2, keepdim=True) / denom | |
| x_c = x - mu | |
| xgt_c = x_gt - mu_gt | |
| H = torch.einsum("bni,bnj->bij", w * xgt_c, x_c) | |
| H32 = H.float() | |
| U, _, Vh = torch.linalg.svd(H32, driver="gesvd" if H32.is_cuda else None) | |
| det = torch.linalg.det(U @ Vh) | |
| ones = torch.ones_like(det) | |
| R = (U @ torch.diag_embed(torch.stack([ones, ones, det], dim=-1)) @ Vh).to( | |
| H.dtype | |
| ) | |
| return x_c @ R.transpose(-1, -2) + mu_gt | |
| # ------------------------------------------------------------------ | |
| # Sampling | |
| # ------------------------------------------------------------------ | |
| def sample( | |
| self, | |
| z_trunk: Tensor, | |
| s_inputs: Tensor, | |
| s_trunk: Tensor | None, | |
| relative_position_encoding: Tensor, | |
| ref_pos: Tensor, | |
| ref_charge: Tensor, | |
| ref_mask: Tensor, | |
| ref_element: Tensor, | |
| ref_atom_name_chars: Tensor, | |
| ref_space_uid: Tensor, | |
| tok_idx: Tensor, | |
| asym_id: Tensor, | |
| residue_index: Tensor, | |
| entity_id: Tensor, | |
| token_index: Tensor, | |
| sym_id: Tensor, | |
| token_attention_mask: Tensor | None = None, | |
| num_diffusion_samples: int = 1, | |
| num_sampling_steps: int | None = None, | |
| max_inference_sigma: float | None = 256.0, | |
| noise_scale: float | None = None, | |
| step_scale: float | None = None, | |
| return_atom_repr: bool = False, | |
| use_inference_cache: bool = True, | |
| denoising_early_exit_rmsd: float | None = None, | |
| ) -> dict[str, Tensor | None]: | |
| """Diffusion sampling (Algorithm 18). | |
| ``num_sampling_steps`` is the number of denoising steps actually run. | |
| When ``max_inference_sigma`` is set, the Karras schedule built with | |
| ``num_sampling_steps`` entries would lose its high-σ tail to the cap, | |
| so we inflate the underlying schedule length here to land back at the | |
| requested step count post-truncation. | |
| """ | |
| n_atoms = tok_idx.shape[1] | |
| device = s_inputs.device | |
| target_batch = s_inputs.shape[0] * num_diffusion_samples | |
| inference_cache: dict[str, Tensor] | None = {} if use_inference_cache else None | |
| steps = ( | |
| self.inference_num_steps | |
| if num_sampling_steps is None | |
| else int(num_sampling_steps) | |
| ) | |
| schedule = self.inference_noise_schedule(steps, device) | |
| if max_inference_sigma is not None: | |
| schedule = schedule[schedule <= float(max_inference_sigma)] | |
| schedule = F.pad(schedule, (1, 0), value=float(max_inference_sigma)) | |
| lam = self.noise_scale if noise_scale is None else float(noise_scale) | |
| eta = self.step_scale if step_scale is None else float(step_scale) | |
| x = schedule[0] * torch.randn( | |
| target_batch, n_atoms, 3, device=device, dtype=torch.float32 | |
| ) | |
| atom_mask = ref_mask.repeat_interleave(num_diffusion_samples, 0).float() | |
| gammas = torch.where( | |
| schedule > self.gamma_min, | |
| torch.full_like(schedule, self.gamma_0), | |
| torch.zeros_like(schedule), | |
| ) | |
| x_denoised_prev: Tensor | None = None | |
| token_repr: Tensor | None = None | |
| diff_atom_intermediates: Tensor | None = None | |
| step_pairs = list(zip(schedule[:-1], schedule[1:], gammas[1:])) | |
| num_steps = len(step_pairs) | |
| for step_idx, (sigma_tm, sigma_t, gamma) in enumerate(step_pairs): | |
| x, x_denoised_prev = self._center_random_augmentation( | |
| x, atom_mask, second_coords=x_denoised_prev | |
| ) | |
| sigma_tm_val = float(sigma_tm.item()) | |
| t_hat_val = sigma_tm_val * (1.0 + float(gamma.item())) | |
| eps_std = lam * max(t_hat_val**2 - sigma_tm_val**2, 0.0) ** 0.5 | |
| x_noisy = x + eps_std * torch.randn_like(x) | |
| is_last_step = step_idx == num_steps - 1 | |
| request_atom_repr = return_atom_repr and ( | |
| is_last_step or denoising_early_exit_rmsd is not None | |
| ) | |
| dm_out = self.diffusion_module( | |
| x_noisy=x_noisy, | |
| t_hat=torch.full( | |
| (target_batch,), t_hat_val, device=device, dtype=torch.float32 | |
| ), | |
| ref_pos=ref_pos, | |
| ref_charge=ref_charge, | |
| ref_mask=ref_mask, | |
| ref_element=ref_element, | |
| ref_atom_name_chars=ref_atom_name_chars, | |
| ref_space_uid=ref_space_uid, | |
| tok_idx=tok_idx, | |
| s_inputs=s_inputs, | |
| s_trunk=s_trunk, | |
| z_trunk=z_trunk, | |
| relative_position_encoding=relative_position_encoding, | |
| asym_id=asym_id, | |
| residue_index=residue_index, | |
| entity_id=entity_id, | |
| token_index=token_index, | |
| sym_id=sym_id, | |
| token_attention_mask=token_attention_mask, | |
| num_diffusion_samples=num_diffusion_samples, | |
| return_token_repr=True, | |
| return_atom_repr=request_atom_repr, | |
| inference_cache=inference_cache, | |
| ) | |
| x_denoised = dm_out["x_denoised"] | |
| token_repr = dm_out["token_repr"] | |
| if request_atom_repr: | |
| diff_atom_intermediates = dm_out.get("atom_intermediates") | |
| # Reverse diffusion alignment (Kabsch) | |
| with torch.autocast(device_type="cuda", enabled=False): | |
| x_noisy = self._weighted_rigid_align( | |
| x_noisy.float(), x_denoised.float(), atom_mask, atom_mask | |
| ) | |
| x_noisy = x_noisy.to(dtype=x_denoised.dtype) | |
| # ODE/SDE step | |
| sigma_t_val = float(sigma_t.item()) | |
| denoised_over_sigma = (x_noisy - x_denoised) / t_hat_val | |
| x = x_noisy + eta * (sigma_t_val - t_hat_val) * denoised_over_sigma | |
| # Denoising early-exit: stop when consecutive predictions converge | |
| if ( | |
| denoising_early_exit_rmsd is not None | |
| and x_denoised_prev is not None | |
| and step_idx >= 1 | |
| ): | |
| with torch.autocast(device_type="cuda", enabled=False): | |
| aligned = self._weighted_rigid_align( | |
| x_denoised_prev.float(), | |
| x_denoised.float(), | |
| atom_mask, | |
| atom_mask, | |
| ) | |
| diff = (x_denoised.float() - aligned) * atom_mask.unsqueeze(-1) | |
| per_sample_rmsd = ( | |
| diff.pow(2).sum(dim=(-1, -2)) / atom_mask.sum(dim=-1).clamp(min=1) | |
| ).sqrt() | |
| if per_sample_rmsd.max().item() < denoising_early_exit_rmsd: | |
| x = x_denoised | |
| x_denoised_prev = x_denoised | |
| break | |
| x_denoised_prev = x_denoised | |
| result: dict[str, Tensor | None] = { | |
| "sample_atom_coords": x, | |
| "diff_token_repr": token_repr, | |
| } | |
| if return_atom_repr: | |
| result["diff_atom_intermediates"] = diff_atom_intermediates | |
| return result | |
| # =========================================================================== | |
| # RowAttentionPooling | |
| # =========================================================================== | |
| class RowAttentionPooling(nn.Module): | |
| """Row-wise attention pooling: attn_proj, out_proj.""" | |
| def __init__(self, d_pair: int, d_single: int) -> None: | |
| super().__init__() | |
| self.attn_proj = nn.Linear(d_pair, 1, bias=False) | |
| self.out_proj = nn.Linear(d_pair, d_single, bias=False) | |
| def forward(self, z: Tensor, mask: Tensor) -> Tensor: | |
| scores = self.attn_proj(z).squeeze(-1) | |
| mask_bias = torch.where( | |
| mask[:, None, :].bool(), | |
| torch.zeros_like(scores), | |
| torch.full_like(scores, -1e9), | |
| ) | |
| scores = scores + mask_bias | |
| weights = F.softmax(scores, dim=-1) | |
| pooled = torch.einsum("bnm,bnmd->bnd", weights, z) | |
| return self.out_proj(pooled) | |
| # =========================================================================== | |
| # InputsEmbedder | |
| # =========================================================================== | |
| class InputsEmbedder(nn.Module): | |
| """Embeds input features including atom-level encoding via SWA attention.""" | |
| def __init__(self, config: ESMFold2Config) -> None: | |
| super().__init__() | |
| swa_cfg = config.inputs.atom_encoder | |
| self.atom_attention_encoder = ESMFold2AtomEncoder( | |
| d_atom=swa_cfg.d_atom, | |
| d_token=swa_cfg.d_token, | |
| n_blocks=swa_cfg.n_blocks, | |
| n_heads=swa_cfg.n_heads, | |
| swa_window_size=swa_cfg.swa_window_size, | |
| expansion_ratio=swa_cfg.expansion_ratio, | |
| structure_prediction=False, # no coords_linear | |
| spatial_rope_base_frequency=swa_cfg.spatial_rope_base_frequency, | |
| n_spatial_rope_pairs_per_axis=swa_cfg.n_spatial_rope_pairs_per_axis, | |
| n_uid_rope_pairs=swa_cfg.n_uid_rope_pairs, | |
| uid_rope_base_frequency=swa_cfg.uid_rope_base_frequency, | |
| ) | |
| def forward( | |
| self, | |
| aatype: Tensor, | |
| profile: Tensor, | |
| deletion_mean: Tensor, | |
| ref_pos: Tensor, | |
| atom_attention_mask: Tensor, | |
| ref_space_uid: Tensor, | |
| ref_charge: Tensor, | |
| ref_element: Tensor, | |
| ref_atom_name_chars: Tensor, | |
| atom_to_token: Tensor, | |
| ) -> Tensor: | |
| """Embed inputs into per-token features. | |
| Returns: | |
| [B, L, d_inputs] concatenation of atom encoding, aatype, profile, | |
| and deletion_mean. | |
| """ | |
| a, _q, _c, _attn_params, _intermediates = self.atom_attention_encoder( | |
| ref_pos=ref_pos, | |
| atom_attention_mask=atom_attention_mask, | |
| ref_space_uid=ref_space_uid, | |
| ref_charge=ref_charge, | |
| ref_element=ref_element, | |
| ref_atom_name_chars=ref_atom_name_chars, | |
| atom_to_token=atom_to_token, | |
| ) | |
| return torch.cat([a, aatype, profile, deletion_mean.unsqueeze(-1)], dim=-1) | |
| # =========================================================================== | |
| # ResIdxAsymIdSymIdEntityIdEncoding (trunk relative position) | |
| # =========================================================================== | |
| class ResIdxAsymIdSymIdEntityIdEncoding(nn.Module): | |
| """embed.weight [d_pair, n_features] where n_features = 2*(2*r_bins+2) + 1 + (2*c_bins+2). | |
| For default r_bins=32, c_bins=2: 2*66 + 1 + 6 = 139. | |
| """ | |
| def __init__( | |
| self, | |
| n_relative_residx_bins: int = 32, | |
| n_relative_chain_bins: int = 2, | |
| d_pair: int = 256, | |
| ) -> None: | |
| super().__init__() | |
| self.n_relative_residx_bins = n_relative_residx_bins | |
| self.n_relative_chain_bins = n_relative_chain_bins | |
| self.d_pair = d_pair | |
| n_feats_residue = 2 * n_relative_residx_bins + 2 | |
| n_feats_token = 2 * n_relative_residx_bins + 2 | |
| n_feats_chain = 2 * n_relative_chain_bins + 2 | |
| n_feats_same_entity = 1 | |
| total_feats = ( | |
| n_feats_residue + n_feats_token + n_feats_chain + n_feats_same_entity | |
| ) | |
| self.embed = nn.Linear(total_feats, d_pair, bias=False) | |
| def forward( | |
| self, | |
| residue_index: Tensor, | |
| asym_id: Tensor, | |
| sym_id: Tensor, | |
| entity_id: Tensor, | |
| token_index: Tensor, | |
| ) -> Tensor: | |
| bij_same_chain = asym_id.unsqueeze(2) == asym_id.unsqueeze(1) | |
| bij_same_residue = residue_index.unsqueeze(2) == residue_index.unsqueeze(1) | |
| bij_same_entity = entity_id.unsqueeze(2) == entity_id.unsqueeze(1) | |
| dij_residue = residue_index.unsqueeze(2) - residue_index.unsqueeze(1) | |
| dij_residue = torch.clip( | |
| dij_residue + self.n_relative_residx_bins, | |
| 0, | |
| 2 * self.n_relative_residx_bins, | |
| ) | |
| dij_residue = torch.where( | |
| bij_same_chain, dij_residue, 2 * self.n_relative_residx_bins + 1 | |
| ) | |
| aij_rel_pos = F.one_hot(dij_residue, 2 * self.n_relative_residx_bins + 2) | |
| dij_token = torch.clip( | |
| token_index.unsqueeze(2) | |
| - token_index.unsqueeze(1) | |
| + self.n_relative_residx_bins, | |
| 0, | |
| 2 * self.n_relative_residx_bins, | |
| ) | |
| dij_token = torch.where( | |
| bij_same_chain & bij_same_residue, | |
| dij_token, | |
| 2 * self.n_relative_residx_bins + 1, | |
| ) | |
| aij_rel_token = F.one_hot(dij_token, 2 * self.n_relative_residx_bins + 2) | |
| dij_chain = torch.clip( | |
| sym_id.unsqueeze(2) - sym_id.unsqueeze(1) + self.n_relative_chain_bins, | |
| 0, | |
| 2 * self.n_relative_chain_bins, | |
| ) | |
| dij_chain = torch.where( | |
| bij_same_chain, 2 * self.n_relative_chain_bins + 1, dij_chain | |
| ) | |
| aij_rel_chain = F.one_hot(dij_chain, 2 * self.n_relative_chain_bins + 2) | |
| feats = torch.cat( | |
| [ | |
| aij_rel_pos.float(), | |
| aij_rel_token.float(), | |
| bij_same_entity.float().unsqueeze(-1), | |
| aij_rel_chain.float(), | |
| ], | |
| dim=-1, | |
| ) | |
| return self.embed(feats) | |
| # =========================================================================== | |
| # SingleToPair (for LanguageModelShim) | |
| # =========================================================================== | |
| class SingleToPair(nn.Module): | |
| """downproject, output_mlp (Sequential of Linear, GELU, Linear).""" | |
| def __init__(self, input_dim: int, downproject_dim: int, output_dim: int) -> None: | |
| super().__init__() | |
| self.downproject = nn.Linear(input_dim, downproject_dim) | |
| self.output_mlp = nn.Sequential( | |
| nn.Linear(2 * downproject_dim, output_dim), | |
| nn.GELU(), | |
| nn.Linear(output_dim, output_dim), | |
| ) | |
| def forward(self, x: Tensor) -> Tensor: | |
| x = self.downproject(x) | |
| x = torch.cat( | |
| [(x.unsqueeze(2) * x.unsqueeze(1)), (x.unsqueeze(2) - x.unsqueeze(1))], | |
| dim=3, | |
| ) | |
| return self.output_mlp(x) | |
| # =========================================================================== | |
| # LanguageModelShim | |
| # =========================================================================== | |
| class LanguageModelShim(nn.Module): | |
| """Shim holding the trainable projection weights for LM integration. | |
| Contains: | |
| - base_z_combine: nn.Parameter [num_layers+1] | |
| - base_z_linear: Sequential(LayerNorm(d_model), Linear(d_model, d_z, bias=False)) | |
| - base_z_mlp: Sequential(SingleToPair(d_z, d_z, d_z), LayerNorm(d_z)) | |
| """ | |
| def __init__( | |
| self, d_z: int = 256, d_model: int = 2560, num_layers: int = 80 | |
| ) -> None: | |
| super().__init__() | |
| self.base_z_mlp = nn.Sequential(SingleToPair(d_z, d_z, d_z), nn.LayerNorm(d_z)) | |
| self.base_z_linear = nn.Sequential( | |
| nn.LayerNorm(d_model), nn.Linear(d_model, d_z, bias=False) | |
| ) | |
| self.base_z_combine = nn.Parameter(torch.zeros(num_layers + 1)) | |
| def forward(self, hidden_states: Tensor, *, lm_dropout: float = 0.0) -> Tensor: | |
| """Project pre-computed ESMC hidden states to pair representation. | |
| Args: | |
| hidden_states: [B, L, num_layers+1, d_model] from ESMC 6B. | |
| lm_dropout: Dropout probability applied to the pair | |
| representation after ``base_z_mlp``. | |
| Returns: | |
| [B, L, L, d_pair] pair representation. | |
| """ | |
| lm_z = self.base_z_linear(hidden_states) # [B, L, 81, d_z] | |
| weights = self.base_z_combine.softmax(0) # [81] | |
| lm_z = (weights @ lm_z).squeeze(-2) # [B, L, d_z] | |
| lm_z = self.base_z_mlp(lm_z) # [B, L, L, d_z] | |
| if lm_dropout > 0: | |
| lm_z = F.dropout(lm_z, p=lm_dropout, training=True) | |
| return lm_z | |
| # =========================================================================== | |
| # Reproducibility helper (mirrors evolutionaryscale.utils.reproducibility) | |
| # =========================================================================== | |
| def _seed_context(seed: int | None, *, cuda: bool = True): | |
| """Temporarily seed Python, NumPy, and PyTorch RNGs.""" | |
| if seed is None: | |
| yield | |
| return | |
| py_state = random.getstate() | |
| np_state = np.random.get_state() | |
| torch_state = torch.get_rng_state() | |
| cuda_states = ( | |
| torch.cuda.get_rng_state_all() if cuda and torch.cuda.is_available() else None | |
| ) | |
| seed = int(seed) % (2**32) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if cuda and torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| try: | |
| yield | |
| finally: | |
| random.setstate(py_state) | |
| np.random.set_state(np_state) | |
| torch.set_rng_state(torch_state) | |
| if cuda_states is not None: | |
| torch.cuda.set_rng_state_all(cuda_states) | |
| # =========================================================================== | |
| # ESMFold2ExperimentalModel — the top-level PreTrainedModel | |
| # =========================================================================== | |
| def compute_lm_hidden_states( | |
| esmc: nn.Module, | |
| input_ids: Tensor, | |
| asym_id: Tensor, | |
| residue_index: Tensor, | |
| mol_type: Tensor, | |
| token_mask: Tensor, | |
| pad_to_multiple: int | None = None, | |
| lm_mask_pct: float = 0.0, | |
| mask_token_id: int = 32, | |
| ) -> Tensor: | |
| """Run ESMC with BOS/EOS wrapping, return hidden states [B, L, N, D] with N=81 layers. | |
| Atom-tokenized modified residues (HYP, MSE, ACE, NH2, ...) span multiple | |
| structure tokens but share a single ``(asym_id, residue_index)`` key — | |
| collapse them to one LM token per residue before running the LM (the LM | |
| was trained on per-residue inputs, not per-atom), then scatter the | |
| hidden states back to the per-token layout. | |
| """ | |
| B, L = input_ids.shape | |
| device = input_ids.device | |
| protein_mask = (mol_type == 0) & token_mask | |
| lm_input_list = [] | |
| lm_lengths = [] | |
| # Per-batch maps from (original protein-token index) to (LM input position). | |
| expand_maps: list[Tensor] = [] | |
| for b in range(B): | |
| mask_b = protein_mask[b] | |
| ids_b = input_ids[b][mask_b] | |
| asym_b = asym_id[b][mask_b] | |
| res_b = residue_index[b][mask_b] | |
| # Collapse: keep first token per (asym_id, residue_index) key, in | |
| # input order. ``inverse`` maps each original protein-token to its | |
| # collapsed residue index. | |
| keys = torch.stack((asym_b, res_b), dim=1) | |
| unique_keys, inverse = torch.unique(keys, dim=0, return_inverse=True) | |
| n_unique = unique_keys.size(0) | |
| token_positions = torch.arange(keys.size(0), device=device, dtype=torch.long) | |
| first_pos = torch.full( | |
| (n_unique,), keys.size(0), device=device, dtype=torch.long | |
| ) | |
| first_pos.scatter_reduce_( | |
| 0, inverse, token_positions, reduce="amin", include_self=True | |
| ) | |
| ordered = torch.argsort(first_pos) | |
| first_pos_ordered = first_pos[ordered] | |
| ids_collapsed = ids_b[first_pos_ordered] | |
| asym_collapsed = asym_b[first_pos_ordered] | |
| remap = torch.empty_like(ordered) | |
| remap[ordered] = torch.arange(n_unique, device=device, dtype=torch.long) | |
| inverse_ordered = remap[inverse] | |
| chain_ids = asym_collapsed.unique(sorted=True) | |
| # [BOS] chain1 [EOS BOS] chain2 ... [EOS] | |
| parts: list[Tensor] = [torch.tensor([0], device=device, dtype=ids_b.dtype)] | |
| # Per-chain LM positions accumulate; track them for the expand map. | |
| per_token_lm_pos = torch.empty(n_unique, device=device, dtype=torch.long) | |
| cursor = 1 # position 0 is the leading BOS | |
| for i, cid in enumerate(chain_ids): | |
| in_chain = (asym_collapsed == cid).nonzero(as_tuple=True)[0] | |
| parts.append(ids_collapsed[in_chain]) | |
| per_token_lm_pos[in_chain] = torch.arange( | |
| cursor, cursor + in_chain.shape[0], device=device, dtype=torch.long | |
| ) | |
| cursor += in_chain.shape[0] | |
| if i < len(chain_ids) - 1: | |
| parts.append(torch.tensor([2, 0], device=device, dtype=ids_b.dtype)) | |
| cursor += 2 # EOS + BOS | |
| parts.append(torch.tensor([2], device=device, dtype=ids_b.dtype)) | |
| lm_seq = torch.cat(parts) | |
| lm_input_list.append(lm_seq) | |
| lm_lengths.append(lm_seq.shape[0]) | |
| # Original protein-token position → LM input position. | |
| prot_pos_b = mask_b.nonzero(as_tuple=True)[0] | |
| expand_map = torch.full((L,), -1, device=device, dtype=torch.long) | |
| expand_map[prot_pos_b] = per_token_lm_pos[inverse_ordered] | |
| expand_maps.append(expand_map) | |
| # Pad to longest LM input; round to ``pad_to_multiple`` when fp8 is on | |
| # (TE fp8 kernels assert prod(shape[:-1]) % 8 == 0). | |
| max_len = max(lm_lengths) | |
| if pad_to_multiple is not None and pad_to_multiple > 1: | |
| max_len = ((max_len + pad_to_multiple - 1) // pad_to_multiple) * pad_to_multiple | |
| lm_input_ids = torch.full( | |
| (B, max_len), | |
| 1, | |
| device=device, | |
| dtype=input_ids.dtype, # PAD=1 | |
| ) | |
| for b in range(B): | |
| lm_input_ids[b, : lm_lengths[b]] = lm_input_list[b] | |
| # sequence_id for chain-aware attention; PAD tokens get -1 (no attention). | |
| sequence_id = (lm_input_ids == 0).cumsum(dim=1) - 1 # BOS=0 | |
| sequence_id = sequence_id.masked_fill(lm_input_ids == 1, -1) # PAD=1 | |
| if lm_mask_pct > 0.0: | |
| special = (lm_input_ids == 0) | (lm_input_ids == 1) | (lm_input_ids == 2) | |
| do_mask = ( | |
| torch.rand(lm_input_ids.shape, device=device) < lm_mask_pct | |
| ) & ~special | |
| lm_input_ids = lm_input_ids.masked_fill(do_mask, mask_token_id) | |
| with torch.inference_mode(): | |
| esmc_out = esmc( | |
| input_ids=lm_input_ids, sequence_id=sequence_id, output_hidden_states=True | |
| ) | |
| hs = esmc_out.hidden_states # [n_layers+1, B, max_len, D] | |
| n_layers_plus_1, _, _, D = hs.shape | |
| result = torch.zeros(B, L, n_layers_plus_1, D, device=device, dtype=hs.dtype) | |
| for b in range(B): | |
| mb = protein_mask[b] | |
| em = expand_maps[b][mb] # [n_protein_tokens] LM positions | |
| # hs[:, b, em, :] -> [n_layers+1, n_protein_tokens, D] | |
| gathered = hs[:, b, em, :].permute(1, 0, 2) | |
| result[b, mb.nonzero(as_tuple=True)[0]] = gathered | |
| return result.detach() | |
| # =========================================================================== | |
| # TriangleMultiplicativeUpdate | |
| # =========================================================================== | |
| class TriangleMultiplicativeBlock(nn.Module): | |
| """Triangle multiplicative update block with gated signal routing.""" | |
| _FLOW_TO_EINSUM = {"outgoing": "bikd,bjkd->bijd", "incoming": "bkid,bkjd->bijd"} | |
| _VALID_FLOWS = ("outgoing", "incoming") | |
| def __init__(self, input_channels: int, latent_channels: int, flow: str) -> None: | |
| super().__init__() | |
| if flow not in self._FLOW_TO_EINSUM: | |
| raise ValueError( | |
| f"Invalid flow={flow!r}. Expected one of {self._VALID_FLOWS}." | |
| ) | |
| self.input_channels = input_channels | |
| self.latent_channels = latent_channels | |
| self.flow = flow | |
| self._einsum_equation = self._FLOW_TO_EINSUM[flow] | |
| self.norm_start = nn.LayerNorm(self.input_channels, eps=_EPS) | |
| self.norm_mix = nn.LayerNorm(self.latent_channels, eps=_EPS) | |
| self.proj_bundle = nn.Linear( | |
| self.input_channels, 4 * self.latent_channels, bias=False | |
| ) | |
| self.proj_emit = nn.Linear( | |
| self.latent_channels, self.input_channels, bias=False | |
| ) | |
| self.proj_gate = nn.Linear(self.input_channels, self.input_channels, bias=False) | |
| self._use_kernels: bool = False | |
| # Default chunked for memory on long sequences; tests override with | |
| # ``set_chunk_size(None)`` for the unchunked path under bit-exact bf16 | |
| # parity checks. | |
| self._chunk_size: int | None = 64 | |
| def set_chunk_size(self, chunk_size: int | None) -> None: | |
| self._chunk_size = chunk_size | |
| def split_kernel_weights(self) -> tuple[Tensor, Tensor]: | |
| return ( | |
| self.proj_bundle.weight[: 2 * self.latent_channels, :], | |
| self.proj_bundle.weight[2 * self.latent_channels :, :], | |
| ) | |
| def _kernel_flow_direction(self) -> str: | |
| return self.flow | |
| def _triangular_contract(self, left_stream: Tensor, right_stream: Tensor) -> Tensor: | |
| return torch.einsum(self._einsum_equation, left_stream, right_stream) | |
| def _triangular_contract_chunked( | |
| self, left_stream: Tensor, right_stream: Tensor, chunk_size: int | |
| ) -> Tensor: | |
| """Compute the triangular einsum in chunks along the output i-dimension.""" | |
| L = left_stream.shape[1] if self.flow == "outgoing" else left_stream.shape[2] | |
| chunks = [] | |
| for start in range(0, L, chunk_size): | |
| end = min(start + chunk_size, L) | |
| if self.flow == "outgoing": | |
| chunk = torch.einsum( | |
| self._einsum_equation, left_stream[:, start:end], right_stream | |
| ) | |
| else: | |
| chunk = torch.einsum( | |
| self._einsum_equation, left_stream[:, :, start:end], right_stream | |
| ) | |
| chunks.append(chunk) | |
| return torch.cat(chunks, dim=1) | |
| def forward(self, pair_grid: Tensor, visibility: Tensor | None = None) -> Tensor: | |
| if visibility is None: | |
| visibility = pair_grid.new_ones(pair_grid.shape[:-1]) | |
| if self._use_kernels: | |
| p_in_weight, g_in_weight = self.split_kernel_weights() | |
| try: | |
| return _cue_tri_mul( # type: ignore[misc] | |
| pair_grid, | |
| direction=self._kernel_flow_direction(), | |
| mask=visibility, | |
| norm_in_weight=self.norm_start.weight, | |
| norm_in_bias=self.norm_start.bias, | |
| p_in_weight=p_in_weight, | |
| g_in_weight=g_in_weight, | |
| norm_out_weight=self.norm_mix.weight, | |
| norm_out_bias=self.norm_mix.bias, | |
| p_out_weight=self.proj_emit.weight, | |
| g_out_weight=self.proj_gate.weight, | |
| eps=_EPS, | |
| ) | |
| except Exception as e: | |
| import logging as _logging | |
| _logging.getLogger(__name__).warning( | |
| "cuequivariance triangle_multiplicative_update kernel failed " | |
| "(flow=%s, shape=%s, dtype=%s); falling back to chunked einsum. " | |
| "Error: %s", | |
| self.flow, | |
| tuple(pair_grid.shape), | |
| pair_grid.dtype, | |
| e, | |
| ) | |
| normalized_grid = self.norm_start(pair_grid) | |
| bundled = self.proj_bundle(normalized_grid) | |
| signal, gate_logits = bundled.split(2 * self.latent_channels, dim=-1) | |
| routed = signal * torch.sigmoid(gate_logits) | |
| routed = routed * visibility.unsqueeze(-1) | |
| left_stream, right_stream = routed.float().chunk(2, dim=-1) | |
| if self._chunk_size is not None: | |
| contracted = self._triangular_contract_chunked( | |
| left_stream, right_stream, self._chunk_size | |
| ) | |
| else: | |
| contracted = self._triangular_contract(left_stream, right_stream) | |
| mixed = self.proj_emit(self.norm_mix(contracted)) | |
| output_gate = torch.sigmoid(self.proj_gate(normalized_grid)) | |
| return mixed * output_gate | |
| class TriangleMultiplicativeUpdate(nn.Module): | |
| """Thin wrapper exposing the triangular mixer with explicit orientation (v3).""" | |
| def __init__(self, dim: int = 128, _outgoing: bool = True) -> None: | |
| super().__init__() | |
| flow = "outgoing" if _outgoing else "incoming" | |
| self._engine = TriangleMultiplicativeBlock( | |
| input_channels=dim, latent_channels=dim, flow=flow | |
| ) | |
| def set_kernel_backend(self, backend: str | None) -> None: | |
| # Engine uses cueq when backend=="cuequivariance"; the "fused" backend | |
| # routes through the parent PairUpdateBlock's fused path (bypassing this). | |
| self._engine._use_kernels = backend == BACKEND_CUEQ | |
| if backend == BACKEND_CUEQ and not CUE_AVAILABLE: | |
| raise RuntimeError( | |
| "backend='cuequivariance' but cuequivariance_torch is not installed." | |
| ) | |
| def set_chunk_size(self, chunk_size: int | None) -> None: | |
| self._engine.set_chunk_size(chunk_size) | |
| def forward(self, z: Tensor, mask: Tensor | None = None) -> Tensor: | |
| return self._engine(z, visibility=mask) | |
| # =========================================================================== | |
| # FoldingTrunk: Transition, PairUpdateBlock, FoldingTrunk | |
| # =========================================================================== | |
| class Transition(nn.Module): | |
| """LN + SwiGLU FFN with addmm-fused residual; optional Triton LN+w12+SwiGLU kernel.""" | |
| def __init__(self, d_model: int, expansion_ratio: int = 4) -> None: | |
| super().__init__() | |
| self.norm = nn.LayerNorm(d_model) | |
| self.ffn = SwiGLUMLP(d_model, expansion_ratio=expansion_ratio, bias=False) | |
| # Default chunked; set_chunk_size(None) disables for bit-exact parity tests. | |
| self._chunk_size: int | None = 64 | |
| self._fused_swiglu: nn.Module | None = None | |
| self._kernel_backend: str | None = None | |
| def set_chunk_size(self, chunk_size: int | None) -> None: | |
| self._chunk_size = chunk_size | |
| def set_kernel_backend(self, backend: str | None) -> None: | |
| """Install / uninstall FusedLNLinearSwiGLU (no cueq equivalent).""" | |
| if backend not in _VALID_BACKENDS: | |
| raise ValueError( | |
| f"backend must be one of {_VALID_BACKENDS}, got {backend!r}" | |
| ) | |
| self._kernel_backend = backend | |
| if backend == BACKEND_FUSED and TRITON_KERNELS_AVAILABLE: | |
| assert _FusedLNLinearSwiGLU is not None | |
| d_model = self.norm.normalized_shape[0] | |
| d_inner = self.ffn.hidden_features | |
| has_ln_bias = self.norm.bias is not None | |
| device = self.ffn.w12.weight.device | |
| dtype = self.ffn.w12.weight.dtype | |
| fused = _FusedLNLinearSwiGLU( | |
| d_model=d_model, | |
| d_inner=d_inner, | |
| has_ln_bias=has_ln_bias, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| with torch.no_grad(): | |
| fused.LN_W.copy_(self.norm.weight) | |
| if has_ln_bias: | |
| fused.LN_B.copy_(self.norm.bias) # type: ignore[union-attr] | |
| # FusedLNLinearSwiGLU.W12 is (d_model, 2*d_inner); transpose nn.Linear once. | |
| fused.W12.copy_(self.ffn.w12.weight.t().contiguous()) | |
| self._fused_swiglu = fused.eval().requires_grad_(False) | |
| else: | |
| self._fused_swiglu = None | |
| def _can_use_fused_path(self, x: Tensor) -> bool: | |
| return ( | |
| _fused_active(self, x) | |
| and self._fused_swiglu is not None | |
| and x.dtype == torch.bfloat16 | |
| ) | |
| def _swiglu_pre_w3(self, x_normed: Tensor) -> Tensor: | |
| """SwiGLU through silu(x1)*x2, before the final w3.""" | |
| ffn = self.ffn | |
| x12 = ffn.w12(x_normed) | |
| x1, x2 = x12.split(ffn.hidden_features, dim=-1) | |
| return F.silu(x1) * x2 | |
| def _addmm_residual(self, x: Tensor, hidden: Tensor) -> Tensor: | |
| """x + w3(hidden) via single cuBLAS addmm — avoids transition-output allocation.""" | |
| ffn = self.ffn | |
| x_shape = x.shape | |
| out = torch.addmm( | |
| x.contiguous().view(-1, x_shape[-1]), | |
| hidden.view(-1, hidden.shape[-1]), | |
| ffn.w3.weight.t(), | |
| ) | |
| return out.view(x_shape) | |
| def forward(self, x: Tensor) -> Tensor: | |
| # Inference-only fast path (addmm-fused residual + pre-alloc out) | |
| # — diverges bit-exactly from ``x + ffn(norm(x))`` so we only use | |
| # it when grad is disabled (binder-design / bit-exact tests run | |
| # with grad on and need the reference path). | |
| if not torch.is_grad_enabled() and self._can_use_fused_path(x): | |
| fused = self._fused_swiglu | |
| assert fused is not None | |
| pre_w3 = fused | |
| if self._chunk_size is None or x.shape[1] <= self._chunk_size: | |
| hidden = pre_w3(x) | |
| return self._addmm_residual(x, hidden) | |
| out = torch.empty_like(x) | |
| for s in range(0, x.shape[1], self._chunk_size): | |
| e = min(s + self._chunk_size, x.shape[1]) | |
| sl = x[:, s:e] | |
| hidden = pre_w3(sl) | |
| out[:, s:e] = self._addmm_residual(sl, hidden) | |
| return out | |
| # Reference path — bit-exact with main: x + ffn(norm(x)). | |
| if self._chunk_size is None or x.shape[1] <= self._chunk_size: | |
| return x + self.ffn(self.norm(x)) | |
| out_list: list[Tensor] = [] | |
| for s in range(0, x.shape[1], self._chunk_size): | |
| e = min(s + self._chunk_size, x.shape[1]) | |
| sl = x[:, s:e] | |
| out_list.append(sl + self.ffn(self.norm(sl))) | |
| return torch.cat(out_list, dim=1) | |
| class PairUpdateBlock(nn.Module): | |
| """tri_mul_out, tri_mul_in, pair_transition.""" | |
| def __init__(self, d_pair: int = 256, expansion_ratio: int = 4) -> None: | |
| super().__init__() | |
| self.tri_mul_out = TriangleMultiplicativeUpdate(dim=d_pair, _outgoing=True) | |
| self.tri_mul_in = TriangleMultiplicativeUpdate(dim=d_pair, _outgoing=False) | |
| self.pair_transition = Transition(d_pair, expansion_ratio=expansion_ratio) | |
| self._kernel_backend: str | None = None | |
| # Row-shared dropout-residual; r=0 for inference (HF model is inference-only). | |
| # backend='fused' swaps in the FusedDropoutResidual Triton kernel. | |
| self.row_drop = DropoutResidual(0.0, batch_dim=1, use_fused_kernels=False) | |
| def set_kernel_backend(self, backend: str | None) -> None: | |
| if backend not in _VALID_BACKENDS: | |
| raise ValueError( | |
| f"backend must be one of {_VALID_BACKENDS}, got {backend!r}" | |
| ) | |
| self.tri_mul_out.set_kernel_backend(backend) | |
| self.tri_mul_in.set_kernel_backend(backend) | |
| self.pair_transition.set_kernel_backend(backend) | |
| self._kernel_backend = backend | |
| self.row_drop = DropoutResidual( | |
| 0.0, batch_dim=1, use_fused_kernels=(backend == BACKEND_FUSED) | |
| ) | |
| def set_chunk_size(self, chunk_size: int | None) -> None: | |
| self.tri_mul_out.set_chunk_size(chunk_size) | |
| self.tri_mul_in.set_chunk_size(chunk_size) | |
| self.pair_transition.set_chunk_size(chunk_size) | |
| def _can_use_fused_trimul_with_residual(self, pair: Tensor) -> bool: | |
| return _fused_active(self, pair) and pair.dtype == torch.bfloat16 | |
| def _fused_trimul_with_residual( | |
| self, pair: Tensor, direction: str, pair_attention_mask: Tensor | None | |
| ) -> Tensor: | |
| """Fused TriMul+residual call; weights from the corresponding engine.""" | |
| tri = self.tri_mul_out if direction == "outgoing" else self.tri_mul_in | |
| engine: TriangleMultiplicativeBlock = tri._engine # type: ignore[assignment] | |
| p_in_weight, g_in_weight = engine.split_kernel_weights() | |
| def _bf16(t: Tensor) -> Tensor: | |
| return t if t.dtype == torch.bfloat16 else t.to(torch.bfloat16) | |
| return _fused_trimul_with_residual( # type: ignore[misc] | |
| pair, | |
| direction, | |
| residual=pair, | |
| drop_mask=None, # inference: no dropout, matches internal's eval path | |
| norm_in_weight=_bf16(engine.norm_start.weight), | |
| norm_in_bias=_bf16(engine.norm_start.bias), | |
| p_in_weight=_bf16(p_in_weight), | |
| g_in_weight=_bf16(g_in_weight), | |
| norm_out_weight=_bf16(engine.norm_mix.weight), | |
| norm_out_bias=_bf16(engine.norm_mix.bias), | |
| p_out_weight=_bf16(engine.proj_emit.weight), | |
| g_out_weight=_bf16(engine.proj_gate.weight), | |
| mask=pair_attention_mask, | |
| eps=_EPS, | |
| ) | |
| def forward( | |
| self, pair: Tensor, pair_attention_mask: Tensor | None = None | |
| ) -> Tensor: | |
| if self._can_use_fused_trimul_with_residual(pair): | |
| pair = self._fused_trimul_with_residual( | |
| pair, "outgoing", pair_attention_mask | |
| ) | |
| pair = self._fused_trimul_with_residual( | |
| pair, "incoming", pair_attention_mask | |
| ) | |
| else: | |
| pair = self.row_drop(pair, self.tri_mul_out(pair, mask=pair_attention_mask)) | |
| pair = self.row_drop(pair, self.tri_mul_in(pair, mask=pair_attention_mask)) | |
| pair = self.pair_transition(pair) | |
| return pair | |
| class FoldingTrunk(nn.Module): | |
| """ModuleList of PairUpdateBlocks.""" | |
| def __init__( | |
| self, n_layers: int = 24, d_pair: int = 256, expansion_ratio: int = 4 | |
| ) -> None: | |
| super().__init__() | |
| self.blocks = nn.ModuleList( | |
| [ | |
| PairUpdateBlock(d_pair=d_pair, expansion_ratio=expansion_ratio) | |
| for _ in range(n_layers) | |
| ] | |
| ) | |
| def set_kernel_backend(self, backend: str | None) -> None: | |
| for block in self.blocks: | |
| cast(PairUpdateBlock, block).set_kernel_backend(backend) | |
| def set_chunk_size(self, chunk_size: int | None) -> None: | |
| for block in self.blocks: | |
| cast(PairUpdateBlock, block).set_chunk_size(chunk_size) | |
| def forward( | |
| self, pair: Tensor, pair_attention_mask: Tensor | None = None | |
| ) -> Tensor: | |
| # Cast pair → bf16 internally when the fused trimul backend is enabled | |
| # (its bwd kernel requires bf16). Other backends keep the input dtype. | |
| orig_dtype = pair.dtype | |
| fused_on = ( | |
| len(self.blocks) > 0 | |
| and getattr(self.blocks[0], "_kernel_backend", None) == BACKEND_FUSED | |
| ) | |
| if pair.is_cuda and fused_on and orig_dtype != torch.bfloat16: | |
| pair = pair.to(torch.bfloat16) | |
| for block in self.blocks: | |
| fn = partial(block, pair_attention_mask=pair_attention_mask) | |
| if torch.is_grad_enabled(): | |
| pair = checkpoint(fn, pair, use_reentrant=False) # pyright: ignore | |
| else: | |
| pair = fn(pair) | |
| if pair.dtype != orig_dtype: | |
| pair = pair.to(orig_dtype) | |
| return pair | |
| # =========================================================================== | |
| # MSA Encoder | |
| # =========================================================================== | |
| class OuterProductMean(nn.Module): | |
| """Outer-product mean: maps an MSA representation into a pair update. | |
| The order of the ``/ n_valid`` divide vs. the ``Wout`` projection is | |
| selectable via ``divide_outer_before_proj`` because different ESMFold2 | |
| checkpoints were trained with different orderings: | |
| * ``False`` (default): ``Wout(outer) / n_valid`` — the projection bias | |
| is scaled by 1/n_valid alongside the outer product. | |
| * ``True``: ``Wout(outer / n_valid)`` — the projection bias is added | |
| unscaled, post-divide. | |
| """ | |
| def __init__( | |
| self, | |
| d_msa: int, | |
| d_hidden: int, | |
| d_pair: int, | |
| divide_outer_before_proj: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| self.d_hidden = d_hidden | |
| self.divide_outer_before_proj = divide_outer_before_proj | |
| self.norm = nn.LayerNorm(d_msa) | |
| self.W = nn.Linear(d_msa, 2 * d_hidden, bias=False) | |
| self.Wout = nn.Linear(d_hidden * d_hidden, d_pair, bias=True) | |
| # Off for bit-exact bf16; ``set_chunk_size(64)`` for long sequences. | |
| self._chunk_size: int | None = None | |
| def set_chunk_size(self, chunk_size: int | None) -> None: | |
| self._chunk_size = chunk_size | |
| def forward(self, m: Tensor, msa_attention_mask: Tensor) -> Tensor: | |
| m_norm = self.norm(m) | |
| x = self.W(m_norm) * msa_attention_mask.unsqueeze(-1).to(m_norm.dtype) | |
| a, b = x.chunk(2, dim=-1) | |
| mask_f = msa_attention_mask.to(a.dtype) | |
| n_valid = (mask_f @ mask_f.transpose(-1, -2)).unsqueeze(-1).clamp(min=1.0) | |
| if self._chunk_size is None: | |
| outer = torch.einsum("bimc,bjmd->bijcd", a, b).flatten(-2) | |
| if self.divide_outer_before_proj: | |
| return self.Wout(outer / n_valid) | |
| return self.Wout(outer) / n_valid | |
| # Chunk along the left (i) axis so the peak einsum intermediate is | |
| # [B, chunk, L, c, d] instead of [B, L, L, c, d]. | |
| L = a.shape[1] | |
| out_chunks: list[Tensor] = [] | |
| for s in range(0, L, self._chunk_size): | |
| e = min(s + self._chunk_size, L) | |
| outer_chunk = torch.einsum("bimc,bjmd->bijcd", a[:, s:e], b).flatten(-2) | |
| if self.divide_outer_before_proj: | |
| out_chunks.append(self.Wout(outer_chunk / n_valid[:, s:e])) | |
| else: | |
| out_chunks.append(self.Wout(outer_chunk) / n_valid[:, s:e]) | |
| return torch.cat(out_chunks, dim=1) | |
| class MSAPairWeightedAveraging(nn.Module): | |
| """Pair-biased MSA row update (AF3 Supplement Algorithm 10).""" | |
| def __init__( | |
| self, d_msa: int, d_pair: int, n_heads: int = 8, head_width: int = 32 | |
| ) -> None: | |
| super().__init__() | |
| self.n_heads = n_heads | |
| self.head_width = head_width | |
| self.norm_single = nn.LayerNorm(d_msa) | |
| self.compute_bias = nn.Sequential( | |
| nn.LayerNorm(d_pair), nn.Linear(d_pair, n_heads, bias=False) | |
| ) | |
| self.Wv = nn.Linear(d_msa, n_heads * head_width, bias=False) | |
| self.Wgate = nn.Linear(d_msa, n_heads * head_width, bias=False) | |
| self.Wout = nn.Linear(n_heads * head_width, d_msa, bias=False) | |
| def forward( | |
| self, msa_repr: Tensor, pair_repr: Tensor, pair_attention_mask: Tensor | |
| ) -> Tensor: | |
| """ | |
| Args: | |
| msa_repr: [B, L, M, d_msa] | |
| pair_repr: [B, L, L, d_pair] | |
| pair_attention_mask:[B, L, L] | |
| Returns: | |
| [B, L, M, d_msa] | |
| """ | |
| B, L, M, _ = msa_repr.shape | |
| h, dh = self.n_heads, self.head_width | |
| msa_normed = self.norm_single(msa_repr) | |
| bias = self.compute_bias(pair_repr) # [B, L, L, n_heads] | |
| bias.masked_fill_(~pair_attention_mask.unsqueeze(-1).bool(), -1e5) | |
| attn = torch.softmax(bias, dim=-2) # softmax over j | |
| v = self.Wv(msa_normed).reshape(B, L, M, h, dh) | |
| gate = torch.sigmoid(self.Wgate(msa_normed)).reshape(B, L, M, h, dh) | |
| output = torch.einsum("bijh,bjmhd,bimhd->bimhd", attn, v, gate) | |
| return self.Wout(output.reshape(B, L, M, h * dh)) | |