diff --git "a/modeling_fast_esmfold.py" "b/modeling_fast_esmfold.py" --- "a/modeling_fast_esmfold.py" +++ "b/modeling_fast_esmfold.py" @@ -1,1285 +1,1323 @@ -"""FastESMFold: Self-contained ESMFold with FastESM2 attention backends + built-in Test-Time Training. - -Usage: - from transformers import AutoModel - model = AutoModel.from_pretrained("Synthyra/FastESMFold", trust_remote_code=True).cuda() - - # Basic folding - result = model.fold_protein("MKTLLILAVVA...") - print(result["plddt"], result["pdb_string"][:100]) - - # Folding with TTT (test-time training improves structure prediction) - result = model.fold_protein("MKTLLILAVVA...", ttt=True) - -Dependencies: torch, transformers, einops, peft (for LoRA TTT only) -No dependency on: esm (fair-esm), proteinttt, openfold -""" -import copy -from dataclasses import dataclass, field -from enum import Enum -from functools import wraps -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -from torch.nn import functional as F - -from einops import rearrange -from transformers import EsmTokenizer, PretrainedConfig, PreTrainedModel -from transformers.modeling_outputs import ModelOutput -from transformers.models.esm.configuration_esm import EsmConfig -from transformers.models.esm.modeling_esm import ( - EsmContactPredictionHead, - EsmEmbeddings, - EsmIntermediate, - EsmLMHead, - EsmOutput, - EsmSelfOutput, - RotaryEmbedding, -) -from transformers.models.esm.modeling_esmfold import EsmForProteinFolding - - -# ============================================================================= -# Flash Attention Detection (from FastPLMs/esm2/modeling_fastesm.py) -# ============================================================================= - -try: - from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask -except ImportError: - create_block_mask = None - flex_attention = None - BlockMask = None - -_compiled_flex_attention = None - - -def _get_flex_attention_fn(): - global _compiled_flex_attention - if flex_attention is None: - return None - flex_mod = torch.nn.attention.flex_attention - if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False): - return flex_attention - if _compiled_flex_attention is None: - _compiled_flex_attention = torch.compile(flex_attention) - return _compiled_flex_attention - - -def _infer_kernels_flash_variant(kernel) -> str | None: - if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"): - return "flash_attn2" - if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"): - return "flash_attn3" - return None - - -def _try_get_kernels_flash(): - try: - from kernels import get_kernel - except ImportError: - return None, None - - flash_kernel = None - flash_kernel_variant = None - try: - flash_kernel = get_kernel("kernels-community/flash-attn3") - flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) - assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API." - except Exception: - try: - flash_kernel = get_kernel("kernels-community/flash-attn2") - flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) - assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API." - except Exception: - flash_kernel = None - flash_kernel_variant = None - return flash_kernel, flash_kernel_variant - - -_FLASH_KERNELS_LOADED = False -FLASH_KERNEL = None -FLASH_KERNEL_VARIANT = None - - -def _ensure_flash_kernels_loaded(): - global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT - if _FLASH_KERNELS_LOADED: - return - _FLASH_KERNELS_LOADED = True - FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash() - - -def _kernels_flash_forward( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - causal: bool = False, -) -> torch.Tensor: - assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." - if FLASH_KERNEL_VARIANT == "flash_attn2": - return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0] - if FLASH_KERNEL_VARIANT == "flash_attn3": - try: - output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal) - except TypeError: - output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal) - if isinstance(output, tuple): - return output[0] - return output - raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") - - -def _kernels_flash_varlen_forward( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_in_batch_q: int, - max_seqlen_in_batch_k: int, - causal: bool = False, -) -> torch.Tensor: - assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." - if FLASH_KERNEL_VARIANT == "flash_attn2": - return FLASH_KERNEL.varlen_fwd( - q=query_states, k=key_states, v=value_states, - cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, - is_causal=causal, - )[0] - if FLASH_KERNEL_VARIANT == "flash_attn3": - try: - output = FLASH_KERNEL.flash_attn_varlen_func( - q=query_states, k=key_states, v=value_states, - cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, - causal=causal, - ) - except TypeError: - output = FLASH_KERNEL.flash_attn_varlen_func( - query_states, key_states, value_states, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_in_batch_q, max_seqlen_in_batch_k, - 0.0, None, causal, - ) - if isinstance(output, tuple): - return output[0] - return output - raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") - - -# Unpad / Pad helpers for varlen flash attention -class IndexFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, input, indices) -> torch.Tensor: - ctx.save_for_backward(indices) - assert input.ndim >= 2 - ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = other_shape.numel() - return torch.gather( - rearrange(input, "b ... -> b (...)"), 0, indices.unsqueeze(1).expand(-1, second_dim) - ).reshape(-1, *other_shape) - - @staticmethod - def backward(ctx, grad_output) -> tuple[torch.Tensor, None]: - (indices,) = ctx.saved_tensors - assert grad_output.ndim >= 2 - other_shape = grad_output.shape[1:] - grad_output = rearrange(grad_output, "b ... -> b (...)") - grad_input = torch.zeros( - [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype - ) - grad_input.scatter_(0, indices.unsqueeze(1).expand(-1, grad_output.shape[1]), grad_output) - return grad_input.reshape(ctx.first_axis_dim, *other_shape), None - - -class IndexPutFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor: - ctx.save_for_backward(indices) - assert indices.ndim == 1 - assert values.ndim >= 2 - output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) - output[indices] = values - return output - - @staticmethod - def backward(ctx, grad_output) -> tuple[torch.Tensor, None, None]: - (indices,) = ctx.saved_tensors - return grad_output[indices], None, None - - -index_first_axis = IndexFirstAxis.apply -index_put_first_axis = IndexPutFirstAxis.apply - - -def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor: - output = index_put_first_axis(hidden_states, indices, batch * seqlen) - return rearrange(output, "(b s) ... -> b s ...", b=batch) - - -def _unpad_input( - query_layer: torch.Tensor, - key_layer: torch.Tensor, - value_layer: torch.Tensor, - attention_mask_2d: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[int, int]]: - batch_size, seq_len, num_heads, head_dim = query_layer.shape - seqlens = attention_mask_2d.sum(dim=1).int() - cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0)) - max_seqlen = int(seqlens.max().item()) - indices = attention_mask_2d.flatten().nonzero(as_tuple=False).flatten() - query_layer = index_first_axis(query_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) - key_layer = index_first_axis(key_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) - value_layer = index_first_axis(value_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) - return query_layer, key_layer, value_layer, indices, (cu_seqlens, cu_seqlens), (max_seqlen, max_seqlen) - - -def kernels_flash_attention_func( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - attention_mask_2d: torch.Tensor | None = None, - causal: bool = False, -) -> torch.Tensor: - assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." - if not causal and attention_mask_2d is not None: - batch_size, q_len = query_states.shape[:2] - ( - query_states, key_states, value_states, - indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k), - ) = _unpad_input(query_states, key_states, value_states, attention_mask_2d) - attn_output_unpad = _kernels_flash_varlen_forward( - query_states=query_states, key_states=key_states, value_states=value_states, - cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, - max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k, - ) - return pad_input(attn_output_unpad, indices_q, batch_size, q_len) - else: - return _kernels_flash_forward( - query_states=query_states, key_states=key_states, value_states=value_states, causal=causal, - ) - - -# ============================================================================= -# Attention Backend Enum & Resolution -# ============================================================================= - -class AttentionBackend(Enum): - AUTO = "auto" - KERNELS_FLASH = "kernels_flash" - FLEX = "flex" - SDPA = "sdpa" - - -VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend) - -_BACKEND_CONFIRMED = False - - -def resolve_attention_backend(requested_backend: str) -> AttentionBackend: - global _BACKEND_CONFIRMED - assert requested_backend in VALID_ATTENTION_BACKENDS, ( - f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}." - ) - if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value): - _ensure_flash_kernels_loaded() - if requested_backend == AttentionBackend.AUTO.value: - if FLASH_KERNEL is not None: - resolved = AttentionBackend.KERNELS_FLASH - elif flex_attention is not None: - resolved = AttentionBackend.FLEX - else: - resolved = AttentionBackend.SDPA - elif requested_backend == AttentionBackend.KERNELS_FLASH.value: - assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment." - resolved = AttentionBackend.KERNELS_FLASH - elif requested_backend == AttentionBackend.FLEX.value: - assert flex_attention is not None, "Flex Attention is not available in this environment." - resolved = AttentionBackend.FLEX - elif requested_backend == AttentionBackend.SDPA.value: - resolved = AttentionBackend.SDPA - else: - raise AssertionError(f"Unsupported attention backend: {requested_backend}") - if not _BACKEND_CONFIRMED: - print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'") - _BACKEND_CONFIRMED = True - return resolved - - -def get_attention_mask( - effective_backend: AttentionBackend, - batch_size: int, - seq_len: int, - device: torch.device, - attention_mask: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor | None, torch.Tensor | None, "BlockMask | None"]: - if attention_mask is None: - return None, None, None - - attention_mask_2d = attention_mask.bool() - - if effective_backend == AttentionBackend.KERNELS_FLASH: - return attention_mask_2d, None, None - - if effective_backend == AttentionBackend.FLEX: - assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable." - valid_lens = attention_mask_2d.sum(dim=-1) - - def mask_mod(batch_idx, head_idx, q_idx, kv_idx): - return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx]) - - flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device) - return attention_mask_2d, None, flex_block_mask - - # SDPA / manual - attention_mask_4d = attention_mask_2d[:, None, None, :] - return attention_mask_2d, attention_mask_4d, None - - -# ============================================================================= -# Output Dataclass -# ============================================================================= - -@dataclass -class FastEsmEncoderOutput(ModelOutput): - last_hidden_state: Optional[torch.Tensor] = None - hidden_states: Optional[Tuple[torch.Tensor, ...]] = None - attentions: Optional[Tuple[torch.Tensor, ...]] = None - - -# ============================================================================= -# FastESM2 Attention Layers (multi-backend: SDPA, Flash, Flex) -# ============================================================================= - -class EsmSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type: Optional[str] = None): - super().__init__() - assert config.hidden_size % config.num_attention_heads == 0, ( - f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads})" - ) - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(config.hidden_size, self.all_head_size) - self.key = nn.Linear(config.hidden_size, self.all_head_size) - self.value = nn.Linear(config.hidden_size, self.all_head_size) - self.scale = self.attention_head_size**-0.5 - - self.dropout_prob = config.attention_probs_dropout_prob - self.config = config - self.attn_backend = resolve_attention_backend(config.attn_backend) - self.position_embedding_type = position_embedding_type or config.position_embedding_type - self.rotary_embeddings = None - if self.position_embedding_type == "rotary": - self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask_2d: torch.Tensor | None = None, - attention_mask_4d: torch.Tensor | None = None, - flex_block_mask: "BlockMask | None" = None, - output_attentions: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - batch_size, seq_length = hidden_states.shape[:-1] - hidden_shape = (batch_size, seq_length, -1, self.attention_head_size) - query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2) - key_BHLD = self.key(hidden_states).view(hidden_shape).transpose(1, 2) - value_BHLD = self.value(hidden_states).view(hidden_shape).transpose(1, 2) - - query_BHLD = query_BHLD * self.scale - - if self.position_embedding_type == "rotary": - query_BHLD, key_BHLD = self.rotary_embeddings(query_BHLD, key_BHLD) - - attn_output, attn_weights = self._attn( - query_BHLD, key_BHLD, value_BHLD, - attention_mask_2d=attention_mask_2d, - attention_mask_4d=attention_mask_4d, - flex_block_mask=flex_block_mask, - output_attentions=output_attentions, - ) - return attn_output, attn_weights - - def _attn( - self, - query_BHLD: torch.Tensor, - key_BHLD: torch.Tensor, - value_BHLD: torch.Tensor, - attention_mask_2d: torch.Tensor | None = None, - attention_mask_4d: torch.Tensor | None = None, - flex_block_mask: "BlockMask | None" = None, - output_attentions: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - if output_attentions: - return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d) - - if self.attn_backend == AttentionBackend.KERNELS_FLASH: - return self._kernels_flash_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_2d) - elif self.attn_backend == AttentionBackend.FLEX: - return self._flex_attn(query_BHLD, key_BHLD, value_BHLD, flex_block_mask) - elif self.attn_backend == AttentionBackend.SDPA: - return self._sdpa_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d) - else: - raise AssertionError(f"Unsupported resolved backend: {self.attn_backend}") - - def _manual_attn( - self, - query_BHLD: torch.Tensor, - key_BHLD: torch.Tensor, - value_BHLD: torch.Tensor, - attention_mask_4d: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2)) - if attention_mask_4d is not None: - attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf")) - attn_weights = F.softmax(attn_weights, dim=-1) - if self.dropout_prob > 0 and self.training: - attn_weights = F.dropout(attn_weights, p=self.dropout_prob, training=self.training) - context_BHLD = torch.matmul(attn_weights, value_BHLD) - attn_output = rearrange(context_BHLD, "b h s d -> b s (h d)") - return attn_output, attn_weights - - def _kernels_flash_attn( - self, - query_BHLD: torch.Tensor, - key_BHLD: torch.Tensor, - value_BHLD: torch.Tensor, - attention_mask_2d: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, None]: - query_BLHD = query_BHLD.transpose(1, 2).contiguous() - key_BLHD = key_BHLD.transpose(1, 2).contiguous() - value_BLHD = value_BHLD.transpose(1, 2).contiguous() - attn_output = kernels_flash_attention_func( - query_states=query_BLHD, key_states=key_BLHD, value_states=value_BLHD, - attention_mask_2d=attention_mask_2d, causal=False, - ) - return rearrange(attn_output, "b s h d -> b s (h d)"), None - - def _flex_attn( - self, - query_BHLD: torch.Tensor, - key_BHLD: torch.Tensor, - value_BHLD: torch.Tensor, - flex_block_mask: "BlockMask | None" = None, - ) -> tuple[torch.Tensor, None]: - assert flex_attention is not None, "Flex attention is not available in this environment." - fn = _get_flex_attention_fn() - context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0) - return rearrange(context_BHLD, "b h s d -> b s (h d)"), None - - def _sdpa_attn( - self, - query_BHLD: torch.Tensor, - key_BHLD: torch.Tensor, - value_BHLD: torch.Tensor, - attention_mask_4d: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, None]: - context_BHLD = F.scaled_dot_product_attention( - query_BHLD, key_BHLD, value_BHLD, - attn_mask=attention_mask_4d, - dropout_p=self.dropout_prob if self.training else 0.0, - scale=1.0, - ) - return rearrange(context_BHLD, "b h s d -> b s (h d)"), None - - -class EsmAttention(nn.Module): - def __init__(self, config): - super().__init__() - self.self = EsmSelfAttention(config) - self.output = EsmSelfOutput(config) - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask_2d: torch.Tensor | None = None, - attention_mask_4d: torch.Tensor | None = None, - flex_block_mask: "BlockMask | None" = None, - output_attentions: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - hidden_states_ln = self.LayerNorm(hidden_states) - attn_output, attn_weights = self.self( - hidden_states_ln, - attention_mask_2d=attention_mask_2d, - attention_mask_4d=attention_mask_4d, - flex_block_mask=flex_block_mask, - output_attentions=output_attentions, - ) - attention_output = self.output(attn_output, hidden_states) - return attention_output, attn_weights - - -class EsmLayer(nn.Module): - def __init__(self, config): - super().__init__() - self.attention = EsmAttention(config) - self.intermediate = EsmIntermediate(config) - self.output = EsmOutput(config) - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask_2d: torch.Tensor | None = None, - attention_mask_4d: torch.Tensor | None = None, - flex_block_mask: "BlockMask | None" = None, - output_attentions: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - attention_output, attn_weights = self.attention( - hidden_states, - attention_mask_2d=attention_mask_2d, - attention_mask_4d=attention_mask_4d, - flex_block_mask=flex_block_mask, - output_attentions=output_attentions, - ) - layer_output = self._feed_forward(attention_output) - return layer_output, attn_weights - - def _feed_forward(self, attention_output: torch.Tensor) -> torch.Tensor: - attention_output_ln = self.LayerNorm(attention_output) - intermediate_output = self.intermediate(attention_output_ln) - return self.output(intermediate_output, attention_output) - - -class FastEsmEncoder(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.attention_backend = resolve_attention_backend(config.attn_backend) - self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)]) - self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_hidden_states: bool = False, - output_attentions: bool = False, - ) -> FastEsmEncoderOutput: - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - attention_mask_2d, attention_mask_4d, flex_block_mask = get_attention_mask( - effective_backend=self.attention_backend, - batch_size=hidden_states.shape[0], - seq_len=hidden_states.shape[1], - device=hidden_states.device, - attention_mask=attention_mask, - ) - - for layer_module in self.layer: - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - hidden_states, attn_weights = layer_module( - hidden_states, - attention_mask_2d=attention_mask_2d, - attention_mask_4d=attention_mask_4d, - flex_block_mask=flex_block_mask, - output_attentions=output_attentions, - ) - - if all_attentions is not None: - all_attentions = all_attentions + (attn_weights,) - - if self.emb_layer_norm_after: - hidden_states = self.emb_layer_norm_after(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - return FastEsmEncoderOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - ) - - -# ============================================================================= -# FastESM Backbone (replaces EsmModel inside ESMFold) -# ============================================================================= - -class FastEsmBackbone(nn.Module): - """FastESM2 backbone with multi-backend attention. Drop-in replacement for - transformers.EsmModel inside EsmForProteinFolding. - - State dict keys match HuggingFace EsmModel exactly, so pretrained weights - load without any key remapping. - """ - - def __init__(self, config): - super().__init__() - self.config = config - self.embeddings = EsmEmbeddings(config) - self.encoder = FastEsmEncoder(config) - self.contact_head = EsmContactPredictionHead( - in_features=config.num_hidden_layers * config.num_attention_heads, bias=True - ) - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs, - ) -> FastEsmEncoderOutput: - output_attentions = output_attentions if output_attentions is not None else False - output_hidden_states = output_hidden_states if output_hidden_states is not None else False - - token_embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - ) - encoder_outputs = self.encoder( - token_embedding_output, - attention_mask=attention_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - ) - return FastEsmEncoderOutput( - last_hidden_state=encoder_outputs.last_hidden_state, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -# ============================================================================= -# TTT (Test-Time Training) Configuration and Utilities -# ============================================================================= - -_ESM_STANDARD_AA = list("ACDEFGHIKLMNPQRSTVWY") - - -class LoraInjectedLinear(nn.Module): - """LoRA-augmented linear layer matching lora_diffusion's behavior. - - Replaces an existing nn.Linear with base(x) + lora_up(lora_down(x)) * scale. - Initialization follows cloneofsimo/lora: down=Normal(0, 1/r), up=zeros. - """ - - def __init__(self, original_linear: nn.Linear, r: int = 4, scale: float = 1.0): - super().__init__() - self.linear = original_linear - in_features = original_linear.in_features - out_features = original_linear.out_features - assert r <= min(in_features, out_features), f"LoRA rank {r} exceeds dimensions ({in_features}, {out_features})" - self.lora_down = nn.Linear(in_features, r, bias=False) - self.lora_up = nn.Linear(r, out_features, bias=False) - self.scale = scale - nn.init.normal_(self.lora_down.weight, std=1.0 / r) - nn.init.zeros_(self.lora_up.weight) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(x) + self.lora_up(self.lora_down(x)) * self.scale - - -def inject_trainable_lora( - model: nn.Module, - target_class_name: str, - r: int, - scale: float, -) -> List[nn.Parameter]: - """Replace nn.Linear layers inside modules matching target_class_name with LoRA. - - Matches lora_diffusion's inject_trainable_lora behavior: finds all modules whose - class name matches target_class_name, then replaces their nn.Linear children with - LoraInjectedLinear. Returns the list of trainable LoRA parameters. - """ - lora_params: List[nn.Parameter] = [] - for _parent_name, parent_module in model.named_modules(): - if parent_module.__class__.__name__ != target_class_name: - continue - for child_name, child_module in list(parent_module.named_children()): - if not isinstance(child_module, nn.Linear): - continue - lora_linear = LoraInjectedLinear(child_module, r=r, scale=scale) - lora_linear = lora_linear.to( - device=child_module.weight.device, - dtype=child_module.weight.dtype, - ) - setattr(parent_module, child_name, lora_linear) - lora_params.extend(lora_linear.lora_down.parameters()) - lora_params.extend(lora_linear.lora_up.parameters()) - return lora_params - - -@dataclass -class TTTConfig: - lr: float = 4e-4 - ags: int = 4 - steps: int = 10 - batch_size: int = 4 - mask_ratio: float = 0.15 - crop_size: int = 1024 - bert_leave_prob: float = 0.1 - bert_replace_prob: float = 0.1 - optimizer: str = "sgd" - momentum: float = 0.0 - weight_decay: float = 0.0 - seed: Optional[int] = 0 - initial_state_reset: bool = True - freeze_embeddings: bool = True - lora_rank: int = 8 - lora_alpha: float = 32.0 - lora_target_class: str = "EsmSelfAttention" - - def verify(self) -> None: - assert self.lr > 0.0, "TTT learning rate must be positive." - assert self.ags > 0, "TTT ags must be positive." - assert self.steps >= 0, "TTT steps must be non-negative." - assert self.batch_size > 0, "TTT batch_size must be positive." - assert 0.0 < self.mask_ratio <= 1.0, "TTT mask_ratio must be in (0, 1]." - assert self.crop_size > 0, "TTT crop_size must be positive." - assert 0.0 <= self.bert_leave_prob <= 1.0 - assert 0.0 <= self.bert_replace_prob <= 1.0 - assert self.bert_leave_prob + self.bert_replace_prob <= 1.0 - assert self.optimizer in {"sgd", "adamw"} - assert self.lora_rank >= 0 - assert self.lora_alpha > 0.0 - - -def preserve_model_state(func: Callable[..., Any]) -> Callable[..., Any]: - @wraps(func) - def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - was_training = self.training - original_device = next(self.parameters()).device - original_requires_grad = { - name: parameter.requires_grad - for name, parameter in self.named_parameters() - } - try: - return func(self, *args, **kwargs) - finally: - self.train(was_training) - self.to(original_device) - for name, parameter in self.named_parameters(): - if name in original_requires_grad: - parameter.requires_grad = original_requires_grad[name] - else: - parameter.requires_grad = False - return wrapper - - -# ============================================================================= -# FastEsmFoldConfig -# ============================================================================= - -class FastEsmFoldConfig(EsmConfig): - model_type = "fast_esmfold" - - def __init__(self, attn_backend: str = "sdpa", ttt_config: Optional[Dict[str, Any]] = None, **kwargs): - super().__init__(**kwargs) - self.attn_backend = attn_backend - self.ttt_config = ttt_config or { - "lr": 4e-4, - "steps": 10, - "lora_rank": 8, - "lora_alpha": 32.0, - } - - -# ============================================================================= -# FastEsmForProteinFolding -# ============================================================================= - -class FastEsmForProteinFolding(EsmForProteinFolding): - """ESMFold with FastESM2 attention backends + built-in Test-Time Training. - - Inherits all folding logic (trunk, structure module, output_to_pdb, infer) - from transformers.EsmForProteinFolding. Replaces the ESM2 backbone with - FastESM2 for optimized attention and adds TTT for improved structure prediction. - - Key API: - result = model.fold_protein("MKTL...", ttt=True) - # result = {"plddt": float, "ptm": float, "pdb_string": str} - """ - config_class = FastEsmFoldConfig - - def __init__(self, config: FastEsmFoldConfig): - super().__init__(config) - - # Replace standard ESM2 backbone with FastESM2 (multi-backend attention) - # unless use_standard_backbone is set (for TTT debugging/compatibility) - if not config.ttt_config.get("use_standard_backbone", False): - self.esm = FastEsmBackbone(config) - self.esm.requires_grad_(False) - if config.esmfold_config.fp16_esm: - self.esm.half() - - # MLM head for TTT (pretrained EsmLMHead: Dense -> GELU -> LN -> Linear) - self.mlm_head = EsmLMHead(config) - - # TTT state (lazy initialization) - ttt_kwargs = {k: v for k, v in config.ttt_config.items() if k != "use_standard_backbone"} - self._ttt_cfg = TTTConfig(**ttt_kwargs) - self._ttt_cfg.verify() - self._ttt_initialized = False - self._ttt_initial_state = None - self._ttt_generator = torch.Generator() - if self._ttt_cfg.seed is not None: - self._ttt_generator.manual_seed(self._ttt_cfg.seed) - self._non_special_tokens_cache = None - self._ttt_tokenizer = None - - def _get_ttt_tokenizer(self) -> EsmTokenizer: - if self._ttt_tokenizer is None: - self._ttt_tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") - return self._ttt_tokenizer - - def _ensure_ttt_ready(self) -> None: - """Lazy TTT initialization. Injects LoRA adapters and saves initial state. - Must be called after weights are loaded (not in __init__).""" - if self._ttt_initialized: - return - self._ttt_initialized = True - - tokenizer = self._get_ttt_tokenizer() - vocab = tokenizer.get_vocab() - self._non_special_tokens_cache = [vocab[c] for c in _ESM_STANDARD_AA if c in vocab] - - if self._ttt_cfg.lora_rank > 0: - self.mlm_head.eval() - for p in self.mlm_head.parameters(): - p.requires_grad = False - # Seed global state before LoRA init for reproducible weight initialization - if self._ttt_cfg.seed is not None: - torch.manual_seed(self._ttt_cfg.seed) - self._inject_lora() - else: - # Legacy path: jointly-trained random linear projection head - H = self.config.hidden_size - V = self.config.vocab_size - device = next(self.esm.parameters()).device - self._ttt_lm_proj = nn.Linear(H, V, bias=True).to(device) - - if self._ttt_cfg.initial_state_reset: - self._ttt_initial_state = self._ttt_get_state() - - @property - def _uses_lora(self) -> bool: - return self._ttt_cfg.lora_rank > 0 - - def _inject_lora(self) -> None: - """Inject LoRA adapters into ESM2 attention layers (matching lora_diffusion behavior).""" - self._lora_params = inject_trainable_lora( - self.esm, - target_class_name=self._ttt_cfg.lora_target_class, - r=self._ttt_cfg.lora_rank, - scale=self._ttt_cfg.lora_alpha, - ) - assert len(self._lora_params) > 0, ( - f"No LoRA params injected. Check target_class_name='{self._ttt_cfg.lora_target_class}' " - f"matches attention modules in the backbone." - ) - - # ---- TTT State Management ---- - - def _get_lora_modules(self) -> List[LoraInjectedLinear]: - """Find all LoraInjectedLinear modules in the backbone.""" - return [m for m in self.esm.modules() if isinstance(m, LoraInjectedLinear)] - - def _ttt_get_state(self) -> Dict[str, Any]: - if self._uses_lora: - lora_state = [] - for m in self._get_lora_modules(): - lora_state.append({ - "down": m.lora_down.weight.data.clone(), - "up": m.lora_up.weight.data.clone(), - }) - return {"_lora_state": lora_state} - return { - "esm": copy.deepcopy(self.esm), - "_ttt_lm_proj": copy.deepcopy(self._ttt_lm_proj), - } - - def _ttt_set_state(self, state: Dict[str, Any]) -> None: - if "_lora_state" in state: - modules = self._get_lora_modules() - assert len(modules) == len(state["_lora_state"]) - for m, saved in zip(modules, state["_lora_state"]): - m.lora_down.weight.data.copy_(saved["down"]) - m.lora_up.weight.data.copy_(saved["up"]) - return - if "esm" in state: - self.esm = copy.deepcopy(state["esm"]) - if "_ttt_lm_proj" in state: - self._ttt_lm_proj = copy.deepcopy(state["_ttt_lm_proj"]) - - def ttt_reset(self) -> None: - """Reset model to pre-TTT state (restore initial LoRA or backbone weights).""" - assert self._ttt_initial_state is not None, "TTT reset requires initial_state_reset=True." - self._ttt_set_state(self._ttt_initial_state) - - # ---- TTT Core ---- - - def _ttt_tokenize(self, seq: str) -> torch.Tensor: - tokenizer = self._get_ttt_tokenizer() - out = tokenizer( - seq, - return_tensors="pt", - add_special_tokens=self._uses_lora, - padding=False, - truncation=False, - ) - return out["input_ids"] - - def _ttt_mask_token(self) -> int: - return self._get_ttt_tokenizer().mask_token_id - - def _ttt_get_non_special_tokens(self) -> List[int]: - if self._non_special_tokens_cache is not None: - return self._non_special_tokens_cache - tokenizer = self._get_ttt_tokenizer() - vocab = tokenizer.get_vocab() - self._non_special_tokens_cache = [vocab[c] for c in _ESM_STANDARD_AA if c in vocab] - return self._non_special_tokens_cache - - def _ttt_predict_logits(self, batch: torch.Tensor) -> torch.Tensor: - """Run ESM2 backbone + LM head to get MLM logits.""" - # Temporarily unfreeze backbone for gradient flow during TTT - output = self.esm(input_ids=batch) - hidden = output.last_hidden_state - if self._uses_lora: - return self.mlm_head(hidden) - return self._ttt_lm_proj(hidden) - - def _ttt_sample_batch( - self, - x: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - _, seq_len = x.shape - batch_size = self._ttt_cfg.batch_size - crop_size = min(self._ttt_cfg.crop_size, seq_len) - - x_expanded = x.expand(batch_size, -1) - if seq_len == crop_size: - start_indices = torch.zeros(batch_size, dtype=torch.long) - else: - start_indices = torch.randint( - 0, seq_len - crop_size + 1, (batch_size,), - generator=self._ttt_generator, - ).to(torch.long) - - batch_cropped = torch.stack([ - x_expanded[index, start : start + crop_size] - for index, start in enumerate(start_indices) - ]) - - non_special_tokens = set(self._ttt_get_non_special_tokens()) - mask = torch.zeros((batch_size, crop_size), dtype=torch.bool) - mask_token_id = self._ttt_mask_token() - - for row_index in range(batch_size): - non_special_positions = [ - col for col in range(crop_size) - if batch_cropped[row_index, col].item() in non_special_tokens - ] - assert len(non_special_positions) > 0, "Sequence must contain at least one non-special token." - num_to_mask = max(1, int(round(len(non_special_positions) * self._ttt_cfg.mask_ratio))) - sampled_indices = torch.randperm( - len(non_special_positions), generator=self._ttt_generator, - )[:num_to_mask] - positions_to_mask = torch.tensor(non_special_positions, dtype=torch.long)[sampled_indices] - mask[row_index, positions_to_mask] = True - - batch_masked = batch_cropped.clone() - for row_index in range(batch_size): - masked_positions = torch.nonzero(mask[row_index], as_tuple=True)[0] - for masked_position in masked_positions: - probability = float(torch.rand(1, generator=self._ttt_generator).item()) - if probability < 1.0 - self._ttt_cfg.bert_leave_prob - self._ttt_cfg.bert_replace_prob: - batch_masked[row_index, masked_position] = mask_token_id - continue - if probability < 1.0 - self._ttt_cfg.bert_leave_prob: - replacement_candidates = self._ttt_get_non_special_tokens() - replacement_index = int(torch.randint( - 0, len(replacement_candidates), (1,), generator=self._ttt_generator, - ).item()) - batch_masked[row_index, masked_position] = replacement_candidates[replacement_index] - - return batch_masked, batch_cropped, mask, start_indices - - def _ttt_cross_entropy_loss( - self, - logits: torch.Tensor, - targets: torch.Tensor, - mask: torch.Tensor, - ) -> torch.Tensor: - assert logits.ndim == 3, "Logits must be [batch, seq, vocab]." - _, _, vocab_size = logits.shape - logits_flat = logits.reshape(-1, vocab_size) - targets_flat = targets.reshape(-1) - mask_flat = mask.reshape(-1) - assert int(mask_flat.sum().item()) > 0, "TTT mask must select at least one token." - loss = F.cross_entropy( - logits_flat[mask_flat], - targets_flat[mask_flat], - reduction="none", - ) - masked_tokens_per_seq = mask.sum(dim=1).tolist() - per_sequence_losses = torch.split(loss, masked_tokens_per_seq) - return torch.stack([sl.mean() for sl in per_sequence_losses]).mean() - - def _ttt_get_optimizer(self, parameters) -> torch.optim.Optimizer: - if self._ttt_cfg.optimizer == "sgd": - return torch.optim.SGD( - parameters, - lr=self._ttt_cfg.lr, - momentum=self._ttt_cfg.momentum, - weight_decay=self._ttt_cfg.weight_decay, - ) - return torch.optim.AdamW( - parameters, - lr=self._ttt_cfg.lr, - weight_decay=self._ttt_cfg.weight_decay, - ) - - def _lora_ttt(self, seq: str) -> Dict[str, List[float]]: - """LoRA TTT: only LoRA adapter weights are trained, mlm_head is frozen.""" - x = self._ttt_tokenize(seq) - device = next(self.parameters()).device - non_blocking = device.type == "cuda" - losses = [] - - if self._ttt_cfg.steps == 0: - return {"losses": losses} - - for parameter in self.parameters(): - parameter.requires_grad = False - for p in self._lora_params: - p.requires_grad = True - optimizer = self._ttt_get_optimizer(self._lora_params) - optimizer.zero_grad(set_to_none=True) - - self.eval() - for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags): - batch_masked, targets, mask, start_indices = self._ttt_sample_batch(x) - batch_masked = batch_masked.to(device, non_blocking=non_blocking) - targets = targets.to(device, non_blocking=non_blocking) - mask = mask.to(device, non_blocking=non_blocking) - - self.train() - logits = self._ttt_predict_logits(batch_masked) - loss = self._ttt_cross_entropy_loss(logits, targets, mask) - loss.backward() - losses.append(float(loss.detach().cpu().item())) - - if (step + 1) % self._ttt_cfg.ags == 0: - optimizer.step() - optimizer.zero_grad(set_to_none=True) - - self.eval() - return {"losses": losses} - - def _legacy_ttt(self, seq: str) -> Dict[str, List[float]]: - """Legacy TTT: full fine-tuning of ESM2 backbone with random linear projection head.""" - x = self._ttt_tokenize(seq) - device = next(self.parameters()).device - non_blocking = device.type == "cuda" - losses = [] - - if self._ttt_cfg.steps == 0: - return {"losses": losses} - - # Full fine-tune: all backbone params trainable - for parameter in self.parameters(): - parameter.requires_grad = False - for parameter in self.esm.parameters(): - parameter.requires_grad = True - if self._ttt_cfg.freeze_embeddings: - for parameter in self.esm.embeddings.parameters(): - parameter.requires_grad = False - for parameter in self._ttt_lm_proj.parameters(): - parameter.requires_grad = True - - trainable_params = filter(lambda p: p.requires_grad, self.parameters()) - optimizer = self._ttt_get_optimizer(trainable_params) - optimizer.zero_grad(set_to_none=True) - - self.eval() - for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags): - batch_masked, targets, mask, start_indices = self._ttt_sample_batch(x) - batch_masked = batch_masked.to(device, non_blocking=non_blocking) - targets = targets.to(device, non_blocking=non_blocking) - mask = mask.to(device, non_blocking=non_blocking) - - self.train() - logits = self._ttt_predict_logits(batch_masked) - loss = self._ttt_cross_entropy_loss(logits, targets, mask) - loss.backward() - losses.append(float(loss.detach().cpu().item())) - - if (step + 1) % self._ttt_cfg.ags == 0: - optimizer.step() - optimizer.zero_grad(set_to_none=True) - - self.eval() - return {"losses": losses} - - @preserve_model_state - def ttt(self, seq: str) -> Dict[str, List[float]]: - """Run test-time training on a single sequence using masked language modeling. - - Adapts the ESM2 backbone (via LoRA or full fine-tuning) to the input sequence - before structure prediction. Call fold_protein(seq, ttt=True) for the full pipeline. - - Args: - seq: Protein sequence (single-letter amino acid codes) - - Returns: - Dict with "losses" key containing per-step MLM loss values - """ - self._ensure_ttt_ready() - # TTT requires fp32 for stable gradient computation. ESMFold typically - # runs the backbone in fp16, but small LoRA updates vanish in half precision. - esm_dtype = next(self.esm.parameters()).dtype - if esm_dtype != torch.float32: - self.esm.float() - self.mlm_head.float() - if self._uses_lora: - result = self._lora_ttt(seq) - else: - result = self._legacy_ttt(seq) - # Restore original dtype (backbone back to fp16 for inference) - if esm_dtype != torch.float32: - self.esm.to(esm_dtype) - self.mlm_head.to(esm_dtype) - return result - - # ---- High-Level API ---- - - def _fold_single(self, sequence: str, return_pdb_string: bool = True) -> Dict[str, Any]: - """Fold a sequence once and return pLDDT, ptm, and optionally PDB string.""" - with torch.no_grad(): - output = self.infer(sequence) - plddt = output["plddt"] - if plddt.dim() >= 2: - mean_plddt = float(plddt.mean(dim=-1).mean().item()) - else: - mean_plddt = float(plddt.mean().item()) - result = { - "plddt": mean_plddt, - "ptm": float(output["ptm"].item()) if "ptm" in output else None, - } - if return_pdb_string: - pdb_strings = self.output_to_pdb(output) - result["pdb_string"] = pdb_strings[0] if isinstance(pdb_strings, list) else pdb_strings - return result - - def fold_protein( - self, - sequence: str, - return_pdb_string: bool = True, - ) -> Dict[str, Any]: - """Fold a protein sequence with test-time training. - - Runs TTT (masked language model adaptation via LoRA) for the configured - number of steps, folding after each optimizer step to track pLDDT. Returns - the structure with the highest pLDDT across all steps (including baseline). - - Args: - sequence: Protein sequence (single-letter amino acid codes) - return_pdb_string: If True, include PDB string in output - - Returns: - Dict with keys: - - plddt: float, best mean pLDDT across all TTT steps - - ptm: float, predicted TM-score from best step - - pdb_string: str (if return_pdb_string=True), PDB from best step - - step_plddts: list[float], pLDDT at each step [baseline, s1, ..., s10] - - best_step: int, which step produced best structure (0=baseline) - """ - self._ensure_ttt_ready() - - # Cast to fp32 for TTT stability - esm_dtype = next(self.esm.parameters()).dtype - if esm_dtype != torch.float32: - self.esm.float() - self.mlm_head.float() - - device = next(self.parameters()).device - non_blocking = device.type == "cuda" - - # Step 0: baseline fold (no TTT adaptation) - best = self._fold_single(sequence, return_pdb_string=return_pdb_string) - step_plddts = [best["plddt"]] - - if self._ttt_cfg.steps > 0: - # Tokenize for masked LM training - x = self._ttt_tokenize(sequence) - - # Freeze all, unfreeze LoRA - for p in self.parameters(): - p.requires_grad = False - if self._uses_lora: - for p in self._lora_params: - p.requires_grad = True - optimizer = self._ttt_get_optimizer(self._lora_params) - else: - for p in self.esm.parameters(): - p.requires_grad = True - if self._ttt_cfg.freeze_embeddings: - for p in self.esm.embeddings.parameters(): - p.requires_grad = False - for p in self._ttt_lm_proj.parameters(): - p.requires_grad = True - trainable = [p for p in self.parameters() if p.requires_grad] - optimizer = self._ttt_get_optimizer(trainable) - optimizer.zero_grad(set_to_none=True) - - self.eval() - for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags): - batch_masked, targets, mask, _start = self._ttt_sample_batch(x) - batch_masked = batch_masked.to(device, non_blocking=non_blocking) - targets = targets.to(device, non_blocking=non_blocking) - mask = mask.to(device, non_blocking=non_blocking) - - self.train() - logits = self._ttt_predict_logits(batch_masked) - loss = self._ttt_cross_entropy_loss(logits, targets, mask) - loss.backward() - - if (step + 1) % self._ttt_cfg.ags == 0: - optimizer.step() - optimizer.zero_grad(set_to_none=True) - - # Fold after this optimizer step - self.eval() - current = self._fold_single(sequence, return_pdb_string=return_pdb_string) - step_plddts.append(current["plddt"]) - if current["plddt"] > best["plddt"]: - best = current - - self.eval() - - # Restore requires_grad - for p in self.parameters(): - p.requires_grad = False - - # Reset LoRA weights for next sequence - self.ttt_reset() - - # Restore dtype - if esm_dtype != torch.float32: - self.esm.to(esm_dtype) - self.mlm_head.to(esm_dtype) - - return { - "plddt": best["plddt"], - "ptm": best["ptm"], - "pdb_string": best.get("pdb_string"), - "step_plddts": step_plddts, - "best_step": step_plddts.index(max(step_plddts)), - } +import torch +import torch._inductor.config as inductor_config +import torch._dynamo as dynamo + +# Enable TensorFloat32 tensor cores for float32 matmul (Ampere+ GPUs) +# Provides significant speedup with minimal precision loss +torch.set_float32_matmul_precision('high') + +# Enable TF32 for matrix multiplications and cuDNN operations +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + +# Enable cuDNN autotuner - finds fastest algorithms for your hardware +# Best when input sizes are consistent; may slow down first iterations +torch.backends.cudnn.benchmark = True + +# Deterministic operations off for speed (set True if reproducibility needed) +torch.backends.cudnn.deterministic = False +inductor_config.max_autotune_gemm_backends = "ATEN,CUTLASS,FBGEMM" + +dynamo.config.capture_scalar_outputs = True +torch._dynamo.config.recompile_limit = 16 + +"""Shared attention infrastructure for all FastPLMs models. + +Contains: AttentionBackend enum, backend resolution, mask creation, +flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities. +""" +from enum import Enum +from typing import Optional + +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange + +try: + from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask +except ImportError: + create_block_mask = None + flex_attention = None + BlockMask = None + +_compiled_flex_attention = None + + +def _get_flex_attention_fn(): + """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.""" + global _compiled_flex_attention + if flex_attention is None: + return None + flex_mod = torch.nn.attention.flex_attention + if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False): + return flex_attention + if _compiled_flex_attention is None: + _compiled_flex_attention = torch.compile(flex_attention) + return _compiled_flex_attention + + +### Kernels Flash Attention Detection +def _infer_kernels_flash_variant(kernel) -> str | None: + if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"): + return "flash_attn2" + if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"): + return "flash_attn3" + return None + + +def _try_get_kernels_flash(): + try: + from kernels import get_kernel + except ImportError: + return None, None + + flash_kernel = None + flash_kernel_variant = None + try: + flash_kernel = get_kernel("kernels-community/flash-attn3") + flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) + assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API." + except Exception: + try: + flash_kernel = get_kernel("kernels-community/flash-attn2") + flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) + assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API." + except Exception: + flash_kernel = None + flash_kernel_variant = None + return flash_kernel, flash_kernel_variant + + +_FLASH_KERNELS_LOADED = False +FLASH_KERNEL = None +FLASH_KERNEL_VARIANT = None + + +def _ensure_flash_kernels_loaded(): + global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT + if _FLASH_KERNELS_LOADED: + return + _FLASH_KERNELS_LOADED = True + FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash() + + +def _kernels_flash_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + causal: bool = False, +) -> torch.Tensor: + assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." + if FLASH_KERNEL_VARIANT == "flash_attn2": + return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0] + if FLASH_KERNEL_VARIANT == "flash_attn3": + try: + output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal) + except TypeError: + output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal) + if isinstance(output, tuple): + return output[0] + return output + raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") + + +def _kernels_flash_varlen_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_in_batch_q: int, + max_seqlen_in_batch_k: int, + causal: bool = False, +) -> torch.Tensor: + assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." + if FLASH_KERNEL_VARIANT == "flash_attn2": + return FLASH_KERNEL.varlen_fwd( + q=query_states, k=key_states, v=value_states, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, + is_causal=causal, + )[0] + if FLASH_KERNEL_VARIANT == "flash_attn3": + try: + output = FLASH_KERNEL.flash_attn_varlen_func( + q=query_states, k=key_states, v=value_states, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, + causal=causal, + ) + except TypeError: + output = FLASH_KERNEL.flash_attn_varlen_func( + query_states, key_states, value_states, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_in_batch_q, max_seqlen_in_batch_k, + 0.0, None, causal, + ) + if isinstance(output, tuple): + return output[0] + return output + raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") + + +### Unpad / Pad helpers for varlen flash attention +class IndexFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices) -> torch.Tensor: + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + return torch.gather( + rearrange(input, "b ... -> b (...)"), 0, indices.unsqueeze(1).expand(-1, second_dim) + ).reshape(-1, *other_shape) + + @staticmethod + def backward(ctx, grad_output) -> tuple[torch.Tensor, None]: + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, "b ... -> b (...)") + grad_input = torch.zeros( + [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype + ) + grad_input.scatter_(0, indices.unsqueeze(1).expand(-1, grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor: + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) + output[indices] = values + return output + + @staticmethod + def backward(ctx, grad_output) -> tuple[torch.Tensor, None, None]: + (indices,) = ctx.saved_tensors + return grad_output[indices], None, None + + +index_first_axis = IndexFirstAxis.apply +index_put_first_axis = IndexPutFirstAxis.apply + + +def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor: + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +def _unpad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask_2d: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[int, int]]: + batch_size, seq_len, num_heads, head_dim = query_layer.shape + seqlens = attention_mask_2d.sum(dim=1).int() + cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0)) + max_seqlen = int(seqlens.max().item()) + indices = attention_mask_2d.flatten().nonzero(as_tuple=False).flatten() + query_layer = index_first_axis(query_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) + key_layer = index_first_axis(key_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) + value_layer = index_first_axis(value_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) + return query_layer, key_layer, value_layer, indices, (cu_seqlens, cu_seqlens), (max_seqlen, max_seqlen) + + +def kernels_flash_attention_func( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask_2d: torch.Tensor | None = None, + causal: bool = False, +) -> torch.Tensor: + assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." + if not causal and attention_mask_2d is not None: + batch_size, q_len = query_states.shape[:2] + ( + query_states, key_states, value_states, + indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k), + ) = _unpad_input(query_states, key_states, value_states, attention_mask_2d) + attn_output_unpad = _kernels_flash_varlen_forward( + query_states=query_states, key_states=key_states, value_states=value_states, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k, + ) + return pad_input(attn_output_unpad, indices_q, batch_size, q_len) + else: + return _kernels_flash_forward( + query_states=query_states, key_states=key_states, value_states=value_states, causal=causal, + ) + + +### Attention Backend Enum & Resolution +class AttentionBackend(Enum): + AUTO = "auto" + KERNELS_FLASH = "kernels_flash" + FLEX = "flex" + SDPA = "sdpa" + + +VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend) + + +_BACKEND_CONFIRMED = False + + +def resolve_attention_backend(requested_backend: str) -> AttentionBackend: + global _BACKEND_CONFIRMED + assert requested_backend in VALID_ATTENTION_BACKENDS, ( + f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}." + ) + if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value): + _ensure_flash_kernels_loaded() + if requested_backend == AttentionBackend.AUTO.value: + if FLASH_KERNEL is not None: + resolved = AttentionBackend.KERNELS_FLASH + elif flex_attention is not None: + resolved = AttentionBackend.FLEX + else: + resolved = AttentionBackend.SDPA + elif requested_backend == AttentionBackend.KERNELS_FLASH.value: + assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment." + resolved = AttentionBackend.KERNELS_FLASH + elif requested_backend == AttentionBackend.FLEX.value: + assert flex_attention is not None, "Flex Attention is not available in this environment." + resolved = AttentionBackend.FLEX + elif requested_backend == AttentionBackend.SDPA.value: + resolved = AttentionBackend.SDPA + else: + raise AssertionError(f"Unsupported attention backend: {requested_backend}") + if not _BACKEND_CONFIRMED: + print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'") + _BACKEND_CONFIRMED = True + return resolved + + +@torch.compiler.disable +def get_attention_mask( + effective_backend: AttentionBackend, + batch_size: int, + seq_len: int, + device: torch.device, + attention_mask: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor | None, torch.Tensor | None, "BlockMask | None"]: + """Build padding masks once for all encoder layers. + + Returns (attention_mask_2d, attention_mask_4d, flex_block_mask). + """ + if attention_mask is None: + return None, None, None + + attention_mask_2d = attention_mask.bool() + + if effective_backend == AttentionBackend.KERNELS_FLASH: + return attention_mask_2d, None, None + + if effective_backend == AttentionBackend.FLEX: + assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable." + valid_lens = attention_mask_2d.sum(dim=-1) + + def mask_mod(batch_idx, head_idx, q_idx, kv_idx): + return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx]) + + flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device) + return attention_mask_2d, None, flex_block_mask + + # SDPA / manual -- only mask the key dimension so padding query positions attend to + # real keys and produce valid (non-NaN) outputs instead of NaN from softmax(-inf,...,-inf). + attention_mask_4d = attention_mask_2d[:, None, None, :] + return attention_mask_2d, attention_mask_4d, None + +"""FastESMFold: Self-contained ESMFold with FastESM2 attention backends + built-in Test-Time Training. + +Usage: + from transformers import AutoModel + model = AutoModel.from_pretrained("Synthyra/FastESMFold", trust_remote_code=True).cuda() + + # Basic folding + result = model.fold_protein("MKTLLILAVVA...") + print(result["plddt"], result["pdb_string"][:100]) + + # Folding with TTT (test-time training improves structure prediction) + result = model.fold_protein("MKTLLILAVVA...", ttt=True) + +Dependencies: torch, transformers, einops, peft (for LoRA TTT only) +No dependency on: esm (fair-esm), proteinttt, openfold +""" +import copy +from dataclasses import dataclass, field +from functools import wraps +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from einops import rearrange +from transformers import EsmTokenizer, PretrainedConfig, PreTrainedModel +from transformers.modeling_outputs import ModelOutput +from transformers.models.esm.configuration_esm import EsmConfig +from transformers.models.esm.modeling_esm import ( + EsmContactPredictionHead, + EsmEmbeddings, + EsmIntermediate, + EsmLMHead, + EsmOutput, + EsmSelfOutput, + RotaryEmbedding, +) +from transformers.models.esm.modeling_esmfold import EsmForProteinFolding + + + + +# ============================================================================= +# Output Dataclass +# ============================================================================= + +@dataclass +class FastEsmEncoderOutput(ModelOutput): + last_hidden_state: Optional[torch.Tensor] = None + hidden_states: Optional[Tuple[torch.Tensor, ...]] = None + attentions: Optional[Tuple[torch.Tensor, ...]] = None + + +# ============================================================================= +# FastESM2 Attention Layers (multi-backend: SDPA, Flash, Flex) +# ============================================================================= + +class EsmSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type: Optional[str] = None): + super().__init__() + assert config.hidden_size % config.num_attention_heads == 0, ( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + self.scale = self.attention_head_size**-0.5 + + self.dropout_prob = config.attention_probs_dropout_prob + self.config = config + self.attn_backend = resolve_attention_backend(config.attn_backend) + self.position_embedding_type = position_embedding_type or config.position_embedding_type + self.rotary_embeddings = None + if self.position_embedding_type == "rotary": + self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask_2d: torch.Tensor | None = None, + attention_mask_4d: torch.Tensor | None = None, + flex_block_mask: "BlockMask | None" = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + batch_size, seq_length = hidden_states.shape[:-1] + hidden_shape = (batch_size, seq_length, -1, self.attention_head_size) + query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_BHLD = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_BHLD = self.value(hidden_states).view(hidden_shape).transpose(1, 2) + + query_BHLD = query_BHLD * self.scale + + if self.position_embedding_type == "rotary": + query_BHLD, key_BHLD = self.rotary_embeddings(query_BHLD, key_BHLD) + + attn_output, attn_weights = self._attn( + query_BHLD, key_BHLD, value_BHLD, + attention_mask_2d=attention_mask_2d, + attention_mask_4d=attention_mask_4d, + flex_block_mask=flex_block_mask, + output_attentions=output_attentions, + ) + return attn_output, attn_weights + + def _attn( + self, + query_BHLD: torch.Tensor, + key_BHLD: torch.Tensor, + value_BHLD: torch.Tensor, + attention_mask_2d: torch.Tensor | None = None, + attention_mask_4d: torch.Tensor | None = None, + flex_block_mask: "BlockMask | None" = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if output_attentions: + return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d) + + if self.attn_backend == AttentionBackend.KERNELS_FLASH: + return self._kernels_flash_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_2d) + elif self.attn_backend == AttentionBackend.FLEX: + return self._flex_attn(query_BHLD, key_BHLD, value_BHLD, flex_block_mask) + elif self.attn_backend == AttentionBackend.SDPA: + return self._sdpa_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d) + else: + raise AssertionError(f"Unsupported resolved backend: {self.attn_backend}") + + def _manual_attn( + self, + query_BHLD: torch.Tensor, + key_BHLD: torch.Tensor, + value_BHLD: torch.Tensor, + attention_mask_4d: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2)) + if attention_mask_4d is not None: + attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf")) + attn_weights = F.softmax(attn_weights, dim=-1) + if self.dropout_prob > 0 and self.training: + attn_weights = F.dropout(attn_weights, p=self.dropout_prob, training=self.training) + context_BHLD = torch.matmul(attn_weights, value_BHLD) + attn_output = rearrange(context_BHLD, "b h s d -> b s (h d)") + return attn_output, attn_weights + + def _kernels_flash_attn( + self, + query_BHLD: torch.Tensor, + key_BHLD: torch.Tensor, + value_BHLD: torch.Tensor, + attention_mask_2d: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, None]: + query_BLHD = query_BHLD.transpose(1, 2).contiguous() + key_BLHD = key_BHLD.transpose(1, 2).contiguous() + value_BLHD = value_BHLD.transpose(1, 2).contiguous() + attn_output = kernels_flash_attention_func( + query_states=query_BLHD, key_states=key_BLHD, value_states=value_BLHD, + attention_mask_2d=attention_mask_2d, causal=False, + ) + return rearrange(attn_output, "b s h d -> b s (h d)"), None + + def _flex_attn( + self, + query_BHLD: torch.Tensor, + key_BHLD: torch.Tensor, + value_BHLD: torch.Tensor, + flex_block_mask: "BlockMask | None" = None, + ) -> tuple[torch.Tensor, None]: + assert flex_attention is not None, "Flex attention is not available in this environment." + fn = _get_flex_attention_fn() + context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0) + return rearrange(context_BHLD, "b h s d -> b s (h d)"), None + + def _sdpa_attn( + self, + query_BHLD: torch.Tensor, + key_BHLD: torch.Tensor, + value_BHLD: torch.Tensor, + attention_mask_4d: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, None]: + context_BHLD = F.scaled_dot_product_attention( + query_BHLD, key_BHLD, value_BHLD, + attn_mask=attention_mask_4d, + dropout_p=self.dropout_prob if self.training else 0.0, + scale=1.0, + ) + return rearrange(context_BHLD, "b h s d -> b s (h d)"), None + + +class EsmAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = EsmSelfAttention(config) + self.output = EsmSelfOutput(config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask_2d: torch.Tensor | None = None, + attention_mask_4d: torch.Tensor | None = None, + flex_block_mask: "BlockMask | None" = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + hidden_states_ln = self.LayerNorm(hidden_states) + attn_output, attn_weights = self.self( + hidden_states_ln, + attention_mask_2d=attention_mask_2d, + attention_mask_4d=attention_mask_4d, + flex_block_mask=flex_block_mask, + output_attentions=output_attentions, + ) + attention_output = self.output(attn_output, hidden_states) + return attention_output, attn_weights + + +class EsmLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = EsmAttention(config) + self.intermediate = EsmIntermediate(config) + self.output = EsmOutput(config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask_2d: torch.Tensor | None = None, + attention_mask_4d: torch.Tensor | None = None, + flex_block_mask: "BlockMask | None" = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + attention_output, attn_weights = self.attention( + hidden_states, + attention_mask_2d=attention_mask_2d, + attention_mask_4d=attention_mask_4d, + flex_block_mask=flex_block_mask, + output_attentions=output_attentions, + ) + layer_output = self._feed_forward(attention_output) + return layer_output, attn_weights + + def _feed_forward(self, attention_output: torch.Tensor) -> torch.Tensor: + attention_output_ln = self.LayerNorm(attention_output) + intermediate_output = self.intermediate(attention_output_ln) + return self.output(intermediate_output, attention_output) + + +class FastEsmEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.attention_backend = resolve_attention_backend(config.attn_backend) + self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)]) + self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_hidden_states: bool = False, + output_attentions: bool = False, + ) -> FastEsmEncoderOutput: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + attention_mask_2d, attention_mask_4d, flex_block_mask = get_attention_mask( + effective_backend=self.attention_backend, + batch_size=hidden_states.shape[0], + seq_len=hidden_states.shape[1], + device=hidden_states.device, + attention_mask=attention_mask, + ) + + for layer_module in self.layer: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states, attn_weights = layer_module( + hidden_states, + attention_mask_2d=attention_mask_2d, + attention_mask_4d=attention_mask_4d, + flex_block_mask=flex_block_mask, + output_attentions=output_attentions, + ) + + if all_attentions is not None: + all_attentions = all_attentions + (attn_weights,) + + if self.emb_layer_norm_after: + hidden_states = self.emb_layer_norm_after(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return FastEsmEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + +# ============================================================================= +# FastESM Backbone (replaces EsmModel inside ESMFold) +# ============================================================================= + +class FastEsmBackbone(nn.Module): + """FastESM2 backbone with multi-backend attention. Drop-in replacement for + transformers.EsmModel inside EsmForProteinFolding. + + State dict keys match HuggingFace EsmModel exactly, so pretrained weights + load without any key remapping. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.embeddings = EsmEmbeddings(config) + self.encoder = FastEsmEncoder(config) + self.contact_head = EsmContactPredictionHead( + in_features=config.num_hidden_layers * config.num_attention_heads, bias=True + ) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> FastEsmEncoderOutput: + output_attentions = output_attentions if output_attentions is not None else False + output_hidden_states = output_hidden_states if output_hidden_states is not None else False + + token_embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + ) + encoder_outputs = self.encoder( + token_embedding_output, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + return FastEsmEncoderOutput( + last_hidden_state=encoder_outputs.last_hidden_state, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# ============================================================================= +# TTT (Test-Time Training) Configuration and Utilities +# ============================================================================= + +_ESM_STANDARD_AA = list("ACDEFGHIKLMNPQRSTVWY") + + +class LoraInjectedLinear(nn.Module): + """LoRA-augmented linear layer matching lora_diffusion's behavior. + + Replaces an existing nn.Linear with base(x) + lora_up(lora_down(x)) * scale. + Initialization follows cloneofsimo/lora: down=Normal(0, 1/r), up=zeros. + """ + + def __init__(self, original_linear: nn.Linear, r: int = 4, scale: float = 1.0): + super().__init__() + self.linear = original_linear + in_features = original_linear.in_features + out_features = original_linear.out_features + assert r <= min(in_features, out_features), f"LoRA rank {r} exceeds dimensions ({in_features}, {out_features})" + self.lora_down = nn.Linear(in_features, r, bias=False) + self.lora_up = nn.Linear(r, out_features, bias=False) + self.scale = scale + nn.init.normal_(self.lora_down.weight, std=1.0 / r) + nn.init.zeros_(self.lora_up.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + self.lora_up(self.lora_down(x)) * self.scale + + +def inject_trainable_lora( + model: nn.Module, + target_class_name: str, + r: int, + scale: float, +) -> List[nn.Parameter]: + """Replace nn.Linear layers inside modules matching target_class_name with LoRA. + + Matches lora_diffusion's inject_trainable_lora behavior: finds all modules whose + class name matches target_class_name, then replaces their nn.Linear children with + LoraInjectedLinear. Returns the list of trainable LoRA parameters. + """ + lora_params: List[nn.Parameter] = [] + for _parent_name, parent_module in model.named_modules(): + if parent_module.__class__.__name__ != target_class_name: + continue + for child_name, child_module in list(parent_module.named_children()): + if not isinstance(child_module, nn.Linear): + continue + lora_linear = LoraInjectedLinear(child_module, r=r, scale=scale) + lora_linear = lora_linear.to( + device=child_module.weight.device, + dtype=child_module.weight.dtype, + ) + setattr(parent_module, child_name, lora_linear) + lora_params.extend(lora_linear.lora_down.parameters()) + lora_params.extend(lora_linear.lora_up.parameters()) + return lora_params + + +@dataclass +class TTTConfig: + lr: float = 4e-4 + ags: int = 4 + steps: int = 10 + batch_size: int = 4 + mask_ratio: float = 0.15 + crop_size: int = 1024 + bert_leave_prob: float = 0.1 + bert_replace_prob: float = 0.1 + optimizer: str = "sgd" + momentum: float = 0.0 + weight_decay: float = 0.0 + seed: Optional[int] = 0 + initial_state_reset: bool = True + freeze_embeddings: bool = True + lora_rank: int = 8 + lora_alpha: float = 32.0 + lora_target_class: str = "EsmSelfAttention" + + def verify(self) -> None: + assert self.lr > 0.0, "TTT learning rate must be positive." + assert self.ags > 0, "TTT ags must be positive." + assert self.steps >= 0, "TTT steps must be non-negative." + assert self.batch_size > 0, "TTT batch_size must be positive." + assert 0.0 < self.mask_ratio <= 1.0, "TTT mask_ratio must be in (0, 1]." + assert self.crop_size > 0, "TTT crop_size must be positive." + assert 0.0 <= self.bert_leave_prob <= 1.0 + assert 0.0 <= self.bert_replace_prob <= 1.0 + assert self.bert_leave_prob + self.bert_replace_prob <= 1.0 + assert self.optimizer in {"sgd", "adamw"} + assert self.lora_rank >= 0 + assert self.lora_alpha > 0.0 + + +def preserve_model_state(func: Callable[..., Any]) -> Callable[..., Any]: + @wraps(func) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + was_training = self.training + original_device = next(self.parameters()).device + original_requires_grad = { + name: parameter.requires_grad + for name, parameter in self.named_parameters() + } + try: + return func(self, *args, **kwargs) + finally: + self.train(was_training) + self.to(original_device) + for name, parameter in self.named_parameters(): + if name in original_requires_grad: + parameter.requires_grad = original_requires_grad[name] + else: + parameter.requires_grad = False + return wrapper + + +# ============================================================================= +# FastEsmFoldConfig +# ============================================================================= + +class FastEsmFoldConfig(EsmConfig): + model_type = "fast_esmfold" + + def __init__(self, attn_backend: str = "sdpa", ttt_config: Optional[Dict[str, Any]] = None, **kwargs): + super().__init__(**kwargs) + self.attn_backend = attn_backend + self.ttt_config = ttt_config or { + "lr": 4e-4, + "steps": 10, + "lora_rank": 8, + "lora_alpha": 32.0, + } + + +# ============================================================================= +# FastEsmForProteinFolding +# ============================================================================= + +class FastEsmForProteinFolding(EsmForProteinFolding): + """ESMFold with FastESM2 attention backends + built-in Test-Time Training. + + Inherits all folding logic (trunk, structure module, output_to_pdb, infer) + from transformers.EsmForProteinFolding. Replaces the ESM2 backbone with + FastESM2 for optimized attention and adds TTT for improved structure prediction. + + Key API: + result = model.fold_protein("MKTL...", ttt=True) + # result = {"plddt": float, "ptm": float, "pdb_string": str} + """ + config_class = FastEsmFoldConfig + + def __init__(self, config: FastEsmFoldConfig): + super().__init__(config) + + # Replace standard ESM2 backbone with FastESM2 (multi-backend attention) + # unless use_standard_backbone is set (for TTT debugging/compatibility) + if not config.ttt_config.get("use_standard_backbone", False): + self.esm = FastEsmBackbone(config) + self.esm.requires_grad_(False) + if config.esmfold_config.fp16_esm: + self.esm.half() + + # MLM head for TTT (pretrained EsmLMHead: Dense -> GELU -> LN -> Linear) + self.mlm_head = EsmLMHead(config) + + # TTT state (lazy initialization) + ttt_kwargs = {k: v for k, v in config.ttt_config.items() if k != "use_standard_backbone"} + self._ttt_cfg = TTTConfig(**ttt_kwargs) + self._ttt_cfg.verify() + self._ttt_initialized = False + self._ttt_initial_state = None + self._ttt_generator = torch.Generator() + if self._ttt_cfg.seed is not None: + self._ttt_generator.manual_seed(self._ttt_cfg.seed) + self._non_special_tokens_cache = None + self._ttt_tokenizer = None + + def _get_ttt_tokenizer(self) -> EsmTokenizer: + if self._ttt_tokenizer is None: + self._ttt_tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") + return self._ttt_tokenizer + + def _ensure_ttt_ready(self) -> None: + """Lazy TTT initialization. Injects LoRA adapters and saves initial state. + Must be called after weights are loaded (not in __init__).""" + if self._ttt_initialized: + return + self._ttt_initialized = True + + tokenizer = self._get_ttt_tokenizer() + vocab = tokenizer.get_vocab() + self._non_special_tokens_cache = [vocab[c] for c in _ESM_STANDARD_AA if c in vocab] + + if self._ttt_cfg.lora_rank > 0: + self.mlm_head.eval() + for p in self.mlm_head.parameters(): + p.requires_grad = False + # Seed global state before LoRA init for reproducible weight initialization + if self._ttt_cfg.seed is not None: + torch.manual_seed(self._ttt_cfg.seed) + self._inject_lora() + else: + # Legacy path: jointly-trained random linear projection head + H = self.config.hidden_size + V = self.config.vocab_size + device = next(self.esm.parameters()).device + self._ttt_lm_proj = nn.Linear(H, V, bias=True).to(device) + + if self._ttt_cfg.initial_state_reset: + self._ttt_initial_state = self._ttt_get_state() + + @property + def _uses_lora(self) -> bool: + return self._ttt_cfg.lora_rank > 0 + + def _inject_lora(self) -> None: + """Inject LoRA adapters into ESM2 attention layers (matching lora_diffusion behavior).""" + self._lora_params = inject_trainable_lora( + self.esm, + target_class_name=self._ttt_cfg.lora_target_class, + r=self._ttt_cfg.lora_rank, + scale=self._ttt_cfg.lora_alpha, + ) + assert len(self._lora_params) > 0, ( + f"No LoRA params injected. Check target_class_name='{self._ttt_cfg.lora_target_class}' " + f"matches attention modules in the backbone." + ) + + # ---- TTT State Management ---- + + def _get_lora_modules(self) -> List[LoraInjectedLinear]: + """Find all LoraInjectedLinear modules in the backbone.""" + return [m for m in self.esm.modules() if isinstance(m, LoraInjectedLinear)] + + def _ttt_get_state(self) -> Dict[str, Any]: + if self._uses_lora: + lora_state = [] + for m in self._get_lora_modules(): + lora_state.append({ + "down": m.lora_down.weight.data.clone(), + "up": m.lora_up.weight.data.clone(), + }) + return {"_lora_state": lora_state} + return { + "esm": copy.deepcopy(self.esm), + "_ttt_lm_proj": copy.deepcopy(self._ttt_lm_proj), + } + + def _ttt_set_state(self, state: Dict[str, Any]) -> None: + if "_lora_state" in state: + modules = self._get_lora_modules() + assert len(modules) == len(state["_lora_state"]) + for m, saved in zip(modules, state["_lora_state"]): + m.lora_down.weight.data.copy_(saved["down"]) + m.lora_up.weight.data.copy_(saved["up"]) + return + if "esm" in state: + self.esm = copy.deepcopy(state["esm"]) + if "_ttt_lm_proj" in state: + self._ttt_lm_proj = copy.deepcopy(state["_ttt_lm_proj"]) + + def ttt_reset(self) -> None: + """Reset model to pre-TTT state (restore initial LoRA or backbone weights).""" + assert self._ttt_initial_state is not None, "TTT reset requires initial_state_reset=True." + self._ttt_set_state(self._ttt_initial_state) + + # ---- TTT Core ---- + + def _ttt_tokenize(self, seq: str) -> torch.Tensor: + tokenizer = self._get_ttt_tokenizer() + out = tokenizer( + seq, + return_tensors="pt", + add_special_tokens=self._uses_lora, + padding=False, + truncation=False, + ) + return out["input_ids"] + + def _ttt_mask_token(self) -> int: + return self._get_ttt_tokenizer().mask_token_id + + def _ttt_get_non_special_tokens(self) -> List[int]: + if self._non_special_tokens_cache is not None: + return self._non_special_tokens_cache + tokenizer = self._get_ttt_tokenizer() + vocab = tokenizer.get_vocab() + self._non_special_tokens_cache = [vocab[c] for c in _ESM_STANDARD_AA if c in vocab] + return self._non_special_tokens_cache + + def _ttt_predict_logits(self, batch: torch.Tensor) -> torch.Tensor: + """Run ESM2 backbone + LM head to get MLM logits.""" + # Temporarily unfreeze backbone for gradient flow during TTT + output = self.esm(input_ids=batch) + hidden = output.last_hidden_state + if self._uses_lora: + return self.mlm_head(hidden) + return self._ttt_lm_proj(hidden) + + def _ttt_sample_batch( + self, + x: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + _, seq_len = x.shape + batch_size = self._ttt_cfg.batch_size + crop_size = min(self._ttt_cfg.crop_size, seq_len) + + x_expanded = x.expand(batch_size, -1) + if seq_len == crop_size: + start_indices = torch.zeros(batch_size, dtype=torch.long) + else: + start_indices = torch.randint( + 0, seq_len - crop_size + 1, (batch_size,), + generator=self._ttt_generator, + ).to(torch.long) + + batch_cropped = torch.stack([ + x_expanded[index, start : start + crop_size] + for index, start in enumerate(start_indices) + ]) + + non_special_tokens = set(self._ttt_get_non_special_tokens()) + mask = torch.zeros((batch_size, crop_size), dtype=torch.bool) + mask_token_id = self._ttt_mask_token() + + for row_index in range(batch_size): + non_special_positions = [ + col for col in range(crop_size) + if batch_cropped[row_index, col].item() in non_special_tokens + ] + assert len(non_special_positions) > 0, "Sequence must contain at least one non-special token." + num_to_mask = max(1, int(round(len(non_special_positions) * self._ttt_cfg.mask_ratio))) + sampled_indices = torch.randperm( + len(non_special_positions), generator=self._ttt_generator, + )[:num_to_mask] + positions_to_mask = torch.tensor(non_special_positions, dtype=torch.long)[sampled_indices] + mask[row_index, positions_to_mask] = True + + batch_masked = batch_cropped.clone() + for row_index in range(batch_size): + masked_positions = torch.nonzero(mask[row_index], as_tuple=True)[0] + for masked_position in masked_positions: + probability = float(torch.rand(1, generator=self._ttt_generator).item()) + if probability < 1.0 - self._ttt_cfg.bert_leave_prob - self._ttt_cfg.bert_replace_prob: + batch_masked[row_index, masked_position] = mask_token_id + continue + if probability < 1.0 - self._ttt_cfg.bert_leave_prob: + replacement_candidates = self._ttt_get_non_special_tokens() + replacement_index = int(torch.randint( + 0, len(replacement_candidates), (1,), generator=self._ttt_generator, + ).item()) + batch_masked[row_index, masked_position] = replacement_candidates[replacement_index] + + return batch_masked, batch_cropped, mask, start_indices + + def _ttt_cross_entropy_loss( + self, + logits: torch.Tensor, + targets: torch.Tensor, + mask: torch.Tensor, + ) -> torch.Tensor: + assert logits.ndim == 3, "Logits must be [batch, seq, vocab]." + _, _, vocab_size = logits.shape + logits_flat = logits.reshape(-1, vocab_size) + targets_flat = targets.reshape(-1) + mask_flat = mask.reshape(-1) + assert int(mask_flat.sum().item()) > 0, "TTT mask must select at least one token." + loss = F.cross_entropy( + logits_flat[mask_flat], + targets_flat[mask_flat], + reduction="none", + ) + masked_tokens_per_seq = mask.sum(dim=1).tolist() + per_sequence_losses = torch.split(loss, masked_tokens_per_seq) + return torch.stack([sl.mean() for sl in per_sequence_losses]).mean() + + def _ttt_get_optimizer(self, parameters) -> torch.optim.Optimizer: + if self._ttt_cfg.optimizer == "sgd": + return torch.optim.SGD( + parameters, + lr=self._ttt_cfg.lr, + momentum=self._ttt_cfg.momentum, + weight_decay=self._ttt_cfg.weight_decay, + ) + return torch.optim.AdamW( + parameters, + lr=self._ttt_cfg.lr, + weight_decay=self._ttt_cfg.weight_decay, + ) + + def _lora_ttt(self, seq: str) -> Dict[str, List[float]]: + """LoRA TTT: only LoRA adapter weights are trained, mlm_head is frozen.""" + x = self._ttt_tokenize(seq) + device = next(self.parameters()).device + non_blocking = device.type == "cuda" + losses = [] + + if self._ttt_cfg.steps == 0: + return {"losses": losses} + + for parameter in self.parameters(): + parameter.requires_grad = False + for p in self._lora_params: + p.requires_grad = True + optimizer = self._ttt_get_optimizer(self._lora_params) + optimizer.zero_grad(set_to_none=True) + + self.eval() + for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags): + batch_masked, targets, mask, start_indices = self._ttt_sample_batch(x) + batch_masked = batch_masked.to(device, non_blocking=non_blocking) + targets = targets.to(device, non_blocking=non_blocking) + mask = mask.to(device, non_blocking=non_blocking) + + self.train() + logits = self._ttt_predict_logits(batch_masked) + loss = self._ttt_cross_entropy_loss(logits, targets, mask) + loss.backward() + losses.append(float(loss.detach().cpu().item())) + + if (step + 1) % self._ttt_cfg.ags == 0: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + + self.eval() + return {"losses": losses} + + def _legacy_ttt(self, seq: str) -> Dict[str, List[float]]: + """Legacy TTT: full fine-tuning of ESM2 backbone with random linear projection head.""" + x = self._ttt_tokenize(seq) + device = next(self.parameters()).device + non_blocking = device.type == "cuda" + losses = [] + + if self._ttt_cfg.steps == 0: + return {"losses": losses} + + # Full fine-tune: all backbone params trainable + for parameter in self.parameters(): + parameter.requires_grad = False + for parameter in self.esm.parameters(): + parameter.requires_grad = True + if self._ttt_cfg.freeze_embeddings: + for parameter in self.esm.embeddings.parameters(): + parameter.requires_grad = False + for parameter in self._ttt_lm_proj.parameters(): + parameter.requires_grad = True + + trainable_params = filter(lambda p: p.requires_grad, self.parameters()) + optimizer = self._ttt_get_optimizer(trainable_params) + optimizer.zero_grad(set_to_none=True) + + self.eval() + for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags): + batch_masked, targets, mask, start_indices = self._ttt_sample_batch(x) + batch_masked = batch_masked.to(device, non_blocking=non_blocking) + targets = targets.to(device, non_blocking=non_blocking) + mask = mask.to(device, non_blocking=non_blocking) + + self.train() + logits = self._ttt_predict_logits(batch_masked) + loss = self._ttt_cross_entropy_loss(logits, targets, mask) + loss.backward() + losses.append(float(loss.detach().cpu().item())) + + if (step + 1) % self._ttt_cfg.ags == 0: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + + self.eval() + return {"losses": losses} + + @preserve_model_state + def ttt(self, seq: str) -> Dict[str, List[float]]: + """Run test-time training on a single sequence using masked language modeling. + + Adapts the ESM2 backbone (via LoRA or full fine-tuning) to the input sequence + before structure prediction. Call fold_protein(seq, ttt=True) for the full pipeline. + + Args: + seq: Protein sequence (single-letter amino acid codes) + + Returns: + Dict with "losses" key containing per-step MLM loss values + """ + self._ensure_ttt_ready() + # TTT requires fp32 for stable gradient computation. ESMFold typically + # runs the backbone in fp16, but small LoRA updates vanish in half precision. + esm_dtype = next(self.esm.parameters()).dtype + if esm_dtype != torch.float32: + self.esm.float() + self.mlm_head.float() + if self._uses_lora: + result = self._lora_ttt(seq) + else: + result = self._legacy_ttt(seq) + # Restore original dtype (backbone back to fp16 for inference) + if esm_dtype != torch.float32: + self.esm.to(esm_dtype) + self.mlm_head.to(esm_dtype) + return result + + # ---- High-Level API ---- + + def _fold_single(self, sequence: str, return_pdb_string: bool = True) -> Dict[str, Any]: + """Fold a sequence once and return pLDDT, ptm, and optionally PDB string.""" + with torch.no_grad(): + output = self.infer(sequence) + plddt = output["plddt"] + if plddt.dim() >= 2: + mean_plddt = float(plddt.mean(dim=-1).mean().item()) + else: + mean_plddt = float(plddt.mean().item()) + result = { + "plddt": mean_plddt, + "ptm": float(output["ptm"].item()) if "ptm" in output else None, + } + if return_pdb_string: + pdb_strings = self.output_to_pdb(output) + result["pdb_string"] = pdb_strings[0] if isinstance(pdb_strings, list) else pdb_strings + return result + + def fold_protein( + self, + sequence: str, + return_pdb_string: bool = True, + ) -> Dict[str, Any]: + """Fold a protein sequence with test-time training. + + Runs TTT (masked language model adaptation via LoRA) for the configured + number of steps, folding after each optimizer step to track pLDDT. Returns + the structure with the highest pLDDT across all steps (including baseline). + + Args: + sequence: Protein sequence (single-letter amino acid codes) + return_pdb_string: If True, include PDB string in output + + Returns: + Dict with keys: + - plddt: float, best mean pLDDT across all TTT steps + - ptm: float, predicted TM-score from best step + - pdb_string: str (if return_pdb_string=True), PDB from best step + - step_plddts: list[float], pLDDT at each step [baseline, s1, ..., s10] + - best_step: int, which step produced best structure (0=baseline) + """ + self._ensure_ttt_ready() + + # Cast to fp32 for TTT stability + esm_dtype = next(self.esm.parameters()).dtype + if esm_dtype != torch.float32: + self.esm.float() + self.mlm_head.float() + + device = next(self.parameters()).device + non_blocking = device.type == "cuda" + + # Step 0: baseline fold (no TTT adaptation) + best = self._fold_single(sequence, return_pdb_string=return_pdb_string) + step_plddts = [best["plddt"]] + + if self._ttt_cfg.steps > 0: + # Tokenize for masked LM training + x = self._ttt_tokenize(sequence) + + # Freeze all, unfreeze LoRA + for p in self.parameters(): + p.requires_grad = False + if self._uses_lora: + for p in self._lora_params: + p.requires_grad = True + optimizer = self._ttt_get_optimizer(self._lora_params) + else: + for p in self.esm.parameters(): + p.requires_grad = True + if self._ttt_cfg.freeze_embeddings: + for p in self.esm.embeddings.parameters(): + p.requires_grad = False + for p in self._ttt_lm_proj.parameters(): + p.requires_grad = True + trainable = [p for p in self.parameters() if p.requires_grad] + optimizer = self._ttt_get_optimizer(trainable) + optimizer.zero_grad(set_to_none=True) + + self.eval() + for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags): + batch_masked, targets, mask, _start = self._ttt_sample_batch(x) + batch_masked = batch_masked.to(device, non_blocking=non_blocking) + targets = targets.to(device, non_blocking=non_blocking) + mask = mask.to(device, non_blocking=non_blocking) + + self.train() + logits = self._ttt_predict_logits(batch_masked) + loss = self._ttt_cross_entropy_loss(logits, targets, mask) + loss.backward() + + if (step + 1) % self._ttt_cfg.ags == 0: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + + # Fold after this optimizer step + self.eval() + current = self._fold_single(sequence, return_pdb_string=return_pdb_string) + step_plddts.append(current["plddt"]) + if current["plddt"] > best["plddt"]: + best = current + + self.eval() + + # Restore requires_grad + for p in self.parameters(): + p.requires_grad = False + + # Reset LoRA weights for next sequence + self.ttt_reset() + + # Restore dtype + if esm_dtype != torch.float32: + self.esm.to(esm_dtype) + self.mlm_head.to(esm_dtype) + + return { + "plddt": best["plddt"], + "ptm": best["ptm"], + "pdb_string": best.get("pdb_string"), + "step_plddts": step_plddts, + "best_step": step_plddts.index(max(step_plddts)), + }