| """FastESMFold: Self-contained ESMFold with FastESM2 attention backends + built-in Test-Time Training. |
| |
| Usage: |
| from transformers import AutoModel |
| model = AutoModel.from_pretrained("Synthyra/FastESMFold", trust_remote_code=True).cuda() |
| |
| # Basic folding |
| result = model.fold_protein("MKTLLILAVVA...") |
| print(result["plddt"], result["pdb_string"][:100]) |
| |
| # Folding with TTT (test-time training improves structure prediction) |
| result = model.fold_protein("MKTLLILAVVA...", ttt=True) |
| |
| Dependencies: torch, transformers, einops, peft (for LoRA TTT only) |
| No dependency on: esm (fair-esm), proteinttt, openfold |
| """ |
| import copy |
| from dataclasses import dataclass, field |
| from enum import Enum |
| from functools import wraps |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import functional as F |
|
|
| from einops import rearrange |
| from transformers import EsmTokenizer, PretrainedConfig, PreTrainedModel |
| from transformers.modeling_outputs import ModelOutput |
| from transformers.models.esm.configuration_esm import EsmConfig |
| from transformers.models.esm.modeling_esm import ( |
| EsmContactPredictionHead, |
| EsmEmbeddings, |
| EsmIntermediate, |
| EsmLMHead, |
| EsmOutput, |
| EsmSelfOutput, |
| RotaryEmbedding, |
| ) |
| from transformers.models.esm.modeling_esmfold import EsmForProteinFolding |
|
|
|
|
| |
| |
| |
|
|
| try: |
| from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask |
| except ImportError: |
| create_block_mask = None |
| flex_attention = None |
| BlockMask = None |
|
|
| _compiled_flex_attention = None |
|
|
|
|
| def _get_flex_attention_fn(): |
| global _compiled_flex_attention |
| if flex_attention is None: |
| return None |
| flex_mod = torch.nn.attention.flex_attention |
| if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False): |
| return flex_attention |
| if _compiled_flex_attention is None: |
| _compiled_flex_attention = torch.compile(flex_attention) |
| return _compiled_flex_attention |
|
|
|
|
| def _infer_kernels_flash_variant(kernel) -> str | None: |
| if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"): |
| return "flash_attn2" |
| if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"): |
| return "flash_attn3" |
| return None |
|
|
|
|
| def _try_get_kernels_flash(): |
| try: |
| from kernels import get_kernel |
| except ImportError: |
| return None, None |
|
|
| flash_kernel = None |
| flash_kernel_variant = None |
| try: |
| flash_kernel = get_kernel("kernels-community/flash-attn3") |
| flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) |
| assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API." |
| except Exception: |
| try: |
| flash_kernel = get_kernel("kernels-community/flash-attn2") |
| flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) |
| assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API." |
| except Exception: |
| flash_kernel = None |
| flash_kernel_variant = None |
| return flash_kernel, flash_kernel_variant |
|
|
|
|
| _FLASH_KERNELS_LOADED = False |
| FLASH_KERNEL = None |
| FLASH_KERNEL_VARIANT = None |
|
|
|
|
| def _ensure_flash_kernels_loaded(): |
| global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT |
| if _FLASH_KERNELS_LOADED: |
| return |
| _FLASH_KERNELS_LOADED = True |
| FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash() |
|
|
|
|
| def _kernels_flash_forward( |
| query_states: torch.Tensor, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| causal: bool = False, |
| ) -> torch.Tensor: |
| assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." |
| if FLASH_KERNEL_VARIANT == "flash_attn2": |
| return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0] |
| if FLASH_KERNEL_VARIANT == "flash_attn3": |
| try: |
| output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal) |
| except TypeError: |
| output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal) |
| if isinstance(output, tuple): |
| return output[0] |
| return output |
| raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") |
|
|
|
|
| def _kernels_flash_varlen_forward( |
| query_states: torch.Tensor, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| cu_seqlens_q: torch.Tensor, |
| cu_seqlens_k: torch.Tensor, |
| max_seqlen_in_batch_q: int, |
| max_seqlen_in_batch_k: int, |
| causal: bool = False, |
| ) -> torch.Tensor: |
| assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." |
| if FLASH_KERNEL_VARIANT == "flash_attn2": |
| return FLASH_KERNEL.varlen_fwd( |
| q=query_states, k=key_states, v=value_states, |
| cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, |
| is_causal=causal, |
| )[0] |
| if FLASH_KERNEL_VARIANT == "flash_attn3": |
| try: |
| output = FLASH_KERNEL.flash_attn_varlen_func( |
| q=query_states, k=key_states, v=value_states, |
| cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, |
| causal=causal, |
| ) |
| except TypeError: |
| output = FLASH_KERNEL.flash_attn_varlen_func( |
| query_states, key_states, value_states, |
| cu_seqlens_q, cu_seqlens_k, |
| max_seqlen_in_batch_q, max_seqlen_in_batch_k, |
| 0.0, None, causal, |
| ) |
| if isinstance(output, tuple): |
| return output[0] |
| return output |
| raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") |
|
|
|
|
| |
| class IndexFirstAxis(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, input, indices) -> torch.Tensor: |
| ctx.save_for_backward(indices) |
| assert input.ndim >= 2 |
| ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] |
| second_dim = other_shape.numel() |
| return torch.gather( |
| rearrange(input, "b ... -> b (...)"), 0, indices.unsqueeze(1).expand(-1, second_dim) |
| ).reshape(-1, *other_shape) |
|
|
| @staticmethod |
| def backward(ctx, grad_output) -> tuple[torch.Tensor, None]: |
| (indices,) = ctx.saved_tensors |
| assert grad_output.ndim >= 2 |
| other_shape = grad_output.shape[1:] |
| grad_output = rearrange(grad_output, "b ... -> b (...)") |
| grad_input = torch.zeros( |
| [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype |
| ) |
| grad_input.scatter_(0, indices.unsqueeze(1).expand(-1, grad_output.shape[1]), grad_output) |
| return grad_input.reshape(ctx.first_axis_dim, *other_shape), None |
|
|
|
|
| class IndexPutFirstAxis(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor: |
| ctx.save_for_backward(indices) |
| assert indices.ndim == 1 |
| assert values.ndim >= 2 |
| output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) |
| output[indices] = values |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_output) -> tuple[torch.Tensor, None, None]: |
| (indices,) = ctx.saved_tensors |
| return grad_output[indices], None, None |
|
|
|
|
| index_first_axis = IndexFirstAxis.apply |
| index_put_first_axis = IndexPutFirstAxis.apply |
|
|
|
|
| def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor: |
| output = index_put_first_axis(hidden_states, indices, batch * seqlen) |
| return rearrange(output, "(b s) ... -> b s ...", b=batch) |
|
|
|
|
| def _unpad_input( |
| query_layer: torch.Tensor, |
| key_layer: torch.Tensor, |
| value_layer: torch.Tensor, |
| attention_mask_2d: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[int, int]]: |
| batch_size, seq_len, num_heads, head_dim = query_layer.shape |
| seqlens = attention_mask_2d.sum(dim=1).int() |
| cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0)) |
| max_seqlen = int(seqlens.max().item()) |
| indices = attention_mask_2d.flatten().nonzero(as_tuple=False).flatten() |
| query_layer = index_first_axis(query_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) |
| key_layer = index_first_axis(key_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) |
| value_layer = index_first_axis(value_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) |
| return query_layer, key_layer, value_layer, indices, (cu_seqlens, cu_seqlens), (max_seqlen, max_seqlen) |
|
|
|
|
| def kernels_flash_attention_func( |
| query_states: torch.Tensor, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| attention_mask_2d: torch.Tensor | None = None, |
| causal: bool = False, |
| ) -> torch.Tensor: |
| assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." |
| if not causal and attention_mask_2d is not None: |
| batch_size, q_len = query_states.shape[:2] |
| ( |
| query_states, key_states, value_states, |
| indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k), |
| ) = _unpad_input(query_states, key_states, value_states, attention_mask_2d) |
| attn_output_unpad = _kernels_flash_varlen_forward( |
| query_states=query_states, key_states=key_states, value_states=value_states, |
| cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k, |
| ) |
| return pad_input(attn_output_unpad, indices_q, batch_size, q_len) |
| else: |
| return _kernels_flash_forward( |
| query_states=query_states, key_states=key_states, value_states=value_states, causal=causal, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class AttentionBackend(Enum): |
| AUTO = "auto" |
| KERNELS_FLASH = "kernels_flash" |
| FLEX = "flex" |
| SDPA = "sdpa" |
|
|
|
|
| VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend) |
|
|
| _BACKEND_CONFIRMED = False |
|
|
|
|
| def resolve_attention_backend(requested_backend: str) -> AttentionBackend: |
| global _BACKEND_CONFIRMED |
| assert requested_backend in VALID_ATTENTION_BACKENDS, ( |
| f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}." |
| ) |
| if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value): |
| _ensure_flash_kernels_loaded() |
| if requested_backend == AttentionBackend.AUTO.value: |
| if FLASH_KERNEL is not None: |
| resolved = AttentionBackend.KERNELS_FLASH |
| elif flex_attention is not None: |
| resolved = AttentionBackend.FLEX |
| else: |
| resolved = AttentionBackend.SDPA |
| elif requested_backend == AttentionBackend.KERNELS_FLASH.value: |
| assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment." |
| resolved = AttentionBackend.KERNELS_FLASH |
| elif requested_backend == AttentionBackend.FLEX.value: |
| assert flex_attention is not None, "Flex Attention is not available in this environment." |
| resolved = AttentionBackend.FLEX |
| elif requested_backend == AttentionBackend.SDPA.value: |
| resolved = AttentionBackend.SDPA |
| else: |
| raise AssertionError(f"Unsupported attention backend: {requested_backend}") |
| if not _BACKEND_CONFIRMED: |
| print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'") |
| _BACKEND_CONFIRMED = True |
| return resolved |
|
|
|
|
| def get_attention_mask( |
| effective_backend: AttentionBackend, |
| batch_size: int, |
| seq_len: int, |
| device: torch.device, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> tuple[torch.Tensor | None, torch.Tensor | None, "BlockMask | None"]: |
| if attention_mask is None: |
| return None, None, None |
|
|
| attention_mask_2d = attention_mask.bool() |
|
|
| if effective_backend == AttentionBackend.KERNELS_FLASH: |
| return attention_mask_2d, None, None |
|
|
| if effective_backend == AttentionBackend.FLEX: |
| assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable." |
| valid_lens = attention_mask_2d.sum(dim=-1) |
|
|
| def mask_mod(batch_idx, head_idx, q_idx, kv_idx): |
| return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx]) |
|
|
| flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device) |
| return attention_mask_2d, None, flex_block_mask |
|
|
| |
| attention_mask_4d = attention_mask_2d[:, None, None, :] |
| return attention_mask_2d, attention_mask_4d, None |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class FastEsmEncoderOutput(ModelOutput): |
| last_hidden_state: Optional[torch.Tensor] = None |
| hidden_states: Optional[Tuple[torch.Tensor, ...]] = None |
| attentions: Optional[Tuple[torch.Tensor, ...]] = None |
|
|
|
|
| |
| |
| |
|
|
| class EsmSelfAttention(nn.Module): |
| def __init__(self, config, position_embedding_type: Optional[str] = None): |
| super().__init__() |
| assert config.hidden_size % config.num_attention_heads == 0, ( |
| f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " |
| f"heads ({config.num_attention_heads})" |
| ) |
| self.num_attention_heads = config.num_attention_heads |
| self.attention_head_size = int(config.hidden_size / config.num_attention_heads) |
| self.all_head_size = self.num_attention_heads * self.attention_head_size |
|
|
| self.query = nn.Linear(config.hidden_size, self.all_head_size) |
| self.key = nn.Linear(config.hidden_size, self.all_head_size) |
| self.value = nn.Linear(config.hidden_size, self.all_head_size) |
| self.scale = self.attention_head_size**-0.5 |
|
|
| self.dropout_prob = config.attention_probs_dropout_prob |
| self.config = config |
| self.attn_backend = resolve_attention_backend(config.attn_backend) |
| self.position_embedding_type = position_embedding_type or config.position_embedding_type |
| self.rotary_embeddings = None |
| if self.position_embedding_type == "rotary": |
| self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask_2d: torch.Tensor | None = None, |
| attention_mask_4d: torch.Tensor | None = None, |
| flex_block_mask: "BlockMask | None" = None, |
| output_attentions: bool = False, |
| ) -> tuple[torch.Tensor, torch.Tensor | None]: |
| batch_size, seq_length = hidden_states.shape[:-1] |
| hidden_shape = (batch_size, seq_length, -1, self.attention_head_size) |
| query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2) |
| key_BHLD = self.key(hidden_states).view(hidden_shape).transpose(1, 2) |
| value_BHLD = self.value(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
| query_BHLD = query_BHLD * self.scale |
|
|
| if self.position_embedding_type == "rotary": |
| query_BHLD, key_BHLD = self.rotary_embeddings(query_BHLD, key_BHLD) |
|
|
| attn_output, attn_weights = self._attn( |
| query_BHLD, key_BHLD, value_BHLD, |
| attention_mask_2d=attention_mask_2d, |
| attention_mask_4d=attention_mask_4d, |
| flex_block_mask=flex_block_mask, |
| output_attentions=output_attentions, |
| ) |
| return attn_output, attn_weights |
|
|
| def _attn( |
| self, |
| query_BHLD: torch.Tensor, |
| key_BHLD: torch.Tensor, |
| value_BHLD: torch.Tensor, |
| attention_mask_2d: torch.Tensor | None = None, |
| attention_mask_4d: torch.Tensor | None = None, |
| flex_block_mask: "BlockMask | None" = None, |
| output_attentions: bool = False, |
| ) -> tuple[torch.Tensor, torch.Tensor | None]: |
| if output_attentions: |
| return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d) |
|
|
| if self.attn_backend == AttentionBackend.KERNELS_FLASH: |
| return self._kernels_flash_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_2d) |
| elif self.attn_backend == AttentionBackend.FLEX: |
| return self._flex_attn(query_BHLD, key_BHLD, value_BHLD, flex_block_mask) |
| elif self.attn_backend == AttentionBackend.SDPA: |
| return self._sdpa_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d) |
| else: |
| raise AssertionError(f"Unsupported resolved backend: {self.attn_backend}") |
|
|
| def _manual_attn( |
| self, |
| query_BHLD: torch.Tensor, |
| key_BHLD: torch.Tensor, |
| value_BHLD: torch.Tensor, |
| attention_mask_4d: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2)) |
| if attention_mask_4d is not None: |
| attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf")) |
| attn_weights = F.softmax(attn_weights, dim=-1) |
| if self.dropout_prob > 0 and self.training: |
| attn_weights = F.dropout(attn_weights, p=self.dropout_prob, training=self.training) |
| context_BHLD = torch.matmul(attn_weights, value_BHLD) |
| attn_output = rearrange(context_BHLD, "b h s d -> b s (h d)") |
| return attn_output, attn_weights |
|
|
| def _kernels_flash_attn( |
| self, |
| query_BHLD: torch.Tensor, |
| key_BHLD: torch.Tensor, |
| value_BHLD: torch.Tensor, |
| attention_mask_2d: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, None]: |
| query_BLHD = query_BHLD.transpose(1, 2).contiguous() |
| key_BLHD = key_BHLD.transpose(1, 2).contiguous() |
| value_BLHD = value_BHLD.transpose(1, 2).contiguous() |
| attn_output = kernels_flash_attention_func( |
| query_states=query_BLHD, key_states=key_BLHD, value_states=value_BLHD, |
| attention_mask_2d=attention_mask_2d, causal=False, |
| ) |
| return rearrange(attn_output, "b s h d -> b s (h d)"), None |
|
|
| def _flex_attn( |
| self, |
| query_BHLD: torch.Tensor, |
| key_BHLD: torch.Tensor, |
| value_BHLD: torch.Tensor, |
| flex_block_mask: "BlockMask | None" = None, |
| ) -> tuple[torch.Tensor, None]: |
| assert flex_attention is not None, "Flex attention is not available in this environment." |
| assert query_BHLD.dtype in (torch.float16, torch.bfloat16), ( |
| f"Flex attention requires float16 or bfloat16, got {query_BHLD.dtype}." |
| ) |
| fn = _get_flex_attention_fn() |
| context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0) |
| return rearrange(context_BHLD, "b h s d -> b s (h d)"), None |
|
|
| def _sdpa_attn( |
| self, |
| query_BHLD: torch.Tensor, |
| key_BHLD: torch.Tensor, |
| value_BHLD: torch.Tensor, |
| attention_mask_4d: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, None]: |
| context_BHLD = F.scaled_dot_product_attention( |
| query_BHLD, key_BHLD, value_BHLD, |
| attn_mask=attention_mask_4d, |
| dropout_p=self.dropout_prob if self.training else 0.0, |
| scale=1.0, |
| ) |
| return rearrange(context_BHLD, "b h s d -> b s (h d)"), None |
|
|
|
|
| class EsmAttention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.self = EsmSelfAttention(config) |
| self.output = EsmSelfOutput(config) |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask_2d: torch.Tensor | None = None, |
| attention_mask_4d: torch.Tensor | None = None, |
| flex_block_mask: "BlockMask | None" = None, |
| output_attentions: bool = False, |
| ) -> tuple[torch.Tensor, torch.Tensor | None]: |
| hidden_states_ln = self.LayerNorm(hidden_states) |
| attn_output, attn_weights = self.self( |
| hidden_states_ln, |
| attention_mask_2d=attention_mask_2d, |
| attention_mask_4d=attention_mask_4d, |
| flex_block_mask=flex_block_mask, |
| output_attentions=output_attentions, |
| ) |
| attention_output = self.output(attn_output, hidden_states) |
| return attention_output, attn_weights |
|
|
|
|
| class EsmLayer(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.attention = EsmAttention(config) |
| self.intermediate = EsmIntermediate(config) |
| self.output = EsmOutput(config) |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask_2d: torch.Tensor | None = None, |
| attention_mask_4d: torch.Tensor | None = None, |
| flex_block_mask: "BlockMask | None" = None, |
| output_attentions: bool = False, |
| ) -> tuple[torch.Tensor, torch.Tensor | None]: |
| attention_output, attn_weights = self.attention( |
| hidden_states, |
| attention_mask_2d=attention_mask_2d, |
| attention_mask_4d=attention_mask_4d, |
| flex_block_mask=flex_block_mask, |
| output_attentions=output_attentions, |
| ) |
| layer_output = self._feed_forward(attention_output) |
| return layer_output, attn_weights |
|
|
| def _feed_forward(self, attention_output: torch.Tensor) -> torch.Tensor: |
| attention_output_ln = self.LayerNorm(attention_output) |
| intermediate_output = self.intermediate(attention_output_ln) |
| return self.output(intermediate_output, attention_output) |
|
|
|
|
| class FastEsmEncoder(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.attention_backend = resolve_attention_backend(config.attn_backend) |
| self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)]) |
| self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| output_hidden_states: bool = False, |
| output_attentions: bool = False, |
| ) -> FastEsmEncoderOutput: |
| all_hidden_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
|
|
| attention_mask_2d, attention_mask_4d, flex_block_mask = get_attention_mask( |
| effective_backend=self.attention_backend, |
| batch_size=hidden_states.shape[0], |
| seq_len=hidden_states.shape[1], |
| device=hidden_states.device, |
| attention_mask=attention_mask, |
| ) |
|
|
| for layer_module in self.layer: |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| hidden_states, attn_weights = layer_module( |
| hidden_states, |
| attention_mask_2d=attention_mask_2d, |
| attention_mask_4d=attention_mask_4d, |
| flex_block_mask=flex_block_mask, |
| output_attentions=output_attentions, |
| ) |
|
|
| if all_attentions is not None: |
| all_attentions = all_attentions + (attn_weights,) |
|
|
| if self.emb_layer_norm_after: |
| hidden_states = self.emb_layer_norm_after(hidden_states) |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| return FastEsmEncoderOutput( |
| last_hidden_state=hidden_states, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class FastEsmBackbone(nn.Module): |
| """FastESM2 backbone with multi-backend attention. Drop-in replacement for |
| transformers.EsmModel inside EsmForProteinFolding. |
| |
| State dict keys match HuggingFace EsmModel exactly, so pretrained weights |
| load without any key remapping. |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.embeddings = EsmEmbeddings(config) |
| self.encoder = FastEsmEncoder(config) |
| self.contact_head = EsmContactPredictionHead( |
| in_features=config.num_hidden_layers * config.num_attention_heads, bias=True |
| ) |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs, |
| ) -> FastEsmEncoderOutput: |
| output_attentions = output_attentions if output_attentions is not None else False |
| output_hidden_states = output_hidden_states if output_hidden_states is not None else False |
|
|
| token_embedding_output = self.embeddings( |
| input_ids=input_ids, |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| ) |
| encoder_outputs = self.encoder( |
| token_embedding_output, |
| attention_mask=attention_mask, |
| output_hidden_states=output_hidden_states, |
| output_attentions=output_attentions, |
| ) |
| return FastEsmEncoderOutput( |
| last_hidden_state=encoder_outputs.last_hidden_state, |
| hidden_states=encoder_outputs.hidden_states, |
| attentions=encoder_outputs.attentions, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| _ESM_STANDARD_AA = list("ACDEFGHIKLMNPQRSTVWY") |
|
|
|
|
| class LoraInjectedLinear(nn.Module): |
| """LoRA-augmented linear layer matching lora_diffusion's behavior. |
| |
| Replaces an existing nn.Linear with base(x) + lora_up(lora_down(x)) * scale. |
| Initialization follows cloneofsimo/lora: down=Normal(0, 1/r), up=zeros. |
| """ |
|
|
| def __init__(self, original_linear: nn.Linear, r: int = 4, scale: float = 1.0): |
| super().__init__() |
| self.linear = original_linear |
| in_features = original_linear.in_features |
| out_features = original_linear.out_features |
| assert r <= min(in_features, out_features), f"LoRA rank {r} exceeds dimensions ({in_features}, {out_features})" |
| self.lora_down = nn.Linear(in_features, r, bias=False) |
| self.lora_up = nn.Linear(r, out_features, bias=False) |
| self.scale = scale |
| nn.init.normal_(self.lora_down.weight, std=1.0 / r) |
| nn.init.zeros_(self.lora_up.weight) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.linear(x) + self.lora_up(self.lora_down(x)) * self.scale |
|
|
|
|
| def inject_trainable_lora( |
| model: nn.Module, |
| target_class_name: str, |
| r: int, |
| scale: float, |
| ) -> List[nn.Parameter]: |
| """Replace nn.Linear layers inside modules matching target_class_name with LoRA. |
| |
| Matches lora_diffusion's inject_trainable_lora behavior: finds all modules whose |
| class name matches target_class_name, then replaces their nn.Linear children with |
| LoraInjectedLinear. Returns the list of trainable LoRA parameters. |
| """ |
| lora_params: List[nn.Parameter] = [] |
| for _parent_name, parent_module in model.named_modules(): |
| if parent_module.__class__.__name__ != target_class_name: |
| continue |
| for child_name, child_module in list(parent_module.named_children()): |
| if not isinstance(child_module, nn.Linear): |
| continue |
| lora_linear = LoraInjectedLinear(child_module, r=r, scale=scale) |
| lora_linear = lora_linear.to( |
| device=child_module.weight.device, |
| dtype=child_module.weight.dtype, |
| ) |
| setattr(parent_module, child_name, lora_linear) |
| lora_params.extend(lora_linear.lora_down.parameters()) |
| lora_params.extend(lora_linear.lora_up.parameters()) |
| return lora_params |
|
|
|
|
| @dataclass |
| class TTTConfig: |
| lr: float = 4e-4 |
| ags: int = 4 |
| steps: int = 30 |
| batch_size: int = 4 |
| mask_ratio: float = 0.15 |
| crop_size: int = 1024 |
| bert_leave_prob: float = 0.1 |
| bert_replace_prob: float = 0.1 |
| optimizer: str = "sgd" |
| momentum: float = 0.0 |
| weight_decay: float = 0.0 |
| seed: Optional[int] = 0 |
| initial_state_reset: bool = True |
| freeze_embeddings: bool = True |
| lora_rank: int = 8 |
| lora_alpha: float = 32.0 |
| lora_target_class: str = "EsmSelfAttention" |
|
|
| def verify(self) -> None: |
| assert self.lr > 0.0, "TTT learning rate must be positive." |
| assert self.ags > 0, "TTT ags must be positive." |
| assert self.steps >= 0, "TTT steps must be non-negative." |
| assert self.batch_size > 0, "TTT batch_size must be positive." |
| assert 0.0 < self.mask_ratio <= 1.0, "TTT mask_ratio must be in (0, 1]." |
| assert self.crop_size > 0, "TTT crop_size must be positive." |
| assert 0.0 <= self.bert_leave_prob <= 1.0 |
| assert 0.0 <= self.bert_replace_prob <= 1.0 |
| assert self.bert_leave_prob + self.bert_replace_prob <= 1.0 |
| assert self.optimizer in {"sgd", "adamw"} |
| assert self.lora_rank >= 0 |
| assert self.lora_alpha > 0.0 |
|
|
|
|
| def preserve_model_state(func: Callable[..., Any]) -> Callable[..., Any]: |
| @wraps(func) |
| def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: |
| was_training = self.training |
| original_device = next(self.parameters()).device |
| original_requires_grad = { |
| name: parameter.requires_grad |
| for name, parameter in self.named_parameters() |
| } |
| try: |
| return func(self, *args, **kwargs) |
| finally: |
| self.train(was_training) |
| self.to(original_device) |
| for name, parameter in self.named_parameters(): |
| if name in original_requires_grad: |
| parameter.requires_grad = original_requires_grad[name] |
| else: |
| parameter.requires_grad = False |
| return wrapper |
|
|
|
|
| |
| |
| |
|
|
| class FastEsmFoldConfig(EsmConfig): |
| model_type = "fast_esmfold" |
|
|
| def __init__(self, attn_backend: str = "sdpa", ttt_config: Optional[Dict[str, Any]] = None, **kwargs): |
| super().__init__(**kwargs) |
| self.attn_backend = attn_backend |
| self.ttt_config = ttt_config or { |
| "lr": 4e-4, |
| "steps": 30, |
| "lora_rank": 8, |
| "lora_alpha": 32.0, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class FastEsmForProteinFolding(EsmForProteinFolding): |
| """ESMFold with FastESM2 attention backends + built-in Test-Time Training. |
| |
| Inherits all folding logic (trunk, structure module, output_to_pdb, infer) |
| from transformers.EsmForProteinFolding. Replaces the ESM2 backbone with |
| FastESM2 for optimized attention and adds TTT for improved structure prediction. |
| |
| Key API: |
| result = model.fold_protein("MKTL...", ttt=True) |
| # result = {"plddt": float, "ptm": float, "pdb_string": str} |
| """ |
| config_class = FastEsmFoldConfig |
|
|
| def __init__(self, config: FastEsmFoldConfig): |
| super().__init__(config) |
|
|
| |
| |
| if not config.ttt_config.get("use_standard_backbone", False): |
| self.esm = FastEsmBackbone(config) |
| self.esm.requires_grad_(False) |
| if config.esmfold_config.fp16_esm: |
| self.esm.half() |
|
|
| |
| self.mlm_head = EsmLMHead(config) |
|
|
| |
| ttt_kwargs = {k: v for k, v in config.ttt_config.items() if k != "use_standard_backbone"} |
| self._ttt_cfg = TTTConfig(**ttt_kwargs) |
| self._ttt_cfg.verify() |
| self._ttt_initialized = False |
| self._ttt_initial_state = None |
| self._ttt_generator = torch.Generator() |
| if self._ttt_cfg.seed is not None: |
| self._ttt_generator.manual_seed(self._ttt_cfg.seed) |
| self._non_special_tokens_cache = None |
| self._ttt_tokenizer = None |
|
|
| def _get_ttt_tokenizer(self) -> EsmTokenizer: |
| if self._ttt_tokenizer is None: |
| self._ttt_tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") |
| return self._ttt_tokenizer |
|
|
| def _ensure_ttt_ready(self) -> None: |
| """Lazy TTT initialization. Injects LoRA adapters and saves initial state. |
| Must be called after weights are loaded (not in __init__).""" |
| if self._ttt_initialized: |
| return |
| self._ttt_initialized = True |
|
|
| tokenizer = self._get_ttt_tokenizer() |
| vocab = tokenizer.get_vocab() |
| self._non_special_tokens_cache = [vocab[c] for c in _ESM_STANDARD_AA if c in vocab] |
|
|
| if self._ttt_cfg.lora_rank > 0: |
| self.mlm_head.eval() |
| for p in self.mlm_head.parameters(): |
| p.requires_grad = False |
| |
| if self._ttt_cfg.seed is not None: |
| torch.manual_seed(self._ttt_cfg.seed) |
| self._inject_lora() |
| else: |
| |
| H = self.config.hidden_size |
| V = self.config.vocab_size |
| device = next(self.esm.parameters()).device |
| self._ttt_lm_proj = nn.Linear(H, V, bias=True).to(device) |
|
|
| if self._ttt_cfg.initial_state_reset: |
| self._ttt_initial_state = self._ttt_get_state() |
|
|
| @property |
| def _uses_lora(self) -> bool: |
| return self._ttt_cfg.lora_rank > 0 |
|
|
| def _inject_lora(self) -> None: |
| """Inject LoRA adapters into ESM2 attention layers (matching lora_diffusion behavior).""" |
| self._lora_params = inject_trainable_lora( |
| self.esm, |
| target_class_name=self._ttt_cfg.lora_target_class, |
| r=self._ttt_cfg.lora_rank, |
| scale=self._ttt_cfg.lora_alpha, |
| ) |
| assert len(self._lora_params) > 0, ( |
| f"No LoRA params injected. Check target_class_name='{self._ttt_cfg.lora_target_class}' " |
| f"matches attention modules in the backbone." |
| ) |
|
|
| |
|
|
| def _get_lora_modules(self) -> List[LoraInjectedLinear]: |
| """Find all LoraInjectedLinear modules in the backbone.""" |
| return [m for m in self.esm.modules() if isinstance(m, LoraInjectedLinear)] |
|
|
| def _ttt_get_state(self) -> Dict[str, Any]: |
| if self._uses_lora: |
| lora_state = [] |
| for m in self._get_lora_modules(): |
| lora_state.append({ |
| "down": m.lora_down.weight.data.clone(), |
| "up": m.lora_up.weight.data.clone(), |
| }) |
| return {"_lora_state": lora_state} |
| return { |
| "esm": copy.deepcopy(self.esm), |
| "_ttt_lm_proj": copy.deepcopy(self._ttt_lm_proj), |
| } |
|
|
| def _ttt_set_state(self, state: Dict[str, Any]) -> None: |
| if "_lora_state" in state: |
| modules = self._get_lora_modules() |
| assert len(modules) == len(state["_lora_state"]) |
| for m, saved in zip(modules, state["_lora_state"]): |
| m.lora_down.weight.data.copy_(saved["down"]) |
| m.lora_up.weight.data.copy_(saved["up"]) |
| return |
| if "esm" in state: |
| self.esm = copy.deepcopy(state["esm"]) |
| if "_ttt_lm_proj" in state: |
| self._ttt_lm_proj = copy.deepcopy(state["_ttt_lm_proj"]) |
|
|
| def ttt_reset(self) -> None: |
| """Reset model to pre-TTT state (restore initial LoRA or backbone weights).""" |
| assert self._ttt_initial_state is not None, "TTT reset requires initial_state_reset=True." |
| self._ttt_set_state(self._ttt_initial_state) |
|
|
| |
|
|
| def _ttt_tokenize(self, seq: str) -> torch.Tensor: |
| tokenizer = self._get_ttt_tokenizer() |
| out = tokenizer( |
| seq, |
| return_tensors="pt", |
| add_special_tokens=self._uses_lora, |
| padding=False, |
| truncation=False, |
| ) |
| return out["input_ids"] |
|
|
| def _ttt_mask_token(self) -> int: |
| return self._get_ttt_tokenizer().mask_token_id |
|
|
| def _ttt_get_non_special_tokens(self) -> List[int]: |
| if self._non_special_tokens_cache is not None: |
| return self._non_special_tokens_cache |
| tokenizer = self._get_ttt_tokenizer() |
| vocab = tokenizer.get_vocab() |
| self._non_special_tokens_cache = [vocab[c] for c in _ESM_STANDARD_AA if c in vocab] |
| return self._non_special_tokens_cache |
|
|
| def _ttt_predict_logits(self, batch: torch.Tensor) -> torch.Tensor: |
| """Run ESM2 backbone + LM head to get MLM logits.""" |
| |
| output = self.esm(input_ids=batch) |
| hidden = output.last_hidden_state |
| if self._uses_lora: |
| return self.mlm_head(hidden) |
| return self._ttt_lm_proj(hidden) |
|
|
| def _ttt_sample_batch( |
| self, |
| x: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| _, seq_len = x.shape |
| batch_size = self._ttt_cfg.batch_size |
| crop_size = min(self._ttt_cfg.crop_size, seq_len) |
|
|
| x_expanded = x.expand(batch_size, -1) |
| if seq_len == crop_size: |
| start_indices = torch.zeros(batch_size, dtype=torch.long) |
| else: |
| start_indices = torch.randint( |
| 0, seq_len - crop_size + 1, (batch_size,), |
| generator=self._ttt_generator, |
| ).to(torch.long) |
|
|
| batch_cropped = torch.stack([ |
| x_expanded[index, start : start + crop_size] |
| for index, start in enumerate(start_indices) |
| ]) |
|
|
| non_special_tokens = set(self._ttt_get_non_special_tokens()) |
| mask = torch.zeros((batch_size, crop_size), dtype=torch.bool) |
| mask_token_id = self._ttt_mask_token() |
|
|
| for row_index in range(batch_size): |
| non_special_positions = [ |
| col for col in range(crop_size) |
| if batch_cropped[row_index, col].item() in non_special_tokens |
| ] |
| assert len(non_special_positions) > 0, "Sequence must contain at least one non-special token." |
| num_to_mask = max(1, int(round(len(non_special_positions) * self._ttt_cfg.mask_ratio))) |
| sampled_indices = torch.randperm( |
| len(non_special_positions), generator=self._ttt_generator, |
| )[:num_to_mask] |
| positions_to_mask = torch.tensor(non_special_positions, dtype=torch.long)[sampled_indices] |
| mask[row_index, positions_to_mask] = True |
|
|
| batch_masked = batch_cropped.clone() |
| for row_index in range(batch_size): |
| masked_positions = torch.nonzero(mask[row_index], as_tuple=True)[0] |
| for masked_position in masked_positions: |
| probability = float(torch.rand(1, generator=self._ttt_generator).item()) |
| if probability < 1.0 - self._ttt_cfg.bert_leave_prob - self._ttt_cfg.bert_replace_prob: |
| batch_masked[row_index, masked_position] = mask_token_id |
| continue |
| if probability < 1.0 - self._ttt_cfg.bert_leave_prob: |
| replacement_candidates = self._ttt_get_non_special_tokens() |
| replacement_index = int(torch.randint( |
| 0, len(replacement_candidates), (1,), generator=self._ttt_generator, |
| ).item()) |
| batch_masked[row_index, masked_position] = replacement_candidates[replacement_index] |
|
|
| return batch_masked, batch_cropped, mask, start_indices |
|
|
| def _ttt_cross_entropy_loss( |
| self, |
| logits: torch.Tensor, |
| targets: torch.Tensor, |
| mask: torch.Tensor, |
| ) -> torch.Tensor: |
| assert logits.ndim == 3, "Logits must be [batch, seq, vocab]." |
| _, _, vocab_size = logits.shape |
| logits_flat = logits.reshape(-1, vocab_size) |
| targets_flat = targets.reshape(-1) |
| mask_flat = mask.reshape(-1) |
| assert int(mask_flat.sum().item()) > 0, "TTT mask must select at least one token." |
| loss = F.cross_entropy( |
| logits_flat[mask_flat], |
| targets_flat[mask_flat], |
| reduction="none", |
| ) |
| masked_tokens_per_seq = mask.sum(dim=1).tolist() |
| per_sequence_losses = torch.split(loss, masked_tokens_per_seq) |
| return torch.stack([sl.mean() for sl in per_sequence_losses]).mean() |
|
|
| def _ttt_get_optimizer(self, parameters) -> torch.optim.Optimizer: |
| if self._ttt_cfg.optimizer == "sgd": |
| return torch.optim.SGD( |
| parameters, |
| lr=self._ttt_cfg.lr, |
| momentum=self._ttt_cfg.momentum, |
| weight_decay=self._ttt_cfg.weight_decay, |
| ) |
| return torch.optim.AdamW( |
| parameters, |
| lr=self._ttt_cfg.lr, |
| weight_decay=self._ttt_cfg.weight_decay, |
| ) |
|
|
| def _lora_ttt(self, seq: str) -> Dict[str, List[float]]: |
| """LoRA TTT: only LoRA adapter weights are trained, mlm_head is frozen.""" |
| x = self._ttt_tokenize(seq) |
| device = next(self.parameters()).device |
| non_blocking = device.type == "cuda" |
| losses = [] |
|
|
| if self._ttt_cfg.steps == 0: |
| return {"losses": losses} |
|
|
| for parameter in self.parameters(): |
| parameter.requires_grad = False |
| for p in self._lora_params: |
| p.requires_grad = True |
| optimizer = self._ttt_get_optimizer(self._lora_params) |
| optimizer.zero_grad(set_to_none=True) |
|
|
| self.eval() |
| for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags): |
| batch_masked, targets, mask, start_indices = self._ttt_sample_batch(x) |
| batch_masked = batch_masked.to(device, non_blocking=non_blocking) |
| targets = targets.to(device, non_blocking=non_blocking) |
| mask = mask.to(device, non_blocking=non_blocking) |
|
|
| self.train() |
| logits = self._ttt_predict_logits(batch_masked) |
| loss = self._ttt_cross_entropy_loss(logits, targets, mask) |
| loss.backward() |
| losses.append(float(loss.detach().cpu().item())) |
|
|
| if (step + 1) % self._ttt_cfg.ags == 0: |
| optimizer.step() |
| optimizer.zero_grad(set_to_none=True) |
|
|
| self.eval() |
| return {"losses": losses} |
|
|
| def _legacy_ttt(self, seq: str) -> Dict[str, List[float]]: |
| """Legacy TTT: full fine-tuning of ESM2 backbone with random linear projection head.""" |
| x = self._ttt_tokenize(seq) |
| device = next(self.parameters()).device |
| non_blocking = device.type == "cuda" |
| losses = [] |
|
|
| if self._ttt_cfg.steps == 0: |
| return {"losses": losses} |
|
|
| |
| for parameter in self.parameters(): |
| parameter.requires_grad = False |
| for parameter in self.esm.parameters(): |
| parameter.requires_grad = True |
| if self._ttt_cfg.freeze_embeddings: |
| for parameter in self.esm.embeddings.parameters(): |
| parameter.requires_grad = False |
| for parameter in self._ttt_lm_proj.parameters(): |
| parameter.requires_grad = True |
|
|
| trainable_params = filter(lambda p: p.requires_grad, self.parameters()) |
| optimizer = self._ttt_get_optimizer(trainable_params) |
| optimizer.zero_grad(set_to_none=True) |
|
|
| self.eval() |
| for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags): |
| batch_masked, targets, mask, start_indices = self._ttt_sample_batch(x) |
| batch_masked = batch_masked.to(device, non_blocking=non_blocking) |
| targets = targets.to(device, non_blocking=non_blocking) |
| mask = mask.to(device, non_blocking=non_blocking) |
|
|
| self.train() |
| logits = self._ttt_predict_logits(batch_masked) |
| loss = self._ttt_cross_entropy_loss(logits, targets, mask) |
| loss.backward() |
| losses.append(float(loss.detach().cpu().item())) |
|
|
| if (step + 1) % self._ttt_cfg.ags == 0: |
| optimizer.step() |
| optimizer.zero_grad(set_to_none=True) |
|
|
| self.eval() |
| return {"losses": losses} |
|
|
| @preserve_model_state |
| def ttt(self, seq: str) -> Dict[str, List[float]]: |
| """Run test-time training on a single sequence using masked language modeling. |
| |
| Adapts the ESM2 backbone (via LoRA or full fine-tuning) to the input sequence |
| before structure prediction. Call fold_protein(seq, ttt=True) for the full pipeline. |
| |
| Args: |
| seq: Protein sequence (single-letter amino acid codes) |
| |
| Returns: |
| Dict with "losses" key containing per-step MLM loss values |
| """ |
| self._ensure_ttt_ready() |
| |
| |
| esm_dtype = next(self.esm.parameters()).dtype |
| if esm_dtype != torch.float32: |
| self.esm.float() |
| self.mlm_head.float() |
| if self._uses_lora: |
| result = self._lora_ttt(seq) |
| else: |
| result = self._legacy_ttt(seq) |
| |
| if esm_dtype != torch.float32: |
| self.esm.to(esm_dtype) |
| self.mlm_head.to(esm_dtype) |
| return result |
|
|
| |
|
|
| def _fold_single(self, sequence: str, return_pdb_string: bool = True) -> Dict[str, Any]: |
| """Fold a sequence once and return pLDDT, ptm, and optionally PDB string.""" |
| with torch.no_grad(): |
| output = self.infer(sequence) |
| plddt = output["plddt"] |
| if plddt.dim() >= 2: |
| mean_plddt = float(plddt.mean(dim=-1).mean().item()) |
| else: |
| mean_plddt = float(plddt.mean().item()) |
| result = { |
| "plddt": mean_plddt, |
| "ptm": float(output["ptm"].item()) if "ptm" in output else None, |
| } |
| if return_pdb_string: |
| pdb_strings = self.output_to_pdb(output) |
| result["pdb_string"] = pdb_strings[0] if isinstance(pdb_strings, list) else pdb_strings |
| return result |
|
|
| def fold_protein( |
| self, |
| sequence: str, |
| return_pdb_string: bool = True, |
| ) -> Dict[str, Any]: |
| """Fold a protein sequence with test-time training. |
| |
| Runs TTT (masked language model adaptation via LoRA) for the configured |
| number of steps, folding after each optimizer step to track pLDDT. Returns |
| the structure with the highest pLDDT across all steps (including baseline). |
| |
| Args: |
| sequence: Protein sequence (single-letter amino acid codes) |
| return_pdb_string: If True, include PDB string in output |
| |
| Returns: |
| Dict with keys: |
| - plddt: float, best mean pLDDT across all TTT steps |
| - ptm: float, predicted TM-score from best step |
| - pdb_string: str (if return_pdb_string=True), PDB from best step |
| - step_plddts: list[float], pLDDT at each step [baseline, s1, ..., s10] |
| - best_step: int, which step produced best structure (0=baseline) |
| """ |
| self._ensure_ttt_ready() |
|
|
| |
| esm_dtype = next(self.esm.parameters()).dtype |
| if esm_dtype != torch.float32: |
| self.esm.float() |
| self.mlm_head.float() |
|
|
| device = next(self.parameters()).device |
| non_blocking = device.type == "cuda" |
|
|
| |
| best = self._fold_single(sequence, return_pdb_string=return_pdb_string) |
| step_plddts = [best["plddt"]] |
|
|
| if self._ttt_cfg.steps > 0: |
| |
| x = self._ttt_tokenize(sequence) |
|
|
| |
| for p in self.parameters(): |
| p.requires_grad = False |
| if self._uses_lora: |
| for p in self._lora_params: |
| p.requires_grad = True |
| optimizer = self._ttt_get_optimizer(self._lora_params) |
| else: |
| for p in self.esm.parameters(): |
| p.requires_grad = True |
| if self._ttt_cfg.freeze_embeddings: |
| for p in self.esm.embeddings.parameters(): |
| p.requires_grad = False |
| for p in self._ttt_lm_proj.parameters(): |
| p.requires_grad = True |
| trainable = [p for p in self.parameters() if p.requires_grad] |
| optimizer = self._ttt_get_optimizer(trainable) |
| optimizer.zero_grad(set_to_none=True) |
|
|
| self.eval() |
| for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags): |
| batch_masked, targets, mask, _start = self._ttt_sample_batch(x) |
| batch_masked = batch_masked.to(device, non_blocking=non_blocking) |
| targets = targets.to(device, non_blocking=non_blocking) |
| mask = mask.to(device, non_blocking=non_blocking) |
|
|
| self.train() |
| logits = self._ttt_predict_logits(batch_masked) |
| loss = self._ttt_cross_entropy_loss(logits, targets, mask) |
| loss.backward() |
|
|
| if (step + 1) % self._ttt_cfg.ags == 0: |
| optimizer.step() |
| optimizer.zero_grad(set_to_none=True) |
|
|
| |
| self.eval() |
| current = self._fold_single(sequence, return_pdb_string=return_pdb_string) |
| step_plddts.append(current["plddt"]) |
| if current["plddt"] > best["plddt"]: |
| best = current |
|
|
| self.eval() |
|
|
| |
| for p in self.parameters(): |
| p.requires_grad = False |
|
|
| |
| self.ttt_reset() |
|
|
| |
| if esm_dtype != torch.float32: |
| self.esm.to(esm_dtype) |
| self.mlm_head.to(esm_dtype) |
|
|
| return { |
| "plddt": best["plddt"], |
| "ptm": best["ptm"], |
| "pdb_string": best.get("pdb_string"), |
| "step_plddts": step_plddts, |
| "best_step": step_plddts.index(max(step_plddts)), |
| } |
|
|