lhallee commited on
Commit
ea6dd9b
·
verified ·
1 Parent(s): b84f03a

Upload modeling_fast_esmfold.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fast_esmfold.py +1135 -0
modeling_fast_esmfold.py ADDED
@@ -0,0 +1,1135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastESMFold: Self-contained ESMFold with FastESM2 attention backends + built-in Test-Time Training.
2
+
3
+ Usage:
4
+ from transformers import AutoModel
5
+ model = AutoModel.from_pretrained("Synthyra/FastESMFold", trust_remote_code=True).cuda()
6
+
7
+ # Basic folding
8
+ result = model.fold_protein("MKTLLILAVVA...")
9
+ print(result["plddt"], result["pdb_string"][:100])
10
+
11
+ # Folding with TTT (test-time training improves structure prediction)
12
+ result = model.fold_protein("MKTLLILAVVA...", ttt=True)
13
+
14
+ Dependencies: torch, transformers, einops, peft (for LoRA TTT only)
15
+ No dependency on: esm (fair-esm), proteinttt, openfold
16
+ """
17
+ import copy
18
+ from dataclasses import dataclass, field
19
+ from enum import Enum
20
+ from functools import wraps
21
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ from torch.nn import functional as F
26
+
27
+ from einops import rearrange
28
+ from transformers import EsmTokenizer, PretrainedConfig, PreTrainedModel
29
+ from transformers.modeling_outputs import ModelOutput
30
+ from transformers.models.esm.configuration_esm import EsmConfig
31
+ from transformers.models.esm.modeling_esm import (
32
+ EsmContactPredictionHead,
33
+ EsmEmbeddings,
34
+ EsmIntermediate,
35
+ EsmLMHead,
36
+ EsmOutput,
37
+ EsmSelfOutput,
38
+ RotaryEmbedding,
39
+ )
40
+ from transformers.models.esm.modeling_esmfold import EsmForProteinFolding
41
+
42
+
43
+ # =============================================================================
44
+ # Flash Attention Detection (from FastPLMs/esm2/modeling_fastesm.py)
45
+ # =============================================================================
46
+
47
+ try:
48
+ from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask
49
+ except ImportError:
50
+ create_block_mask = None
51
+ flex_attention = None
52
+ BlockMask = None
53
+
54
+ _compiled_flex_attention = None
55
+
56
+
57
+ def _get_flex_attention_fn():
58
+ global _compiled_flex_attention
59
+ if flex_attention is None:
60
+ return None
61
+ flex_mod = torch.nn.attention.flex_attention
62
+ if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
63
+ return flex_attention
64
+ if _compiled_flex_attention is None:
65
+ _compiled_flex_attention = torch.compile(flex_attention)
66
+ return _compiled_flex_attention
67
+
68
+
69
+ def _infer_kernels_flash_variant(kernel) -> str | None:
70
+ if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
71
+ return "flash_attn2"
72
+ if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
73
+ return "flash_attn3"
74
+ return None
75
+
76
+
77
+ def _try_get_kernels_flash():
78
+ try:
79
+ from kernels import get_kernel
80
+ except ImportError:
81
+ return None, None
82
+
83
+ flash_kernel = None
84
+ flash_kernel_variant = None
85
+ try:
86
+ flash_kernel = get_kernel("kernels-community/flash-attn3")
87
+ flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel)
88
+ assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API."
89
+ except Exception:
90
+ try:
91
+ flash_kernel = get_kernel("kernels-community/flash-attn2")
92
+ flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel)
93
+ assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API."
94
+ except Exception:
95
+ flash_kernel = None
96
+ flash_kernel_variant = None
97
+ return flash_kernel, flash_kernel_variant
98
+
99
+
100
+ _FLASH_KERNELS_LOADED = False
101
+ FLASH_KERNEL = None
102
+ FLASH_KERNEL_VARIANT = None
103
+
104
+
105
+ def _ensure_flash_kernels_loaded():
106
+ global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT
107
+ if _FLASH_KERNELS_LOADED:
108
+ return
109
+ _FLASH_KERNELS_LOADED = True
110
+ FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
111
+
112
+
113
+ def _kernels_flash_forward(
114
+ query_states: torch.Tensor,
115
+ key_states: torch.Tensor,
116
+ value_states: torch.Tensor,
117
+ causal: bool = False,
118
+ ) -> torch.Tensor:
119
+ assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
120
+ if FLASH_KERNEL_VARIANT == "flash_attn2":
121
+ return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0]
122
+ if FLASH_KERNEL_VARIANT == "flash_attn3":
123
+ try:
124
+ output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal)
125
+ except TypeError:
126
+ output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal)
127
+ if isinstance(output, tuple):
128
+ return output[0]
129
+ return output
130
+ raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}")
131
+
132
+
133
+ def _kernels_flash_varlen_forward(
134
+ query_states: torch.Tensor,
135
+ key_states: torch.Tensor,
136
+ value_states: torch.Tensor,
137
+ cu_seqlens_q: torch.Tensor,
138
+ cu_seqlens_k: torch.Tensor,
139
+ max_seqlen_in_batch_q: int,
140
+ max_seqlen_in_batch_k: int,
141
+ causal: bool = False,
142
+ ) -> torch.Tensor:
143
+ assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
144
+ if FLASH_KERNEL_VARIANT == "flash_attn2":
145
+ return FLASH_KERNEL.varlen_fwd(
146
+ q=query_states, k=key_states, v=value_states,
147
+ cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
148
+ max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
149
+ is_causal=causal,
150
+ )[0]
151
+ if FLASH_KERNEL_VARIANT == "flash_attn3":
152
+ try:
153
+ output = FLASH_KERNEL.flash_attn_varlen_func(
154
+ q=query_states, k=key_states, v=value_states,
155
+ cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
156
+ max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
157
+ causal=causal,
158
+ )
159
+ except TypeError:
160
+ output = FLASH_KERNEL.flash_attn_varlen_func(
161
+ query_states, key_states, value_states,
162
+ cu_seqlens_q, cu_seqlens_k,
163
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k,
164
+ 0.0, None, causal,
165
+ )
166
+ if isinstance(output, tuple):
167
+ return output[0]
168
+ return output
169
+ raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}")
170
+
171
+
172
+ # Unpad / Pad helpers for varlen flash attention
173
+ class IndexFirstAxis(torch.autograd.Function):
174
+ @staticmethod
175
+ def forward(ctx, input, indices) -> torch.Tensor:
176
+ ctx.save_for_backward(indices)
177
+ assert input.ndim >= 2
178
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
179
+ second_dim = other_shape.numel()
180
+ return torch.gather(
181
+ rearrange(input, "b ... -> b (...)"), 0, indices.unsqueeze(1).expand(-1, second_dim)
182
+ ).reshape(-1, *other_shape)
183
+
184
+ @staticmethod
185
+ def backward(ctx, grad_output) -> tuple[torch.Tensor, None]:
186
+ (indices,) = ctx.saved_tensors
187
+ assert grad_output.ndim >= 2
188
+ other_shape = grad_output.shape[1:]
189
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
190
+ grad_input = torch.zeros(
191
+ [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype
192
+ )
193
+ grad_input.scatter_(0, indices.unsqueeze(1).expand(-1, grad_output.shape[1]), grad_output)
194
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
195
+
196
+
197
+ class IndexPutFirstAxis(torch.autograd.Function):
198
+ @staticmethod
199
+ def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor:
200
+ ctx.save_for_backward(indices)
201
+ assert indices.ndim == 1
202
+ assert values.ndim >= 2
203
+ output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
204
+ output[indices] = values
205
+ return output
206
+
207
+ @staticmethod
208
+ def backward(ctx, grad_output) -> tuple[torch.Tensor, None, None]:
209
+ (indices,) = ctx.saved_tensors
210
+ return grad_output[indices], None, None
211
+
212
+
213
+ index_first_axis = IndexFirstAxis.apply
214
+ index_put_first_axis = IndexPutFirstAxis.apply
215
+
216
+
217
+ def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor:
218
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
219
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)
220
+
221
+
222
+ def _unpad_input(
223
+ query_layer: torch.Tensor,
224
+ key_layer: torch.Tensor,
225
+ value_layer: torch.Tensor,
226
+ attention_mask_2d: torch.Tensor,
227
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[int, int]]:
228
+ batch_size, seq_len, num_heads, head_dim = query_layer.shape
229
+ seqlens = attention_mask_2d.sum(dim=1).int()
230
+ cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
231
+ max_seqlen = int(seqlens.max().item())
232
+ indices = attention_mask_2d.flatten().nonzero(as_tuple=False).flatten()
233
+ query_layer = index_first_axis(query_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
234
+ key_layer = index_first_axis(key_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
235
+ value_layer = index_first_axis(value_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
236
+ return query_layer, key_layer, value_layer, indices, (cu_seqlens, cu_seqlens), (max_seqlen, max_seqlen)
237
+
238
+
239
+ def kernels_flash_attention_func(
240
+ query_states: torch.Tensor,
241
+ key_states: torch.Tensor,
242
+ value_states: torch.Tensor,
243
+ attention_mask_2d: torch.Tensor | None = None,
244
+ causal: bool = False,
245
+ ) -> torch.Tensor:
246
+ assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
247
+ if not causal and attention_mask_2d is not None:
248
+ batch_size, q_len = query_states.shape[:2]
249
+ (
250
+ query_states, key_states, value_states,
251
+ indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k),
252
+ ) = _unpad_input(query_states, key_states, value_states, attention_mask_2d)
253
+ attn_output_unpad = _kernels_flash_varlen_forward(
254
+ query_states=query_states, key_states=key_states, value_states=value_states,
255
+ cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
256
+ max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k,
257
+ )
258
+ return pad_input(attn_output_unpad, indices_q, batch_size, q_len)
259
+ else:
260
+ return _kernels_flash_forward(
261
+ query_states=query_states, key_states=key_states, value_states=value_states, causal=causal,
262
+ )
263
+
264
+
265
+ # =============================================================================
266
+ # Attention Backend Enum & Resolution
267
+ # =============================================================================
268
+
269
+ class AttentionBackend(Enum):
270
+ AUTO = "auto"
271
+ KERNELS_FLASH = "kernels_flash"
272
+ FLEX = "flex"
273
+ SDPA = "sdpa"
274
+
275
+
276
+ VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend)
277
+
278
+ _BACKEND_CONFIRMED = False
279
+
280
+
281
+ def resolve_attention_backend(requested_backend: str) -> AttentionBackend:
282
+ global _BACKEND_CONFIRMED
283
+ assert requested_backend in VALID_ATTENTION_BACKENDS, (
284
+ f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
285
+ )
286
+ if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value):
287
+ _ensure_flash_kernels_loaded()
288
+ if requested_backend == AttentionBackend.AUTO.value:
289
+ if FLASH_KERNEL is not None:
290
+ resolved = AttentionBackend.KERNELS_FLASH
291
+ elif flex_attention is not None:
292
+ resolved = AttentionBackend.FLEX
293
+ else:
294
+ resolved = AttentionBackend.SDPA
295
+ elif requested_backend == AttentionBackend.KERNELS_FLASH.value:
296
+ assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment."
297
+ resolved = AttentionBackend.KERNELS_FLASH
298
+ elif requested_backend == AttentionBackend.FLEX.value:
299
+ assert flex_attention is not None, "Flex Attention is not available in this environment."
300
+ resolved = AttentionBackend.FLEX
301
+ elif requested_backend == AttentionBackend.SDPA.value:
302
+ resolved = AttentionBackend.SDPA
303
+ else:
304
+ raise AssertionError(f"Unsupported attention backend: {requested_backend}")
305
+ if not _BACKEND_CONFIRMED:
306
+ print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'")
307
+ _BACKEND_CONFIRMED = True
308
+ return resolved
309
+
310
+
311
+ def get_attention_mask(
312
+ effective_backend: AttentionBackend,
313
+ batch_size: int,
314
+ seq_len: int,
315
+ device: torch.device,
316
+ attention_mask: Optional[torch.Tensor] = None,
317
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None, "BlockMask | None"]:
318
+ if attention_mask is None:
319
+ return None, None, None
320
+
321
+ attention_mask_2d = attention_mask.bool()
322
+
323
+ if effective_backend == AttentionBackend.KERNELS_FLASH:
324
+ return attention_mask_2d, None, None
325
+
326
+ if effective_backend == AttentionBackend.FLEX:
327
+ assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
328
+ valid_lens = attention_mask_2d.sum(dim=-1)
329
+
330
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
331
+ return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx])
332
+
333
+ flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device)
334
+ return attention_mask_2d, None, flex_block_mask
335
+
336
+ # SDPA / manual
337
+ attention_mask_4d = attention_mask_2d[:, None, None, :]
338
+ return attention_mask_2d, attention_mask_4d, None
339
+
340
+
341
+ # =============================================================================
342
+ # Output Dataclass
343
+ # =============================================================================
344
+
345
+ @dataclass
346
+ class FastEsmEncoderOutput(ModelOutput):
347
+ last_hidden_state: Optional[torch.Tensor] = None
348
+ hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
349
+ attentions: Optional[Tuple[torch.Tensor, ...]] = None
350
+
351
+
352
+ # =============================================================================
353
+ # FastESM2 Attention Layers (multi-backend: SDPA, Flash, Flex)
354
+ # =============================================================================
355
+
356
+ class EsmSelfAttention(nn.Module):
357
+ def __init__(self, config, position_embedding_type: Optional[str] = None):
358
+ super().__init__()
359
+ assert config.hidden_size % config.num_attention_heads == 0, (
360
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
361
+ f"heads ({config.num_attention_heads})"
362
+ )
363
+ self.num_attention_heads = config.num_attention_heads
364
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
365
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
366
+
367
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
368
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
369
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
370
+ self.scale = self.attention_head_size**-0.5
371
+
372
+ self.dropout_prob = config.attention_probs_dropout_prob
373
+ self.config = config
374
+ self.attn_backend = resolve_attention_backend(config.attn_backend)
375
+ self.position_embedding_type = position_embedding_type or config.position_embedding_type
376
+ self.rotary_embeddings = None
377
+ if self.position_embedding_type == "rotary":
378
+ self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
379
+
380
+ def forward(
381
+ self,
382
+ hidden_states: torch.Tensor,
383
+ attention_mask_2d: torch.Tensor | None = None,
384
+ attention_mask_4d: torch.Tensor | None = None,
385
+ flex_block_mask: "BlockMask | None" = None,
386
+ output_attentions: bool = False,
387
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
388
+ batch_size, seq_length = hidden_states.shape[:-1]
389
+ hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
390
+ query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
391
+ key_BHLD = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
392
+ value_BHLD = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
393
+
394
+ query_BHLD = query_BHLD * self.scale
395
+
396
+ if self.position_embedding_type == "rotary":
397
+ query_BHLD, key_BHLD = self.rotary_embeddings(query_BHLD, key_BHLD)
398
+
399
+ attn_output, attn_weights = self._attn(
400
+ query_BHLD, key_BHLD, value_BHLD,
401
+ attention_mask_2d=attention_mask_2d,
402
+ attention_mask_4d=attention_mask_4d,
403
+ flex_block_mask=flex_block_mask,
404
+ output_attentions=output_attentions,
405
+ )
406
+ return attn_output, attn_weights
407
+
408
+ def _attn(
409
+ self,
410
+ query_BHLD: torch.Tensor,
411
+ key_BHLD: torch.Tensor,
412
+ value_BHLD: torch.Tensor,
413
+ attention_mask_2d: torch.Tensor | None = None,
414
+ attention_mask_4d: torch.Tensor | None = None,
415
+ flex_block_mask: "BlockMask | None" = None,
416
+ output_attentions: bool = False,
417
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
418
+ if output_attentions:
419
+ return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d)
420
+
421
+ if self.attn_backend == AttentionBackend.KERNELS_FLASH:
422
+ return self._kernels_flash_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_2d)
423
+ elif self.attn_backend == AttentionBackend.FLEX:
424
+ return self._flex_attn(query_BHLD, key_BHLD, value_BHLD, flex_block_mask)
425
+ elif self.attn_backend == AttentionBackend.SDPA:
426
+ return self._sdpa_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d)
427
+ else:
428
+ raise AssertionError(f"Unsupported resolved backend: {self.attn_backend}")
429
+
430
+ def _manual_attn(
431
+ self,
432
+ query_BHLD: torch.Tensor,
433
+ key_BHLD: torch.Tensor,
434
+ value_BHLD: torch.Tensor,
435
+ attention_mask_4d: torch.Tensor | None = None,
436
+ ) -> tuple[torch.Tensor, torch.Tensor]:
437
+ attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2))
438
+ if attention_mask_4d is not None:
439
+ attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
440
+ attn_weights = F.softmax(attn_weights, dim=-1)
441
+ if self.dropout_prob > 0 and self.training:
442
+ attn_weights = F.dropout(attn_weights, p=self.dropout_prob, training=self.training)
443
+ context_BHLD = torch.matmul(attn_weights, value_BHLD)
444
+ attn_output = rearrange(context_BHLD, "b h s d -> b s (h d)")
445
+ return attn_output, attn_weights
446
+
447
+ def _kernels_flash_attn(
448
+ self,
449
+ query_BHLD: torch.Tensor,
450
+ key_BHLD: torch.Tensor,
451
+ value_BHLD: torch.Tensor,
452
+ attention_mask_2d: torch.Tensor | None = None,
453
+ ) -> tuple[torch.Tensor, None]:
454
+ query_BLHD = query_BHLD.transpose(1, 2).contiguous()
455
+ key_BLHD = key_BHLD.transpose(1, 2).contiguous()
456
+ value_BLHD = value_BHLD.transpose(1, 2).contiguous()
457
+ attn_output = kernels_flash_attention_func(
458
+ query_states=query_BLHD, key_states=key_BLHD, value_states=value_BLHD,
459
+ attention_mask_2d=attention_mask_2d, causal=False,
460
+ )
461
+ return rearrange(attn_output, "b s h d -> b s (h d)"), None
462
+
463
+ def _flex_attn(
464
+ self,
465
+ query_BHLD: torch.Tensor,
466
+ key_BHLD: torch.Tensor,
467
+ value_BHLD: torch.Tensor,
468
+ flex_block_mask: "BlockMask | None" = None,
469
+ ) -> tuple[torch.Tensor, None]:
470
+ assert flex_attention is not None, "Flex attention is not available in this environment."
471
+ assert query_BHLD.dtype in (torch.float16, torch.bfloat16), (
472
+ f"Flex attention requires float16 or bfloat16, got {query_BHLD.dtype}."
473
+ )
474
+ fn = _get_flex_attention_fn()
475
+ context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
476
+ return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
477
+
478
+ def _sdpa_attn(
479
+ self,
480
+ query_BHLD: torch.Tensor,
481
+ key_BHLD: torch.Tensor,
482
+ value_BHLD: torch.Tensor,
483
+ attention_mask_4d: torch.Tensor | None = None,
484
+ ) -> tuple[torch.Tensor, None]:
485
+ context_BHLD = F.scaled_dot_product_attention(
486
+ query_BHLD, key_BHLD, value_BHLD,
487
+ attn_mask=attention_mask_4d,
488
+ dropout_p=self.dropout_prob if self.training else 0.0,
489
+ scale=1.0,
490
+ )
491
+ return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
492
+
493
+
494
+ class EsmAttention(nn.Module):
495
+ def __init__(self, config):
496
+ super().__init__()
497
+ self.self = EsmSelfAttention(config)
498
+ self.output = EsmSelfOutput(config)
499
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
500
+
501
+ def forward(
502
+ self,
503
+ hidden_states: torch.Tensor,
504
+ attention_mask_2d: torch.Tensor | None = None,
505
+ attention_mask_4d: torch.Tensor | None = None,
506
+ flex_block_mask: "BlockMask | None" = None,
507
+ output_attentions: bool = False,
508
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
509
+ hidden_states_ln = self.LayerNorm(hidden_states)
510
+ attn_output, attn_weights = self.self(
511
+ hidden_states_ln,
512
+ attention_mask_2d=attention_mask_2d,
513
+ attention_mask_4d=attention_mask_4d,
514
+ flex_block_mask=flex_block_mask,
515
+ output_attentions=output_attentions,
516
+ )
517
+ attention_output = self.output(attn_output, hidden_states)
518
+ return attention_output, attn_weights
519
+
520
+
521
+ class EsmLayer(nn.Module):
522
+ def __init__(self, config):
523
+ super().__init__()
524
+ self.attention = EsmAttention(config)
525
+ self.intermediate = EsmIntermediate(config)
526
+ self.output = EsmOutput(config)
527
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
528
+
529
+ def forward(
530
+ self,
531
+ hidden_states: torch.Tensor,
532
+ attention_mask_2d: torch.Tensor | None = None,
533
+ attention_mask_4d: torch.Tensor | None = None,
534
+ flex_block_mask: "BlockMask | None" = None,
535
+ output_attentions: bool = False,
536
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
537
+ attention_output, attn_weights = self.attention(
538
+ hidden_states,
539
+ attention_mask_2d=attention_mask_2d,
540
+ attention_mask_4d=attention_mask_4d,
541
+ flex_block_mask=flex_block_mask,
542
+ output_attentions=output_attentions,
543
+ )
544
+ layer_output = self._feed_forward(attention_output)
545
+ return layer_output, attn_weights
546
+
547
+ def _feed_forward(self, attention_output: torch.Tensor) -> torch.Tensor:
548
+ attention_output_ln = self.LayerNorm(attention_output)
549
+ intermediate_output = self.intermediate(attention_output_ln)
550
+ return self.output(intermediate_output, attention_output)
551
+
552
+
553
+ class FastEsmEncoder(nn.Module):
554
+ def __init__(self, config):
555
+ super().__init__()
556
+ self.config = config
557
+ self.attention_backend = resolve_attention_backend(config.attn_backend)
558
+ self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
559
+ self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
560
+
561
+ def forward(
562
+ self,
563
+ hidden_states: torch.Tensor,
564
+ attention_mask: Optional[torch.Tensor] = None,
565
+ output_hidden_states: bool = False,
566
+ output_attentions: bool = False,
567
+ ) -> FastEsmEncoderOutput:
568
+ all_hidden_states = () if output_hidden_states else None
569
+ all_attentions = () if output_attentions else None
570
+
571
+ attention_mask_2d, attention_mask_4d, flex_block_mask = get_attention_mask(
572
+ effective_backend=self.attention_backend,
573
+ batch_size=hidden_states.shape[0],
574
+ seq_len=hidden_states.shape[1],
575
+ device=hidden_states.device,
576
+ attention_mask=attention_mask,
577
+ )
578
+
579
+ for layer_module in self.layer:
580
+ if output_hidden_states:
581
+ all_hidden_states = all_hidden_states + (hidden_states,)
582
+
583
+ hidden_states, attn_weights = layer_module(
584
+ hidden_states,
585
+ attention_mask_2d=attention_mask_2d,
586
+ attention_mask_4d=attention_mask_4d,
587
+ flex_block_mask=flex_block_mask,
588
+ output_attentions=output_attentions,
589
+ )
590
+
591
+ if all_attentions is not None:
592
+ all_attentions = all_attentions + (attn_weights,)
593
+
594
+ if self.emb_layer_norm_after:
595
+ hidden_states = self.emb_layer_norm_after(hidden_states)
596
+
597
+ if output_hidden_states:
598
+ all_hidden_states = all_hidden_states + (hidden_states,)
599
+
600
+ return FastEsmEncoderOutput(
601
+ last_hidden_state=hidden_states,
602
+ hidden_states=all_hidden_states,
603
+ attentions=all_attentions,
604
+ )
605
+
606
+
607
+ # =============================================================================
608
+ # FastESM Backbone (replaces EsmModel inside ESMFold)
609
+ # =============================================================================
610
+
611
+ class FastEsmBackbone(nn.Module):
612
+ """FastESM2 backbone with multi-backend attention. Drop-in replacement for
613
+ transformers.EsmModel inside EsmForProteinFolding.
614
+
615
+ State dict keys match HuggingFace EsmModel exactly, so pretrained weights
616
+ load without any key remapping.
617
+ """
618
+
619
+ def __init__(self, config):
620
+ super().__init__()
621
+ self.config = config
622
+ self.embeddings = EsmEmbeddings(config)
623
+ self.encoder = FastEsmEncoder(config)
624
+ self.contact_head = EsmContactPredictionHead(
625
+ in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
626
+ )
627
+
628
+ def forward(
629
+ self,
630
+ input_ids: Optional[torch.Tensor] = None,
631
+ attention_mask: Optional[torch.Tensor] = None,
632
+ position_ids: Optional[torch.Tensor] = None,
633
+ inputs_embeds: Optional[torch.Tensor] = None,
634
+ output_attentions: Optional[bool] = None,
635
+ output_hidden_states: Optional[bool] = None,
636
+ return_dict: Optional[bool] = None,
637
+ **kwargs,
638
+ ) -> FastEsmEncoderOutput:
639
+ output_attentions = output_attentions if output_attentions is not None else False
640
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
641
+
642
+ token_embedding_output = self.embeddings(
643
+ input_ids=input_ids,
644
+ position_ids=position_ids,
645
+ attention_mask=attention_mask,
646
+ inputs_embeds=inputs_embeds,
647
+ )
648
+ encoder_outputs = self.encoder(
649
+ token_embedding_output,
650
+ attention_mask=attention_mask,
651
+ output_hidden_states=output_hidden_states,
652
+ output_attentions=output_attentions,
653
+ )
654
+ return FastEsmEncoderOutput(
655
+ last_hidden_state=encoder_outputs.last_hidden_state,
656
+ hidden_states=encoder_outputs.hidden_states,
657
+ attentions=encoder_outputs.attentions,
658
+ )
659
+
660
+
661
+ # =============================================================================
662
+ # TTT (Test-Time Training) Configuration and Utilities
663
+ # =============================================================================
664
+
665
+ _ESM_STANDARD_AA = list("ACDEFGHIKLMNPQRSTVWY")
666
+
667
+
668
+ @dataclass
669
+ class TTTConfig:
670
+ lr: float = 4e-4
671
+ ags: int = 4
672
+ steps: int = 30
673
+ batch_size: int = 4
674
+ mask_ratio: float = 0.15
675
+ crop_size: int = 1024
676
+ bert_leave_prob: float = 0.1
677
+ bert_replace_prob: float = 0.1
678
+ optimizer: str = "sgd"
679
+ momentum: float = 0.0
680
+ weight_decay: float = 0.0
681
+ seed: Optional[int] = 0
682
+ initial_state_reset: bool = True
683
+ freeze_embeddings: bool = True
684
+ lora_rank: int = 8
685
+ lora_alpha: float = 32.0
686
+ lora_target_modules: Tuple[str, ...] = ("query", "key", "value")
687
+
688
+ def verify(self) -> None:
689
+ assert self.lr > 0.0, "TTT learning rate must be positive."
690
+ assert self.ags > 0, "TTT ags must be positive."
691
+ assert self.steps >= 0, "TTT steps must be non-negative."
692
+ assert self.batch_size > 0, "TTT batch_size must be positive."
693
+ assert 0.0 < self.mask_ratio <= 1.0, "TTT mask_ratio must be in (0, 1]."
694
+ assert self.crop_size > 0, "TTT crop_size must be positive."
695
+ assert 0.0 <= self.bert_leave_prob <= 1.0
696
+ assert 0.0 <= self.bert_replace_prob <= 1.0
697
+ assert self.bert_leave_prob + self.bert_replace_prob <= 1.0
698
+ assert self.optimizer in {"sgd", "adamw"}
699
+ assert self.lora_rank >= 0
700
+ assert self.lora_alpha > 0.0
701
+
702
+
703
+ def preserve_model_state(func: Callable[..., Any]) -> Callable[..., Any]:
704
+ @wraps(func)
705
+ def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
706
+ was_training = self.training
707
+ original_device = next(self.parameters()).device
708
+ original_requires_grad = {
709
+ name: parameter.requires_grad
710
+ for name, parameter in self.named_parameters()
711
+ }
712
+ try:
713
+ return func(self, *args, **kwargs)
714
+ finally:
715
+ self.train(was_training)
716
+ self.to(original_device)
717
+ for name, parameter in self.named_parameters():
718
+ if name in original_requires_grad:
719
+ parameter.requires_grad = original_requires_grad[name]
720
+ else:
721
+ parameter.requires_grad = False
722
+ return wrapper
723
+
724
+
725
+ # =============================================================================
726
+ # FastEsmFoldConfig
727
+ # =============================================================================
728
+
729
+ class FastEsmFoldConfig(EsmConfig):
730
+ model_type = "fast_esmfold"
731
+
732
+ def __init__(self, attn_backend: str = "sdpa", ttt_config: Optional[Dict[str, Any]] = None, **kwargs):
733
+ super().__init__(**kwargs)
734
+ self.attn_backend = attn_backend
735
+ self.ttt_config = ttt_config or {
736
+ "lr": 4e-4,
737
+ "steps": 30,
738
+ "lora_rank": 8,
739
+ "lora_alpha": 32.0,
740
+ }
741
+
742
+
743
+ # =============================================================================
744
+ # FastEsmForProteinFolding
745
+ # =============================================================================
746
+
747
+ class FastEsmForProteinFolding(EsmForProteinFolding):
748
+ """ESMFold with FastESM2 attention backends + built-in Test-Time Training.
749
+
750
+ Inherits all folding logic (trunk, structure module, output_to_pdb, infer)
751
+ from transformers.EsmForProteinFolding. Replaces the ESM2 backbone with
752
+ FastESM2 for optimized attention and adds TTT for improved structure prediction.
753
+
754
+ Key API:
755
+ result = model.fold_protein("MKTL...", ttt=True)
756
+ # result = {"plddt": float, "ptm": float, "pdb_string": str}
757
+ """
758
+ config_class = FastEsmFoldConfig
759
+
760
+ def __init__(self, config: FastEsmFoldConfig):
761
+ super().__init__(config)
762
+
763
+ # Replace standard ESM2 backbone with FastESM2 (multi-backend attention)
764
+ self.esm = FastEsmBackbone(config)
765
+ self.esm.requires_grad_(False)
766
+ if config.esmfold_config.fp16_esm:
767
+ self.esm.half()
768
+
769
+ # MLM head for TTT (pretrained EsmLMHead: Dense -> GELU -> LN -> Linear)
770
+ self.mlm_head = EsmLMHead(config)
771
+
772
+ # TTT state (lazy initialization)
773
+ self._ttt_cfg = TTTConfig(**config.ttt_config)
774
+ self._ttt_cfg.verify()
775
+ self._ttt_initialized = False
776
+ self._ttt_initial_state = None
777
+ self._ttt_generator = torch.Generator()
778
+ if self._ttt_cfg.seed is not None:
779
+ self._ttt_generator.manual_seed(self._ttt_cfg.seed)
780
+ self._non_special_tokens_cache = None
781
+ self._ttt_tokenizer = None
782
+
783
+ def _get_ttt_tokenizer(self) -> EsmTokenizer:
784
+ if self._ttt_tokenizer is None:
785
+ self._ttt_tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
786
+ return self._ttt_tokenizer
787
+
788
+ def _ensure_ttt_ready(self) -> None:
789
+ """Lazy TTT initialization. Injects LoRA adapters and saves initial state.
790
+ Must be called after weights are loaded (not in __init__)."""
791
+ if self._ttt_initialized:
792
+ return
793
+ self._ttt_initialized = True
794
+
795
+ tokenizer = self._get_ttt_tokenizer()
796
+ vocab = tokenizer.get_vocab()
797
+ self._non_special_tokens_cache = [vocab[c] for c in _ESM_STANDARD_AA if c in vocab]
798
+
799
+ if self._ttt_cfg.lora_rank > 0:
800
+ self.mlm_head.eval()
801
+ for p in self.mlm_head.parameters():
802
+ p.requires_grad = False
803
+ self._inject_lora()
804
+ else:
805
+ # Legacy path: jointly-trained random linear projection head
806
+ H = self.config.hidden_size
807
+ V = self.config.vocab_size
808
+ device = next(self.esm.parameters()).device
809
+ self._ttt_lm_proj = nn.Linear(H, V, bias=True).to(device)
810
+
811
+ if self._ttt_cfg.initial_state_reset:
812
+ self._ttt_initial_state = self._ttt_get_state()
813
+
814
+ @property
815
+ def _uses_lora(self) -> bool:
816
+ return self._ttt_cfg.lora_rank > 0
817
+
818
+ def _inject_lora(self) -> None:
819
+ from peft import LoraConfig, inject_adapter_in_model
820
+
821
+ lora_config = LoraConfig(
822
+ r=self._ttt_cfg.lora_rank,
823
+ lora_alpha=self._ttt_cfg.lora_alpha,
824
+ target_modules=list(self._ttt_cfg.lora_target_modules),
825
+ lora_dropout=0.0,
826
+ bias="none",
827
+ )
828
+ inject_adapter_in_model(lora_config, self.esm, adapter_name="ttt")
829
+
830
+ # ---- TTT State Management ----
831
+
832
+ def _ttt_get_state(self) -> Dict[str, Any]:
833
+ if self._uses_lora:
834
+ lora_state = {
835
+ k: v.clone() for k, v in self.esm.state_dict().items()
836
+ if "lora_" in k
837
+ }
838
+ return {"_lora_state": lora_state}
839
+ return {
840
+ "esm": copy.deepcopy(self.esm),
841
+ "_ttt_lm_proj": copy.deepcopy(self._ttt_lm_proj),
842
+ }
843
+
844
+ def _ttt_set_state(self, state: Dict[str, Any]) -> None:
845
+ if "_lora_state" in state:
846
+ current_state = self.esm.state_dict()
847
+ current_state.update(state["_lora_state"])
848
+ self.esm.load_state_dict(current_state)
849
+ return
850
+ if "esm" in state:
851
+ self.esm = copy.deepcopy(state["esm"])
852
+ if "_ttt_lm_proj" in state:
853
+ self._ttt_lm_proj = copy.deepcopy(state["_ttt_lm_proj"])
854
+
855
+ def ttt_reset(self) -> None:
856
+ """Reset model to pre-TTT state (restore initial LoRA or backbone weights)."""
857
+ assert self._ttt_initial_state is not None, "TTT reset requires initial_state_reset=True."
858
+ self._ttt_set_state(self._ttt_initial_state)
859
+
860
+ # ---- TTT Core ----
861
+
862
+ def _ttt_tokenize(self, seq: str) -> torch.Tensor:
863
+ tokenizer = self._get_ttt_tokenizer()
864
+ out = tokenizer(
865
+ seq,
866
+ return_tensors="pt",
867
+ add_special_tokens=self._uses_lora,
868
+ padding=False,
869
+ truncation=False,
870
+ )
871
+ return out["input_ids"]
872
+
873
+ def _ttt_mask_token(self) -> int:
874
+ return self._get_ttt_tokenizer().mask_token_id
875
+
876
+ def _ttt_get_non_special_tokens(self) -> List[int]:
877
+ if self._non_special_tokens_cache is not None:
878
+ return self._non_special_tokens_cache
879
+ tokenizer = self._get_ttt_tokenizer()
880
+ vocab = tokenizer.get_vocab()
881
+ self._non_special_tokens_cache = [vocab[c] for c in _ESM_STANDARD_AA if c in vocab]
882
+ return self._non_special_tokens_cache
883
+
884
+ def _ttt_predict_logits(self, batch: torch.Tensor) -> torch.Tensor:
885
+ """Run ESM2 backbone + LM head to get MLM logits."""
886
+ # Temporarily unfreeze backbone for gradient flow during TTT
887
+ output = self.esm(input_ids=batch)
888
+ hidden = output.last_hidden_state
889
+ if self._uses_lora:
890
+ return self.mlm_head(hidden)
891
+ return self._ttt_lm_proj(hidden)
892
+
893
+ def _ttt_sample_batch(
894
+ self,
895
+ x: torch.Tensor,
896
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
897
+ _, seq_len = x.shape
898
+ batch_size = self._ttt_cfg.batch_size
899
+ crop_size = min(self._ttt_cfg.crop_size, seq_len)
900
+
901
+ x_expanded = x.expand(batch_size, -1)
902
+ if seq_len == crop_size:
903
+ start_indices = torch.zeros(batch_size, dtype=torch.long)
904
+ else:
905
+ start_indices = torch.randint(
906
+ 0, seq_len - crop_size + 1, (batch_size,),
907
+ generator=self._ttt_generator,
908
+ ).to(torch.long)
909
+
910
+ batch_cropped = torch.stack([
911
+ x_expanded[index, start : start + crop_size]
912
+ for index, start in enumerate(start_indices)
913
+ ])
914
+
915
+ non_special_tokens = set(self._ttt_get_non_special_tokens())
916
+ mask = torch.zeros((batch_size, crop_size), dtype=torch.bool)
917
+ mask_token_id = self._ttt_mask_token()
918
+
919
+ for row_index in range(batch_size):
920
+ non_special_positions = [
921
+ col for col in range(crop_size)
922
+ if batch_cropped[row_index, col].item() in non_special_tokens
923
+ ]
924
+ assert len(non_special_positions) > 0, "Sequence must contain at least one non-special token."
925
+ num_to_mask = max(1, int(round(len(non_special_positions) * self._ttt_cfg.mask_ratio)))
926
+ sampled_indices = torch.randperm(
927
+ len(non_special_positions), generator=self._ttt_generator,
928
+ )[:num_to_mask]
929
+ positions_to_mask = torch.tensor(non_special_positions, dtype=torch.long)[sampled_indices]
930
+ mask[row_index, positions_to_mask] = True
931
+
932
+ batch_masked = batch_cropped.clone()
933
+ for row_index in range(batch_size):
934
+ masked_positions = torch.nonzero(mask[row_index], as_tuple=True)[0]
935
+ for masked_position in masked_positions:
936
+ probability = float(torch.rand(1, generator=self._ttt_generator).item())
937
+ if probability < 1.0 - self._ttt_cfg.bert_leave_prob - self._ttt_cfg.bert_replace_prob:
938
+ batch_masked[row_index, masked_position] = mask_token_id
939
+ continue
940
+ if probability < 1.0 - self._ttt_cfg.bert_leave_prob:
941
+ replacement_candidates = self._ttt_get_non_special_tokens()
942
+ replacement_index = int(torch.randint(
943
+ 0, len(replacement_candidates), (1,), generator=self._ttt_generator,
944
+ ).item())
945
+ batch_masked[row_index, masked_position] = replacement_candidates[replacement_index]
946
+
947
+ return batch_masked, batch_cropped, mask, start_indices
948
+
949
+ def _ttt_cross_entropy_loss(
950
+ self,
951
+ logits: torch.Tensor,
952
+ targets: torch.Tensor,
953
+ mask: torch.Tensor,
954
+ ) -> torch.Tensor:
955
+ assert logits.ndim == 3, "Logits must be [batch, seq, vocab]."
956
+ _, _, vocab_size = logits.shape
957
+ logits_flat = logits.reshape(-1, vocab_size)
958
+ targets_flat = targets.reshape(-1)
959
+ mask_flat = mask.reshape(-1)
960
+ assert int(mask_flat.sum().item()) > 0, "TTT mask must select at least one token."
961
+ loss = F.cross_entropy(
962
+ logits_flat[mask_flat],
963
+ targets_flat[mask_flat],
964
+ reduction="none",
965
+ )
966
+ masked_tokens_per_seq = mask.sum(dim=1).tolist()
967
+ per_sequence_losses = torch.split(loss, masked_tokens_per_seq)
968
+ return torch.stack([sl.mean() for sl in per_sequence_losses]).mean()
969
+
970
+ def _ttt_get_optimizer(self, parameters) -> torch.optim.Optimizer:
971
+ if self._ttt_cfg.optimizer == "sgd":
972
+ return torch.optim.SGD(
973
+ parameters,
974
+ lr=self._ttt_cfg.lr,
975
+ momentum=self._ttt_cfg.momentum,
976
+ weight_decay=self._ttt_cfg.weight_decay,
977
+ )
978
+ return torch.optim.AdamW(
979
+ parameters,
980
+ lr=self._ttt_cfg.lr,
981
+ weight_decay=self._ttt_cfg.weight_decay,
982
+ )
983
+
984
+ def _lora_ttt(self, seq: str) -> Dict[str, List[float]]:
985
+ """LoRA TTT: only LoRA adapter weights are trained, mlm_head is frozen."""
986
+ x = self._ttt_tokenize(seq)
987
+ device = next(self.parameters()).device
988
+ non_blocking = device.type == "cuda"
989
+ losses = []
990
+
991
+ if self._ttt_cfg.steps == 0:
992
+ return {"losses": losses}
993
+
994
+ for parameter in self.parameters():
995
+ parameter.requires_grad = False
996
+ for name, parameter in self.esm.named_parameters():
997
+ if "lora_" in name:
998
+ parameter.requires_grad = True
999
+ lora_params = [p for n, p in self.esm.named_parameters() if "lora_" in n]
1000
+ optimizer = self._ttt_get_optimizer(iter(lora_params))
1001
+ optimizer.zero_grad(set_to_none=True)
1002
+
1003
+ self.eval()
1004
+ for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags):
1005
+ batch_masked, targets, mask, start_indices = self._ttt_sample_batch(x)
1006
+ batch_masked = batch_masked.to(device, non_blocking=non_blocking)
1007
+ targets = targets.to(device, non_blocking=non_blocking)
1008
+ mask = mask.to(device, non_blocking=non_blocking)
1009
+
1010
+ self.train()
1011
+ logits = self._ttt_predict_logits(batch_masked)
1012
+ loss = self._ttt_cross_entropy_loss(logits, targets, mask)
1013
+ loss.backward()
1014
+ losses.append(float(loss.detach().cpu().item()))
1015
+
1016
+ if (step + 1) % self._ttt_cfg.ags == 0:
1017
+ optimizer.step()
1018
+ optimizer.zero_grad(set_to_none=True)
1019
+
1020
+ self.eval()
1021
+ return {"losses": losses}
1022
+
1023
+ def _legacy_ttt(self, seq: str) -> Dict[str, List[float]]:
1024
+ """Legacy TTT: full fine-tuning of ESM2 backbone with random linear projection head."""
1025
+ x = self._ttt_tokenize(seq)
1026
+ device = next(self.parameters()).device
1027
+ non_blocking = device.type == "cuda"
1028
+ losses = []
1029
+
1030
+ if self._ttt_cfg.steps == 0:
1031
+ return {"losses": losses}
1032
+
1033
+ # Full fine-tune: all backbone params trainable
1034
+ for parameter in self.parameters():
1035
+ parameter.requires_grad = False
1036
+ for parameter in self.esm.parameters():
1037
+ parameter.requires_grad = True
1038
+ if self._ttt_cfg.freeze_embeddings:
1039
+ for parameter in self.esm.embeddings.parameters():
1040
+ parameter.requires_grad = False
1041
+ for parameter in self._ttt_lm_proj.parameters():
1042
+ parameter.requires_grad = True
1043
+
1044
+ trainable_params = filter(lambda p: p.requires_grad, self.parameters())
1045
+ optimizer = self._ttt_get_optimizer(trainable_params)
1046
+ optimizer.zero_grad(set_to_none=True)
1047
+
1048
+ self.eval()
1049
+ for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags):
1050
+ batch_masked, targets, mask, start_indices = self._ttt_sample_batch(x)
1051
+ batch_masked = batch_masked.to(device, non_blocking=non_blocking)
1052
+ targets = targets.to(device, non_blocking=non_blocking)
1053
+ mask = mask.to(device, non_blocking=non_blocking)
1054
+
1055
+ self.train()
1056
+ logits = self._ttt_predict_logits(batch_masked)
1057
+ loss = self._ttt_cross_entropy_loss(logits, targets, mask)
1058
+ loss.backward()
1059
+ losses.append(float(loss.detach().cpu().item()))
1060
+
1061
+ if (step + 1) % self._ttt_cfg.ags == 0:
1062
+ optimizer.step()
1063
+ optimizer.zero_grad(set_to_none=True)
1064
+
1065
+ self.eval()
1066
+ return {"losses": losses}
1067
+
1068
+ @preserve_model_state
1069
+ def ttt(self, seq: str) -> Dict[str, List[float]]:
1070
+ """Run test-time training on a single sequence using masked language modeling.
1071
+
1072
+ Adapts the ESM2 backbone (via LoRA or full fine-tuning) to the input sequence
1073
+ before structure prediction. Call fold_protein(seq, ttt=True) for the full pipeline.
1074
+
1075
+ Args:
1076
+ seq: Protein sequence (single-letter amino acid codes)
1077
+
1078
+ Returns:
1079
+ Dict with "losses" key containing per-step MLM loss values
1080
+ """
1081
+ self._ensure_ttt_ready()
1082
+ if self._uses_lora:
1083
+ return self._lora_ttt(seq)
1084
+ return self._legacy_ttt(seq)
1085
+
1086
+ # ---- High-Level API ----
1087
+
1088
+ def fold_protein(
1089
+ self,
1090
+ sequence: str,
1091
+ ttt: bool = False,
1092
+ num_recycles: Optional[int] = None,
1093
+ return_pdb_string: bool = True,
1094
+ ) -> Dict[str, Any]:
1095
+ """Fold a protein sequence, optionally with test-time training.
1096
+
1097
+ Args:
1098
+ sequence: Protein sequence (single-letter amino acid codes)
1099
+ ttt: If True, run test-time training before folding (improves accuracy)
1100
+ num_recycles: Override default number of recycling iterations (None = use config default)
1101
+ return_pdb_string: If True, include PDB string in output
1102
+
1103
+ Returns:
1104
+ Dict with keys:
1105
+ - plddt: float, mean per-residue pLDDT confidence score
1106
+ - ptm: float, predicted TM-score
1107
+ - pdb_string: str (if return_pdb_string=True), PDB format structure
1108
+ - ttt_losses: list[float] (if ttt=True), per-step MLM losses
1109
+ """
1110
+ result: Dict[str, Any] = {}
1111
+
1112
+ if ttt:
1113
+ ttt_result = self.ttt(sequence)
1114
+ result["ttt_losses"] = ttt_result["losses"]
1115
+
1116
+ with torch.no_grad():
1117
+ output = self.infer(sequence, num_recycles=num_recycles)
1118
+
1119
+ plddt = output["plddt"]
1120
+ if plddt.dim() >= 2:
1121
+ mean_plddt = float(plddt.mean(dim=-1).mean().item())
1122
+ else:
1123
+ mean_plddt = float(plddt.mean().item())
1124
+
1125
+ result["plddt"] = mean_plddt
1126
+ result["ptm"] = float(output["ptm"].item()) if "ptm" in output else None
1127
+
1128
+ if return_pdb_string:
1129
+ pdb_strings = self.output_to_pdb(output)
1130
+ result["pdb_string"] = pdb_strings[0] if isinstance(pdb_strings, list) else pdb_strings
1131
+
1132
+ if ttt:
1133
+ self.ttt_reset()
1134
+
1135
+ return result