| from __future__ import annotations
|
|
|
| import torch
|
| import torch._inductor.config as inductor_config
|
| import torch._dynamo as dynamo
|
|
|
|
|
|
|
| torch.set_float32_matmul_precision('high')
|
|
|
|
|
| torch.backends.cuda.matmul.allow_tf32 = True
|
| torch.backends.cudnn.allow_tf32 = True
|
|
|
|
|
|
|
| torch.backends.cudnn.benchmark = True
|
|
|
|
|
| torch.backends.cudnn.deterministic = False
|
| inductor_config.max_autotune_gemm_backends = "ATEN,CUTLASS,FBGEMM"
|
|
|
| dynamo.config.capture_scalar_outputs = True
|
| torch._dynamo.config.recompile_limit = 16
|
|
|
| """Shared attention infrastructure for all FastPLMs models.
|
|
|
| Contains: AttentionBackend enum, backend resolution, mask creation,
|
| flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
|
| """
|
| from enum import Enum
|
| from typing import Dict, List, Optional, Tuple
|
|
|
| import torch
|
| import torch.nn as nn
|
| from torch.nn import functional as F
|
| from einops import rearrange
|
|
|
| try:
|
| from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask
|
| except ImportError:
|
| create_block_mask = None
|
| flex_attention = None
|
| BlockMask = None
|
|
|
| _compiled_flex_attention = None
|
|
|
|
|
| def _get_flex_attention_fn():
|
| """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set."""
|
| global _compiled_flex_attention
|
| if flex_attention is None:
|
| return None
|
| flex_mod = torch.nn.attention.flex_attention
|
| if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
|
| return flex_attention
|
| if _compiled_flex_attention is None:
|
| _compiled_flex_attention = torch.compile(
|
| flex_attention,
|
| dynamic=False,
|
| )
|
| return _compiled_flex_attention
|
|
|
|
|
|
|
| def _infer_kernels_flash_variant(kernel) -> Optional[str]:
|
| if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
|
| return "flash_attn2"
|
| if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
|
| return "flash_attn3"
|
| return None
|
|
|
|
|
| def _try_get_kernels_flash():
|
| try:
|
| from kernels import get_kernel
|
| except ImportError:
|
| return None, None
|
|
|
| flash_kernel = None
|
| flash_kernel_variant = None
|
| try:
|
| flash_kernel = get_kernel("kernels-community/flash-attn3")
|
| flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel)
|
| assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API."
|
| except Exception:
|
| try:
|
| flash_kernel = get_kernel("kernels-community/flash-attn2")
|
| flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel)
|
| assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API."
|
| except Exception:
|
| flash_kernel = None
|
| flash_kernel_variant = None
|
| return flash_kernel, flash_kernel_variant
|
|
|
|
|
| _FLASH_KERNELS_LOADED = False
|
| FLASH_KERNEL = None
|
| FLASH_KERNEL_VARIANT = None
|
|
|
|
|
| def _ensure_flash_kernels_loaded():
|
| global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT
|
| if _FLASH_KERNELS_LOADED:
|
| return
|
| _FLASH_KERNELS_LOADED = True
|
| FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
|
|
|
|
|
| def _kernels_flash_forward(
|
| query_states: torch.Tensor,
|
| key_states: torch.Tensor,
|
| value_states: torch.Tensor,
|
| causal: bool = False,
|
| ) -> torch.Tensor:
|
| assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
| if FLASH_KERNEL_VARIANT == "flash_attn2":
|
| return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0]
|
| if FLASH_KERNEL_VARIANT == "flash_attn3":
|
| try:
|
| output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal)
|
| except TypeError:
|
| output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal)
|
| if isinstance(output, tuple):
|
| return output[0]
|
| return output
|
| raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}")
|
|
|
|
|
| def _kernels_flash_varlen_forward(
|
| query_states: torch.Tensor,
|
| key_states: torch.Tensor,
|
| value_states: torch.Tensor,
|
| cu_seqlens_q: torch.Tensor,
|
| cu_seqlens_k: torch.Tensor,
|
| max_seqlen_in_batch_q: int,
|
| max_seqlen_in_batch_k: int,
|
| causal: bool = False,
|
| ) -> torch.Tensor:
|
| assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
| if FLASH_KERNEL_VARIANT == "flash_attn2":
|
| return FLASH_KERNEL.varlen_fwd(
|
| q=query_states, k=key_states, v=value_states,
|
| cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
| max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
|
| is_causal=causal,
|
| )[0]
|
| if FLASH_KERNEL_VARIANT == "flash_attn3":
|
| try:
|
| output = FLASH_KERNEL.flash_attn_varlen_func(
|
| q=query_states, k=key_states, v=value_states,
|
| cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
| max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
|
| causal=causal,
|
| )
|
| except TypeError:
|
| output = FLASH_KERNEL.flash_attn_varlen_func(
|
| query_states, key_states, value_states,
|
| cu_seqlens_q, cu_seqlens_k,
|
| max_seqlen_in_batch_q, max_seqlen_in_batch_k,
|
| 0.0, None, causal,
|
| )
|
| if isinstance(output, tuple):
|
| return output[0]
|
| return output
|
| raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}")
|
|
|
|
|
|
|
| class IndexFirstAxis(torch.autograd.Function):
|
| @staticmethod
|
| def forward(ctx, input, indices) -> torch.Tensor:
|
| ctx.save_for_backward(indices)
|
| assert input.ndim >= 2
|
| ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
| second_dim = other_shape.numel()
|
| return torch.gather(
|
| rearrange(input, "b ... -> b (...)"), 0, indices.unsqueeze(1).expand(-1, second_dim)
|
| ).reshape(-1, *other_shape)
|
|
|
| @staticmethod
|
| def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]:
|
| (indices,) = ctx.saved_tensors
|
| assert grad_output.ndim >= 2
|
| other_shape = grad_output.shape[1:]
|
| grad_output = rearrange(grad_output, "b ... -> b (...)")
|
| grad_input = torch.zeros(
|
| [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype
|
| )
|
| grad_input.scatter_(0, indices.unsqueeze(1).expand(-1, grad_output.shape[1]), grad_output)
|
| return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
|
|
|
|
| class IndexPutFirstAxis(torch.autograd.Function):
|
| @staticmethod
|
| def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor:
|
| ctx.save_for_backward(indices)
|
| assert indices.ndim == 1
|
| assert values.ndim >= 2
|
| output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
|
| output[indices] = values
|
| return output
|
|
|
| @staticmethod
|
| def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]:
|
| (indices,) = ctx.saved_tensors
|
| return grad_output[indices], None, None
|
|
|
|
|
| index_first_axis = IndexFirstAxis.apply
|
| index_put_first_axis = IndexPutFirstAxis.apply
|
|
|
|
|
| def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor:
|
| output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
| return rearrange(output, "(b s) ... -> b s ...", b=batch)
|
|
|
|
|
| def _unpad_input(
|
| query_layer: torch.Tensor,
|
| key_layer: torch.Tensor,
|
| value_layer: torch.Tensor,
|
| attention_mask_2d: torch.Tensor,
|
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
|
| batch_size, seq_len, num_heads, head_dim = query_layer.shape
|
| seqlens = attention_mask_2d.sum(dim=1).int()
|
| cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
|
| max_seqlen = int(seqlens.max().item())
|
| indices = attention_mask_2d.flatten().nonzero(as_tuple=False).flatten()
|
| query_layer = index_first_axis(query_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
|
| key_layer = index_first_axis(key_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
|
| value_layer = index_first_axis(value_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
|
| return query_layer, key_layer, value_layer, indices, (cu_seqlens, cu_seqlens), (max_seqlen, max_seqlen)
|
|
|
|
|
| def kernels_flash_attention_func(
|
| query_states: torch.Tensor,
|
| key_states: torch.Tensor,
|
| value_states: torch.Tensor,
|
| attention_mask_2d: Optional[torch.Tensor] = None,
|
| causal: bool = False,
|
| ) -> torch.Tensor:
|
| assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
| if not causal and attention_mask_2d is not None:
|
| batch_size, q_len = query_states.shape[:2]
|
| (
|
| query_states, key_states, value_states,
|
| indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k),
|
| ) = _unpad_input(query_states, key_states, value_states, attention_mask_2d)
|
| attn_output_unpad = _kernels_flash_varlen_forward(
|
| query_states=query_states, key_states=key_states, value_states=value_states,
|
| cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
| max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k,
|
| )
|
| return pad_input(attn_output_unpad, indices_q, batch_size, q_len)
|
| else:
|
| return _kernels_flash_forward(
|
| query_states=query_states, key_states=key_states, value_states=value_states, causal=causal,
|
| )
|
|
|
|
|
|
|
| class AttentionBackend(Enum):
|
| AUTO = "auto"
|
| KERNELS_FLASH = "kernels_flash"
|
| FLEX = "flex"
|
| SDPA = "sdpa"
|
|
|
|
|
| VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend)
|
|
|
|
|
| _BACKEND_CONFIRMED = False
|
|
|
|
|
| def resolve_attention_backend(requested_backend: str) -> AttentionBackend:
|
| global _BACKEND_CONFIRMED
|
| assert requested_backend in VALID_ATTENTION_BACKENDS, (
|
| f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
|
| )
|
| if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value):
|
| _ensure_flash_kernels_loaded()
|
| if requested_backend == AttentionBackend.AUTO.value:
|
| if FLASH_KERNEL is not None:
|
| resolved = AttentionBackend.KERNELS_FLASH
|
| elif flex_attention is not None:
|
| resolved = AttentionBackend.FLEX
|
| else:
|
| resolved = AttentionBackend.SDPA
|
| elif requested_backend == AttentionBackend.KERNELS_FLASH.value:
|
| assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment."
|
| resolved = AttentionBackend.KERNELS_FLASH
|
| elif requested_backend == AttentionBackend.FLEX.value:
|
| assert flex_attention is not None, "Flex Attention is not available in this environment."
|
| resolved = AttentionBackend.FLEX
|
| elif requested_backend == AttentionBackend.SDPA.value:
|
| resolved = AttentionBackend.SDPA
|
| else:
|
| raise AssertionError(f"Unsupported attention backend: {requested_backend}")
|
| if not _BACKEND_CONFIRMED:
|
| print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'")
|
| _BACKEND_CONFIRMED = True
|
| return resolved
|
|
|
|
|
| @torch.compiler.disable
|
| def get_attention_mask(
|
| effective_backend: AttentionBackend,
|
| batch_size: int,
|
| seq_len: int,
|
| device: torch.device,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]:
|
| """Build padding masks once for all encoder layers.
|
|
|
| Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
|
| """
|
| if attention_mask is None:
|
| return None, None, None
|
|
|
| attention_mask_2d = attention_mask.bool()
|
|
|
| if effective_backend == AttentionBackend.KERNELS_FLASH:
|
| return attention_mask_2d, None, None
|
|
|
| if effective_backend == AttentionBackend.FLEX:
|
| assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
|
| valid_lens = attention_mask_2d.sum(dim=-1)
|
|
|
| def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
| return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx])
|
|
|
| flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device)
|
| return attention_mask_2d, None, flex_block_mask
|
|
|
|
|
|
|
| attention_mask_4d = attention_mask_2d[:, None, None, :]
|
| return attention_mask_2d, attention_mask_4d, None
|
|
|
| """FastESMFold: Self-contained ESMFold with FastESM2 attention backends + built-in Test-Time Training.
|
|
|
| Usage:
|
| from transformers import AutoModel
|
| model = AutoModel.from_pretrained("Synthyra/FastESMFold", trust_remote_code=True).cuda()
|
|
|
| # Basic folding
|
| result = model.fold_protein("MKTLLILAVVA...")
|
| print(result["plddt"], result["pdb_string"][:100])
|
|
|
| # Folding with TTT (test-time training improves structure prediction)
|
| result = model.fold_protein("MKTLLILAVVA...", ttt=True)
|
|
|
| Dependencies: torch, transformers, einops, peft (for LoRA TTT only)
|
| No dependency on: esm (fair-esm), proteinttt, openfold
|
| """
|
| import copy
|
| from dataclasses import dataclass, field
|
| from functools import wraps
|
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
| import torch
|
| import torch.nn as nn
|
| from torch.nn import functional as F
|
|
|
| from einops import rearrange
|
| from transformers import EsmTokenizer, PretrainedConfig, PreTrainedModel
|
| from transformers.modeling_outputs import ModelOutput
|
| from transformers.models.esm.configuration_esm import EsmConfig
|
| from transformers.models.esm.modeling_esm import (
|
| EsmContactPredictionHead,
|
| EsmEmbeddings,
|
| EsmIntermediate,
|
| EsmLMHead,
|
| EsmOutput,
|
| EsmSelfOutput,
|
| RotaryEmbedding,
|
| )
|
| from transformers.models.esm.modeling_esmfold import EsmForProteinFolding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| @dataclass
|
| class FastEsmEncoderOutput(ModelOutput):
|
| last_hidden_state: Optional[torch.Tensor] = None
|
| hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
|
| attentions: Optional[Tuple[torch.Tensor, ...]] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| class EsmSelfAttention(nn.Module):
|
| def __init__(self, config, position_embedding_type: Optional[str] = None):
|
| super().__init__()
|
| assert config.hidden_size % config.num_attention_heads == 0, (
|
| f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| f"heads ({config.num_attention_heads})"
|
| )
|
| self.num_attention_heads = config.num_attention_heads
|
| self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
|
| self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| self.scale = self.attention_head_size**-0.5
|
|
|
| self.dropout_prob = config.attention_probs_dropout_prob
|
| self.config = config
|
| self.attn_backend = resolve_attention_backend(config.attn_backend)
|
| self.position_embedding_type = position_embedding_type or config.position_embedding_type
|
| self.rotary_embeddings = None
|
| if self.position_embedding_type == "rotary":
|
| self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
|
|
|
| def forward(
|
| self,
|
| hidden_states: torch.Tensor,
|
| attention_mask_2d: Optional[torch.Tensor] = None,
|
| attention_mask_4d: Optional[torch.Tensor] = None,
|
| flex_block_mask: Optional[BlockMask] = None,
|
| output_attentions: bool = False,
|
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| batch_size, seq_length = hidden_states.shape[:-1]
|
| hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
|
| query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
| key_BHLD = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
| value_BHLD = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
|
| query_BHLD = query_BHLD * self.scale
|
|
|
| if self.position_embedding_type == "rotary":
|
| query_BHLD, key_BHLD = self.rotary_embeddings(query_BHLD, key_BHLD)
|
|
|
| attn_output, attn_weights = self._attn(
|
| query_BHLD, key_BHLD, value_BHLD,
|
| attention_mask_2d=attention_mask_2d,
|
| attention_mask_4d=attention_mask_4d,
|
| flex_block_mask=flex_block_mask,
|
| output_attentions=output_attentions,
|
| )
|
| return attn_output, attn_weights
|
|
|
| def _attn(
|
| self,
|
| query_BHLD: torch.Tensor,
|
| key_BHLD: torch.Tensor,
|
| value_BHLD: torch.Tensor,
|
| attention_mask_2d: Optional[torch.Tensor] = None,
|
| attention_mask_4d: Optional[torch.Tensor] = None,
|
| flex_block_mask: Optional[BlockMask] = None,
|
| output_attentions: bool = False,
|
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| if output_attentions:
|
| return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d)
|
|
|
| if self.attn_backend == AttentionBackend.KERNELS_FLASH:
|
| return self._kernels_flash_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_2d)
|
| elif self.attn_backend == AttentionBackend.FLEX:
|
| return self._flex_attn(query_BHLD, key_BHLD, value_BHLD, flex_block_mask)
|
| elif self.attn_backend == AttentionBackend.SDPA:
|
| return self._sdpa_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d)
|
| else:
|
| raise AssertionError(f"Unsupported resolved backend: {self.attn_backend}")
|
|
|
| def _manual_attn(
|
| self,
|
| query_BHLD: torch.Tensor,
|
| key_BHLD: torch.Tensor,
|
| value_BHLD: torch.Tensor,
|
| attention_mask_4d: Optional[torch.Tensor] = None,
|
| ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2))
|
| if attention_mask_4d is not None:
|
| attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
|
| attn_weights = F.softmax(attn_weights, dim=-1)
|
| if self.dropout_prob > 0 and self.training:
|
| attn_weights = F.dropout(attn_weights, p=self.dropout_prob, training=self.training)
|
| context_BHLD = torch.matmul(attn_weights, value_BHLD)
|
| attn_output = rearrange(context_BHLD, "b h s d -> b s (h d)")
|
| return attn_output, attn_weights
|
|
|
| def _kernels_flash_attn(
|
| self,
|
| query_BHLD: torch.Tensor,
|
| key_BHLD: torch.Tensor,
|
| value_BHLD: torch.Tensor,
|
| attention_mask_2d: Optional[torch.Tensor] = None,
|
| ) -> Tuple[torch.Tensor, None]:
|
| query_BLHD = query_BHLD.transpose(1, 2).contiguous()
|
| key_BLHD = key_BHLD.transpose(1, 2).contiguous()
|
| value_BLHD = value_BHLD.transpose(1, 2).contiguous()
|
| attn_output = kernels_flash_attention_func(
|
| query_states=query_BLHD, key_states=key_BLHD, value_states=value_BLHD,
|
| attention_mask_2d=attention_mask_2d, causal=False,
|
| )
|
| return rearrange(attn_output, "b s h d -> b s (h d)"), None
|
|
|
| def _flex_attn(
|
| self,
|
| query_BHLD: torch.Tensor,
|
| key_BHLD: torch.Tensor,
|
| value_BHLD: torch.Tensor,
|
| flex_block_mask: Optional[BlockMask] = None,
|
| ) -> Tuple[torch.Tensor, None]:
|
| assert flex_attention is not None, "Flex attention is not available in this environment."
|
| fn = _get_flex_attention_fn()
|
| context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
|
| return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
|
|
|
| def _sdpa_attn(
|
| self,
|
| query_BHLD: torch.Tensor,
|
| key_BHLD: torch.Tensor,
|
| value_BHLD: torch.Tensor,
|
| attention_mask_4d: Optional[torch.Tensor] = None,
|
| ) -> Tuple[torch.Tensor, None]:
|
| context_BHLD = F.scaled_dot_product_attention(
|
| query_BHLD, key_BHLD, value_BHLD,
|
| attn_mask=attention_mask_4d,
|
| dropout_p=self.dropout_prob if self.training else 0.0,
|
| scale=1.0,
|
| )
|
| return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
|
|
|
|
|
| class EsmAttention(nn.Module):
|
| def __init__(self, config):
|
| super().__init__()
|
| self.self = EsmSelfAttention(config)
|
| self.output = EsmSelfOutput(config)
|
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
| def forward(
|
| self,
|
| hidden_states: torch.Tensor,
|
| attention_mask_2d: Optional[torch.Tensor] = None,
|
| attention_mask_4d: Optional[torch.Tensor] = None,
|
| flex_block_mask: Optional[BlockMask] = None,
|
| output_attentions: bool = False,
|
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| hidden_states_ln = self.LayerNorm(hidden_states)
|
| attn_output, attn_weights = self.self(
|
| hidden_states_ln,
|
| attention_mask_2d=attention_mask_2d,
|
| attention_mask_4d=attention_mask_4d,
|
| flex_block_mask=flex_block_mask,
|
| output_attentions=output_attentions,
|
| )
|
| attention_output = self.output(attn_output, hidden_states)
|
| return attention_output, attn_weights
|
|
|
|
|
| class EsmLayer(nn.Module):
|
| def __init__(self, config):
|
| super().__init__()
|
| self.attention = EsmAttention(config)
|
| self.intermediate = EsmIntermediate(config)
|
| self.output = EsmOutput(config)
|
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
| def forward(
|
| self,
|
| hidden_states: torch.Tensor,
|
| attention_mask_2d: Optional[torch.Tensor] = None,
|
| attention_mask_4d: Optional[torch.Tensor] = None,
|
| flex_block_mask: Optional[BlockMask] = None,
|
| output_attentions: bool = False,
|
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| attention_output, attn_weights = self.attention(
|
| hidden_states,
|
| attention_mask_2d=attention_mask_2d,
|
| attention_mask_4d=attention_mask_4d,
|
| flex_block_mask=flex_block_mask,
|
| output_attentions=output_attentions,
|
| )
|
| layer_output = self._feed_forward(attention_output)
|
| return layer_output, attn_weights
|
|
|
| def _feed_forward(self, attention_output: torch.Tensor) -> torch.Tensor:
|
| attention_output_ln = self.LayerNorm(attention_output)
|
| intermediate_output = self.intermediate(attention_output_ln)
|
| return self.output(intermediate_output, attention_output)
|
|
|
|
|
| class FastEsmEncoder(nn.Module):
|
| def __init__(self, config):
|
| super().__init__()
|
| self.config = config
|
| self.attention_backend = resolve_attention_backend(config.attn_backend)
|
| self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
|
| self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
| def forward(
|
| self,
|
| hidden_states: torch.Tensor,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| output_hidden_states: bool = False,
|
| output_attentions: bool = False,
|
| ) -> FastEsmEncoderOutput:
|
| all_hidden_states = () if output_hidden_states else None
|
| all_attentions = () if output_attentions else None
|
|
|
| attention_mask_2d, attention_mask_4d, flex_block_mask = get_attention_mask(
|
| effective_backend=self.attention_backend,
|
| batch_size=hidden_states.shape[0],
|
| seq_len=hidden_states.shape[1],
|
| device=hidden_states.device,
|
| attention_mask=attention_mask,
|
| )
|
|
|
| for layer_module in self.layer:
|
| if output_hidden_states:
|
| all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
| hidden_states, attn_weights = layer_module(
|
| hidden_states,
|
| attention_mask_2d=attention_mask_2d,
|
| attention_mask_4d=attention_mask_4d,
|
| flex_block_mask=flex_block_mask,
|
| output_attentions=output_attentions,
|
| )
|
|
|
| if all_attentions is not None:
|
| all_attentions = all_attentions + (attn_weights,)
|
|
|
| if self.emb_layer_norm_after:
|
| hidden_states = self.emb_layer_norm_after(hidden_states)
|
|
|
| if output_hidden_states:
|
| all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
| return FastEsmEncoderOutput(
|
| last_hidden_state=hidden_states,
|
| hidden_states=all_hidden_states,
|
| attentions=all_attentions,
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
| class FastEsmBackbone(nn.Module):
|
| """FastESM2 backbone with multi-backend attention. Drop-in replacement for
|
| transformers.EsmModel inside EsmForProteinFolding.
|
|
|
| State dict keys match HuggingFace EsmModel exactly, so pretrained weights
|
| load without any key remapping.
|
| """
|
|
|
| def __init__(self, config):
|
| super().__init__()
|
| self.config = config
|
| self.embeddings = EsmEmbeddings(config)
|
| self.encoder = FastEsmEncoder(config)
|
| self.contact_head = EsmContactPredictionHead(
|
| in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
|
| )
|
|
|
| def forward(
|
| self,
|
| input_ids: Optional[torch.Tensor] = None,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| position_ids: Optional[torch.Tensor] = None,
|
| inputs_embeds: Optional[torch.Tensor] = None,
|
| output_attentions: Optional[bool] = None,
|
| output_hidden_states: Optional[bool] = None,
|
| return_dict: Optional[bool] = None,
|
| **kwargs,
|
| ) -> FastEsmEncoderOutput:
|
| output_attentions = output_attentions if output_attentions is not None else False
|
| output_hidden_states = output_hidden_states if output_hidden_states is not None else False
|
|
|
| token_embedding_output = self.embeddings(
|
| input_ids=input_ids,
|
| position_ids=position_ids,
|
| attention_mask=attention_mask,
|
| inputs_embeds=inputs_embeds,
|
| )
|
| encoder_outputs = self.encoder(
|
| token_embedding_output,
|
| attention_mask=attention_mask,
|
| output_hidden_states=output_hidden_states,
|
| output_attentions=output_attentions,
|
| )
|
| return FastEsmEncoderOutput(
|
| last_hidden_state=encoder_outputs.last_hidden_state,
|
| hidden_states=encoder_outputs.hidden_states,
|
| attentions=encoder_outputs.attentions,
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
| _ESM_STANDARD_AA = list("ACDEFGHIKLMNPQRSTVWY")
|
|
|
|
|
| class LoraInjectedLinear(nn.Module):
|
| """LoRA-augmented linear layer matching lora_diffusion's behavior.
|
|
|
| Replaces an existing nn.Linear with base(x) + lora_up(lora_down(x)) * scale.
|
| Initialization follows cloneofsimo/lora: down=Normal(0, 1/r), up=zeros.
|
| """
|
|
|
| def __init__(self, original_linear: nn.Linear, r: int = 4, scale: float = 1.0):
|
| super().__init__()
|
| self.linear = original_linear
|
| in_features = original_linear.in_features
|
| out_features = original_linear.out_features
|
| assert r <= min(in_features, out_features), f"LoRA rank {r} exceeds dimensions ({in_features}, {out_features})"
|
| self.lora_down = nn.Linear(in_features, r, bias=False)
|
| self.lora_up = nn.Linear(r, out_features, bias=False)
|
| self.scale = scale
|
| nn.init.normal_(self.lora_down.weight, std=1.0 / r)
|
| nn.init.zeros_(self.lora_up.weight)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| return self.linear(x) + self.lora_up(self.lora_down(x)) * self.scale
|
|
|
|
|
| def inject_trainable_lora(
|
| model: nn.Module,
|
| target_class_name: str,
|
| r: int,
|
| scale: float,
|
| ) -> List[nn.Parameter]:
|
| """Replace nn.Linear layers inside modules matching target_class_name with LoRA.
|
|
|
| Matches lora_diffusion's inject_trainable_lora behavior: finds all modules whose
|
| class name matches target_class_name, then replaces their nn.Linear children with
|
| LoraInjectedLinear. Returns the list of trainable LoRA parameters.
|
| """
|
| lora_params: List[nn.Parameter] = []
|
| for _parent_name, parent_module in model.named_modules():
|
| if parent_module.__class__.__name__ != target_class_name:
|
| continue
|
| for child_name, child_module in list(parent_module.named_children()):
|
| if not isinstance(child_module, nn.Linear):
|
| continue
|
| lora_linear = LoraInjectedLinear(child_module, r=r, scale=scale)
|
| lora_linear = lora_linear.to(
|
| device=child_module.weight.device,
|
| dtype=child_module.weight.dtype,
|
| )
|
| setattr(parent_module, child_name, lora_linear)
|
| lora_params.extend(lora_linear.lora_down.parameters())
|
| lora_params.extend(lora_linear.lora_up.parameters())
|
| return lora_params
|
|
|
|
|
| @dataclass
|
| class TTTConfig:
|
| lr: float = 4e-4
|
| ags: int = 4
|
| steps: int = 10
|
| batch_size: int = 4
|
| mask_ratio: float = 0.15
|
| crop_size: int = 1024
|
| bert_leave_prob: float = 0.1
|
| bert_replace_prob: float = 0.1
|
| optimizer: str = "sgd"
|
| momentum: float = 0.0
|
| weight_decay: float = 0.0
|
| seed: Optional[int] = 0
|
| initial_state_reset: bool = True
|
| freeze_embeddings: bool = True
|
| lora_rank: int = 8
|
| lora_alpha: float = 32.0
|
| lora_target_class: str = "EsmSelfAttention"
|
|
|
| def verify(self) -> None:
|
| assert self.lr > 0.0, "TTT learning rate must be positive."
|
| assert self.ags > 0, "TTT ags must be positive."
|
| assert self.steps >= 0, "TTT steps must be non-negative."
|
| assert self.batch_size > 0, "TTT batch_size must be positive."
|
| assert 0.0 < self.mask_ratio <= 1.0, "TTT mask_ratio must be in (0, 1]."
|
| assert self.crop_size > 0, "TTT crop_size must be positive."
|
| assert 0.0 <= self.bert_leave_prob <= 1.0
|
| assert 0.0 <= self.bert_replace_prob <= 1.0
|
| assert self.bert_leave_prob + self.bert_replace_prob <= 1.0
|
| assert self.optimizer in {"sgd", "adamw"}
|
| assert self.lora_rank >= 0
|
| assert self.lora_alpha > 0.0
|
|
|
|
|
| def preserve_model_state(func: Callable[..., Any]) -> Callable[..., Any]:
|
| @wraps(func)
|
| def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
| was_training = self.training
|
| original_device = next(self.parameters()).device
|
| original_requires_grad = {
|
| name: parameter.requires_grad
|
| for name, parameter in self.named_parameters()
|
| }
|
| try:
|
| return func(self, *args, **kwargs)
|
| finally:
|
| self.train(was_training)
|
| self.to(original_device)
|
| for name, parameter in self.named_parameters():
|
| if name in original_requires_grad:
|
| parameter.requires_grad = original_requires_grad[name]
|
| else:
|
| parameter.requires_grad = False
|
| return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
| class FastEsmFoldConfig(EsmConfig):
|
| model_type = "fast_esmfold"
|
|
|
| def __init__(self, attn_backend: str = "sdpa", ttt_config: Optional[Dict[str, Any]] = None, **kwargs):
|
| super().__init__(**kwargs)
|
| self.attn_backend = attn_backend
|
| self.ttt_config = ttt_config or {
|
| "lr": 4e-4,
|
| "steps": 10,
|
| "lora_rank": 8,
|
| "lora_alpha": 32.0,
|
| }
|
|
|
|
|
|
|
|
|
|
|
|
|
| class FastEsmForProteinFolding(EsmForProteinFolding):
|
| """ESMFold with FastESM2 attention backends + built-in Test-Time Training.
|
|
|
| Inherits all folding logic (trunk, structure module, output_to_pdb, infer)
|
| from transformers.EsmForProteinFolding. Replaces the ESM2 backbone with
|
| FastESM2 for optimized attention and adds TTT for improved structure prediction.
|
|
|
| Key API:
|
| result = model.fold_protein("MKTL...", ttt=True)
|
| # result = {"plddt": float, "ptm": float, "pdb_string": str}
|
| """
|
| config_class = FastEsmFoldConfig
|
|
|
| def __init__(self, config: FastEsmFoldConfig):
|
| super().__init__(config)
|
|
|
|
|
|
|
| if not config.ttt_config.get("use_standard_backbone", False):
|
| self.esm = FastEsmBackbone(config)
|
| self.esm.requires_grad_(False)
|
| if config.esmfold_config.fp16_esm:
|
| self.esm.half()
|
|
|
|
|
| self.mlm_head = EsmLMHead(config)
|
|
|
|
|
| ttt_kwargs = {k: v for k, v in config.ttt_config.items() if k != "use_standard_backbone"}
|
| self._ttt_cfg = TTTConfig(**ttt_kwargs)
|
| self._ttt_cfg.verify()
|
| self._ttt_initialized = False
|
| self._ttt_initial_state = None
|
| self._ttt_generator = torch.Generator()
|
| if self._ttt_cfg.seed is not None:
|
| self._ttt_generator.manual_seed(self._ttt_cfg.seed)
|
| self._non_special_tokens_cache = None
|
| self._ttt_tokenizer = None
|
|
|
| def _get_ttt_tokenizer(self) -> EsmTokenizer:
|
| if self._ttt_tokenizer is None:
|
| self._ttt_tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
|
| return self._ttt_tokenizer
|
|
|
| def _ensure_ttt_ready(self) -> None:
|
| """Lazy TTT initialization. Injects LoRA adapters and saves initial state.
|
| Must be called after weights are loaded (not in __init__)."""
|
| if self._ttt_initialized:
|
| return
|
| self._ttt_initialized = True
|
|
|
| tokenizer = self._get_ttt_tokenizer()
|
| vocab = tokenizer.get_vocab()
|
| self._non_special_tokens_cache = [vocab[c] for c in _ESM_STANDARD_AA if c in vocab]
|
|
|
| if self._ttt_cfg.lora_rank > 0:
|
| self.mlm_head.eval()
|
| for p in self.mlm_head.parameters():
|
| p.requires_grad = False
|
|
|
| if self._ttt_cfg.seed is not None:
|
| torch.manual_seed(self._ttt_cfg.seed)
|
| self._inject_lora()
|
| else:
|
|
|
| H = self.config.hidden_size
|
| V = self.config.vocab_size
|
| device = next(self.esm.parameters()).device
|
| self._ttt_lm_proj = nn.Linear(H, V, bias=True).to(device)
|
|
|
| if self._ttt_cfg.initial_state_reset:
|
| self._ttt_initial_state = self._ttt_get_state()
|
|
|
| @property
|
| def _uses_lora(self) -> bool:
|
| return self._ttt_cfg.lora_rank > 0
|
|
|
| def _inject_lora(self) -> None:
|
| """Inject LoRA adapters into ESM2 attention layers (matching lora_diffusion behavior)."""
|
| self._lora_params = inject_trainable_lora(
|
| self.esm,
|
| target_class_name=self._ttt_cfg.lora_target_class,
|
| r=self._ttt_cfg.lora_rank,
|
| scale=self._ttt_cfg.lora_alpha,
|
| )
|
| assert len(self._lora_params) > 0, (
|
| f"No LoRA params injected. Check target_class_name='{self._ttt_cfg.lora_target_class}' "
|
| f"matches attention modules in the backbone."
|
| )
|
|
|
|
|
|
|
| def _get_lora_modules(self) -> List[LoraInjectedLinear]:
|
| """Find all LoraInjectedLinear modules in the backbone."""
|
| return [m for m in self.esm.modules() if isinstance(m, LoraInjectedLinear)]
|
|
|
| def _ttt_get_state(self) -> Dict[str, Any]:
|
| if self._uses_lora:
|
| lora_state = []
|
| for m in self._get_lora_modules():
|
| lora_state.append({
|
| "down": m.lora_down.weight.data.clone(),
|
| "up": m.lora_up.weight.data.clone(),
|
| })
|
| return {"_lora_state": lora_state}
|
| return {
|
| "esm": copy.deepcopy(self.esm),
|
| "_ttt_lm_proj": copy.deepcopy(self._ttt_lm_proj),
|
| }
|
|
|
| def _ttt_set_state(self, state: Dict[str, Any]) -> None:
|
| if "_lora_state" in state:
|
| modules = self._get_lora_modules()
|
| assert len(modules) == len(state["_lora_state"])
|
| for m, saved in zip(modules, state["_lora_state"]):
|
| m.lora_down.weight.data.copy_(saved["down"])
|
| m.lora_up.weight.data.copy_(saved["up"])
|
| return
|
| if "esm" in state:
|
| self.esm = copy.deepcopy(state["esm"])
|
| if "_ttt_lm_proj" in state:
|
| self._ttt_lm_proj = copy.deepcopy(state["_ttt_lm_proj"])
|
|
|
| def ttt_reset(self) -> None:
|
| """Reset model to pre-TTT state (restore initial LoRA or backbone weights)."""
|
| assert self._ttt_initial_state is not None, "TTT reset requires initial_state_reset=True."
|
| self._ttt_set_state(self._ttt_initial_state)
|
|
|
|
|
|
|
| def _ttt_tokenize(self, seq: str) -> torch.Tensor:
|
| tokenizer = self._get_ttt_tokenizer()
|
| out = tokenizer(
|
| seq,
|
| return_tensors="pt",
|
| add_special_tokens=self._uses_lora,
|
| padding=False,
|
| truncation=False,
|
| )
|
| return out["input_ids"]
|
|
|
| def _ttt_mask_token(self) -> int:
|
| return self._get_ttt_tokenizer().mask_token_id
|
|
|
| def _ttt_get_non_special_tokens(self) -> List[int]:
|
| if self._non_special_tokens_cache is not None:
|
| return self._non_special_tokens_cache
|
| tokenizer = self._get_ttt_tokenizer()
|
| vocab = tokenizer.get_vocab()
|
| self._non_special_tokens_cache = [vocab[c] for c in _ESM_STANDARD_AA if c in vocab]
|
| return self._non_special_tokens_cache
|
|
|
| def _ttt_predict_logits(self, batch: torch.Tensor) -> torch.Tensor:
|
| """Run ESM2 backbone + LM head to get MLM logits."""
|
|
|
| output = self.esm(input_ids=batch)
|
| hidden = output.last_hidden_state
|
| if self._uses_lora:
|
| return self.mlm_head(hidden)
|
| return self._ttt_lm_proj(hidden)
|
|
|
| def _ttt_sample_batch(
|
| self,
|
| x: torch.Tensor,
|
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| _, seq_len = x.shape
|
| batch_size = self._ttt_cfg.batch_size
|
| crop_size = min(self._ttt_cfg.crop_size, seq_len)
|
|
|
| x_expanded = x.expand(batch_size, -1)
|
| if seq_len == crop_size:
|
| start_indices = torch.zeros(batch_size, dtype=torch.long)
|
| else:
|
| start_indices = torch.randint(
|
| 0, seq_len - crop_size + 1, (batch_size,),
|
| generator=self._ttt_generator,
|
| ).to(torch.long)
|
|
|
| batch_cropped = torch.stack([
|
| x_expanded[index, start : start + crop_size]
|
| for index, start in enumerate(start_indices)
|
| ])
|
|
|
| non_special_tokens = set(self._ttt_get_non_special_tokens())
|
| mask = torch.zeros((batch_size, crop_size), dtype=torch.bool)
|
| mask_token_id = self._ttt_mask_token()
|
|
|
| for row_index in range(batch_size):
|
| non_special_positions = [
|
| col for col in range(crop_size)
|
| if batch_cropped[row_index, col].item() in non_special_tokens
|
| ]
|
| assert len(non_special_positions) > 0, "Sequence must contain at least one non-special token."
|
| num_to_mask = max(1, int(round(len(non_special_positions) * self._ttt_cfg.mask_ratio)))
|
| sampled_indices = torch.randperm(
|
| len(non_special_positions), generator=self._ttt_generator,
|
| )[:num_to_mask]
|
| positions_to_mask = torch.tensor(non_special_positions, dtype=torch.long)[sampled_indices]
|
| mask[row_index, positions_to_mask] = True
|
|
|
| batch_masked = batch_cropped.clone()
|
| for row_index in range(batch_size):
|
| masked_positions = torch.nonzero(mask[row_index], as_tuple=True)[0]
|
| for masked_position in masked_positions:
|
| probability = float(torch.rand(1, generator=self._ttt_generator).item())
|
| if probability < 1.0 - self._ttt_cfg.bert_leave_prob - self._ttt_cfg.bert_replace_prob:
|
| batch_masked[row_index, masked_position] = mask_token_id
|
| continue
|
| if probability < 1.0 - self._ttt_cfg.bert_leave_prob:
|
| replacement_candidates = self._ttt_get_non_special_tokens()
|
| replacement_index = int(torch.randint(
|
| 0, len(replacement_candidates), (1,), generator=self._ttt_generator,
|
| ).item())
|
| batch_masked[row_index, masked_position] = replacement_candidates[replacement_index]
|
|
|
| return batch_masked, batch_cropped, mask, start_indices
|
|
|
| def _ttt_cross_entropy_loss(
|
| self,
|
| logits: torch.Tensor,
|
| targets: torch.Tensor,
|
| mask: torch.Tensor,
|
| ) -> torch.Tensor:
|
| assert logits.ndim == 3, "Logits must be [batch, seq, vocab]."
|
| _, _, vocab_size = logits.shape
|
| logits_flat = logits.reshape(-1, vocab_size)
|
| targets_flat = targets.reshape(-1)
|
| mask_flat = mask.reshape(-1)
|
| assert int(mask_flat.sum().item()) > 0, "TTT mask must select at least one token."
|
| loss = F.cross_entropy(
|
| logits_flat[mask_flat],
|
| targets_flat[mask_flat],
|
| reduction="none",
|
| )
|
| masked_tokens_per_seq = mask.sum(dim=1).tolist()
|
| per_sequence_losses = torch.split(loss, masked_tokens_per_seq)
|
| return torch.stack([sl.mean() for sl in per_sequence_losses]).mean()
|
|
|
| def _ttt_get_optimizer(self, parameters) -> torch.optim.Optimizer:
|
| if self._ttt_cfg.optimizer == "sgd":
|
| return torch.optim.SGD(
|
| parameters,
|
| lr=self._ttt_cfg.lr,
|
| momentum=self._ttt_cfg.momentum,
|
| weight_decay=self._ttt_cfg.weight_decay,
|
| )
|
| return torch.optim.AdamW(
|
| parameters,
|
| lr=self._ttt_cfg.lr,
|
| weight_decay=self._ttt_cfg.weight_decay,
|
| )
|
|
|
| def _lora_ttt(self, seq: str) -> Dict[str, List[float]]:
|
| """LoRA TTT: only LoRA adapter weights are trained, mlm_head is frozen."""
|
| x = self._ttt_tokenize(seq)
|
| device = next(self.parameters()).device
|
| non_blocking = device.type == "cuda"
|
| losses = []
|
|
|
| if self._ttt_cfg.steps == 0:
|
| return {"losses": losses}
|
|
|
| for parameter in self.parameters():
|
| parameter.requires_grad = False
|
| for p in self._lora_params:
|
| p.requires_grad = True
|
| optimizer = self._ttt_get_optimizer(self._lora_params)
|
| optimizer.zero_grad(set_to_none=True)
|
|
|
| self.eval()
|
| for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags):
|
| batch_masked, targets, mask, start_indices = self._ttt_sample_batch(x)
|
| batch_masked = batch_masked.to(device, non_blocking=non_blocking)
|
| targets = targets.to(device, non_blocking=non_blocking)
|
| mask = mask.to(device, non_blocking=non_blocking)
|
|
|
| self.train()
|
| logits = self._ttt_predict_logits(batch_masked)
|
| loss = self._ttt_cross_entropy_loss(logits, targets, mask)
|
| loss.backward()
|
| losses.append(float(loss.detach().cpu().item()))
|
|
|
| if (step + 1) % self._ttt_cfg.ags == 0:
|
| optimizer.step()
|
| optimizer.zero_grad(set_to_none=True)
|
|
|
| self.eval()
|
| return {"losses": losses}
|
|
|
| def _legacy_ttt(self, seq: str) -> Dict[str, List[float]]:
|
| """Legacy TTT: full fine-tuning of ESM2 backbone with random linear projection head."""
|
| x = self._ttt_tokenize(seq)
|
| device = next(self.parameters()).device
|
| non_blocking = device.type == "cuda"
|
| losses = []
|
|
|
| if self._ttt_cfg.steps == 0:
|
| return {"losses": losses}
|
|
|
|
|
| for parameter in self.parameters():
|
| parameter.requires_grad = False
|
| for parameter in self.esm.parameters():
|
| parameter.requires_grad = True
|
| if self._ttt_cfg.freeze_embeddings:
|
| for parameter in self.esm.embeddings.parameters():
|
| parameter.requires_grad = False
|
| for parameter in self._ttt_lm_proj.parameters():
|
| parameter.requires_grad = True
|
|
|
| trainable_params = filter(lambda p: p.requires_grad, self.parameters())
|
| optimizer = self._ttt_get_optimizer(trainable_params)
|
| optimizer.zero_grad(set_to_none=True)
|
|
|
| self.eval()
|
| for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags):
|
| batch_masked, targets, mask, start_indices = self._ttt_sample_batch(x)
|
| batch_masked = batch_masked.to(device, non_blocking=non_blocking)
|
| targets = targets.to(device, non_blocking=non_blocking)
|
| mask = mask.to(device, non_blocking=non_blocking)
|
|
|
| self.train()
|
| logits = self._ttt_predict_logits(batch_masked)
|
| loss = self._ttt_cross_entropy_loss(logits, targets, mask)
|
| loss.backward()
|
| losses.append(float(loss.detach().cpu().item()))
|
|
|
| if (step + 1) % self._ttt_cfg.ags == 0:
|
| optimizer.step()
|
| optimizer.zero_grad(set_to_none=True)
|
|
|
| self.eval()
|
| return {"losses": losses}
|
|
|
| @preserve_model_state
|
| def ttt(self, seq: str) -> Dict[str, List[float]]:
|
| """Run test-time training on a single sequence using masked language modeling.
|
|
|
| Adapts the ESM2 backbone (via LoRA or full fine-tuning) to the input sequence
|
| before structure prediction. Call fold_protein(seq, ttt=True) for the full pipeline.
|
|
|
| Args:
|
| seq: Protein sequence (single-letter amino acid codes)
|
|
|
| Returns:
|
| Dict with "losses" key containing per-step MLM loss values
|
| """
|
| self._ensure_ttt_ready()
|
|
|
|
|
| esm_dtype = next(self.esm.parameters()).dtype
|
| if esm_dtype != torch.float32:
|
| self.esm.float()
|
| self.mlm_head.float()
|
| if self._uses_lora:
|
| result = self._lora_ttt(seq)
|
| else:
|
| result = self._legacy_ttt(seq)
|
|
|
| if esm_dtype != torch.float32:
|
| self.esm.to(esm_dtype)
|
| self.mlm_head.to(esm_dtype)
|
| return result
|
|
|
|
|
|
|
| def _fold_single(self, sequence: str, return_pdb_string: bool = True) -> Dict[str, Any]:
|
| """Fold a sequence once and return pLDDT, ptm, and optionally PDB string."""
|
| with torch.no_grad():
|
| output = self.infer(sequence)
|
| plddt = output["plddt"]
|
|
|
|
|
| if plddt.dim() == 3:
|
| mean_plddt = float(plddt[:, :, 1].mean().item())
|
| elif plddt.dim() == 2:
|
| mean_plddt = float(plddt[:, 1].mean().item())
|
| else:
|
| mean_plddt = float(plddt.mean().item())
|
| result = {
|
| "plddt": mean_plddt,
|
| "ptm": float(output["ptm"].item()) if "ptm" in output else None,
|
| }
|
| if return_pdb_string:
|
| pdb_strings = self.output_to_pdb(output)
|
| result["pdb_string"] = pdb_strings[0] if isinstance(pdb_strings, list) else pdb_strings
|
| return result
|
|
|
| def fold_protein(
|
| self,
|
| sequence: str,
|
| return_pdb_string: bool = True,
|
| ) -> Dict[str, Any]:
|
| """Fold a protein sequence with test-time training.
|
|
|
| Runs TTT (masked language model adaptation via LoRA) for the configured
|
| number of steps, folding after each optimizer step to track pLDDT. Returns
|
| the structure with the highest pLDDT across all steps (including baseline).
|
|
|
| Args:
|
| sequence: Protein sequence (single-letter amino acid codes)
|
| return_pdb_string: If True, include PDB string in output
|
|
|
| Returns:
|
| Dict with keys:
|
| - plddt: float, best mean pLDDT across all TTT steps
|
| - ptm: float, predicted TM-score from best step
|
| - pdb_string: str (if return_pdb_string=True), PDB from best step
|
| - step_plddts: list[float], pLDDT at each step [baseline, s1, ..., s10]
|
| - best_step: int, which step produced best structure (0=baseline)
|
| """
|
| self._ensure_ttt_ready()
|
|
|
|
|
| esm_dtype = next(self.esm.parameters()).dtype
|
| if esm_dtype != torch.float32:
|
| self.esm.float()
|
| self.mlm_head.float()
|
|
|
| device = next(self.parameters()).device
|
| non_blocking = device.type == "cuda"
|
|
|
|
|
| best = self._fold_single(sequence, return_pdb_string=return_pdb_string)
|
| step_plddts = [best["plddt"]]
|
|
|
| if self._ttt_cfg.steps > 0:
|
|
|
| x = self._ttt_tokenize(sequence)
|
|
|
|
|
| for p in self.parameters():
|
| p.requires_grad = False
|
| if self._uses_lora:
|
| for p in self._lora_params:
|
| p.requires_grad = True
|
| optimizer = self._ttt_get_optimizer(self._lora_params)
|
| else:
|
| for p in self.esm.parameters():
|
| p.requires_grad = True
|
| if self._ttt_cfg.freeze_embeddings:
|
| for p in self.esm.embeddings.parameters():
|
| p.requires_grad = False
|
| for p in self._ttt_lm_proj.parameters():
|
| p.requires_grad = True
|
| trainable = [p for p in self.parameters() if p.requires_grad]
|
| optimizer = self._ttt_get_optimizer(trainable)
|
| optimizer.zero_grad(set_to_none=True)
|
|
|
| self.eval()
|
| for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags):
|
| batch_masked, targets, mask, _start = self._ttt_sample_batch(x)
|
| batch_masked = batch_masked.to(device, non_blocking=non_blocking)
|
| targets = targets.to(device, non_blocking=non_blocking)
|
| mask = mask.to(device, non_blocking=non_blocking)
|
|
|
| self.train()
|
| logits = self._ttt_predict_logits(batch_masked)
|
| loss = self._ttt_cross_entropy_loss(logits, targets, mask)
|
| loss.backward()
|
|
|
| if (step + 1) % self._ttt_cfg.ags == 0:
|
| optimizer.step()
|
| optimizer.zero_grad(set_to_none=True)
|
|
|
|
|
| self.eval()
|
| current = self._fold_single(sequence, return_pdb_string=return_pdb_string)
|
| step_plddts.append(current["plddt"])
|
| if current["plddt"] > best["plddt"]:
|
| best = current
|
|
|
| self.eval()
|
|
|
|
|
| for p in self.parameters():
|
| p.requires_grad = False
|
|
|
|
|
| self.ttt_reset()
|
|
|
|
|
| if esm_dtype != torch.float32:
|
| self.esm.to(esm_dtype)
|
| self.mlm_head.to(esm_dtype)
|
|
|
| return {
|
| "plddt": best["plddt"],
|
| "ptm": best["ptm"],
|
| "pdb_string": best.get("pdb_string"),
|
| "step_plddts": step_plddts,
|
| "best_step": step_plddts.index(max(step_plddts)),
|
| }
|
|
|