ESMFold2-Fast / modeling_esmfold2.py
lhallee's picture
Upload folder using huggingface_hub
407875b verified
Raw
History Blame Contribute Delete
67.5 kB
"""PyTorch ESMFold2 model — the standard released architecture.
Quickstart::
from transformers import ESMFold2Model
model = ESMFold2Model.from_pretrained("biohub/ESMFold2").cuda().eval()
open("ubq.pdb", "w").write(model.infer_protein_as_pdb("MQIFVKTLTGKT..."))
For multi-chain, ligand, and MSA inputs, use ``model.input_types`` together
with ``model.fold(...)`` or ``model.prepare_structure_input(...)``.
"""
import importlib
import math
from contextlib import contextmanager
from pathlib import Path
from typing import Any, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
try:
te = importlib.import_module("transformer_engine.pytorch")
te_recipe = importlib.import_module("transformer_engine.common.recipe")
DelayedScaling = te_recipe.DelayedScaling
Format = te_recipe.Format
TE_AVAILABLE = True
except ImportError:
te = None # type: ignore[assignment]
DelayedScaling = None # type: ignore[assignment]
Format = None # type: ignore[assignment]
TE_AVAILABLE = False
from transformers.modeling_utils import PreTrainedModel
try:
from fastplms.test_time_training import FastPLMTestTimeTrainingMixin, TTTConfig
except ImportError:
from .test_time_training import FastPLMTestTimeTrainingMixin, TTTConfig
from .configuration_esmfold2 import ESMFold2Config, normalize_esmc_id
from .modeling_esmfold2_common import (
CHAR_VOCAB_SIZE,
MAX_ATOMIC_NUMBER,
NUM_RES_TYPES,
DiffusionStructureHead,
FoldingTrunk,
InputsEmbedder,
LanguageModelShim,
MSAPairWeightedAveraging,
OuterProductMean,
ResIdxAsymIdSymIdEntityIdEncoding,
RowAttentionPooling,
SwiGLUMLP,
TriangleMultiplicativeUpdate,
_categorical_mean,
_compute_intra_token_idx,
compute_lm_hidden_states,
gather_rep_atom_coords,
gather_token_to_atom,
maybe_apply_msa_column_masking,
maybe_subsample_msa,
)
from .esmfold2_affine3d import Affine3D as _FastPLMSESMFold2Affine3D
from .esmfold2_aligner import Aligner as _FastPLMSESMFold2Aligner
from .esmfold2_atom_indexer import AtomIndexer as _FastPLMSESMFold2AtomIndexer
from .esmfold2_conformers import load_ccd as _fastplms_esmfold2_load_ccd
from .esmfold2_constants import ELEMENT_NUMBER_TO_SYMBOL as _FASTPLMS_ESMFOLD2_ELEMENT_NUMBER_TO_SYMBOL
from .esmfold2_constants_esm3 import (
CHAIN_BREAK_STR as _FASTPLMS_ESMFOLD2_CHAIN_BREAK_STR,
SEQUENCE_BOS_TOKEN,
SEQUENCE_EOS_TOKEN,
SEQUENCE_MASK_TOKEN,
SEQUENCE_PAD_TOKEN,
SEQUENCE_STANDARD_AA_MAX_TOKEN,
SEQUENCE_STANDARD_AA_MIN_TOKEN,
SEQUENCE_VOCAB,
)
from .esmfold2_input_builder import StructurePredictionInput as _FastPLMSESMFold2StructurePredictionInput
from .esmfold2_metrics import compute_rmsd as _fastplms_esmfold2_compute_rmsd
from .esmfold2_misc import slice_any_object as _fastplms_esmfold2_slice_any_object
from .esmfold2_mmcif_parsing import MmcifWrapper as _FastPLMSESMFold2MmcifWrapper
from .esmfold2_molecular_complex import MolecularComplex as _FastPLMSESMFold2MolecularComplex
from .esmfold2_msa import MSA as _FastPLMSESMFold2MSA
from .esmfold2_msa_filter_sequences import greedy_select_indices as _fastplms_esmfold2_greedy_select_indices
from .esmfold2_normalize_coordinates import normalize_coordinates as _fastplms_esmfold2_normalize_coordinates
from .esmfold2_output import build_molecular_complex_from_features as _fastplms_esmfold2_build_molecular_complex_from_features
from .esmfold2_paired_msa import construct_paired_msa as _fastplms_esmfold2_construct_paired_msa
from .esmfold2_parsing import FastaEntry as _FastPLMSESMFold2FastaEntry
from .esmfold2_predicted_aligned_error import compute_tm as _fastplms_esmfold2_compute_tm
from .esmfold2_prepare_input import prepare_esmfold2_input as _fastplms_esmfold2_prepare_esmfold2_input
from .esmfold2_processor import ESMFold2InputBuilder as _FastPLMSESMFold2InputBuilder
from .esmfold2_protein_chain import ProteinChain as _FastPLMSESMFold2ProteinChain
from .esmfold2_protein_complex import ProteinComplex as _FastPLMSESMFold2ProteinComplex
from .esmfold2_protein_structure import index_by_atom_name as _fastplms_esmfold2_index_by_atom_name
from .esmfold2_residue_constants import restypes as _FASTPLMS_ESMFOLD2_RESTYPES
from .esmfold2_sequential_dataclass import SequentialDataclass as _FastPLMSESMFold2SequentialDataclass
from .esmfold2_system import run_subprocess_with_errorcheck as _fastplms_esmfold2_run_subprocess_with_errorcheck
from .esmfold2_types import ProteinInput as _FastPLMSESMFold2ProteinInput
from .esmfold2_utils_types import PathOrBuffer as _FastPLMSESMFold2PathOrBuffer
_EPS = 1e-6
_NONPOLYMER_ID = 4
# Default for the triangle / OPM / pair-transition L² ops. Caps peak memory
# so L≈2k folds on an 80 GB GPU (~76 GB peak at chunk=128 for L=1438;
# chunk=64 leaves headroom for the largest foldbench targets). Override via
# ``model.set_chunk_size(...)``; pass None to disable chunking (faster for
# short L but OOM-prone past ~600).
_DEFAULT_CHUNK_SIZE = 64
class _ESMFold2ESMplusplusAdapter(nn.Module):
def __init__(self, model: nn.Module) -> None:
super().__init__()
self.model = model
@property
def config(self):
return self.model.config
def forward(
self,
input_ids: Tensor,
attention_mask: Tensor | None = None,
sequence_id: Tensor | None = None,
output_hidden_states: bool | None = None,
output_attentions: bool | None = None,
return_dict: bool | None = None,
compute_sae: bool = True,
normalize_sae: bool = False,
):
del return_dict, compute_sae, normalize_sae
output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
sequence_id=sequence_id,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=True,
esmfold2_hidden_states=True,
)
if output_hidden_states:
hidden_states = output.hidden_states
assert hidden_states is not None, "ESM++ did not return hidden states."
if isinstance(hidden_states, torch.Tensor):
output.hidden_states = hidden_states
else:
output.hidden_states = torch.stack(tuple(hidden_states), dim=0)
return output
def _load_fastplms_esmplusplus_for_esmfold2(
esmc_model_path: str,
attn_backend: str,
device: torch.device,
dtype: torch.dtype,
) -> _ESMFold2ESMplusplusAdapter:
try:
from fastplms.esm_plusplus.modeling_esm_plusplus import (
ESMplusplusConfig,
ESMplusplusModel,
)
except ImportError:
from .modeling_esm_plusplus import ESMplusplusConfig, ESMplusplusModel
normalized_path = normalize_esmc_id(esmc_model_path)
esmc_config = ESMplusplusConfig.from_pretrained(normalized_path)
esmc_config.attn_backend = attn_backend
esmc = ESMplusplusModel.from_pretrained(
normalized_path,
config=esmc_config,
)
return _ESMFold2ESMplusplusAdapter(esmc).to(device=device, dtype=dtype).eval()
class PairTransition(nn.Module):
"""LayerNorm + SwiGLU feed-forward residual block on the pair representation."""
def __init__(self, d_model: int, expansion_ratio: int = 4) -> None:
super().__init__()
self.norm = nn.LayerNorm(d_model)
self.ffn = SwiGLUMLP(d_model, expansion_ratio=expansion_ratio, bias=False)
self._chunk_size: int | None = _DEFAULT_CHUNK_SIZE
def set_chunk_size(self, chunk_size: int | None) -> None:
self._chunk_size = chunk_size
def forward(self, x: Tensor) -> Tensor:
if self._chunk_size is None or x.shape[1] <= self._chunk_size:
return self.ffn(self.norm(x))
out: list[Tensor] = []
for s in range(0, x.shape[1], self._chunk_size):
e = min(s + self._chunk_size, x.shape[1])
sl = x[:, s:e]
out.append(self.ffn(self.norm(sl)))
return torch.cat(out, dim=1)
class ConfidenceHead(nn.Module):
"""Predicts pLDDT, PAE, PDE, resolved-atom probability and distogram bins."""
boundaries: Tensor
def __init__(self, config: "ESMFold2Config") -> None:
super().__init__()
ch = config.confidence_head
d_single = config.d_single
d_pair = config.d_pair
d_inputs = config.inputs.d_inputs
boundaries = torch.linspace(ch.min_dist, ch.max_dist, ch.distogram_bins - 1)
self.register_buffer("boundaries", boundaries)
self.dist_bin_pairwise_embed = nn.Embedding(ch.distogram_bins, d_pair)
self.s_norm = nn.LayerNorm(d_single)
self.s_inputs_to_single = nn.Linear(d_inputs, d_single, bias=False)
self.s_to_z = nn.Linear(d_inputs, d_pair, bias=False)
self.s_to_z_transpose = nn.Linear(d_inputs, d_pair, bias=False)
self.s_to_z_prod_in1 = nn.Linear(d_inputs, d_pair, bias=False)
self.s_to_z_prod_in2 = nn.Linear(d_inputs, d_pair, bias=False)
self.s_to_z_prod_out = nn.Linear(d_pair, d_pair, bias=False)
self.s_input_to_s = nn.Linear(d_inputs, d_single, bias=False)
self.s_inputs_norm = nn.LayerNorm(d_inputs)
self.z_norm = nn.LayerNorm(d_pair)
self.row_attention_pooling = RowAttentionPooling(
d_pair=d_pair, d_single=d_single
)
pf = ch.folding_trunk
self.folding_trunk = FoldingTrunk(
n_layers=pf.n_layers, d_pair=d_pair, expansion_ratio=4
)
# Heads.
self.plddt_ln = nn.LayerNorm(d_single)
max_atoms_per_token = 23
self.plddt_weight = nn.Parameter(
torch.zeros(max_atoms_per_token, d_single, ch.num_plddt_bins)
)
self.pae_ln = nn.LayerNorm(d_pair)
self.pae_head = nn.Linear(d_pair, ch.num_pae_bins, bias=False)
self.pde_ln = nn.LayerNorm(d_pair)
self.pde_head = nn.Linear(d_pair, ch.num_pde_bins, bias=False)
self.resolved_ln = nn.LayerNorm(d_single)
# 2 = resolved logits ([unresolved, resolved]).
self.resolved_weight = nn.Parameter(
torch.zeros(max_atoms_per_token, d_single, 2)
)
def set_kernel_backend(self, backend: str | None) -> None:
self.folding_trunk.set_kernel_backend(backend)
def set_chunk_size(self, chunk_size: int | None) -> None:
self.folding_trunk.set_chunk_size(chunk_size)
@staticmethod
def _repeat_batch(x: Tensor, num_diffusion_samples: int) -> Tensor:
return (
x
if num_diffusion_samples == 1
else x.repeat_interleave(num_diffusion_samples, 0)
)
@staticmethod
def _flatten_sample_axis(x: Tensor) -> Tensor:
if x.ndim == 4:
b, mult, n, c = x.shape
return x.reshape(b * mult, n, c)
return x
def forward(
self,
s_inputs: Tensor,
z: Tensor,
x_pred: Tensor,
distogram_atom_idx: Tensor,
token_attention_mask: Tensor,
atom_to_token: Tensor,
atom_attention_mask: Tensor,
asym_id: Tensor,
mol_type: Tensor,
num_diffusion_samples: int = 1,
relative_position_encoding: Tensor | None = None,
token_bonds_encoding: Tensor | None = None,
) -> dict[str, Tensor]:
s_inputs_normed = self.s_inputs_norm(s_inputs)
z_base = self.z_norm(z)
if relative_position_encoding is not None:
z_base = z_base + relative_position_encoding
if token_bonds_encoding is not None:
z_base = z_base + token_bonds_encoding
z_base = z_base + self.s_to_z(s_inputs_normed).unsqueeze(2)
z_base = z_base + self.s_to_z_transpose(s_inputs_normed).unsqueeze(1)
z_base = z_base + self.s_to_z_prod_out(
self.s_to_z_prod_in1(s_inputs_normed)[:, :, None, :]
* self.s_to_z_prod_in2(s_inputs_normed)[:, None, :, :]
)
pair = self._repeat_batch(z_base, num_diffusion_samples)
x_pred_flat = self._flatten_sample_axis(x_pred)
atom_to_token_m = self._repeat_batch(atom_to_token, num_diffusion_samples)
atom_mask_m = self._repeat_batch(atom_attention_mask, num_diffusion_samples)
rep_idx_m = self._repeat_batch(distogram_atom_idx, num_diffusion_samples).long()
mask = self._repeat_batch(token_attention_mask, num_diffusion_samples)
Bm = pair.shape[0]
rep_coords = gather_rep_atom_coords(x_pred_flat, rep_idx_m)
rep_distances = torch.cdist(
rep_coords, rep_coords, compute_mode="donot_use_mm_for_euclid_dist"
)
distogram_bins = (
(rep_distances.unsqueeze(-1) > self.boundaries).sum(dim=-1).long()
)
pair = pair + self.dist_bin_pairwise_embed(distogram_bins)
pair_mask = mask[:, :, None].float() * mask[:, None, :].float()
# FoldingTrunk handles the bf16 cast internally during inference so
# each block's fused trimul engages. In-place residual avoids an
# extra fp32 pair allocation.
with torch.amp.autocast("cuda", enabled=pair.is_cuda, dtype=torch.bfloat16):
pair_delta = self.folding_trunk(pair, pair_attention_mask=pair_mask)
pair.add_(pair_delta.float())
del pair_delta
single = self.row_attention_pooling(pair, mask)
atom_mask_f = atom_mask_m.float()
s_at_atoms = gather_token_to_atom(single, atom_to_token_m)
s_at_atoms_ln = self.plddt_ln(s_at_atoms)
intra_idx = _compute_intra_token_idx(atom_to_token_m)
intra_idx = intra_idx.clamp(max=self.plddt_weight.shape[0] - 1)
w_plddt = self.plddt_weight[intra_idx]
plddt_logits = torch.einsum("...c,...cb->...b", s_at_atoms_ln, w_plddt)
plddt_per_atom = _categorical_mean(plddt_logits, start=0.0, end=1.0)
L = single.shape[1]
plddt_sum = torch.zeros(Bm, L, device=single.device, dtype=plddt_per_atom.dtype)
atom_count = torch.zeros(
Bm, L, device=single.device, dtype=plddt_per_atom.dtype
)
atom_mask_t = atom_mask_f.to(plddt_per_atom.dtype)
plddt_sum.scatter_add_(1, atom_to_token_m, plddt_per_atom * atom_mask_t)
atom_count.scatter_add_(1, atom_to_token_m, atom_mask_t)
plddt = plddt_sum / atom_count.clamp(min=1e-6)
complex_plddt = (plddt_per_atom * atom_mask_f).sum(dim=-1) / (
atom_mask_f.sum(dim=-1) + _EPS
)
expanded_type = self._repeat_batch(mol_type, num_diffusion_samples)
expanded_asym = self._repeat_batch(asym_id, num_diffusion_samples)
is_ligand = (expanded_type == _NONPOLYMER_ID).float()
inter_chain = (
expanded_asym.unsqueeze(-1) != expanded_asym.unsqueeze(-2)
).float()
near_contact = (rep_distances < 8).float()
interface_per_token = (
near_contact * inter_chain * (1.0 - is_ligand).unsqueeze(-1)
).amax(dim=-1)
iplddt_weight = torch.where(
is_ligand.bool(),
torch.full_like(interface_per_token, 2.0),
interface_per_token,
)
iplddt_weight_atoms = gather_token_to_atom(
iplddt_weight.unsqueeze(-1), atom_to_token_m
).squeeze(-1)
atom_iplddt_w = atom_mask_f * iplddt_weight_atoms
complex_iplddt = (plddt_per_atom * atom_iplddt_w).sum(dim=-1) / (
atom_iplddt_w.sum(dim=-1) + _EPS
)
plddt_ca = plddt_per_atom.gather(1, rep_idx_m)
# PAE
pae_logits = self.pae_head(self.pae_ln(pair))
pae = _categorical_mean(pae_logits, start=0.0, end=32.0).detach()
# PDE
pde_logits = self.pde_head(self.pde_ln(pair))
pde = _categorical_mean(pde_logits, start=0.0, end=32.0).detach()
# Resolved (per-atom binary).
s_at_atoms_res = self.resolved_ln(s_at_atoms)
w_res = self.resolved_weight[intra_idx]
resolved_logits = torch.einsum("...c,...cb->...b", s_at_atoms_res, w_res)
# pTM / ipTM from pae_logits.
n_bins = pae_logits.shape[-1]
bin_width = 32.0 / n_bins
bin_centers = torch.arange(
0.5 * bin_width, 32.0, bin_width, device=pae_logits.device
)
mask_f = mask.float()
N_res = mask_f.sum(dim=-1, keepdim=True)
d0 = 1.24 * (N_res.clamp(min=19) - 15) ** (1 / 3) - 1.8
tm_per_bin = 1 / (1 + (bin_centers / d0) ** 2)
pae_probs = F.softmax(pae_logits, dim=-1)
tm_expected = (pae_probs * tm_per_bin[:, None, None, :]).sum(dim=-1)
pair_mask_2d = mask_f.unsqueeze(-1) * mask_f.unsqueeze(-2)
ptm_per_row = (tm_expected * pair_mask_2d).sum(dim=-1) / (
pair_mask_2d.sum(dim=-1) + _EPS
)
ptm = ptm_per_row.max(dim=-1).values
inter_chain_mask = (
expanded_asym.unsqueeze(-1) != expanded_asym.unsqueeze(-2)
).float() * pair_mask_2d
iptm_per_row = (tm_expected * inter_chain_mask).sum(dim=-1) / (
inter_chain_mask.sum(dim=-1) + _EPS
)
iptm = iptm_per_row.max(dim=-1).values
max_chain_id = int(expanded_asym.max().item()) if Bm > 0 else 0
n_chains = max_chain_id + 1
pair_chains_iptm = torch.zeros(
Bm, n_chains, n_chains, device=tm_expected.device, dtype=tm_expected.dtype
)
for c1 in range(n_chains):
chain_c1 = (expanded_asym == c1).float() * mask_f
if chain_c1.sum() == 0:
continue
for c2 in range(n_chains):
chain_c2 = (expanded_asym == c2).float() * mask_f
pair_m = chain_c1.unsqueeze(-1) * chain_c2.unsqueeze(-2)
denom = pair_m.sum(dim=(-1, -2)) + _EPS
pair_chains_iptm[:, c1, c2] = (tm_expected * pair_m).sum(
dim=(-1, -2)
) / denom
return {
"plddt_logits": plddt_logits,
"plddt": plddt.detach(),
"plddt_per_atom": plddt_per_atom.detach(),
"plddt_ca": plddt_ca.detach(),
"complex_plddt": complex_plddt.detach(),
"complex_iplddt": complex_iplddt.detach(),
"pae_logits": pae_logits,
"pae": pae,
"pde_logits": pde_logits,
"pde": pde,
"resolved_logits": resolved_logits,
"ptm": ptm.detach(),
"iptm": iptm.detach(),
"pair_chains_iptm": pair_chains_iptm.detach(),
}
def _inverse_softplus(value: float) -> float:
return value + math.log(-math.expm1(-value))
def _convert_te_modules_to_fp8_inplace(module: nn.Module) -> None:
"""Re-init each TE module via quantized_model_init so weights live as fp8.
Must be called inside torch.no_grad(); covers nn.Linear, te.Linear,
te.LayerNormLinear, te.LayerNormMLP — the last two hold 99% of ESMC weight.
"""
if not TE_AVAILABLE:
raise RuntimeError("transformer_engine is not available; cannot use fp8.")
quantized_model_init = importlib.import_module(
"transformer_engine.pytorch"
).quantized_model_init
def _walk(mod: nn.Module) -> None:
for name, child in list(mod.named_children()):
replaced = False
if isinstance(child, nn.Linear):
in_f, out_f = child.in_features, child.out_features
has_bias = child.bias is not None
device = child.weight.device
dtype = child.weight.dtype
w = child.weight.data
b = child.bias.data if has_bias else None
setattr(mod, name, nn.Identity())
del child
torch.cuda.empty_cache()
with quantized_model_init(enabled=True):
new_mod = te.Linear( # type: ignore[union-attr]
in_f, out_f, bias=has_bias, params_dtype=dtype
).to(device)
new_mod.weight.quantize_(w) # type: ignore[attr-defined,operator]
if has_bias:
assert b is not None
new_mod.bias.data.copy_(b) # type: ignore[union-attr]
del w, b
replaced = True
elif isinstance(child, te.Linear): # type: ignore[union-attr]
# te.Linear with bf16 weight → re-init inside quantized_model_init for fp8.
in_f, out_f = child.in_features, child.out_features
has_bias = child.bias is not None
device = child.weight.device
dtype = (
child.weight.dtype
if not hasattr(child.weight, "_data")
else torch.bfloat16
)
state = {k: v.detach().clone() for k, v in child.state_dict().items()}
setattr(mod, name, nn.Identity())
del child
torch.cuda.empty_cache()
with quantized_model_init(enabled=True):
new_mod = te.Linear( # type: ignore[union-attr]
in_f,
out_f,
bias=has_bias,
params_dtype=dtype, # type: ignore[arg-type]
).to(device) # type: ignore[arg-type]
new_mod.load_state_dict(state, strict=False)
replaced = True
elif (
hasattr(te, "LayerNormLinear") and isinstance(child, te.LayerNormLinear) # type: ignore[union-attr]
):
state = {k: v.detach().clone() for k, v in child.state_dict().items()}
hidden_size = child.in_features
out_features = child.out_features
has_bias = child.use_bias
device = next(child.parameters()).device
setattr(mod, name, nn.Identity())
del child
torch.cuda.empty_cache()
with quantized_model_init(enabled=True):
new_mod = te.LayerNormLinear( # type: ignore[union-attr]
hidden_size,
out_features,
bias=has_bias,
params_dtype=torch.bfloat16,
).to(device)
new_mod.load_state_dict(state, strict=False)
replaced = True
elif (
hasattr(te, "LayerNormMLP") and isinstance(child, te.LayerNormMLP) # type: ignore[union-attr]
):
state = {k: v.detach().clone() for k, v in child.state_dict().items()}
fc1_weight: Tensor = child.fc1_weight # type: ignore[attr-defined]
hidden_size = int(fc1_weight.shape[1])
# fc1 packed as (2*ffn_hidden_size, hidden_size) for swiglu.
ffn_hidden_size = int(fc1_weight.shape[0]) // 2
has_bias = (
getattr(child, "fc1_bias", None) is not None
and child.fc1_bias is not None # type: ignore[attr-defined]
)
device = fc1_weight.device
setattr(mod, name, nn.Identity())
del child
torch.cuda.empty_cache()
with quantized_model_init(enabled=True):
new_mod = te.LayerNormMLP( # type: ignore[union-attr]
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
bias=has_bias,
activation="swiglu",
params_dtype=torch.bfloat16,
).to(device) # type: ignore[arg-type]
new_mod.load_state_dict(state, strict=False)
replaced = True
if replaced:
# Freeze via .eval()+.requires_grad_(False); per-param ops would unwrap Float8Tensor.
new_mod.eval().requires_grad_(False)
setattr(mod, name, new_mod)
torch.cuda.empty_cache()
else:
_walk(child)
_walk(module)
torch.cuda.empty_cache()
@contextmanager
def _lm_precision_context(fp8: bool):
"""bf16 autocast (+ optional TE fp8 autocast) around the LM forward.
te.autocast keeps te.Linear outputs bf16 instead of the fp32 default
(~425 MB at L=1024 in the hidden-state cache).
"""
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
if fp8 and TE_AVAILABLE:
fp8_recipe = DelayedScaling( # type: ignore[misc]
fp8_format=Format.HYBRID, # type: ignore[union-attr]
amax_history_len=1,
amax_compute_algo="most_recent",
)
with te.autocast(enabled=True, recipe=fp8_recipe): # type: ignore[union-attr]
yield
else:
yield
class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel):
"""ESMFold2 — all-atom structure prediction with an ESMC PLM backbone.
This is the standard released ESMFold2 architecture (uses a linear-
recurrent trunk, internally referred to as "parcae").
Forward kwargs that callers commonly override:
* ``num_loops`` (default ``config.num_loops``): trunk refinement
loops.
* ``num_diffusion_samples`` (default ``config.num_diffusion_samples``):
parallel structure samples; the confidence head re-runs once per
sample, so memory scales linearly. Pass ``1`` for cheap inference.
* ``num_sampling_steps`` (default ``config.structure_head.inference_num_steps``):
diffusion ODE solver steps. Lower for speed, higher for quality.
Memory / perf knobs:
* ``model.set_chunk_size(int|None)``: caps L² ops (triangle / OPM /
pair transition) at this token-axis chunk. Default 64 — fits
L≈2k on an 80 GB GPU. Pass ``None`` for faster inference at L<600.
* ``model.set_kernel_backend(None | "fused" | "cuequivariance")``:
select kernel backend (None = reference path).
"""
config_class = ESMFold2Config
_keys_to_ignore_on_load_unexpected = [r"\._extra_state$"]
def __init__(self, config: ESMFold2Config) -> None:
super().__init__(config)
d_inputs = config.inputs.d_inputs
d_pair = config.d_pair
self.inputs_embedder = InputsEmbedder(config)
self.z_init_1 = nn.Linear(d_inputs, d_pair, bias=False)
self.z_init_2 = nn.Linear(d_inputs, d_pair, bias=False)
self.rel_pos = ResIdxAsymIdSymIdEntityIdEncoding(
n_relative_residx_bins=config.n_relative_residx_bins,
n_relative_chain_bins=config.n_relative_chain_bins,
d_pair=d_pair,
)
self.token_bonds = nn.Linear(1, d_pair, bias=False)
self.language_model = LanguageModelShim(
d_z=d_pair, d_model=config.lm_d_model, num_layers=config.lm_num_layers
)
self._esmc: nn.Module | None = None
self._esmc_fp8: bool = False
self._ttt_lm_head: nn.Module | None = None
self._esmfold2_input_builder: Any | None = None
pf = config.folding_trunk
self.folding_trunk = FoldingTrunk(
n_layers=pf.n_layers, d_pair=d_pair, expansion_ratio=4
)
if config.lm_encoder.enabled:
self.lm_encoder: FoldingTrunk | None = FoldingTrunk(
n_layers=config.lm_encoder.n_layers, d_pair=d_pair, expansion_ratio=4
)
else:
self.lm_encoder = None
self.parcae_input_norm = nn.LayerNorm(d_pair)
self.parcae_log_a = nn.Parameter(torch.zeros(d_pair))
parcae_decay_init = math.sqrt(1.0 / 5.0)
parcae_delta_init = -math.log(parcae_decay_init)
self.parcae_log_delta = nn.Parameter(
torch.full(
(d_pair,), _inverse_softplus(parcae_delta_init), dtype=torch.float32
)
)
self.parcae_b_cont = nn.Parameter(torch.eye(d_pair))
self.parcae_readout = nn.Linear(d_pair, d_pair, bias=False)
nn.init.eye_(self.parcae_readout.weight)
self.parcae_coda = FoldingTrunk(
n_layers=config.parcae.coda_n_layers, d_pair=d_pair, expansion_ratio=4
)
# Heads --------------------------------------------------------------
self.structure_head = DiffusionStructureHead(config)
self.distogram_head = nn.Linear(
d_pair, config.structure_head.distogram_bins, bias=True
)
self.confidence_head = ConfidenceHead(config)
msa_cfg = config.msa_encoder
self.msa_encoder = None
if msa_cfg.enabled:
self.msa_encoder = MSAEncoder(
d_msa=msa_cfg.d_msa,
d_pair=d_pair,
d_inputs=d_inputs,
d_hidden=msa_cfg.d_hidden,
n_layers=msa_cfg.n_layers,
n_heads_msa=msa_cfg.n_heads_msa,
msa_head_width=msa_cfg.msa_head_width,
)
self.post_init()
self.init_ttt({"lora_target_replace_module": "MultiHeadAttention"})
def load_esmc(self, esmc_model_path: str, precision: str = "bf16") -> None:
"""Load the FastPLMs ESM++ LM used as the ESMFold2 PLM backbone.
``precision``: ``"bf16"`` (default), ``"fp32"``, or opt-in ``"fp8"``.
"""
dtype_map = {
"bf16": torch.bfloat16,
"fp32": torch.float32,
"fp8": torch.bfloat16,
}
if precision not in dtype_map:
raise ValueError(f"precision must be one of {list(dtype_map)}, got {precision!r}")
if precision == "fp8" and not TE_AVAILABLE:
raise RuntimeError(
"esmc_precision='fp8' requires transformer_engine.pytorch."
)
dtype = dtype_map[precision]
esmc = _load_fastplms_esmplusplus_for_esmfold2(
esmc_model_path=esmc_model_path,
attn_backend=self.config.esmc_attn_backend,
device=self.device,
dtype=dtype,
)
assert esmc.config.hidden_size == self.config.lm_d_model, (
f"ESMFold2 expected lm_d_model={self.config.lm_d_model}, "
f"but loaded ESM++ hidden_size={esmc.config.hidden_size}."
)
assert esmc.config.num_hidden_layers == self.config.lm_num_layers, (
f"ESMFold2 expected lm_num_layers={self.config.lm_num_layers}, "
f"but loaded ESM++ num_hidden_layers={esmc.config.num_hidden_layers}."
)
for p in esmc.parameters():
p.requires_grad_(False)
if precision == "fp8":
with torch.no_grad():
_convert_te_modules_to_fp8_inplace(esmc)
self._esmc_fp8 = precision == "fp8"
self._esmc = esmc
self._ttt_lm_head = None
def _ensure_ttt_lm_head(self) -> None:
assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
if self._esmc_fp8:
raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
if self._ttt_lm_head is not None:
return
try:
from fastplms.esm_plusplus.modeling_esm_plusplus import (
ESMplusplusConfig,
ESMplusplusForMaskedLM,
)
except ImportError:
from .modeling_esm_plusplus import (
ESMplusplusConfig,
ESMplusplusForMaskedLM,
)
esmc_config = ESMplusplusConfig.from_pretrained(self.config.esmc_id)
esmc_config.attn_backend = self.config.esmc_attn_backend
mlm, loading_info = ESMplusplusForMaskedLM.from_pretrained(
self.config.esmc_id,
config=esmc_config,
output_loading_info=True,
)
missing_head_keys = [
key
for key in loading_info["missing_keys"]
if key.startswith("sequence_head")
]
assert len(missing_head_keys) == 0, (
f"ESMFold2 TTT could not load a pretrained ESM++ MLM head from "
f"{self.config.esmc_id}: missing {missing_head_keys}"
)
dtype = next(self._esmc.parameters()).dtype
mlm = mlm.to(device=self.device, dtype=dtype).eval()
self._ttt_lm_head = mlm.sequence_head
self._ttt_lm_head.requires_grad_(False)
del mlm
def _ttt_get_trainable_modules(self) -> list[nn.Module]:
assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
if self._esmc_fp8:
raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
return [self._esmc]
def _ttt_tokenize(
self,
seq: str | list[str] | None = None,
input_ids: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
del kwargs
if input_ids is not None:
return input_ids
assert seq is not None, "Pass either seq or input_ids for ESMFold2 TTT."
sequences = [seq] if isinstance(seq, str) else seq
token_to_id = {token: idx for idx, token in enumerate(SEQUENCE_VOCAB)}
encoded = []
for sequence in sequences:
token_ids = [SEQUENCE_BOS_TOKEN]
for amino_acid in sequence:
token_ids.append(
token_to_id[amino_acid if amino_acid in token_to_id else "X"]
)
token_ids.append(SEQUENCE_EOS_TOKEN)
encoded.append(token_ids)
max_len = max(len(token_ids) for token_ids in encoded)
input_tensor = torch.full(
(len(encoded), max_len),
SEQUENCE_PAD_TOKEN,
dtype=torch.long,
)
for row, token_ids in enumerate(encoded):
input_tensor[row, : len(token_ids)] = torch.tensor(
token_ids,
dtype=torch.long,
)
return input_tensor
def _ttt_mask_token(self) -> int:
return SEQUENCE_MASK_TOKEN
def _ttt_padding_token(self) -> int:
return SEQUENCE_PAD_TOKEN
def _ttt_replacement_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
return torch.arange(
SEQUENCE_STANDARD_AA_MIN_TOKEN,
SEQUENCE_STANDARD_AA_MAX_TOKEN,
device=input_ids.device,
dtype=input_ids.dtype,
)
def _ttt_non_special_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
return (input_ids >= SEQUENCE_STANDARD_AA_MIN_TOKEN) & (
input_ids < SEQUENCE_STANDARD_AA_MAX_TOKEN
)
def _ttt_predict_logits(
self,
batch: torch.Tensor | dict[str, torch.Tensor],
**kwargs,
) -> torch.Tensor:
del kwargs
assert isinstance(batch, torch.Tensor), (
"ESMFold2 TTT expects input_ids tensors."
)
assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
if self._esmc_fp8:
raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
self._ensure_ttt_lm_head()
assert self._ttt_lm_head is not None
attention_mask = batch.ne(SEQUENCE_PAD_TOKEN)
output = self._esmc(
input_ids=batch,
attention_mask=attention_mask,
return_dict=True,
compute_sae=False,
)
return self._ttt_lm_head(output.last_hidden_state)
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path, *args, load_esmc: bool = True, **kwargs
):
if cls is ESMFold2Model and "config" not in kwargs:
config = ESMFold2Config.from_pretrained(
pretrained_model_name_or_path, **kwargs
)
if config.type == "experimental":
raise ValueError(
"FastPLMs ESMFold2 supports the released ESMFold2 and "
"ESMFold2-Fast checkpoints. Experimental ESMFold2 configs "
"are not part of the self-contained AutoModel package."
)
kwargs["config"] = config
# Pop the precision knob before forwarding to the HF loader.
esmc_precision = kwargs.pop("esmc_precision", "bf16")
model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
if load_esmc:
model.load_esmc(model.config.esmc_id, precision=esmc_precision)
return model
def set_kernel_backend(self, backend: str | None) -> None:
"""Select kernel backend.
Args:
backend: ``None`` (reference path), ``"fused"`` (vendored Triton
kernels), or ``"cuequivariance"`` (cuequivariance kernels
where applicable; vanilla python fallback otherwise).
"""
self.folding_trunk.set_kernel_backend(backend)
if self.lm_encoder is not None:
self.lm_encoder.set_kernel_backend(backend)
self.parcae_coda.set_kernel_backend(backend)
self.confidence_head.set_kernel_backend(backend)
self.structure_head.set_kernel_backend(backend)
def apply_torch_compile(
self, mode: str = "fixed_seqlen", dynamic: bool | None = None
) -> None:
"""Compile L²-heavy blocks. ``mode='fixed_seqlen'`` recompiles per L; ``'dynamic_seqlen'`` compiles once.
Does NOT stack with our Triton kernels — call ``set_kernel_backend(None)``
before compiling.
"""
import torch._dynamo
torch._dynamo.config.cache_size_limit = 512 # type: ignore[attr-defined]
torch._dynamo.config.accumulated_cache_size_limit = 512 # type: ignore[attr-defined]
# capture_scalar_outputs avoids graph breaks at .item() in atom-attention path.
torch._dynamo.config.capture_scalar_outputs = True # type: ignore[attr-defined]
if dynamic is None:
dynamic = mode == "dynamic_seqlen"
kwargs: dict = {"dynamic": dynamic}
from .modeling_esmfold2_common import (
DiffusionModule,
DiffusionTransformer,
PairUpdateBlock,
)
compile_targets = (
PairUpdateBlock,
DiffusionTransformer,
DiffusionModule,
MSAEncoderBlock,
)
def _maybe_compile(module: nn.Module) -> None:
if isinstance(module, compile_targets):
module.forward = torch.compile(module.forward, **kwargs) # type: ignore[assignment]
self.apply(_maybe_compile)
def set_chunk_size(self, chunk_size: int | None) -> None:
self.folding_trunk.set_chunk_size(chunk_size)
if self.lm_encoder is not None:
self.lm_encoder.set_chunk_size(chunk_size)
self.parcae_coda.set_chunk_size(chunk_size)
self.confidence_head.set_chunk_size(chunk_size)
if self.msa_encoder is not None:
self.msa_encoder.set_chunk_size(chunk_size)
def _compute_lm_hidden_states(
self,
input_ids: Tensor,
asym_id: Tensor,
residue_index: Tensor,
mol_type: Tensor,
tok_mask: Tensor,
lm_mask_pct: float = 0.0,
) -> Tensor:
assert self._esmc is not None
# fp8 TE kernels require prod(shape[:-1]) % 8 == 0.
pad_to = 8 if self._esmc_fp8 else None
with _lm_precision_context(self._esmc_fp8):
return compute_lm_hidden_states(
self._esmc,
input_ids,
asym_id,
residue_index,
mol_type,
tok_mask,
pad_to_multiple=pad_to,
lm_mask_pct=lm_mask_pct,
mask_token_id=SEQUENCE_MASK_TOKEN,
)
def _discretized_dynamics(self) -> tuple[Tensor, Tensor]:
delta = F.softplus(self.parcae_log_delta)
a = torch.exp(-delta * torch.exp(self.parcae_log_a))
b = delta[:, None] * self.parcae_b_cont
return a, b
def _init_pair_state(self, ref: Tensor) -> Tensor:
std = math.sqrt(2.0 / (5.0 * ref.shape[-1]))
state = torch.empty_like(ref, dtype=torch.float32)
nn.init.trunc_normal_(state, mean=0.0, std=std, a=-3 * std, b=3 * std)
return state.to(dtype=ref.dtype)
def _run_one_loop(
self,
z: Tensor,
z_init: Tensor,
lm_z: Tensor | None,
_msa_inputs: dict | None,
pair_mask: Tensor,
a: Tensor,
b_mat: Tensor,
tok_mask: Tensor,
total_steps: int,
) -> Tensor:
# Helper method (not inline) so per-iter locals free on return —
# otherwise leaks ~2 GB L²×c_z into distogram/sample scope.
# training=True forces dropout under eval(), matching the per-loop
# dropout strategy used at train time.
lm_cfg = self.config.lm_encoder
_per_loop_lm_dropout = (
lm_z is not None
and getattr(lm_cfg, "per_loop_lm_dropout", False)
and getattr(lm_cfg, "lm_dropout", 0.0) > 0.0
)
_lm_dropout_p = getattr(lm_cfg, "lm_dropout", 0.0)
for _ in range(total_steps):
if _per_loop_lm_dropout:
assert lm_z is not None # narrowed by _per_loop_lm_dropout
lm_z_i: Tensor | None = F.dropout(lm_z, p=_lm_dropout_p, training=True)
else:
lm_z_i = lm_z
refined_lm_z: Tensor | None = None
if lm_z_i is not None and self.lm_encoder is not None:
refined_lm_z = self.lm_encoder(
lm_z_i.to(z_init.dtype), pair_attention_mask=pair_mask
)
z_inject_pair = z_init
if lm_z_i is not None and self.lm_encoder is None:
z_inject_pair = z_inject_pair + lm_z_i.to(z_inject_pair.dtype)
if self.msa_encoder is not None and _msa_inputs is not None:
msa_i, mask_i, hd_i, dv_i = maybe_subsample_msa(
_msa_inputs["msa"],
_msa_inputs["msa_attention_mask"],
_msa_inputs["has_deletion"],
_msa_inputs["deletion_value"],
max_depth=_msa_inputs["max_depth"],
enabled=_msa_inputs["subsample_enabled"],
)
B_msa, M, L_msa = msa_i.shape
msa_oh = F.one_hot(
msa_i.permute(0, 2, 1).long(), num_classes=NUM_RES_TYPES
).float()
msa_attn = (
mask_i.permute(0, 2, 1).float()
if mask_i is not None
else tok_mask[:, :, None].expand(-1, -1, M).float()
)
# Bias-free MSAEncoder.embed requires zeroed padding.
msa_oh = msa_oh * msa_attn.unsqueeze(-1)
hd = (
hd_i.permute(0, 2, 1).float()
if hd_i is not None
else torch.zeros(B_msa, L_msa, M, device=msa_i.device)
)
dv = (
dv_i.permute(0, 2, 1).float()
if dv_i is not None
else torch.zeros(B_msa, L_msa, M, device=msa_i.device)
)
msa_pair = self.msa_encoder(
x_pair=z_inject_pair,
x_inputs=_msa_inputs["x_inputs"],
msa_oh=msa_oh,
has_deletion=hd,
deletion_value=dv,
msa_attention_mask=msa_attn,
).to(z_inject_pair.dtype)
z_inject_pair = (
msa_pair
if self.config.msa_encoder_overwrite
else (z_inject_pair + msa_pair)
)
if refined_lm_z is not None:
z_inject_pair = z_inject_pair + refined_lm_z.to(z_inject_pair.dtype)
injected_pair = self.parcae_input_norm(z_inject_pair)
z = a * z + F.linear(injected_pair.to(z.dtype), b_mat)
z = self.folding_trunk(z, pair_attention_mask=pair_mask)
return z
@torch.inference_mode()
def forward(
self,
token_index: Tensor,
residue_index: Tensor,
asym_id: Tensor,
sym_id: Tensor,
entity_id: Tensor,
mol_type: Tensor,
res_type: Tensor,
token_bonds: Tensor,
token_attention_mask: Tensor,
ref_pos: Tensor,
ref_element: Tensor,
ref_charge: Tensor,
ref_atom_name_chars: Tensor,
ref_space_uid: Tensor,
atom_attention_mask: Tensor,
atom_to_token: Tensor,
distogram_atom_idx: Tensor,
deletion_mean: Tensor | None = None,
msa: Tensor | None = None,
has_deletion: Tensor | None = None,
deletion_value: Tensor | None = None,
msa_attention_mask: Tensor | None = None,
input_ids: Tensor | None = None,
lm_hidden_states: Tensor | None = None,
num_loops: int | None = None,
num_diffusion_samples: int | None = None,
num_sampling_steps: int | None = None,
lm_mask_pct: float | None = None,
msa_max_depth: int = 1024,
msa_column_mask_rate: float = 0.1,
msa_subsample_at_inference: bool = True,
**kwargs,
) -> dict[str, Tensor]:
tok_mask = token_attention_mask
atm_mask = atom_attention_mask
disto_idx = distogram_atom_idx
n_loops: int = num_loops if num_loops is not None else self.config.num_loops
n_samples: int = (
num_diffusion_samples
if num_diffusion_samples is not None
else self.config.num_diffusion_samples
)
total_steps = max(1, n_loops + 1)
if res_type.dim() == 2:
res_type_oh = F.one_hot(res_type.long(), num_classes=NUM_RES_TYPES).float()
res_type_oh = res_type_oh * tok_mask.unsqueeze(-1).float()
else:
res_type_oh = res_type.float()
if msa is not None:
msa_oh_profile = F.one_hot(msa.long(), num_classes=NUM_RES_TYPES).float()
if msa_attention_mask is not None:
mask_f = msa_attention_mask.float().unsqueeze(-1)
msa_oh_profile = msa_oh_profile * mask_f
valid_seq_count = msa_attention_mask.float().sum(dim=1).clamp(min=1)
profile = msa_oh_profile.sum(dim=1) / valid_seq_count.unsqueeze(-1)
else:
profile = msa_oh_profile.mean(dim=1)
else:
profile = res_type_oh
if deletion_mean is None:
deletion_mean = torch.zeros(
res_type.shape[0], res_type.shape[1], device=res_type.device
)
ref_element_oh = F.one_hot(
ref_element.long(), num_classes=MAX_ATOMIC_NUMBER
).float()
ref_atom_name_chars_oh = F.one_hot(
ref_atom_name_chars.long(), num_classes=CHAR_VOCAB_SIZE
).float()
# Bias-free downstream Linears require zeroed padding.
atm_mask_f = atm_mask.float()
ref_element_oh = ref_element_oh * atm_mask_f.unsqueeze(-1)
ref_atom_name_chars_oh = ref_atom_name_chars_oh * atm_mask_f.unsqueeze(
-1
).unsqueeze(-1)
atom_to_token = atom_to_token * atm_mask.long()
use_amp = ref_pos.device.type == "cuda"
with torch.amp.autocast("cuda", enabled=use_amp, dtype=torch.bfloat16):
x_inputs = self.inputs_embedder(
aatype=res_type_oh,
profile=profile.float(),
deletion_mean=deletion_mean.float(),
ref_pos=ref_pos,
atom_attention_mask=atm_mask,
ref_space_uid=ref_space_uid,
ref_charge=ref_charge,
ref_element=ref_element_oh,
ref_atom_name_chars=ref_atom_name_chars_oh,
atom_to_token=atom_to_token,
)
z_init = self.z_init_1(x_inputs).unsqueeze(2) + self.z_init_2(
x_inputs
).unsqueeze(1)
relative_position_encoding = self.rel_pos(
residue_index=residue_index,
asym_id=asym_id,
sym_id=sym_id,
entity_id=entity_id,
token_index=token_index,
)
token_bonds_encoding = self.token_bonds(token_bonds.float())
z_init = z_init + relative_position_encoding + token_bonds_encoding
if (
lm_hidden_states is None
and input_ids is not None
and self._esmc is not None
):
lm_hidden_states = self._compute_lm_hidden_states(
input_ids,
asym_id,
residue_index,
mol_type,
tok_mask,
lm_mask_pct=(
self.config.lm_mask_pct
if lm_mask_pct is None
else lm_mask_pct
),
)
lm_z: Tensor | None = None
if lm_hidden_states is not None:
lm_z = self.language_model(lm_hidden_states.detach())
del lm_hidden_states
pair_mask = tok_mask[:, :, None].float() * tok_mask[:, None, :].float()
z = self._init_pair_state(z_init)
a, b = self._discretized_dynamics()
a = a.view(1, 1, 1, -1).to(device=z.device, dtype=z.dtype)
b_mat = b.to(device=z.device, dtype=z.dtype)
_msa_inputs: dict | None = None
if self.msa_encoder is not None and msa is not None:
msa_attention_mask = maybe_apply_msa_column_masking(
msa_attention_mask,
msa_column_mask_rate,
)
_msa_inputs = dict(
x_inputs=x_inputs,
msa=msa,
msa_attention_mask=msa_attention_mask,
has_deletion=has_deletion,
deletion_value=deletion_value,
max_depth=msa_max_depth,
subsample_enabled=msa_subsample_at_inference,
)
# Method call (not inline loop) frees per-iter L²×c_z locals.
z = self._run_one_loop(
z=z,
z_init=z_init,
lm_z=lm_z,
_msa_inputs=_msa_inputs,
pair_mask=pair_mask,
a=a,
b_mat=b_mat,
tok_mask=tok_mask,
total_steps=total_steps,
)
del z_init, lm_z, _msa_inputs, a, b_mat
z = self.parcae_readout(z)
z = self.parcae_coda(z, pair_attention_mask=pair_mask)
z = z.float()
distogram_logits = self.distogram_head(z + z.transpose(-2, -3))
structure_output = self.structure_head.sample(
z_trunk=z,
s_inputs=x_inputs,
s_trunk=None,
relative_position_encoding=relative_position_encoding,
ref_pos=ref_pos,
ref_charge=ref_charge,
ref_mask=atm_mask,
ref_element=ref_element_oh,
ref_atom_name_chars=ref_atom_name_chars_oh,
ref_space_uid=ref_space_uid,
tok_idx=atom_to_token,
asym_id=asym_id,
residue_index=residue_index,
entity_id=entity_id,
token_index=token_index,
sym_id=sym_id,
token_attention_mask=tok_mask,
num_diffusion_samples=n_samples,
num_sampling_steps=num_sampling_steps,
return_atom_repr=False,
denoising_early_exit_rmsd=None,
)
sample_coords = structure_output["sample_atom_coords"]
assert sample_coords is not None
output: dict[str, Tensor] = {"distogram_logits": distogram_logits}
output["sample_atom_coords"] = sample_coords
confidence_output = self.confidence_head(
s_inputs=x_inputs.detach(),
z=z.detach().float(),
x_pred=sample_coords.detach(),
distogram_atom_idx=disto_idx,
token_attention_mask=tok_mask,
atom_to_token=atom_to_token,
atom_attention_mask=atm_mask,
asym_id=asym_id,
mol_type=mol_type,
num_diffusion_samples=n_samples,
relative_position_encoding=relative_position_encoding.detach(),
token_bonds_encoding=token_bonds_encoding.detach(),
)
output.update(confidence_output)
output["atom_pad_mask"] = (
atm_mask.unsqueeze(0) if atm_mask.dim() == 1 else atm_mask
)
output["residue_index"] = residue_index
output["entity_id"] = entity_id
return output
@torch.no_grad()
def infer_protein(self, seq: str, **forward_kwargs) -> dict:
from .protein_utils import prepare_protein_features
features = prepare_protein_features(seq)
features = {k: v.to(self.device) for k, v in features.items()}
return self(**features, **forward_kwargs)
@property
def input_builder(self):
if self._esmfold2_input_builder is None:
from .esmfold2_processor import ESMFold2InputBuilder
self._esmfold2_input_builder = ESMFold2InputBuilder()
return self._esmfold2_input_builder
@property
def input_types(self):
from . import esmfold2_types
return esmfold2_types
def prepare_structure_input(self, input, seed: int | None = None):
return self.input_builder.prepare_input(input, seed=seed, device=self.device)
def fold(
self,
input,
*,
num_loops: int = 3,
num_sampling_steps: int = 50,
num_diffusion_samples: int = 1,
seed: int | None = None,
noise_scale: float | None = None,
step_scale: float | None = None,
max_inference_sigma: int | None = None,
early_exit: bool = False,
complex_id: str = "pred",
):
return self.input_builder.fold(
self,
input,
num_loops=num_loops,
num_sampling_steps=num_sampling_steps,
num_diffusion_samples=num_diffusion_samples,
seed=seed,
noise_scale=noise_scale,
step_scale=step_scale,
max_inference_sigma=max_inference_sigma,
early_exit=early_exit,
complex_id=complex_id,
)
def _fold_protein_no_ttt(
self,
sequence: str,
*,
chain_id: str = "A",
msa: Any | None = None,
msa_path: str | Path | None = None,
msa_max_sequences: int | None = None,
num_loops: int = 3,
num_sampling_steps: int = 50,
num_diffusion_samples: int = 1,
seed: int | None = None,
complex_id: str = "pred",
):
from .esmfold2_types import MSA, ProteinInput, StructurePredictionInput
assert not (
msa is not None and msa_path is not None
), "Pass at most one of msa or msa_path."
if msa_path is not None:
msa = MSA.from_a3m(msa_path, max_sequences=msa_max_sequences)
if msa is not None:
query = str(msa.query).replace("-", "").upper()
assert query == sequence.upper(), (
f"MSA query does not match sequence: expected {sequence.upper()!r}, got {query!r}"
)
input = StructurePredictionInput(
sequences=[ProteinInput(id=chain_id, sequence=sequence, msa=msa)]
)
return self.fold(
input,
num_loops=num_loops,
num_sampling_steps=num_sampling_steps,
num_diffusion_samples=num_diffusion_samples,
seed=seed,
complex_id=complex_id,
)
@staticmethod
def _ttt_mean_plddt(result) -> float:
assert result.plddt is not None, "ESMFold2 result has no pLDDT tensor."
return float(result.plddt.float().mean().item())
def _ttt_select_result(self, result):
if isinstance(result, list):
assert len(result) > 0, "ESMFold2 fold returned an empty result list."
return max(result, key=self._ttt_mean_plddt)
return result
def _ttt_eval_step(
self,
step: int,
loss: float,
seq: str | list[str] | None = None,
input_ids: torch.Tensor | None = None,
**kwargs,
) -> tuple[dict[str, Any], float | None]:
del input_ids
assert isinstance(seq, str), (
"ESMFold2 fold TTT is protein-only and sequence-string only."
)
fold_kwargs = kwargs["fold_kwargs"]
was_training = self.training
self.eval()
try:
result = self._fold_protein_no_ttt(seq, **fold_kwargs)
finally:
self.train(was_training)
selected = self._ttt_select_result(result)
plddt = self._ttt_mean_plddt(selected)
return {
"step": step,
"loss": loss,
"plddt": plddt,
"result": selected,
}, plddt
def fold_protein(
self,
sequence: str,
*,
chain_id: str = "A",
msa: Any | None = None,
msa_path: str | Path | None = None,
msa_max_sequences: int | None = None,
num_loops: int = 3,
num_sampling_steps: int = 50,
num_diffusion_samples: int = 1,
seed: int | None = None,
complex_id: str = "pred",
ttt: bool = False,
ttt_config: TTTConfig | dict[str, Any] | None = None,
):
if ttt:
return self.fold_protein_ttt(
sequence=sequence,
chain_id=chain_id,
msa=msa,
msa_path=msa_path,
msa_max_sequences=msa_max_sequences,
num_loops=num_loops,
num_sampling_steps=num_sampling_steps,
num_diffusion_samples=num_diffusion_samples,
seed=seed,
complex_id=complex_id,
ttt_config=ttt_config,
)
return self._fold_protein_no_ttt(
sequence=sequence,
chain_id=chain_id,
msa=msa,
msa_path=msa_path,
msa_max_sequences=msa_max_sequences,
num_loops=num_loops,
num_sampling_steps=num_sampling_steps,
num_diffusion_samples=num_diffusion_samples,
seed=seed,
complex_id=complex_id,
)
def fold_protein_ttt(
self,
sequence: str,
*,
chain_id: str = "A",
msa: Any | None = None,
msa_path: str | Path | None = None,
msa_max_sequences: int | None = None,
num_loops: int = 3,
num_sampling_steps: int = 50,
num_diffusion_samples: int = 1,
seed: int | None = None,
complex_id: str = "pred",
ttt_config: TTTConfig | dict[str, Any] | None = None,
):
assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True."
if self._esmc_fp8:
raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.")
fold_kwargs = {
"chain_id": chain_id,
"msa": msa,
"msa_path": msa_path,
"msa_max_sequences": msa_max_sequences,
"num_loops": num_loops,
"num_sampling_steps": num_sampling_steps,
"num_diffusion_samples": num_diffusion_samples,
"seed": seed,
"complex_id": complex_id,
}
baseline = self._ttt_select_result(
self._fold_protein_no_ttt(sequence, **fold_kwargs)
)
baseline_plddt = self._ttt_mean_plddt(baseline)
best_result = baseline
best_plddt = baseline_plddt
best_step = 0
step_plddts = [baseline_plddt]
cfg = self.ttt_config.merged(ttt_config).merged(
{"eval_each_step": True, "automatic_best_state_reset": False}
)
try:
metrics = self.ttt(
seq=sequence,
ttt_config=cfg,
fold_kwargs=fold_kwargs,
)
for step_metric in metrics["step_metrics"]:
step_plddt = step_metric["plddt"]
step_plddts.append(step_plddt)
if step_plddt > best_plddt:
best_plddt = step_plddt
best_step = step_metric["step"]
best_result = step_metric["result"]
best_result.ttt_metrics = {
"losses": metrics["losses"],
"step_plddts": step_plddts,
"baseline_plddt": baseline_plddt,
"best_plddt": best_plddt,
"best_step": best_step,
}
return best_result
finally:
if "_ttt_initialized" in self.__dict__ and self._ttt_initialized:
self.ttt_reset()
@staticmethod
def result_to_cif(result) -> str:
assert not isinstance(result, list), "Pass one MolecularComplexResult at a time."
return result.complex.to_mmcif()
@staticmethod
def result_to_pdb(result) -> str:
assert not isinstance(result, list), "Pass one MolecularComplexResult at a time."
return result.complex.to_protein_complex().to_pdb_string()
def save_as_cif(self, result, output_path: str | Path) -> None:
Path(output_path).write_text(self.result_to_cif(result))
def save_as_pdb(self, result, output_path: str | Path) -> None:
Path(output_path).write_text(self.result_to_pdb(result))
def infer_protein_as_cif(self, seq: str, **forward_kwargs) -> str:
return self.result_to_cif(self.fold_protein(seq, **forward_kwargs))
def infer_protein_as_pdb(self, seq: str, **forward_kwargs) -> str:
return self.result_to_pdb(self.fold_protein(seq, **forward_kwargs))
class MSAEncoderBlock(nn.Module):
"""One MSA encoder block: OPM into pair, MSA pair-weighted averaging, triangle update."""
def __init__(
self,
d_msa: int,
d_pair: int,
d_hidden: int,
n_heads_msa: int,
msa_head_width: int,
is_final_block: bool = False,
) -> None:
super().__init__()
self.is_final_block = is_final_block
self.outer_product_mean = OuterProductMean(d_msa, d_hidden, d_pair)
if not is_final_block:
self.msa_pair_weighted_averaging = MSAPairWeightedAveraging(
d_msa, d_pair, n_heads_msa, msa_head_width
)
self.msa_transition = PairTransition(d_msa, expansion_ratio=4)
self.tri_mul_out = TriangleMultiplicativeUpdate(dim=d_pair, _outgoing=True)
self.tri_mul_in = TriangleMultiplicativeUpdate(dim=d_pair, _outgoing=False)
self.pair_transition = PairTransition(d_pair, expansion_ratio=4)
def set_chunk_size(self, chunk_size: int | None) -> None:
self.outer_product_mean.set_chunk_size(chunk_size)
self.tri_mul_out.set_chunk_size(chunk_size)
self.tri_mul_in.set_chunk_size(chunk_size)
if not self.is_final_block:
self.msa_transition.set_chunk_size(chunk_size)
self.pair_transition.set_chunk_size(chunk_size)
def forward(
self,
m: Tensor,
pair: Tensor,
msa_attention_mask: Tensor,
pair_attention_mask: Tensor,
) -> tuple[Tensor, Tensor]:
pair = pair + self.outer_product_mean(m, msa_attention_mask)
if not self.is_final_block:
m = m + self.msa_pair_weighted_averaging(m, pair, pair_attention_mask)
m = m + self.msa_transition(m)
pair = pair + self.tri_mul_out(pair, mask=pair_attention_mask)
pair = pair + self.tri_mul_in(pair, mask=pair_attention_mask)
pair = pair + self.pair_transition(pair)
return m, pair
class MSAEncoder(nn.Module):
"""Stack of [`MSAEncoderBlock`] layers that conditions the pair on an MSA."""
def __init__(
self,
d_msa: int,
d_pair: int,
d_inputs: int,
d_hidden: int = 32,
n_layers: int = 4,
n_heads_msa: int = 8,
msa_head_width: int = 16,
) -> None:
super().__init__()
self.embed = nn.Linear(35, d_msa, bias=False)
self.project_inputs = nn.Linear(d_inputs, d_msa, bias=False)
self.blocks = nn.ModuleList(
[
MSAEncoderBlock(
d_msa=d_msa,
d_pair=d_pair,
d_hidden=d_hidden,
n_heads_msa=n_heads_msa,
msa_head_width=msa_head_width,
is_final_block=(i == n_layers - 1),
)
for i in range(n_layers)
]
)
def set_chunk_size(self, chunk_size: int | None) -> None:
for block in self.blocks:
cast(MSAEncoderBlock, block).set_chunk_size(chunk_size)
def forward(
self,
x_pair: Tensor,
x_inputs: Tensor,
msa_oh: Tensor,
has_deletion: Tensor,
deletion_value: Tensor,
msa_attention_mask: Tensor,
) -> Tensor:
# All inputs are pre-transposed to [B, L, M, ...] before calling.
m_feat = torch.cat(
[msa_oh, has_deletion.unsqueeze(-1), deletion_value.unsqueeze(-1)], dim=-1
)
m = self.embed(m_feat) + self.project_inputs(x_inputs).unsqueeze(2)
tok_mask = msa_attention_mask[:, :, 0].bool()
pair_attention_mask = tok_mask.unsqueeze(2) & tok_mask.unsqueeze(1)
for block in self.blocks:
m, x_pair = block(m, x_pair, msa_attention_mask, pair_attention_mask)
return x_pair