"""PyTorch ESMFold2 model — the standard released architecture. Quickstart:: from transformers import ESMFold2Model model = ESMFold2Model.from_pretrained("biohub/ESMFold2").cuda().eval() open("ubq.pdb", "w").write(model.infer_protein_as_pdb("MQIFVKTLTGKT...")) For multi-chain, ligand, and MSA inputs, use ``model.input_types`` together with ``model.fold(...)`` or ``model.prepare_structure_input(...)``. """ import importlib import math from contextlib import contextmanager from pathlib import Path from typing import Any, cast import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor try: te = importlib.import_module("transformer_engine.pytorch") te_recipe = importlib.import_module("transformer_engine.common.recipe") DelayedScaling = te_recipe.DelayedScaling Format = te_recipe.Format TE_AVAILABLE = True except ImportError: te = None # type: ignore[assignment] DelayedScaling = None # type: ignore[assignment] Format = None # type: ignore[assignment] TE_AVAILABLE = False from transformers.modeling_utils import PreTrainedModel try: from fastplms.test_time_training import FastPLMTestTimeTrainingMixin, TTTConfig except ImportError: from .test_time_training import FastPLMTestTimeTrainingMixin, TTTConfig from .configuration_esmfold2 import ESMFold2Config, normalize_esmc_id from .modeling_esmfold2_common import ( CHAR_VOCAB_SIZE, MAX_ATOMIC_NUMBER, NUM_RES_TYPES, DiffusionStructureHead, FoldingTrunk, InputsEmbedder, LanguageModelShim, MSAPairWeightedAveraging, OuterProductMean, ResIdxAsymIdSymIdEntityIdEncoding, RowAttentionPooling, SwiGLUMLP, TriangleMultiplicativeUpdate, _categorical_mean, _compute_intra_token_idx, compute_lm_hidden_states, gather_rep_atom_coords, gather_token_to_atom, maybe_apply_msa_column_masking, maybe_subsample_msa, ) from .esmfold2_affine3d import Affine3D as _FastPLMSESMFold2Affine3D from .esmfold2_aligner import Aligner as _FastPLMSESMFold2Aligner from .esmfold2_atom_indexer import AtomIndexer as _FastPLMSESMFold2AtomIndexer from .esmfold2_conformers import load_ccd as _fastplms_esmfold2_load_ccd from .esmfold2_constants import ELEMENT_NUMBER_TO_SYMBOL as _FASTPLMS_ESMFOLD2_ELEMENT_NUMBER_TO_SYMBOL from .esmfold2_constants_esm3 import ( CHAIN_BREAK_STR as _FASTPLMS_ESMFOLD2_CHAIN_BREAK_STR, SEQUENCE_BOS_TOKEN, SEQUENCE_EOS_TOKEN, SEQUENCE_MASK_TOKEN, SEQUENCE_PAD_TOKEN, SEQUENCE_STANDARD_AA_MAX_TOKEN, SEQUENCE_STANDARD_AA_MIN_TOKEN, SEQUENCE_VOCAB, ) from .esmfold2_input_builder import StructurePredictionInput as _FastPLMSESMFold2StructurePredictionInput from .esmfold2_metrics import compute_rmsd as _fastplms_esmfold2_compute_rmsd from .esmfold2_misc import slice_any_object as _fastplms_esmfold2_slice_any_object from .esmfold2_mmcif_parsing import MmcifWrapper as _FastPLMSESMFold2MmcifWrapper from .esmfold2_molecular_complex import MolecularComplex as _FastPLMSESMFold2MolecularComplex from .esmfold2_msa import MSA as _FastPLMSESMFold2MSA from .esmfold2_msa_filter_sequences import greedy_select_indices as _fastplms_esmfold2_greedy_select_indices from .esmfold2_normalize_coordinates import normalize_coordinates as _fastplms_esmfold2_normalize_coordinates from .esmfold2_output import build_molecular_complex_from_features as _fastplms_esmfold2_build_molecular_complex_from_features from .esmfold2_paired_msa import construct_paired_msa as _fastplms_esmfold2_construct_paired_msa from .esmfold2_parsing import FastaEntry as _FastPLMSESMFold2FastaEntry from .esmfold2_predicted_aligned_error import compute_tm as _fastplms_esmfold2_compute_tm from .esmfold2_prepare_input import prepare_esmfold2_input as _fastplms_esmfold2_prepare_esmfold2_input from .esmfold2_processor import ESMFold2InputBuilder as _FastPLMSESMFold2InputBuilder from .esmfold2_protein_chain import ProteinChain as _FastPLMSESMFold2ProteinChain from .esmfold2_protein_complex import ProteinComplex as _FastPLMSESMFold2ProteinComplex from .esmfold2_protein_structure import index_by_atom_name as _fastplms_esmfold2_index_by_atom_name from .esmfold2_residue_constants import restypes as _FASTPLMS_ESMFOLD2_RESTYPES from .esmfold2_sequential_dataclass import SequentialDataclass as _FastPLMSESMFold2SequentialDataclass from .esmfold2_system import run_subprocess_with_errorcheck as _fastplms_esmfold2_run_subprocess_with_errorcheck from .esmfold2_types import ProteinInput as _FastPLMSESMFold2ProteinInput from .esmfold2_utils_types import PathOrBuffer as _FastPLMSESMFold2PathOrBuffer _EPS = 1e-6 _NONPOLYMER_ID = 4 # Default for the triangle / OPM / pair-transition L² ops. Caps peak memory # so L≈2k folds on an 80 GB GPU (~76 GB peak at chunk=128 for L=1438; # chunk=64 leaves headroom for the largest foldbench targets). Override via # ``model.set_chunk_size(...)``; pass None to disable chunking (faster for # short L but OOM-prone past ~600). _DEFAULT_CHUNK_SIZE = 64 class _ESMFold2ESMplusplusAdapter(nn.Module): def __init__(self, model: nn.Module) -> None: super().__init__() self.model = model @property def config(self): return self.model.config def forward( self, input_ids: Tensor, attention_mask: Tensor | None = None, sequence_id: Tensor | None = None, output_hidden_states: bool | None = None, output_attentions: bool | None = None, return_dict: bool | None = None, compute_sae: bool = True, normalize_sae: bool = False, ): del return_dict, compute_sae, normalize_sae output = self.model( input_ids=input_ids, attention_mask=attention_mask, sequence_id=sequence_id, output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=True, esmfold2_hidden_states=True, ) if output_hidden_states: hidden_states = output.hidden_states assert hidden_states is not None, "ESM++ did not return hidden states." if isinstance(hidden_states, torch.Tensor): output.hidden_states = hidden_states else: output.hidden_states = torch.stack(tuple(hidden_states), dim=0) return output def _load_fastplms_esmplusplus_for_esmfold2( esmc_model_path: str, attn_backend: str, device: torch.device, dtype: torch.dtype, ) -> _ESMFold2ESMplusplusAdapter: try: from fastplms.esm_plusplus.modeling_esm_plusplus import ( ESMplusplusConfig, ESMplusplusModel, ) except ImportError: from .modeling_esm_plusplus import ESMplusplusConfig, ESMplusplusModel normalized_path = normalize_esmc_id(esmc_model_path) esmc_config = ESMplusplusConfig.from_pretrained(normalized_path) esmc_config.attn_backend = attn_backend esmc = ESMplusplusModel.from_pretrained( normalized_path, config=esmc_config, ) return _ESMFold2ESMplusplusAdapter(esmc).to(device=device, dtype=dtype).eval() class PairTransition(nn.Module): """LayerNorm + SwiGLU feed-forward residual block on the pair representation.""" def __init__(self, d_model: int, expansion_ratio: int = 4) -> None: super().__init__() self.norm = nn.LayerNorm(d_model) self.ffn = SwiGLUMLP(d_model, expansion_ratio=expansion_ratio, bias=False) self._chunk_size: int | None = _DEFAULT_CHUNK_SIZE def set_chunk_size(self, chunk_size: int | None) -> None: self._chunk_size = chunk_size def forward(self, x: Tensor) -> Tensor: if self._chunk_size is None or x.shape[1] <= self._chunk_size: return self.ffn(self.norm(x)) out: list[Tensor] = [] for s in range(0, x.shape[1], self._chunk_size): e = min(s + self._chunk_size, x.shape[1]) sl = x[:, s:e] out.append(self.ffn(self.norm(sl))) return torch.cat(out, dim=1) class ConfidenceHead(nn.Module): """Predicts pLDDT, PAE, PDE, resolved-atom probability and distogram bins.""" boundaries: Tensor def __init__(self, config: "ESMFold2Config") -> None: super().__init__() ch = config.confidence_head d_single = config.d_single d_pair = config.d_pair d_inputs = config.inputs.d_inputs boundaries = torch.linspace(ch.min_dist, ch.max_dist, ch.distogram_bins - 1) self.register_buffer("boundaries", boundaries) self.dist_bin_pairwise_embed = nn.Embedding(ch.distogram_bins, d_pair) self.s_norm = nn.LayerNorm(d_single) self.s_inputs_to_single = nn.Linear(d_inputs, d_single, bias=False) self.s_to_z = nn.Linear(d_inputs, d_pair, bias=False) self.s_to_z_transpose = nn.Linear(d_inputs, d_pair, bias=False) self.s_to_z_prod_in1 = nn.Linear(d_inputs, d_pair, bias=False) self.s_to_z_prod_in2 = nn.Linear(d_inputs, d_pair, bias=False) self.s_to_z_prod_out = nn.Linear(d_pair, d_pair, bias=False) self.s_input_to_s = nn.Linear(d_inputs, d_single, bias=False) self.s_inputs_norm = nn.LayerNorm(d_inputs) self.z_norm = nn.LayerNorm(d_pair) self.row_attention_pooling = RowAttentionPooling( d_pair=d_pair, d_single=d_single ) pf = ch.folding_trunk self.folding_trunk = FoldingTrunk( n_layers=pf.n_layers, d_pair=d_pair, expansion_ratio=4 ) # Heads. self.plddt_ln = nn.LayerNorm(d_single) max_atoms_per_token = 23 self.plddt_weight = nn.Parameter( torch.zeros(max_atoms_per_token, d_single, ch.num_plddt_bins) ) self.pae_ln = nn.LayerNorm(d_pair) self.pae_head = nn.Linear(d_pair, ch.num_pae_bins, bias=False) self.pde_ln = nn.LayerNorm(d_pair) self.pde_head = nn.Linear(d_pair, ch.num_pde_bins, bias=False) self.resolved_ln = nn.LayerNorm(d_single) # 2 = resolved logits ([unresolved, resolved]). self.resolved_weight = nn.Parameter( torch.zeros(max_atoms_per_token, d_single, 2) ) def set_kernel_backend(self, backend: str | None) -> None: self.folding_trunk.set_kernel_backend(backend) def set_chunk_size(self, chunk_size: int | None) -> None: self.folding_trunk.set_chunk_size(chunk_size) @staticmethod def _repeat_batch(x: Tensor, num_diffusion_samples: int) -> Tensor: return ( x if num_diffusion_samples == 1 else x.repeat_interleave(num_diffusion_samples, 0) ) @staticmethod def _flatten_sample_axis(x: Tensor) -> Tensor: if x.ndim == 4: b, mult, n, c = x.shape return x.reshape(b * mult, n, c) return x def forward( self, s_inputs: Tensor, z: Tensor, x_pred: Tensor, distogram_atom_idx: Tensor, token_attention_mask: Tensor, atom_to_token: Tensor, atom_attention_mask: Tensor, asym_id: Tensor, mol_type: Tensor, num_diffusion_samples: int = 1, relative_position_encoding: Tensor | None = None, token_bonds_encoding: Tensor | None = None, ) -> dict[str, Tensor]: s_inputs_normed = self.s_inputs_norm(s_inputs) z_base = self.z_norm(z) if relative_position_encoding is not None: z_base = z_base + relative_position_encoding if token_bonds_encoding is not None: z_base = z_base + token_bonds_encoding z_base = z_base + self.s_to_z(s_inputs_normed).unsqueeze(2) z_base = z_base + self.s_to_z_transpose(s_inputs_normed).unsqueeze(1) z_base = z_base + self.s_to_z_prod_out( self.s_to_z_prod_in1(s_inputs_normed)[:, :, None, :] * self.s_to_z_prod_in2(s_inputs_normed)[:, None, :, :] ) pair = self._repeat_batch(z_base, num_diffusion_samples) x_pred_flat = self._flatten_sample_axis(x_pred) atom_to_token_m = self._repeat_batch(atom_to_token, num_diffusion_samples) atom_mask_m = self._repeat_batch(atom_attention_mask, num_diffusion_samples) rep_idx_m = self._repeat_batch(distogram_atom_idx, num_diffusion_samples).long() mask = self._repeat_batch(token_attention_mask, num_diffusion_samples) Bm = pair.shape[0] rep_coords = gather_rep_atom_coords(x_pred_flat, rep_idx_m) rep_distances = torch.cdist( rep_coords, rep_coords, compute_mode="donot_use_mm_for_euclid_dist" ) distogram_bins = ( (rep_distances.unsqueeze(-1) > self.boundaries).sum(dim=-1).long() ) pair = pair + self.dist_bin_pairwise_embed(distogram_bins) pair_mask = mask[:, :, None].float() * mask[:, None, :].float() # FoldingTrunk handles the bf16 cast internally during inference so # each block's fused trimul engages. In-place residual avoids an # extra fp32 pair allocation. with torch.amp.autocast("cuda", enabled=pair.is_cuda, dtype=torch.bfloat16): pair_delta = self.folding_trunk(pair, pair_attention_mask=pair_mask) pair.add_(pair_delta.float()) del pair_delta single = self.row_attention_pooling(pair, mask) atom_mask_f = atom_mask_m.float() s_at_atoms = gather_token_to_atom(single, atom_to_token_m) s_at_atoms_ln = self.plddt_ln(s_at_atoms) intra_idx = _compute_intra_token_idx(atom_to_token_m) intra_idx = intra_idx.clamp(max=self.plddt_weight.shape[0] - 1) w_plddt = self.plddt_weight[intra_idx] plddt_logits = torch.einsum("...c,...cb->...b", s_at_atoms_ln, w_plddt) plddt_per_atom = _categorical_mean(plddt_logits, start=0.0, end=1.0) L = single.shape[1] plddt_sum = torch.zeros(Bm, L, device=single.device, dtype=plddt_per_atom.dtype) atom_count = torch.zeros( Bm, L, device=single.device, dtype=plddt_per_atom.dtype ) atom_mask_t = atom_mask_f.to(plddt_per_atom.dtype) plddt_sum.scatter_add_(1, atom_to_token_m, plddt_per_atom * atom_mask_t) atom_count.scatter_add_(1, atom_to_token_m, atom_mask_t) plddt = plddt_sum / atom_count.clamp(min=1e-6) complex_plddt = (plddt_per_atom * atom_mask_f).sum(dim=-1) / ( atom_mask_f.sum(dim=-1) + _EPS ) expanded_type = self._repeat_batch(mol_type, num_diffusion_samples) expanded_asym = self._repeat_batch(asym_id, num_diffusion_samples) is_ligand = (expanded_type == _NONPOLYMER_ID).float() inter_chain = ( expanded_asym.unsqueeze(-1) != expanded_asym.unsqueeze(-2) ).float() near_contact = (rep_distances < 8).float() interface_per_token = ( near_contact * inter_chain * (1.0 - is_ligand).unsqueeze(-1) ).amax(dim=-1) iplddt_weight = torch.where( is_ligand.bool(), torch.full_like(interface_per_token, 2.0), interface_per_token, ) iplddt_weight_atoms = gather_token_to_atom( iplddt_weight.unsqueeze(-1), atom_to_token_m ).squeeze(-1) atom_iplddt_w = atom_mask_f * iplddt_weight_atoms complex_iplddt = (plddt_per_atom * atom_iplddt_w).sum(dim=-1) / ( atom_iplddt_w.sum(dim=-1) + _EPS ) plddt_ca = plddt_per_atom.gather(1, rep_idx_m) # PAE pae_logits = self.pae_head(self.pae_ln(pair)) pae = _categorical_mean(pae_logits, start=0.0, end=32.0).detach() # PDE pde_logits = self.pde_head(self.pde_ln(pair)) pde = _categorical_mean(pde_logits, start=0.0, end=32.0).detach() # Resolved (per-atom binary). s_at_atoms_res = self.resolved_ln(s_at_atoms) w_res = self.resolved_weight[intra_idx] resolved_logits = torch.einsum("...c,...cb->...b", s_at_atoms_res, w_res) # pTM / ipTM from pae_logits. n_bins = pae_logits.shape[-1] bin_width = 32.0 / n_bins bin_centers = torch.arange( 0.5 * bin_width, 32.0, bin_width, device=pae_logits.device ) mask_f = mask.float() N_res = mask_f.sum(dim=-1, keepdim=True) d0 = 1.24 * (N_res.clamp(min=19) - 15) ** (1 / 3) - 1.8 tm_per_bin = 1 / (1 + (bin_centers / d0) ** 2) pae_probs = F.softmax(pae_logits, dim=-1) tm_expected = (pae_probs * tm_per_bin[:, None, None, :]).sum(dim=-1) pair_mask_2d = mask_f.unsqueeze(-1) * mask_f.unsqueeze(-2) ptm_per_row = (tm_expected * pair_mask_2d).sum(dim=-1) / ( pair_mask_2d.sum(dim=-1) + _EPS ) ptm = ptm_per_row.max(dim=-1).values inter_chain_mask = ( expanded_asym.unsqueeze(-1) != expanded_asym.unsqueeze(-2) ).float() * pair_mask_2d iptm_per_row = (tm_expected * inter_chain_mask).sum(dim=-1) / ( inter_chain_mask.sum(dim=-1) + _EPS ) iptm = iptm_per_row.max(dim=-1).values max_chain_id = int(expanded_asym.max().item()) if Bm > 0 else 0 n_chains = max_chain_id + 1 pair_chains_iptm = torch.zeros( Bm, n_chains, n_chains, device=tm_expected.device, dtype=tm_expected.dtype ) for c1 in range(n_chains): chain_c1 = (expanded_asym == c1).float() * mask_f if chain_c1.sum() == 0: continue for c2 in range(n_chains): chain_c2 = (expanded_asym == c2).float() * mask_f pair_m = chain_c1.unsqueeze(-1) * chain_c2.unsqueeze(-2) denom = pair_m.sum(dim=(-1, -2)) + _EPS pair_chains_iptm[:, c1, c2] = (tm_expected * pair_m).sum( dim=(-1, -2) ) / denom return { "plddt_logits": plddt_logits, "plddt": plddt.detach(), "plddt_per_atom": plddt_per_atom.detach(), "plddt_ca": plddt_ca.detach(), "complex_plddt": complex_plddt.detach(), "complex_iplddt": complex_iplddt.detach(), "pae_logits": pae_logits, "pae": pae, "pde_logits": pde_logits, "pde": pde, "resolved_logits": resolved_logits, "ptm": ptm.detach(), "iptm": iptm.detach(), "pair_chains_iptm": pair_chains_iptm.detach(), } def _inverse_softplus(value: float) -> float: return value + math.log(-math.expm1(-value)) def _convert_te_modules_to_fp8_inplace(module: nn.Module) -> None: """Re-init each TE module via quantized_model_init so weights live as fp8. Must be called inside torch.no_grad(); covers nn.Linear, te.Linear, te.LayerNormLinear, te.LayerNormMLP — the last two hold 99% of ESMC weight. """ if not TE_AVAILABLE: raise RuntimeError("transformer_engine is not available; cannot use fp8.") quantized_model_init = importlib.import_module( "transformer_engine.pytorch" ).quantized_model_init def _walk(mod: nn.Module) -> None: for name, child in list(mod.named_children()): replaced = False if isinstance(child, nn.Linear): in_f, out_f = child.in_features, child.out_features has_bias = child.bias is not None device = child.weight.device dtype = child.weight.dtype w = child.weight.data b = child.bias.data if has_bias else None setattr(mod, name, nn.Identity()) del child torch.cuda.empty_cache() with quantized_model_init(enabled=True): new_mod = te.Linear( # type: ignore[union-attr] in_f, out_f, bias=has_bias, params_dtype=dtype ).to(device) new_mod.weight.quantize_(w) # type: ignore[attr-defined,operator] if has_bias: assert b is not None new_mod.bias.data.copy_(b) # type: ignore[union-attr] del w, b replaced = True elif isinstance(child, te.Linear): # type: ignore[union-attr] # te.Linear with bf16 weight → re-init inside quantized_model_init for fp8. in_f, out_f = child.in_features, child.out_features has_bias = child.bias is not None device = child.weight.device dtype = ( child.weight.dtype if not hasattr(child.weight, "_data") else torch.bfloat16 ) state = {k: v.detach().clone() for k, v in child.state_dict().items()} setattr(mod, name, nn.Identity()) del child torch.cuda.empty_cache() with quantized_model_init(enabled=True): new_mod = te.Linear( # type: ignore[union-attr] in_f, out_f, bias=has_bias, params_dtype=dtype, # type: ignore[arg-type] ).to(device) # type: ignore[arg-type] new_mod.load_state_dict(state, strict=False) replaced = True elif ( hasattr(te, "LayerNormLinear") and isinstance(child, te.LayerNormLinear) # type: ignore[union-attr] ): state = {k: v.detach().clone() for k, v in child.state_dict().items()} hidden_size = child.in_features out_features = child.out_features has_bias = child.use_bias device = next(child.parameters()).device setattr(mod, name, nn.Identity()) del child torch.cuda.empty_cache() with quantized_model_init(enabled=True): new_mod = te.LayerNormLinear( # type: ignore[union-attr] hidden_size, out_features, bias=has_bias, params_dtype=torch.bfloat16, ).to(device) new_mod.load_state_dict(state, strict=False) replaced = True elif ( hasattr(te, "LayerNormMLP") and isinstance(child, te.LayerNormMLP) # type: ignore[union-attr] ): state = {k: v.detach().clone() for k, v in child.state_dict().items()} fc1_weight: Tensor = child.fc1_weight # type: ignore[attr-defined] hidden_size = int(fc1_weight.shape[1]) # fc1 packed as (2*ffn_hidden_size, hidden_size) for swiglu. ffn_hidden_size = int(fc1_weight.shape[0]) // 2 has_bias = ( getattr(child, "fc1_bias", None) is not None and child.fc1_bias is not None # type: ignore[attr-defined] ) device = fc1_weight.device setattr(mod, name, nn.Identity()) del child torch.cuda.empty_cache() with quantized_model_init(enabled=True): new_mod = te.LayerNormMLP( # type: ignore[union-attr] hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, bias=has_bias, activation="swiglu", params_dtype=torch.bfloat16, ).to(device) # type: ignore[arg-type] new_mod.load_state_dict(state, strict=False) replaced = True if replaced: # Freeze via .eval()+.requires_grad_(False); per-param ops would unwrap Float8Tensor. new_mod.eval().requires_grad_(False) setattr(mod, name, new_mod) torch.cuda.empty_cache() else: _walk(child) _walk(module) torch.cuda.empty_cache() @contextmanager def _lm_precision_context(fp8: bool): """bf16 autocast (+ optional TE fp8 autocast) around the LM forward. te.autocast keeps te.Linear outputs bf16 instead of the fp32 default (~425 MB at L=1024 in the hidden-state cache). """ with torch.autocast(device_type="cuda", dtype=torch.bfloat16): if fp8 and TE_AVAILABLE: fp8_recipe = DelayedScaling( # type: ignore[misc] fp8_format=Format.HYBRID, # type: ignore[union-attr] amax_history_len=1, amax_compute_algo="most_recent", ) with te.autocast(enabled=True, recipe=fp8_recipe): # type: ignore[union-attr] yield else: yield class ESMFold2Model(FastPLMTestTimeTrainingMixin, PreTrainedModel): """ESMFold2 — all-atom structure prediction with an ESMC PLM backbone. This is the standard released ESMFold2 architecture (uses a linear- recurrent trunk, internally referred to as "parcae"). Forward kwargs that callers commonly override: * ``num_loops`` (default ``config.num_loops``): trunk refinement loops. * ``num_diffusion_samples`` (default ``config.num_diffusion_samples``): parallel structure samples; the confidence head re-runs once per sample, so memory scales linearly. Pass ``1`` for cheap inference. * ``num_sampling_steps`` (default ``config.structure_head.inference_num_steps``): diffusion ODE solver steps. Lower for speed, higher for quality. Memory / perf knobs: * ``model.set_chunk_size(int|None)``: caps L² ops (triangle / OPM / pair transition) at this token-axis chunk. Default 64 — fits L≈2k on an 80 GB GPU. Pass ``None`` for faster inference at L<600. * ``model.set_kernel_backend(None | "fused" | "cuequivariance")``: select kernel backend (None = reference path). """ config_class = ESMFold2Config _keys_to_ignore_on_load_unexpected = [r"\._extra_state$"] def __init__(self, config: ESMFold2Config) -> None: super().__init__(config) d_inputs = config.inputs.d_inputs d_pair = config.d_pair self.inputs_embedder = InputsEmbedder(config) self.z_init_1 = nn.Linear(d_inputs, d_pair, bias=False) self.z_init_2 = nn.Linear(d_inputs, d_pair, bias=False) self.rel_pos = ResIdxAsymIdSymIdEntityIdEncoding( n_relative_residx_bins=config.n_relative_residx_bins, n_relative_chain_bins=config.n_relative_chain_bins, d_pair=d_pair, ) self.token_bonds = nn.Linear(1, d_pair, bias=False) self.language_model = LanguageModelShim( d_z=d_pair, d_model=config.lm_d_model, num_layers=config.lm_num_layers ) self._esmc: nn.Module | None = None self._esmc_fp8: bool = False self._ttt_lm_head: nn.Module | None = None self._esmfold2_input_builder: Any | None = None pf = config.folding_trunk self.folding_trunk = FoldingTrunk( n_layers=pf.n_layers, d_pair=d_pair, expansion_ratio=4 ) if config.lm_encoder.enabled: self.lm_encoder: FoldingTrunk | None = FoldingTrunk( n_layers=config.lm_encoder.n_layers, d_pair=d_pair, expansion_ratio=4 ) else: self.lm_encoder = None self.parcae_input_norm = nn.LayerNorm(d_pair) self.parcae_log_a = nn.Parameter(torch.zeros(d_pair)) parcae_decay_init = math.sqrt(1.0 / 5.0) parcae_delta_init = -math.log(parcae_decay_init) self.parcae_log_delta = nn.Parameter( torch.full( (d_pair,), _inverse_softplus(parcae_delta_init), dtype=torch.float32 ) ) self.parcae_b_cont = nn.Parameter(torch.eye(d_pair)) self.parcae_readout = nn.Linear(d_pair, d_pair, bias=False) nn.init.eye_(self.parcae_readout.weight) self.parcae_coda = FoldingTrunk( n_layers=config.parcae.coda_n_layers, d_pair=d_pair, expansion_ratio=4 ) # Heads -------------------------------------------------------------- self.structure_head = DiffusionStructureHead(config) self.distogram_head = nn.Linear( d_pair, config.structure_head.distogram_bins, bias=True ) self.confidence_head = ConfidenceHead(config) msa_cfg = config.msa_encoder self.msa_encoder = None if msa_cfg.enabled: self.msa_encoder = MSAEncoder( d_msa=msa_cfg.d_msa, d_pair=d_pair, d_inputs=d_inputs, d_hidden=msa_cfg.d_hidden, n_layers=msa_cfg.n_layers, n_heads_msa=msa_cfg.n_heads_msa, msa_head_width=msa_cfg.msa_head_width, ) self.post_init() self.init_ttt({"lora_target_replace_module": "MultiHeadAttention"}) def load_esmc(self, esmc_model_path: str, precision: str = "bf16") -> None: """Load the FastPLMs ESM++ LM used as the ESMFold2 PLM backbone. ``precision``: ``"bf16"`` (default), ``"fp32"``, or opt-in ``"fp8"``. """ dtype_map = { "bf16": torch.bfloat16, "fp32": torch.float32, "fp8": torch.bfloat16, } if precision not in dtype_map: raise ValueError(f"precision must be one of {list(dtype_map)}, got {precision!r}") if precision == "fp8" and not TE_AVAILABLE: raise RuntimeError( "esmc_precision='fp8' requires transformer_engine.pytorch." ) dtype = dtype_map[precision] esmc = _load_fastplms_esmplusplus_for_esmfold2( esmc_model_path=esmc_model_path, attn_backend=self.config.esmc_attn_backend, device=self.device, dtype=dtype, ) assert esmc.config.hidden_size == self.config.lm_d_model, ( f"ESMFold2 expected lm_d_model={self.config.lm_d_model}, " f"but loaded ESM++ hidden_size={esmc.config.hidden_size}." ) assert esmc.config.num_hidden_layers == self.config.lm_num_layers, ( f"ESMFold2 expected lm_num_layers={self.config.lm_num_layers}, " f"but loaded ESM++ num_hidden_layers={esmc.config.num_hidden_layers}." ) for p in esmc.parameters(): p.requires_grad_(False) if precision == "fp8": with torch.no_grad(): _convert_te_modules_to_fp8_inplace(esmc) self._esmc_fp8 = precision == "fp8" self._esmc = esmc self._ttt_lm_head = None def _ensure_ttt_lm_head(self) -> None: assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True." if self._esmc_fp8: raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.") if self._ttt_lm_head is not None: return try: from fastplms.esm_plusplus.modeling_esm_plusplus import ( ESMplusplusConfig, ESMplusplusForMaskedLM, ) except ImportError: from .modeling_esm_plusplus import ( ESMplusplusConfig, ESMplusplusForMaskedLM, ) esmc_config = ESMplusplusConfig.from_pretrained(self.config.esmc_id) esmc_config.attn_backend = self.config.esmc_attn_backend mlm, loading_info = ESMplusplusForMaskedLM.from_pretrained( self.config.esmc_id, config=esmc_config, output_loading_info=True, ) missing_head_keys = [ key for key in loading_info["missing_keys"] if key.startswith("sequence_head") ] assert len(missing_head_keys) == 0, ( f"ESMFold2 TTT could not load a pretrained ESM++ MLM head from " f"{self.config.esmc_id}: missing {missing_head_keys}" ) dtype = next(self._esmc.parameters()).dtype mlm = mlm.to(device=self.device, dtype=dtype).eval() self._ttt_lm_head = mlm.sequence_head self._ttt_lm_head.requires_grad_(False) del mlm def _ttt_get_trainable_modules(self) -> list[nn.Module]: assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True." if self._esmc_fp8: raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.") return [self._esmc] def _ttt_tokenize( self, seq: str | list[str] | None = None, input_ids: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: del kwargs if input_ids is not None: return input_ids assert seq is not None, "Pass either seq or input_ids for ESMFold2 TTT." sequences = [seq] if isinstance(seq, str) else seq token_to_id = {token: idx for idx, token in enumerate(SEQUENCE_VOCAB)} encoded = [] for sequence in sequences: token_ids = [SEQUENCE_BOS_TOKEN] for amino_acid in sequence: token_ids.append( token_to_id[amino_acid if amino_acid in token_to_id else "X"] ) token_ids.append(SEQUENCE_EOS_TOKEN) encoded.append(token_ids) max_len = max(len(token_ids) for token_ids in encoded) input_tensor = torch.full( (len(encoded), max_len), SEQUENCE_PAD_TOKEN, dtype=torch.long, ) for row, token_ids in enumerate(encoded): input_tensor[row, : len(token_ids)] = torch.tensor( token_ids, dtype=torch.long, ) return input_tensor def _ttt_mask_token(self) -> int: return SEQUENCE_MASK_TOKEN def _ttt_padding_token(self) -> int: return SEQUENCE_PAD_TOKEN def _ttt_replacement_tokens(self, input_ids: torch.Tensor) -> torch.Tensor: return torch.arange( SEQUENCE_STANDARD_AA_MIN_TOKEN, SEQUENCE_STANDARD_AA_MAX_TOKEN, device=input_ids.device, dtype=input_ids.dtype, ) def _ttt_non_special_mask(self, input_ids: torch.Tensor) -> torch.Tensor: return (input_ids >= SEQUENCE_STANDARD_AA_MIN_TOKEN) & ( input_ids < SEQUENCE_STANDARD_AA_MAX_TOKEN ) def _ttt_predict_logits( self, batch: torch.Tensor | dict[str, torch.Tensor], **kwargs, ) -> torch.Tensor: del kwargs assert isinstance(batch, torch.Tensor), ( "ESMFold2 TTT expects input_ids tensors." ) assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True." if self._esmc_fp8: raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.") self._ensure_ttt_lm_head() assert self._ttt_lm_head is not None attention_mask = batch.ne(SEQUENCE_PAD_TOKEN) output = self._esmc( input_ids=batch, attention_mask=attention_mask, return_dict=True, compute_sae=False, ) return self._ttt_lm_head(output.last_hidden_state) @classmethod def from_pretrained( cls, pretrained_model_name_or_path, *args, load_esmc: bool = True, **kwargs ): if cls is ESMFold2Model and "config" not in kwargs: config = ESMFold2Config.from_pretrained( pretrained_model_name_or_path, **kwargs ) if config.type == "experimental": raise ValueError( "FastPLMs ESMFold2 supports the released ESMFold2 and " "ESMFold2-Fast checkpoints. Experimental ESMFold2 configs " "are not part of the self-contained AutoModel package." ) kwargs["config"] = config # Pop the precision knob before forwarding to the HF loader. esmc_precision = kwargs.pop("esmc_precision", "bf16") model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) if load_esmc: model.load_esmc(model.config.esmc_id, precision=esmc_precision) return model def set_kernel_backend(self, backend: str | None) -> None: """Select kernel backend. Args: backend: ``None`` (reference path), ``"fused"`` (vendored Triton kernels), or ``"cuequivariance"`` (cuequivariance kernels where applicable; vanilla python fallback otherwise). """ self.folding_trunk.set_kernel_backend(backend) if self.lm_encoder is not None: self.lm_encoder.set_kernel_backend(backend) self.parcae_coda.set_kernel_backend(backend) self.confidence_head.set_kernel_backend(backend) self.structure_head.set_kernel_backend(backend) def apply_torch_compile( self, mode: str = "fixed_seqlen", dynamic: bool | None = None ) -> None: """Compile L²-heavy blocks. ``mode='fixed_seqlen'`` recompiles per L; ``'dynamic_seqlen'`` compiles once. Does NOT stack with our Triton kernels — call ``set_kernel_backend(None)`` before compiling. """ import torch._dynamo torch._dynamo.config.cache_size_limit = 512 # type: ignore[attr-defined] torch._dynamo.config.accumulated_cache_size_limit = 512 # type: ignore[attr-defined] # capture_scalar_outputs avoids graph breaks at .item() in atom-attention path. torch._dynamo.config.capture_scalar_outputs = True # type: ignore[attr-defined] if dynamic is None: dynamic = mode == "dynamic_seqlen" kwargs: dict = {"dynamic": dynamic} from .modeling_esmfold2_common import ( DiffusionModule, DiffusionTransformer, PairUpdateBlock, ) compile_targets = ( PairUpdateBlock, DiffusionTransformer, DiffusionModule, MSAEncoderBlock, ) def _maybe_compile(module: nn.Module) -> None: if isinstance(module, compile_targets): module.forward = torch.compile(module.forward, **kwargs) # type: ignore[assignment] self.apply(_maybe_compile) def set_chunk_size(self, chunk_size: int | None) -> None: self.folding_trunk.set_chunk_size(chunk_size) if self.lm_encoder is not None: self.lm_encoder.set_chunk_size(chunk_size) self.parcae_coda.set_chunk_size(chunk_size) self.confidence_head.set_chunk_size(chunk_size) if self.msa_encoder is not None: self.msa_encoder.set_chunk_size(chunk_size) def _compute_lm_hidden_states( self, input_ids: Tensor, asym_id: Tensor, residue_index: Tensor, mol_type: Tensor, tok_mask: Tensor, lm_mask_pct: float = 0.0, ) -> Tensor: assert self._esmc is not None # fp8 TE kernels require prod(shape[:-1]) % 8 == 0. pad_to = 8 if self._esmc_fp8 else None with _lm_precision_context(self._esmc_fp8): return compute_lm_hidden_states( self._esmc, input_ids, asym_id, residue_index, mol_type, tok_mask, pad_to_multiple=pad_to, lm_mask_pct=lm_mask_pct, mask_token_id=SEQUENCE_MASK_TOKEN, ) def _discretized_dynamics(self) -> tuple[Tensor, Tensor]: delta = F.softplus(self.parcae_log_delta) a = torch.exp(-delta * torch.exp(self.parcae_log_a)) b = delta[:, None] * self.parcae_b_cont return a, b def _init_pair_state(self, ref: Tensor) -> Tensor: std = math.sqrt(2.0 / (5.0 * ref.shape[-1])) state = torch.empty_like(ref, dtype=torch.float32) nn.init.trunc_normal_(state, mean=0.0, std=std, a=-3 * std, b=3 * std) return state.to(dtype=ref.dtype) def _run_one_loop( self, z: Tensor, z_init: Tensor, lm_z: Tensor | None, _msa_inputs: dict | None, pair_mask: Tensor, a: Tensor, b_mat: Tensor, tok_mask: Tensor, total_steps: int, ) -> Tensor: # Helper method (not inline) so per-iter locals free on return — # otherwise leaks ~2 GB L²×c_z into distogram/sample scope. # training=True forces dropout under eval(), matching the per-loop # dropout strategy used at train time. lm_cfg = self.config.lm_encoder _per_loop_lm_dropout = ( lm_z is not None and getattr(lm_cfg, "per_loop_lm_dropout", False) and getattr(lm_cfg, "lm_dropout", 0.0) > 0.0 ) _lm_dropout_p = getattr(lm_cfg, "lm_dropout", 0.0) for _ in range(total_steps): if _per_loop_lm_dropout: assert lm_z is not None # narrowed by _per_loop_lm_dropout lm_z_i: Tensor | None = F.dropout(lm_z, p=_lm_dropout_p, training=True) else: lm_z_i = lm_z refined_lm_z: Tensor | None = None if lm_z_i is not None and self.lm_encoder is not None: refined_lm_z = self.lm_encoder( lm_z_i.to(z_init.dtype), pair_attention_mask=pair_mask ) z_inject_pair = z_init if lm_z_i is not None and self.lm_encoder is None: z_inject_pair = z_inject_pair + lm_z_i.to(z_inject_pair.dtype) if self.msa_encoder is not None and _msa_inputs is not None: msa_i, mask_i, hd_i, dv_i = maybe_subsample_msa( _msa_inputs["msa"], _msa_inputs["msa_attention_mask"], _msa_inputs["has_deletion"], _msa_inputs["deletion_value"], max_depth=_msa_inputs["max_depth"], enabled=_msa_inputs["subsample_enabled"], ) B_msa, M, L_msa = msa_i.shape msa_oh = F.one_hot( msa_i.permute(0, 2, 1).long(), num_classes=NUM_RES_TYPES ).float() msa_attn = ( mask_i.permute(0, 2, 1).float() if mask_i is not None else tok_mask[:, :, None].expand(-1, -1, M).float() ) # Bias-free MSAEncoder.embed requires zeroed padding. msa_oh = msa_oh * msa_attn.unsqueeze(-1) hd = ( hd_i.permute(0, 2, 1).float() if hd_i is not None else torch.zeros(B_msa, L_msa, M, device=msa_i.device) ) dv = ( dv_i.permute(0, 2, 1).float() if dv_i is not None else torch.zeros(B_msa, L_msa, M, device=msa_i.device) ) msa_pair = self.msa_encoder( x_pair=z_inject_pair, x_inputs=_msa_inputs["x_inputs"], msa_oh=msa_oh, has_deletion=hd, deletion_value=dv, msa_attention_mask=msa_attn, ).to(z_inject_pair.dtype) z_inject_pair = ( msa_pair if self.config.msa_encoder_overwrite else (z_inject_pair + msa_pair) ) if refined_lm_z is not None: z_inject_pair = z_inject_pair + refined_lm_z.to(z_inject_pair.dtype) injected_pair = self.parcae_input_norm(z_inject_pair) z = a * z + F.linear(injected_pair.to(z.dtype), b_mat) z = self.folding_trunk(z, pair_attention_mask=pair_mask) return z @torch.inference_mode() def forward( self, token_index: Tensor, residue_index: Tensor, asym_id: Tensor, sym_id: Tensor, entity_id: Tensor, mol_type: Tensor, res_type: Tensor, token_bonds: Tensor, token_attention_mask: Tensor, ref_pos: Tensor, ref_element: Tensor, ref_charge: Tensor, ref_atom_name_chars: Tensor, ref_space_uid: Tensor, atom_attention_mask: Tensor, atom_to_token: Tensor, distogram_atom_idx: Tensor, deletion_mean: Tensor | None = None, msa: Tensor | None = None, has_deletion: Tensor | None = None, deletion_value: Tensor | None = None, msa_attention_mask: Tensor | None = None, input_ids: Tensor | None = None, lm_hidden_states: Tensor | None = None, num_loops: int | None = None, num_diffusion_samples: int | None = None, num_sampling_steps: int | None = None, lm_mask_pct: float | None = None, msa_max_depth: int = 1024, msa_column_mask_rate: float = 0.1, msa_subsample_at_inference: bool = True, **kwargs, ) -> dict[str, Tensor]: tok_mask = token_attention_mask atm_mask = atom_attention_mask disto_idx = distogram_atom_idx n_loops: int = num_loops if num_loops is not None else self.config.num_loops n_samples: int = ( num_diffusion_samples if num_diffusion_samples is not None else self.config.num_diffusion_samples ) total_steps = max(1, n_loops + 1) if res_type.dim() == 2: res_type_oh = F.one_hot(res_type.long(), num_classes=NUM_RES_TYPES).float() res_type_oh = res_type_oh * tok_mask.unsqueeze(-1).float() else: res_type_oh = res_type.float() if msa is not None: msa_oh_profile = F.one_hot(msa.long(), num_classes=NUM_RES_TYPES).float() if msa_attention_mask is not None: mask_f = msa_attention_mask.float().unsqueeze(-1) msa_oh_profile = msa_oh_profile * mask_f valid_seq_count = msa_attention_mask.float().sum(dim=1).clamp(min=1) profile = msa_oh_profile.sum(dim=1) / valid_seq_count.unsqueeze(-1) else: profile = msa_oh_profile.mean(dim=1) else: profile = res_type_oh if deletion_mean is None: deletion_mean = torch.zeros( res_type.shape[0], res_type.shape[1], device=res_type.device ) ref_element_oh = F.one_hot( ref_element.long(), num_classes=MAX_ATOMIC_NUMBER ).float() ref_atom_name_chars_oh = F.one_hot( ref_atom_name_chars.long(), num_classes=CHAR_VOCAB_SIZE ).float() # Bias-free downstream Linears require zeroed padding. atm_mask_f = atm_mask.float() ref_element_oh = ref_element_oh * atm_mask_f.unsqueeze(-1) ref_atom_name_chars_oh = ref_atom_name_chars_oh * atm_mask_f.unsqueeze( -1 ).unsqueeze(-1) atom_to_token = atom_to_token * atm_mask.long() use_amp = ref_pos.device.type == "cuda" with torch.amp.autocast("cuda", enabled=use_amp, dtype=torch.bfloat16): x_inputs = self.inputs_embedder( aatype=res_type_oh, profile=profile.float(), deletion_mean=deletion_mean.float(), ref_pos=ref_pos, atom_attention_mask=atm_mask, ref_space_uid=ref_space_uid, ref_charge=ref_charge, ref_element=ref_element_oh, ref_atom_name_chars=ref_atom_name_chars_oh, atom_to_token=atom_to_token, ) z_init = self.z_init_1(x_inputs).unsqueeze(2) + self.z_init_2( x_inputs ).unsqueeze(1) relative_position_encoding = self.rel_pos( residue_index=residue_index, asym_id=asym_id, sym_id=sym_id, entity_id=entity_id, token_index=token_index, ) token_bonds_encoding = self.token_bonds(token_bonds.float()) z_init = z_init + relative_position_encoding + token_bonds_encoding if ( lm_hidden_states is None and input_ids is not None and self._esmc is not None ): lm_hidden_states = self._compute_lm_hidden_states( input_ids, asym_id, residue_index, mol_type, tok_mask, lm_mask_pct=( self.config.lm_mask_pct if lm_mask_pct is None else lm_mask_pct ), ) lm_z: Tensor | None = None if lm_hidden_states is not None: lm_z = self.language_model(lm_hidden_states.detach()) del lm_hidden_states pair_mask = tok_mask[:, :, None].float() * tok_mask[:, None, :].float() z = self._init_pair_state(z_init) a, b = self._discretized_dynamics() a = a.view(1, 1, 1, -1).to(device=z.device, dtype=z.dtype) b_mat = b.to(device=z.device, dtype=z.dtype) _msa_inputs: dict | None = None if self.msa_encoder is not None and msa is not None: msa_attention_mask = maybe_apply_msa_column_masking( msa_attention_mask, msa_column_mask_rate, ) _msa_inputs = dict( x_inputs=x_inputs, msa=msa, msa_attention_mask=msa_attention_mask, has_deletion=has_deletion, deletion_value=deletion_value, max_depth=msa_max_depth, subsample_enabled=msa_subsample_at_inference, ) # Method call (not inline loop) frees per-iter L²×c_z locals. z = self._run_one_loop( z=z, z_init=z_init, lm_z=lm_z, _msa_inputs=_msa_inputs, pair_mask=pair_mask, a=a, b_mat=b_mat, tok_mask=tok_mask, total_steps=total_steps, ) del z_init, lm_z, _msa_inputs, a, b_mat z = self.parcae_readout(z) z = self.parcae_coda(z, pair_attention_mask=pair_mask) z = z.float() distogram_logits = self.distogram_head(z + z.transpose(-2, -3)) structure_output = self.structure_head.sample( z_trunk=z, s_inputs=x_inputs, s_trunk=None, relative_position_encoding=relative_position_encoding, ref_pos=ref_pos, ref_charge=ref_charge, ref_mask=atm_mask, ref_element=ref_element_oh, ref_atom_name_chars=ref_atom_name_chars_oh, ref_space_uid=ref_space_uid, tok_idx=atom_to_token, asym_id=asym_id, residue_index=residue_index, entity_id=entity_id, token_index=token_index, sym_id=sym_id, token_attention_mask=tok_mask, num_diffusion_samples=n_samples, num_sampling_steps=num_sampling_steps, return_atom_repr=False, denoising_early_exit_rmsd=None, ) sample_coords = structure_output["sample_atom_coords"] assert sample_coords is not None output: dict[str, Tensor] = {"distogram_logits": distogram_logits} output["sample_atom_coords"] = sample_coords confidence_output = self.confidence_head( s_inputs=x_inputs.detach(), z=z.detach().float(), x_pred=sample_coords.detach(), distogram_atom_idx=disto_idx, token_attention_mask=tok_mask, atom_to_token=atom_to_token, atom_attention_mask=atm_mask, asym_id=asym_id, mol_type=mol_type, num_diffusion_samples=n_samples, relative_position_encoding=relative_position_encoding.detach(), token_bonds_encoding=token_bonds_encoding.detach(), ) output.update(confidence_output) output["atom_pad_mask"] = ( atm_mask.unsqueeze(0) if atm_mask.dim() == 1 else atm_mask ) output["residue_index"] = residue_index output["entity_id"] = entity_id return output @torch.no_grad() def infer_protein(self, seq: str, **forward_kwargs) -> dict: from .protein_utils import prepare_protein_features features = prepare_protein_features(seq) features = {k: v.to(self.device) for k, v in features.items()} return self(**features, **forward_kwargs) @property def input_builder(self): if self._esmfold2_input_builder is None: from .esmfold2_processor import ESMFold2InputBuilder self._esmfold2_input_builder = ESMFold2InputBuilder() return self._esmfold2_input_builder @property def input_types(self): from . import esmfold2_types return esmfold2_types def prepare_structure_input(self, input, seed: int | None = None): return self.input_builder.prepare_input(input, seed=seed, device=self.device) def fold( self, input, *, num_loops: int = 3, num_sampling_steps: int = 50, num_diffusion_samples: int = 1, seed: int | None = None, noise_scale: float | None = None, step_scale: float | None = None, max_inference_sigma: int | None = None, early_exit: bool = False, complex_id: str = "pred", ): return self.input_builder.fold( self, input, num_loops=num_loops, num_sampling_steps=num_sampling_steps, num_diffusion_samples=num_diffusion_samples, seed=seed, noise_scale=noise_scale, step_scale=step_scale, max_inference_sigma=max_inference_sigma, early_exit=early_exit, complex_id=complex_id, ) def _fold_protein_no_ttt( self, sequence: str, *, chain_id: str = "A", msa: Any | None = None, msa_path: str | Path | None = None, msa_max_sequences: int | None = None, num_loops: int = 3, num_sampling_steps: int = 50, num_diffusion_samples: int = 1, seed: int | None = None, complex_id: str = "pred", ): from .esmfold2_types import MSA, ProteinInput, StructurePredictionInput assert not ( msa is not None and msa_path is not None ), "Pass at most one of msa or msa_path." if msa_path is not None: msa = MSA.from_a3m(msa_path, max_sequences=msa_max_sequences) if msa is not None: query = str(msa.query).replace("-", "").upper() assert query == sequence.upper(), ( f"MSA query does not match sequence: expected {sequence.upper()!r}, got {query!r}" ) input = StructurePredictionInput( sequences=[ProteinInput(id=chain_id, sequence=sequence, msa=msa)] ) return self.fold( input, num_loops=num_loops, num_sampling_steps=num_sampling_steps, num_diffusion_samples=num_diffusion_samples, seed=seed, complex_id=complex_id, ) @staticmethod def _ttt_mean_plddt(result) -> float: assert result.plddt is not None, "ESMFold2 result has no pLDDT tensor." return float(result.plddt.float().mean().item()) def _ttt_select_result(self, result): if isinstance(result, list): assert len(result) > 0, "ESMFold2 fold returned an empty result list." return max(result, key=self._ttt_mean_plddt) return result def _ttt_eval_step( self, step: int, loss: float, seq: str | list[str] | None = None, input_ids: torch.Tensor | None = None, **kwargs, ) -> tuple[dict[str, Any], float | None]: del input_ids assert isinstance(seq, str), ( "ESMFold2 fold TTT is protein-only and sequence-string only." ) fold_kwargs = kwargs["fold_kwargs"] was_training = self.training self.eval() try: result = self._fold_protein_no_ttt(seq, **fold_kwargs) finally: self.train(was_training) selected = self._ttt_select_result(result) plddt = self._ttt_mean_plddt(selected) return { "step": step, "loss": loss, "plddt": plddt, "result": selected, }, plddt def fold_protein( self, sequence: str, *, chain_id: str = "A", msa: Any | None = None, msa_path: str | Path | None = None, msa_max_sequences: int | None = None, num_loops: int = 3, num_sampling_steps: int = 50, num_diffusion_samples: int = 1, seed: int | None = None, complex_id: str = "pred", ttt: bool = False, ttt_config: TTTConfig | dict[str, Any] | None = None, ): if ttt: return self.fold_protein_ttt( sequence=sequence, chain_id=chain_id, msa=msa, msa_path=msa_path, msa_max_sequences=msa_max_sequences, num_loops=num_loops, num_sampling_steps=num_sampling_steps, num_diffusion_samples=num_diffusion_samples, seed=seed, complex_id=complex_id, ttt_config=ttt_config, ) return self._fold_protein_no_ttt( sequence=sequence, chain_id=chain_id, msa=msa, msa_path=msa_path, msa_max_sequences=msa_max_sequences, num_loops=num_loops, num_sampling_steps=num_sampling_steps, num_diffusion_samples=num_diffusion_samples, seed=seed, complex_id=complex_id, ) def fold_protein_ttt( self, sequence: str, *, chain_id: str = "A", msa: Any | None = None, msa_path: str | Path | None = None, msa_max_sequences: int | None = None, num_loops: int = 3, num_sampling_steps: int = 50, num_diffusion_samples: int = 1, seed: int | None = None, complex_id: str = "pred", ttt_config: TTTConfig | dict[str, Any] | None = None, ): assert self._esmc is not None, "ESMFold2 TTT requires load_esmc=True." if self._esmc_fp8: raise RuntimeError("ESMFold2 TTT is not supported with fp8 ESM++.") fold_kwargs = { "chain_id": chain_id, "msa": msa, "msa_path": msa_path, "msa_max_sequences": msa_max_sequences, "num_loops": num_loops, "num_sampling_steps": num_sampling_steps, "num_diffusion_samples": num_diffusion_samples, "seed": seed, "complex_id": complex_id, } baseline = self._ttt_select_result( self._fold_protein_no_ttt(sequence, **fold_kwargs) ) baseline_plddt = self._ttt_mean_plddt(baseline) best_result = baseline best_plddt = baseline_plddt best_step = 0 step_plddts = [baseline_plddt] cfg = self.ttt_config.merged(ttt_config).merged( {"eval_each_step": True, "automatic_best_state_reset": False} ) try: metrics = self.ttt( seq=sequence, ttt_config=cfg, fold_kwargs=fold_kwargs, ) for step_metric in metrics["step_metrics"]: step_plddt = step_metric["plddt"] step_plddts.append(step_plddt) if step_plddt > best_plddt: best_plddt = step_plddt best_step = step_metric["step"] best_result = step_metric["result"] best_result.ttt_metrics = { "losses": metrics["losses"], "step_plddts": step_plddts, "baseline_plddt": baseline_plddt, "best_plddt": best_plddt, "best_step": best_step, } return best_result finally: if "_ttt_initialized" in self.__dict__ and self._ttt_initialized: self.ttt_reset() @staticmethod def result_to_cif(result) -> str: assert not isinstance(result, list), "Pass one MolecularComplexResult at a time." return result.complex.to_mmcif() @staticmethod def result_to_pdb(result) -> str: assert not isinstance(result, list), "Pass one MolecularComplexResult at a time." return result.complex.to_protein_complex().to_pdb_string() def save_as_cif(self, result, output_path: str | Path) -> None: Path(output_path).write_text(self.result_to_cif(result)) def save_as_pdb(self, result, output_path: str | Path) -> None: Path(output_path).write_text(self.result_to_pdb(result)) def infer_protein_as_cif(self, seq: str, **forward_kwargs) -> str: return self.result_to_cif(self.fold_protein(seq, **forward_kwargs)) def infer_protein_as_pdb(self, seq: str, **forward_kwargs) -> str: return self.result_to_pdb(self.fold_protein(seq, **forward_kwargs)) class MSAEncoderBlock(nn.Module): """One MSA encoder block: OPM into pair, MSA pair-weighted averaging, triangle update.""" def __init__( self, d_msa: int, d_pair: int, d_hidden: int, n_heads_msa: int, msa_head_width: int, is_final_block: bool = False, ) -> None: super().__init__() self.is_final_block = is_final_block self.outer_product_mean = OuterProductMean(d_msa, d_hidden, d_pair) if not is_final_block: self.msa_pair_weighted_averaging = MSAPairWeightedAveraging( d_msa, d_pair, n_heads_msa, msa_head_width ) self.msa_transition = PairTransition(d_msa, expansion_ratio=4) self.tri_mul_out = TriangleMultiplicativeUpdate(dim=d_pair, _outgoing=True) self.tri_mul_in = TriangleMultiplicativeUpdate(dim=d_pair, _outgoing=False) self.pair_transition = PairTransition(d_pair, expansion_ratio=4) def set_chunk_size(self, chunk_size: int | None) -> None: self.outer_product_mean.set_chunk_size(chunk_size) self.tri_mul_out.set_chunk_size(chunk_size) self.tri_mul_in.set_chunk_size(chunk_size) if not self.is_final_block: self.msa_transition.set_chunk_size(chunk_size) self.pair_transition.set_chunk_size(chunk_size) def forward( self, m: Tensor, pair: Tensor, msa_attention_mask: Tensor, pair_attention_mask: Tensor, ) -> tuple[Tensor, Tensor]: pair = pair + self.outer_product_mean(m, msa_attention_mask) if not self.is_final_block: m = m + self.msa_pair_weighted_averaging(m, pair, pair_attention_mask) m = m + self.msa_transition(m) pair = pair + self.tri_mul_out(pair, mask=pair_attention_mask) pair = pair + self.tri_mul_in(pair, mask=pair_attention_mask) pair = pair + self.pair_transition(pair) return m, pair class MSAEncoder(nn.Module): """Stack of [`MSAEncoderBlock`] layers that conditions the pair on an MSA.""" def __init__( self, d_msa: int, d_pair: int, d_inputs: int, d_hidden: int = 32, n_layers: int = 4, n_heads_msa: int = 8, msa_head_width: int = 16, ) -> None: super().__init__() self.embed = nn.Linear(35, d_msa, bias=False) self.project_inputs = nn.Linear(d_inputs, d_msa, bias=False) self.blocks = nn.ModuleList( [ MSAEncoderBlock( d_msa=d_msa, d_pair=d_pair, d_hidden=d_hidden, n_heads_msa=n_heads_msa, msa_head_width=msa_head_width, is_final_block=(i == n_layers - 1), ) for i in range(n_layers) ] ) def set_chunk_size(self, chunk_size: int | None) -> None: for block in self.blocks: cast(MSAEncoderBlock, block).set_chunk_size(chunk_size) def forward( self, x_pair: Tensor, x_inputs: Tensor, msa_oh: Tensor, has_deletion: Tensor, deletion_value: Tensor, msa_attention_mask: Tensor, ) -> Tensor: # All inputs are pre-transposed to [B, L, M, ...] before calling. m_feat = torch.cat( [msa_oh, has_deletion.unsqueeze(-1), deletion_value.unsqueeze(-1)], dim=-1 ) m = self.embed(m_feat) + self.project_inputs(x_inputs).unsqueeze(2) tok_mask = msa_attention_mask[:, :, 0].bool() pair_attention_mask = tok_mask.unsqueeze(2) & tok_mask.unsqueeze(1) for block in self.blocks: m, x_pair = block(m, x_pair, msa_attention_mask, pair_attention_mask) return x_pair