FastESMFold / modeling_fast_esmfold.py
lhallee's picture
Upload modeling_fast_esmfold.py with huggingface_hub
a53965b verified
from __future__ import annotations
import torch
import torch._inductor.config as inductor_config
import torch._dynamo as dynamo
# Enable TensorFloat32 tensor cores for float32 matmul (Ampere+ GPUs)
# Provides significant speedup with minimal precision loss
torch.set_float32_matmul_precision('high')
# Enable TF32 for matrix multiplications and cuDNN operations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Enable cuDNN autotuner - finds fastest algorithms for your hardware
# Best when input sizes are consistent; may slow down first iterations
torch.backends.cudnn.benchmark = True
# Deterministic operations off for speed (set True if reproducibility needed)
torch.backends.cudnn.deterministic = False
inductor_config.max_autotune_gemm_backends = "ATEN,CUTLASS,FBGEMM"
dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.recompile_limit = 16
"""Shared attention infrastructure for all FastPLMs models.
Contains: AttentionBackend enum, backend resolution, mask creation,
flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
"""
from enum import Enum
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
try:
from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask
except ImportError:
create_block_mask = None
flex_attention = None
BlockMask = None
_compiled_flex_attention = None
def _get_flex_attention_fn():
"""Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set."""
global _compiled_flex_attention
if flex_attention is None:
return None
flex_mod = torch.nn.attention.flex_attention
if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
return flex_attention
if _compiled_flex_attention is None:
_compiled_flex_attention = torch.compile(
flex_attention,
dynamic=False,
)
return _compiled_flex_attention
### Kernels Flash Attention Detection
def _infer_kernels_flash_variant(kernel) -> Optional[str]:
if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
return "flash_attn2"
if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
return "flash_attn3"
return None
def _try_get_kernels_flash():
try:
from kernels import get_kernel
except ImportError:
return None, None
flash_kernel = None
flash_kernel_variant = None
try:
flash_kernel = get_kernel("kernels-community/flash-attn3")
flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel)
assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API."
except Exception:
try:
flash_kernel = get_kernel("kernels-community/flash-attn2")
flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel)
assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API."
except Exception:
flash_kernel = None
flash_kernel_variant = None
return flash_kernel, flash_kernel_variant
_FLASH_KERNELS_LOADED = False
FLASH_KERNEL = None
FLASH_KERNEL_VARIANT = None
def _ensure_flash_kernels_loaded():
global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT
if _FLASH_KERNELS_LOADED:
return
_FLASH_KERNELS_LOADED = True
FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
def _kernels_flash_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
causal: bool = False,
) -> torch.Tensor:
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
if FLASH_KERNEL_VARIANT == "flash_attn2":
return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0]
if FLASH_KERNEL_VARIANT == "flash_attn3":
try:
output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal)
except TypeError:
output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal)
if isinstance(output, tuple):
return output[0]
return output
raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}")
def _kernels_flash_varlen_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_in_batch_q: int,
max_seqlen_in_batch_k: int,
causal: bool = False,
) -> torch.Tensor:
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
if FLASH_KERNEL_VARIANT == "flash_attn2":
return FLASH_KERNEL.varlen_fwd(
q=query_states, k=key_states, v=value_states,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
is_causal=causal,
)[0]
if FLASH_KERNEL_VARIANT == "flash_attn3":
try:
output = FLASH_KERNEL.flash_attn_varlen_func(
q=query_states, k=key_states, v=value_states,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
causal=causal,
)
except TypeError:
output = FLASH_KERNEL.flash_attn_varlen_func(
query_states, key_states, value_states,
cu_seqlens_q, cu_seqlens_k,
max_seqlen_in_batch_q, max_seqlen_in_batch_k,
0.0, None, causal,
)
if isinstance(output, tuple):
return output[0]
return output
raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}")
### Unpad / Pad helpers for varlen flash attention
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices) -> torch.Tensor:
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
return torch.gather(
rearrange(input, "b ... -> b (...)"), 0, indices.unsqueeze(1).expand(-1, second_dim)
).reshape(-1, *other_shape)
@staticmethod
def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]:
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, "b ... -> b (...)")
grad_input = torch.zeros(
[ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype
)
grad_input.scatter_(0, indices.unsqueeze(1).expand(-1, grad_output.shape[1]), grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor:
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
output[indices] = values
return output
@staticmethod
def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]:
(indices,) = ctx.saved_tensors
return grad_output[indices], None, None
index_first_axis = IndexFirstAxis.apply
index_put_first_axis = IndexPutFirstAxis.apply
def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor:
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, "(b s) ... -> b s ...", b=batch)
def _unpad_input(
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask_2d: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
batch_size, seq_len, num_heads, head_dim = query_layer.shape
seqlens = attention_mask_2d.sum(dim=1).int()
cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
max_seqlen = int(seqlens.max().item())
indices = attention_mask_2d.flatten().nonzero(as_tuple=False).flatten()
query_layer = index_first_axis(query_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
key_layer = index_first_axis(key_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
value_layer = index_first_axis(value_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
return query_layer, key_layer, value_layer, indices, (cu_seqlens, cu_seqlens), (max_seqlen, max_seqlen)
def kernels_flash_attention_func(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask_2d: Optional[torch.Tensor] = None,
causal: bool = False,
) -> torch.Tensor:
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
if not causal and attention_mask_2d is not None:
batch_size, q_len = query_states.shape[:2]
(
query_states, key_states, value_states,
indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k),
) = _unpad_input(query_states, key_states, value_states, attention_mask_2d)
attn_output_unpad = _kernels_flash_varlen_forward(
query_states=query_states, key_states=key_states, value_states=value_states,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k,
)
return pad_input(attn_output_unpad, indices_q, batch_size, q_len)
else:
return _kernels_flash_forward(
query_states=query_states, key_states=key_states, value_states=value_states, causal=causal,
)
### Attention Backend Enum & Resolution
class AttentionBackend(Enum):
AUTO = "auto"
KERNELS_FLASH = "kernels_flash"
FLEX = "flex"
SDPA = "sdpa"
VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend)
_BACKEND_CONFIRMED = False
def resolve_attention_backend(requested_backend: str) -> AttentionBackend:
global _BACKEND_CONFIRMED
assert requested_backend in VALID_ATTENTION_BACKENDS, (
f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
)
if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value):
_ensure_flash_kernels_loaded()
if requested_backend == AttentionBackend.AUTO.value:
if FLASH_KERNEL is not None:
resolved = AttentionBackend.KERNELS_FLASH
elif flex_attention is not None:
resolved = AttentionBackend.FLEX
else:
resolved = AttentionBackend.SDPA
elif requested_backend == AttentionBackend.KERNELS_FLASH.value:
assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment."
resolved = AttentionBackend.KERNELS_FLASH
elif requested_backend == AttentionBackend.FLEX.value:
assert flex_attention is not None, "Flex Attention is not available in this environment."
resolved = AttentionBackend.FLEX
elif requested_backend == AttentionBackend.SDPA.value:
resolved = AttentionBackend.SDPA
else:
raise AssertionError(f"Unsupported attention backend: {requested_backend}")
if not _BACKEND_CONFIRMED:
print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'")
_BACKEND_CONFIRMED = True
return resolved
@torch.compiler.disable
def get_attention_mask(
effective_backend: AttentionBackend,
batch_size: int,
seq_len: int,
device: torch.device,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]:
"""Build padding masks once for all encoder layers.
Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
"""
if attention_mask is None:
return None, None, None
attention_mask_2d = attention_mask.bool()
if effective_backend == AttentionBackend.KERNELS_FLASH:
return attention_mask_2d, None, None
if effective_backend == AttentionBackend.FLEX:
assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
valid_lens = attention_mask_2d.sum(dim=-1)
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx])
flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device)
return attention_mask_2d, None, flex_block_mask
# SDPA / manual -- only mask the key dimension so padding query positions attend to
# real keys and produce valid (non-NaN) outputs instead of NaN from softmax(-inf,...,-inf).
attention_mask_4d = attention_mask_2d[:, None, None, :]
return attention_mask_2d, attention_mask_4d, None
"""FastESMFold: Self-contained ESMFold with FastESM2 attention backends + built-in Test-Time Training.
Usage:
from transformers import AutoModel
model = AutoModel.from_pretrained("Synthyra/FastESMFold", trust_remote_code=True).cuda()
# Basic folding
result = model.fold_protein("MKTLLILAVVA...")
print(result["plddt"], result["pdb_string"][:100])
# Folding with TTT (test-time training improves structure prediction)
result = model.fold_protein("MKTLLILAVVA...", ttt=True)
Dependencies: torch, transformers, einops, peft (for LoRA TTT only)
No dependency on: esm (fair-esm), proteinttt, openfold
"""
import copy
from dataclasses import dataclass, field
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
from transformers import EsmTokenizer, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from transformers.models.esm.configuration_esm import EsmConfig
from transformers.models.esm.modeling_esm import (
EsmContactPredictionHead,
EsmEmbeddings,
EsmIntermediate,
EsmLMHead,
EsmOutput,
EsmSelfOutput,
RotaryEmbedding,
)
from transformers.models.esm.modeling_esmfold import EsmForProteinFolding
# =============================================================================
# Output Dataclass
# =============================================================================
@dataclass
class FastEsmEncoderOutput(ModelOutput):
last_hidden_state: Optional[torch.Tensor] = None
hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
attentions: Optional[Tuple[torch.Tensor, ...]] = None
# =============================================================================
# FastESM2 Attention Layers (multi-backend: SDPA, Flash, Flex)
# =============================================================================
class EsmSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type: Optional[str] = None):
super().__init__()
assert config.hidden_size % config.num_attention_heads == 0, (
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.scale = self.attention_head_size**-0.5
self.dropout_prob = config.attention_probs_dropout_prob
self.config = config
self.attn_backend = resolve_attention_backend(config.attn_backend)
self.position_embedding_type = position_embedding_type or config.position_embedding_type
self.rotary_embeddings = None
if self.position_embedding_type == "rotary":
self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask_2d: Optional[torch.Tensor] = None,
attention_mask_4d: Optional[torch.Tensor] = None,
flex_block_mask: Optional[BlockMask] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch_size, seq_length = hidden_states.shape[:-1]
hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
key_BHLD = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_BHLD = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
query_BHLD = query_BHLD * self.scale
if self.position_embedding_type == "rotary":
query_BHLD, key_BHLD = self.rotary_embeddings(query_BHLD, key_BHLD)
attn_output, attn_weights = self._attn(
query_BHLD, key_BHLD, value_BHLD,
attention_mask_2d=attention_mask_2d,
attention_mask_4d=attention_mask_4d,
flex_block_mask=flex_block_mask,
output_attentions=output_attentions,
)
return attn_output, attn_weights
def _attn(
self,
query_BHLD: torch.Tensor,
key_BHLD: torch.Tensor,
value_BHLD: torch.Tensor,
attention_mask_2d: Optional[torch.Tensor] = None,
attention_mask_4d: Optional[torch.Tensor] = None,
flex_block_mask: Optional[BlockMask] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if output_attentions:
return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d)
if self.attn_backend == AttentionBackend.KERNELS_FLASH:
return self._kernels_flash_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_2d)
elif self.attn_backend == AttentionBackend.FLEX:
return self._flex_attn(query_BHLD, key_BHLD, value_BHLD, flex_block_mask)
elif self.attn_backend == AttentionBackend.SDPA:
return self._sdpa_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d)
else:
raise AssertionError(f"Unsupported resolved backend: {self.attn_backend}")
def _manual_attn(
self,
query_BHLD: torch.Tensor,
key_BHLD: torch.Tensor,
value_BHLD: torch.Tensor,
attention_mask_4d: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2))
if attention_mask_4d is not None:
attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
attn_weights = F.softmax(attn_weights, dim=-1)
if self.dropout_prob > 0 and self.training:
attn_weights = F.dropout(attn_weights, p=self.dropout_prob, training=self.training)
context_BHLD = torch.matmul(attn_weights, value_BHLD)
attn_output = rearrange(context_BHLD, "b h s d -> b s (h d)")
return attn_output, attn_weights
def _kernels_flash_attn(
self,
query_BHLD: torch.Tensor,
key_BHLD: torch.Tensor,
value_BHLD: torch.Tensor,
attention_mask_2d: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, None]:
query_BLHD = query_BHLD.transpose(1, 2).contiguous()
key_BLHD = key_BHLD.transpose(1, 2).contiguous()
value_BLHD = value_BHLD.transpose(1, 2).contiguous()
attn_output = kernels_flash_attention_func(
query_states=query_BLHD, key_states=key_BLHD, value_states=value_BLHD,
attention_mask_2d=attention_mask_2d, causal=False,
)
return rearrange(attn_output, "b s h d -> b s (h d)"), None
def _flex_attn(
self,
query_BHLD: torch.Tensor,
key_BHLD: torch.Tensor,
value_BHLD: torch.Tensor,
flex_block_mask: Optional[BlockMask] = None,
) -> Tuple[torch.Tensor, None]:
assert flex_attention is not None, "Flex attention is not available in this environment."
fn = _get_flex_attention_fn()
context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
def _sdpa_attn(
self,
query_BHLD: torch.Tensor,
key_BHLD: torch.Tensor,
value_BHLD: torch.Tensor,
attention_mask_4d: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, None]:
context_BHLD = F.scaled_dot_product_attention(
query_BHLD, key_BHLD, value_BHLD,
attn_mask=attention_mask_4d,
dropout_p=self.dropout_prob if self.training else 0.0,
scale=1.0,
)
return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
class EsmAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = EsmSelfAttention(config)
self.output = EsmSelfOutput(config)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask_2d: Optional[torch.Tensor] = None,
attention_mask_4d: Optional[torch.Tensor] = None,
flex_block_mask: Optional[BlockMask] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states_ln = self.LayerNorm(hidden_states)
attn_output, attn_weights = self.self(
hidden_states_ln,
attention_mask_2d=attention_mask_2d,
attention_mask_4d=attention_mask_4d,
flex_block_mask=flex_block_mask,
output_attentions=output_attentions,
)
attention_output = self.output(attn_output, hidden_states)
return attention_output, attn_weights
class EsmLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = EsmAttention(config)
self.intermediate = EsmIntermediate(config)
self.output = EsmOutput(config)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask_2d: Optional[torch.Tensor] = None,
attention_mask_4d: Optional[torch.Tensor] = None,
flex_block_mask: Optional[BlockMask] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
attention_output, attn_weights = self.attention(
hidden_states,
attention_mask_2d=attention_mask_2d,
attention_mask_4d=attention_mask_4d,
flex_block_mask=flex_block_mask,
output_attentions=output_attentions,
)
layer_output = self._feed_forward(attention_output)
return layer_output, attn_weights
def _feed_forward(self, attention_output: torch.Tensor) -> torch.Tensor:
attention_output_ln = self.LayerNorm(attention_output)
intermediate_output = self.intermediate(attention_output_ln)
return self.output(intermediate_output, attention_output)
class FastEsmEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.attention_backend = resolve_attention_backend(config.attn_backend)
self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_hidden_states: bool = False,
output_attentions: bool = False,
) -> FastEsmEncoderOutput:
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
attention_mask_2d, attention_mask_4d, flex_block_mask = get_attention_mask(
effective_backend=self.attention_backend,
batch_size=hidden_states.shape[0],
seq_len=hidden_states.shape[1],
device=hidden_states.device,
attention_mask=attention_mask,
)
for layer_module in self.layer:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states, attn_weights = layer_module(
hidden_states,
attention_mask_2d=attention_mask_2d,
attention_mask_4d=attention_mask_4d,
flex_block_mask=flex_block_mask,
output_attentions=output_attentions,
)
if all_attentions is not None:
all_attentions = all_attentions + (attn_weights,)
if self.emb_layer_norm_after:
hidden_states = self.emb_layer_norm_after(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
return FastEsmEncoderOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
# =============================================================================
# FastESM Backbone (replaces EsmModel inside ESMFold)
# =============================================================================
class FastEsmBackbone(nn.Module):
"""FastESM2 backbone with multi-backend attention. Drop-in replacement for
transformers.EsmModel inside EsmForProteinFolding.
State dict keys match HuggingFace EsmModel exactly, so pretrained weights
load without any key remapping.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.embeddings = EsmEmbeddings(config)
self.encoder = FastEsmEncoder(config)
self.contact_head = EsmContactPredictionHead(
in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> FastEsmEncoderOutput:
output_attentions = output_attentions if output_attentions is not None else False
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
token_embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
)
encoder_outputs = self.encoder(
token_embedding_output,
attention_mask=attention_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
)
return FastEsmEncoderOutput(
last_hidden_state=encoder_outputs.last_hidden_state,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
# =============================================================================
# TTT (Test-Time Training) Configuration and Utilities
# =============================================================================
_ESM_STANDARD_AA = list("ACDEFGHIKLMNPQRSTVWY")
class LoraInjectedLinear(nn.Module):
"""LoRA-augmented linear layer matching lora_diffusion's behavior.
Replaces an existing nn.Linear with base(x) + lora_up(lora_down(x)) * scale.
Initialization follows cloneofsimo/lora: down=Normal(0, 1/r), up=zeros.
"""
def __init__(self, original_linear: nn.Linear, r: int = 4, scale: float = 1.0):
super().__init__()
self.linear = original_linear
in_features = original_linear.in_features
out_features = original_linear.out_features
assert r <= min(in_features, out_features), f"LoRA rank {r} exceeds dimensions ({in_features}, {out_features})"
self.lora_down = nn.Linear(in_features, r, bias=False)
self.lora_up = nn.Linear(r, out_features, bias=False)
self.scale = scale
nn.init.normal_(self.lora_down.weight, std=1.0 / r)
nn.init.zeros_(self.lora_up.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x) + self.lora_up(self.lora_down(x)) * self.scale
def inject_trainable_lora(
model: nn.Module,
target_class_name: str,
r: int,
scale: float,
) -> List[nn.Parameter]:
"""Replace nn.Linear layers inside modules matching target_class_name with LoRA.
Matches lora_diffusion's inject_trainable_lora behavior: finds all modules whose
class name matches target_class_name, then replaces their nn.Linear children with
LoraInjectedLinear. Returns the list of trainable LoRA parameters.
"""
lora_params: List[nn.Parameter] = []
for _parent_name, parent_module in model.named_modules():
if parent_module.__class__.__name__ != target_class_name:
continue
for child_name, child_module in list(parent_module.named_children()):
if not isinstance(child_module, nn.Linear):
continue
lora_linear = LoraInjectedLinear(child_module, r=r, scale=scale)
lora_linear = lora_linear.to(
device=child_module.weight.device,
dtype=child_module.weight.dtype,
)
setattr(parent_module, child_name, lora_linear)
lora_params.extend(lora_linear.lora_down.parameters())
lora_params.extend(lora_linear.lora_up.parameters())
return lora_params
@dataclass
class TTTConfig:
lr: float = 4e-4
ags: int = 4
steps: int = 10
batch_size: int = 4
mask_ratio: float = 0.15
crop_size: int = 1024
bert_leave_prob: float = 0.1
bert_replace_prob: float = 0.1
optimizer: str = "sgd"
momentum: float = 0.0
weight_decay: float = 0.0
seed: Optional[int] = 0
initial_state_reset: bool = True
freeze_embeddings: bool = True
lora_rank: int = 8
lora_alpha: float = 32.0
lora_target_class: str = "EsmSelfAttention"
def verify(self) -> None:
assert self.lr > 0.0, "TTT learning rate must be positive."
assert self.ags > 0, "TTT ags must be positive."
assert self.steps >= 0, "TTT steps must be non-negative."
assert self.batch_size > 0, "TTT batch_size must be positive."
assert 0.0 < self.mask_ratio <= 1.0, "TTT mask_ratio must be in (0, 1]."
assert self.crop_size > 0, "TTT crop_size must be positive."
assert 0.0 <= self.bert_leave_prob <= 1.0
assert 0.0 <= self.bert_replace_prob <= 1.0
assert self.bert_leave_prob + self.bert_replace_prob <= 1.0
assert self.optimizer in {"sgd", "adamw"}
assert self.lora_rank >= 0
assert self.lora_alpha > 0.0
def preserve_model_state(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
was_training = self.training
original_device = next(self.parameters()).device
original_requires_grad = {
name: parameter.requires_grad
for name, parameter in self.named_parameters()
}
try:
return func(self, *args, **kwargs)
finally:
self.train(was_training)
self.to(original_device)
for name, parameter in self.named_parameters():
if name in original_requires_grad:
parameter.requires_grad = original_requires_grad[name]
else:
parameter.requires_grad = False
return wrapper
# =============================================================================
# FastEsmFoldConfig
# =============================================================================
class FastEsmFoldConfig(EsmConfig):
model_type = "fast_esmfold"
def __init__(self, attn_backend: str = "sdpa", ttt_config: Optional[Dict[str, Any]] = None, **kwargs):
super().__init__(**kwargs)
self.attn_backend = attn_backend
self.ttt_config = ttt_config or {
"lr": 4e-4,
"steps": 10,
"lora_rank": 8,
"lora_alpha": 32.0,
}
# =============================================================================
# FastEsmForProteinFolding
# =============================================================================
class FastEsmForProteinFolding(EsmForProteinFolding):
"""ESMFold with FastESM2 attention backends + built-in Test-Time Training.
Inherits all folding logic (trunk, structure module, output_to_pdb, infer)
from transformers.EsmForProteinFolding. Replaces the ESM2 backbone with
FastESM2 for optimized attention and adds TTT for improved structure prediction.
Key API:
result = model.fold_protein("MKTL...", ttt=True)
# result = {"plddt": float, "ptm": float, "pdb_string": str}
"""
config_class = FastEsmFoldConfig
def __init__(self, config: FastEsmFoldConfig):
super().__init__(config)
# Replace standard ESM2 backbone with FastESM2 (multi-backend attention)
# unless use_standard_backbone is set (for TTT debugging/compatibility)
if not config.ttt_config.get("use_standard_backbone", False):
self.esm = FastEsmBackbone(config)
self.esm.requires_grad_(False)
if config.esmfold_config.fp16_esm:
self.esm.half()
# MLM head for TTT (pretrained EsmLMHead: Dense -> GELU -> LN -> Linear)
self.mlm_head = EsmLMHead(config)
# TTT state (lazy initialization)
ttt_kwargs = {k: v for k, v in config.ttt_config.items() if k != "use_standard_backbone"}
self._ttt_cfg = TTTConfig(**ttt_kwargs)
self._ttt_cfg.verify()
self._ttt_initialized = False
self._ttt_initial_state = None
self._ttt_generator = torch.Generator()
if self._ttt_cfg.seed is not None:
self._ttt_generator.manual_seed(self._ttt_cfg.seed)
self._non_special_tokens_cache = None
self._ttt_tokenizer = None
def _get_ttt_tokenizer(self) -> EsmTokenizer:
if self._ttt_tokenizer is None:
self._ttt_tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
return self._ttt_tokenizer
def _ensure_ttt_ready(self) -> None:
"""Lazy TTT initialization. Injects LoRA adapters and saves initial state.
Must be called after weights are loaded (not in __init__)."""
if self._ttt_initialized:
return
self._ttt_initialized = True
tokenizer = self._get_ttt_tokenizer()
vocab = tokenizer.get_vocab()
self._non_special_tokens_cache = [vocab[c] for c in _ESM_STANDARD_AA if c in vocab]
if self._ttt_cfg.lora_rank > 0:
self.mlm_head.eval()
for p in self.mlm_head.parameters():
p.requires_grad = False
# Seed global state before LoRA init for reproducible weight initialization
if self._ttt_cfg.seed is not None:
torch.manual_seed(self._ttt_cfg.seed)
self._inject_lora()
else:
# Legacy path: jointly-trained random linear projection head
H = self.config.hidden_size
V = self.config.vocab_size
device = next(self.esm.parameters()).device
self._ttt_lm_proj = nn.Linear(H, V, bias=True).to(device)
if self._ttt_cfg.initial_state_reset:
self._ttt_initial_state = self._ttt_get_state()
@property
def _uses_lora(self) -> bool:
return self._ttt_cfg.lora_rank > 0
def _inject_lora(self) -> None:
"""Inject LoRA adapters into ESM2 attention layers (matching lora_diffusion behavior)."""
self._lora_params = inject_trainable_lora(
self.esm,
target_class_name=self._ttt_cfg.lora_target_class,
r=self._ttt_cfg.lora_rank,
scale=self._ttt_cfg.lora_alpha,
)
assert len(self._lora_params) > 0, (
f"No LoRA params injected. Check target_class_name='{self._ttt_cfg.lora_target_class}' "
f"matches attention modules in the backbone."
)
# ---- TTT State Management ----
def _get_lora_modules(self) -> List[LoraInjectedLinear]:
"""Find all LoraInjectedLinear modules in the backbone."""
return [m for m in self.esm.modules() if isinstance(m, LoraInjectedLinear)]
def _ttt_get_state(self) -> Dict[str, Any]:
if self._uses_lora:
lora_state = []
for m in self._get_lora_modules():
lora_state.append({
"down": m.lora_down.weight.data.clone(),
"up": m.lora_up.weight.data.clone(),
})
return {"_lora_state": lora_state}
return {
"esm": copy.deepcopy(self.esm),
"_ttt_lm_proj": copy.deepcopy(self._ttt_lm_proj),
}
def _ttt_set_state(self, state: Dict[str, Any]) -> None:
if "_lora_state" in state:
modules = self._get_lora_modules()
assert len(modules) == len(state["_lora_state"])
for m, saved in zip(modules, state["_lora_state"]):
m.lora_down.weight.data.copy_(saved["down"])
m.lora_up.weight.data.copy_(saved["up"])
return
if "esm" in state:
self.esm = copy.deepcopy(state["esm"])
if "_ttt_lm_proj" in state:
self._ttt_lm_proj = copy.deepcopy(state["_ttt_lm_proj"])
def ttt_reset(self) -> None:
"""Reset model to pre-TTT state (restore initial LoRA or backbone weights)."""
assert self._ttt_initial_state is not None, "TTT reset requires initial_state_reset=True."
self._ttt_set_state(self._ttt_initial_state)
# ---- TTT Core ----
def _ttt_tokenize(self, seq: str) -> torch.Tensor:
tokenizer = self._get_ttt_tokenizer()
out = tokenizer(
seq,
return_tensors="pt",
add_special_tokens=self._uses_lora,
padding=False,
truncation=False,
)
return out["input_ids"]
def _ttt_mask_token(self) -> int:
return self._get_ttt_tokenizer().mask_token_id
def _ttt_get_non_special_tokens(self) -> List[int]:
if self._non_special_tokens_cache is not None:
return self._non_special_tokens_cache
tokenizer = self._get_ttt_tokenizer()
vocab = tokenizer.get_vocab()
self._non_special_tokens_cache = [vocab[c] for c in _ESM_STANDARD_AA if c in vocab]
return self._non_special_tokens_cache
def _ttt_predict_logits(self, batch: torch.Tensor) -> torch.Tensor:
"""Run ESM2 backbone + LM head to get MLM logits."""
# Temporarily unfreeze backbone for gradient flow during TTT
output = self.esm(input_ids=batch)
hidden = output.last_hidden_state
if self._uses_lora:
return self.mlm_head(hidden)
return self._ttt_lm_proj(hidden)
def _ttt_sample_batch(
self,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
_, seq_len = x.shape
batch_size = self._ttt_cfg.batch_size
crop_size = min(self._ttt_cfg.crop_size, seq_len)
x_expanded = x.expand(batch_size, -1)
if seq_len == crop_size:
start_indices = torch.zeros(batch_size, dtype=torch.long)
else:
start_indices = torch.randint(
0, seq_len - crop_size + 1, (batch_size,),
generator=self._ttt_generator,
).to(torch.long)
batch_cropped = torch.stack([
x_expanded[index, start : start + crop_size]
for index, start in enumerate(start_indices)
])
non_special_tokens = set(self._ttt_get_non_special_tokens())
mask = torch.zeros((batch_size, crop_size), dtype=torch.bool)
mask_token_id = self._ttt_mask_token()
for row_index in range(batch_size):
non_special_positions = [
col for col in range(crop_size)
if batch_cropped[row_index, col].item() in non_special_tokens
]
assert len(non_special_positions) > 0, "Sequence must contain at least one non-special token."
num_to_mask = max(1, int(round(len(non_special_positions) * self._ttt_cfg.mask_ratio)))
sampled_indices = torch.randperm(
len(non_special_positions), generator=self._ttt_generator,
)[:num_to_mask]
positions_to_mask = torch.tensor(non_special_positions, dtype=torch.long)[sampled_indices]
mask[row_index, positions_to_mask] = True
batch_masked = batch_cropped.clone()
for row_index in range(batch_size):
masked_positions = torch.nonzero(mask[row_index], as_tuple=True)[0]
for masked_position in masked_positions:
probability = float(torch.rand(1, generator=self._ttt_generator).item())
if probability < 1.0 - self._ttt_cfg.bert_leave_prob - self._ttt_cfg.bert_replace_prob:
batch_masked[row_index, masked_position] = mask_token_id
continue
if probability < 1.0 - self._ttt_cfg.bert_leave_prob:
replacement_candidates = self._ttt_get_non_special_tokens()
replacement_index = int(torch.randint(
0, len(replacement_candidates), (1,), generator=self._ttt_generator,
).item())
batch_masked[row_index, masked_position] = replacement_candidates[replacement_index]
return batch_masked, batch_cropped, mask, start_indices
def _ttt_cross_entropy_loss(
self,
logits: torch.Tensor,
targets: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor:
assert logits.ndim == 3, "Logits must be [batch, seq, vocab]."
_, _, vocab_size = logits.shape
logits_flat = logits.reshape(-1, vocab_size)
targets_flat = targets.reshape(-1)
mask_flat = mask.reshape(-1)
assert int(mask_flat.sum().item()) > 0, "TTT mask must select at least one token."
loss = F.cross_entropy(
logits_flat[mask_flat],
targets_flat[mask_flat],
reduction="none",
)
masked_tokens_per_seq = mask.sum(dim=1).tolist()
per_sequence_losses = torch.split(loss, masked_tokens_per_seq)
return torch.stack([sl.mean() for sl in per_sequence_losses]).mean()
def _ttt_get_optimizer(self, parameters) -> torch.optim.Optimizer:
if self._ttt_cfg.optimizer == "sgd":
return torch.optim.SGD(
parameters,
lr=self._ttt_cfg.lr,
momentum=self._ttt_cfg.momentum,
weight_decay=self._ttt_cfg.weight_decay,
)
return torch.optim.AdamW(
parameters,
lr=self._ttt_cfg.lr,
weight_decay=self._ttt_cfg.weight_decay,
)
def _lora_ttt(self, seq: str) -> Dict[str, List[float]]:
"""LoRA TTT: only LoRA adapter weights are trained, mlm_head is frozen."""
x = self._ttt_tokenize(seq)
device = next(self.parameters()).device
non_blocking = device.type == "cuda"
losses = []
if self._ttt_cfg.steps == 0:
return {"losses": losses}
for parameter in self.parameters():
parameter.requires_grad = False
for p in self._lora_params:
p.requires_grad = True
optimizer = self._ttt_get_optimizer(self._lora_params)
optimizer.zero_grad(set_to_none=True)
self.eval()
for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags):
batch_masked, targets, mask, start_indices = self._ttt_sample_batch(x)
batch_masked = batch_masked.to(device, non_blocking=non_blocking)
targets = targets.to(device, non_blocking=non_blocking)
mask = mask.to(device, non_blocking=non_blocking)
self.train()
logits = self._ttt_predict_logits(batch_masked)
loss = self._ttt_cross_entropy_loss(logits, targets, mask)
loss.backward()
losses.append(float(loss.detach().cpu().item()))
if (step + 1) % self._ttt_cfg.ags == 0:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
self.eval()
return {"losses": losses}
def _legacy_ttt(self, seq: str) -> Dict[str, List[float]]:
"""Legacy TTT: full fine-tuning of ESM2 backbone with random linear projection head."""
x = self._ttt_tokenize(seq)
device = next(self.parameters()).device
non_blocking = device.type == "cuda"
losses = []
if self._ttt_cfg.steps == 0:
return {"losses": losses}
# Full fine-tune: all backbone params trainable
for parameter in self.parameters():
parameter.requires_grad = False
for parameter in self.esm.parameters():
parameter.requires_grad = True
if self._ttt_cfg.freeze_embeddings:
for parameter in self.esm.embeddings.parameters():
parameter.requires_grad = False
for parameter in self._ttt_lm_proj.parameters():
parameter.requires_grad = True
trainable_params = filter(lambda p: p.requires_grad, self.parameters())
optimizer = self._ttt_get_optimizer(trainable_params)
optimizer.zero_grad(set_to_none=True)
self.eval()
for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags):
batch_masked, targets, mask, start_indices = self._ttt_sample_batch(x)
batch_masked = batch_masked.to(device, non_blocking=non_blocking)
targets = targets.to(device, non_blocking=non_blocking)
mask = mask.to(device, non_blocking=non_blocking)
self.train()
logits = self._ttt_predict_logits(batch_masked)
loss = self._ttt_cross_entropy_loss(logits, targets, mask)
loss.backward()
losses.append(float(loss.detach().cpu().item()))
if (step + 1) % self._ttt_cfg.ags == 0:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
self.eval()
return {"losses": losses}
@preserve_model_state
def ttt(self, seq: str) -> Dict[str, List[float]]:
"""Run test-time training on a single sequence using masked language modeling.
Adapts the ESM2 backbone (via LoRA or full fine-tuning) to the input sequence
before structure prediction. Call fold_protein(seq, ttt=True) for the full pipeline.
Args:
seq: Protein sequence (single-letter amino acid codes)
Returns:
Dict with "losses" key containing per-step MLM loss values
"""
self._ensure_ttt_ready()
# TTT requires fp32 for stable gradient computation. ESMFold typically
# runs the backbone in fp16, but small LoRA updates vanish in half precision.
esm_dtype = next(self.esm.parameters()).dtype
if esm_dtype != torch.float32:
self.esm.float()
self.mlm_head.float()
if self._uses_lora:
result = self._lora_ttt(seq)
else:
result = self._legacy_ttt(seq)
# Restore original dtype (backbone back to fp16 for inference)
if esm_dtype != torch.float32:
self.esm.to(esm_dtype)
self.mlm_head.to(esm_dtype)
return result
# ---- High-Level API ----
def _fold_single(self, sequence: str, return_pdb_string: bool = True) -> Dict[str, Any]:
"""Fold a sequence once and return pLDDT, ptm, and optionally PDB string."""
with torch.no_grad():
output = self.infer(sequence)
plddt = output["plddt"]
# plddt shape is (batch, L, 37) - per-atom across atom37 types.
# Use CA atom (index 1) only, matching PDB B-factor output.
if plddt.dim() == 3:
mean_plddt = float(plddt[:, :, 1].mean().item())
elif plddt.dim() == 2:
mean_plddt = float(plddt[:, 1].mean().item())
else:
mean_plddt = float(plddt.mean().item())
result = {
"plddt": mean_plddt,
"ptm": float(output["ptm"].item()) if "ptm" in output else None,
}
if return_pdb_string:
pdb_strings = self.output_to_pdb(output)
result["pdb_string"] = pdb_strings[0] if isinstance(pdb_strings, list) else pdb_strings
return result
def fold_protein(
self,
sequence: str,
return_pdb_string: bool = True,
) -> Dict[str, Any]:
"""Fold a protein sequence with test-time training.
Runs TTT (masked language model adaptation via LoRA) for the configured
number of steps, folding after each optimizer step to track pLDDT. Returns
the structure with the highest pLDDT across all steps (including baseline).
Args:
sequence: Protein sequence (single-letter amino acid codes)
return_pdb_string: If True, include PDB string in output
Returns:
Dict with keys:
- plddt: float, best mean pLDDT across all TTT steps
- ptm: float, predicted TM-score from best step
- pdb_string: str (if return_pdb_string=True), PDB from best step
- step_plddts: list[float], pLDDT at each step [baseline, s1, ..., s10]
- best_step: int, which step produced best structure (0=baseline)
"""
self._ensure_ttt_ready()
# Cast to fp32 for TTT stability
esm_dtype = next(self.esm.parameters()).dtype
if esm_dtype != torch.float32:
self.esm.float()
self.mlm_head.float()
device = next(self.parameters()).device
non_blocking = device.type == "cuda"
# Step 0: baseline fold (no TTT adaptation)
best = self._fold_single(sequence, return_pdb_string=return_pdb_string)
step_plddts = [best["plddt"]]
if self._ttt_cfg.steps > 0:
# Tokenize for masked LM training
x = self._ttt_tokenize(sequence)
# Freeze all, unfreeze LoRA
for p in self.parameters():
p.requires_grad = False
if self._uses_lora:
for p in self._lora_params:
p.requires_grad = True
optimizer = self._ttt_get_optimizer(self._lora_params)
else:
for p in self.esm.parameters():
p.requires_grad = True
if self._ttt_cfg.freeze_embeddings:
for p in self.esm.embeddings.parameters():
p.requires_grad = False
for p in self._ttt_lm_proj.parameters():
p.requires_grad = True
trainable = [p for p in self.parameters() if p.requires_grad]
optimizer = self._ttt_get_optimizer(trainable)
optimizer.zero_grad(set_to_none=True)
self.eval()
for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags):
batch_masked, targets, mask, _start = self._ttt_sample_batch(x)
batch_masked = batch_masked.to(device, non_blocking=non_blocking)
targets = targets.to(device, non_blocking=non_blocking)
mask = mask.to(device, non_blocking=non_blocking)
self.train()
logits = self._ttt_predict_logits(batch_masked)
loss = self._ttt_cross_entropy_loss(logits, targets, mask)
loss.backward()
if (step + 1) % self._ttt_cfg.ags == 0:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
# Fold after this optimizer step
self.eval()
current = self._fold_single(sequence, return_pdb_string=return_pdb_string)
step_plddts.append(current["plddt"])
if current["plddt"] > best["plddt"]:
best = current
self.eval()
# Restore requires_grad
for p in self.parameters():
p.requires_grad = False
# Reset LoRA weights for next sequence
self.ttt_reset()
# Restore dtype
if esm_dtype != torch.float32:
self.esm.to(esm_dtype)
self.mlm_head.to(esm_dtype)
return {
"plddt": best["plddt"],
"ptm": best["ptm"],
"pdb_string": best.get("pdb_string"),
"step_plddts": step_plddts,
"best_step": step_plddts.index(max(step_plddts)),
}