lhallee commited on
Commit
355193c
·
verified ·
1 Parent(s): 2f071ee

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +1243 -1197
modeling_esm_plusplus.py CHANGED
@@ -1,1197 +1,1243 @@
1
- """
2
- ESM++ model implementation.
3
-
4
- ESM++ is a faithful implementation of ESMC that allows for batching and standard Huggingface compatibility
5
- The ESM Python package is not required
6
-
7
- Modified from https://github.com/evolutionaryscale/esm
8
- License: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement
9
- """
10
-
11
- import entrypoint_setup
12
- import math
13
- import os
14
- import warnings
15
- import torch
16
- import torch.nn as nn
17
- import torch.nn.functional as F
18
- from dataclasses import dataclass
19
- from functools import cache, partial
20
- from pathlib import Path
21
- from typing import Optional, Tuple, Union, List
22
- from einops import rearrange, repeat
23
- from huggingface_hub import snapshot_download
24
- from tokenizers import Tokenizer
25
- from tokenizers.models import BPE
26
- from tokenizers.processors import TemplateProcessing
27
- from torch.nn.attention.flex_attention import create_block_mask
28
- from torch.nn.attention.flex_attention import flex_attention
29
- from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
30
- from transformers.modeling_outputs import ModelOutput
31
-
32
- try:
33
- # when used from AutoModel, these are in the same directory
34
- from .embedding_mixin import EmbeddingMixin, Pooler
35
- except:
36
- # when running from our repo, these are in the base directory
37
- from embedding_mixin import EmbeddingMixin, Pooler
38
-
39
-
40
- def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
41
- token_valid = attention_mask_2d.bool()
42
- batch_size, seq_len = token_valid.shape
43
-
44
- def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
45
- return token_valid[batch_idx, q_idx] & token_valid[batch_idx, kv_idx]
46
-
47
- return create_block_mask(
48
- mask_mod,
49
- batch_size,
50
- 1,
51
- seq_len,
52
- seq_len,
53
- device=attention_mask_2d.device,
54
- )
55
-
56
-
57
- class ESMplusplusConfig(PretrainedConfig):
58
- """Configuration class for ESM++ model.
59
-
60
- Args:
61
- vocab_size: Size of the vocabulary
62
- hidden_size: Dimension of hidden layers
63
- num_attention_heads: Number of attention heads
64
- num_hidden_layers: Number of transformer layers
65
- num_labels: Number of output labels for classification
66
- problem_type: Type of problem - regression, single/multi label classification
67
- """
68
- model_type = "ESMplusplus"
69
- def __init__(
70
- self,
71
- vocab_size: int = 64,
72
- hidden_size: int = 960,
73
- num_attention_heads: int = 15,
74
- num_hidden_layers: int = 30,
75
- num_labels: int = 2,
76
- problem_type: str | None = None,
77
- dropout: float = 0.0,
78
- initializer_range: float = 0.02,
79
- attn_backend: str = "sdpa",
80
- **kwargs,
81
- ):
82
- super().__init__(**kwargs)
83
- self.vocab_size = vocab_size
84
- self.hidden_size = hidden_size
85
- self.num_attention_heads = num_attention_heads
86
- self.num_hidden_layers = num_hidden_layers
87
- self.num_labels = num_labels
88
- self.problem_type = problem_type
89
- self.dropout = dropout
90
- self.initializer_range = initializer_range
91
- self.tie_word_embeddings = False
92
- self.attn_backend = attn_backend
93
-
94
-
95
- ### Rotary Embeddings
96
- def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
97
- """Rotates half the hidden dims of the input."""
98
- if not interleaved:
99
- x1, x2 = x.chunk(2, dim=-1)
100
- return torch.cat((-x2, x1), dim=-1)
101
- else:
102
- x1, x2 = x[..., ::2], x[..., 1::2]
103
- return rearrange(
104
- torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
105
- )
106
-
107
-
108
- def apply_rotary_emb_torch(
109
- x: torch.Tensor,
110
- cos: torch.Tensor,
111
- sin: torch.Tensor,
112
- interleaved: bool = False,
113
- _inplace: bool = False,
114
- ) -> torch.Tensor:
115
- """Apply rotary embeddings to input based on cos and sin."""
116
- ro_dim = cos.shape[-1] * 2
117
- assert ro_dim <= x.shape[-1]
118
- seqlen = x.size(1)
119
- cos = cos[:seqlen]
120
- sin = sin[:seqlen]
121
- cos = repeat(cos, "s d -> s 1 (2 d)")
122
- sin = repeat(sin, "s d -> s 1 (2 d)")
123
- return torch.cat(
124
- [
125
- x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
126
- x[..., ro_dim:],
127
- ],
128
- dim=-1,
129
- )
130
-
131
-
132
- class RotaryEmbedding(torch.nn.Module):
133
- """Rotary position embeddings.
134
-
135
- Based on the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding"
136
-
137
- Args:
138
- dim: Dimension of the embedding
139
- base: Base for computing angular frequencies
140
- interleaved: Whether to use interleaved rotations
141
- scale_base: Base for scaling
142
- scaling_factor: Factor for scaling positions
143
- pos_idx_in_fp32: Whether to compute position indices in fp32
144
- device: Computation device
145
- """
146
- def __init__(
147
- self,
148
- dim: int,
149
- base: float = 10000.0,
150
- interleaved: bool = False,
151
- scale_base: Optional[float] = None,
152
- scaling_factor: float = 1.0,
153
- pos_idx_in_fp32: bool = True,
154
- device: Optional[torch.device] = None,
155
- ):
156
- super().__init__()
157
- self.dim = dim
158
- self.base = float(base)
159
- self.pos_idx_in_fp32 = pos_idx_in_fp32
160
- self.interleaved = interleaved
161
- self.scale_base = scale_base
162
- self.scaling_factor = scaling_factor
163
- self.device = device
164
-
165
- self._seq_len_cached = 0
166
- self._cos_cached = None
167
- self._sin_cached = None
168
- self._cos_k_cached = None
169
- self._sin_k_cached = None
170
- self.reset_parameters()
171
-
172
- def reset_parameters(self):
173
- """Reset the parameters of the embedding."""
174
- inv_freq = self._compute_inv_freq(self.device)
175
- self.register_buffer("inv_freq", inv_freq, persistent=False)
176
- arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32)
177
- scale = (
178
- (arange + 0.4 * self.dim) / (1.4 * self.dim)
179
- if self.scale_base is not None
180
- else None
181
- )
182
- self.register_buffer("scale", scale)
183
-
184
- def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
185
- """Compute inverse frequency bands."""
186
- return 1 / (
187
- self.base
188
- ** (
189
- torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
190
- / self.dim
191
- )
192
- )
193
-
194
- def _update_cos_sin_cache(self, seqlen: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
195
- """Update the cached cosine and sine values."""
196
- if (
197
- seqlen > self._seq_len_cached
198
- or self._cos_cached is None
199
- or self._cos_cached.device != device
200
- or self._cos_cached.dtype != dtype
201
- or (self.training and self._cos_cached.is_inference())
202
- ):
203
- self._seq_len_cached = seqlen
204
- if self.pos_idx_in_fp32:
205
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
206
- t /= self.scaling_factor
207
- if self.inv_freq.dtype != torch.float32:
208
- inv_freq = self.inv_freq.to(torch.float32)
209
- else:
210
- inv_freq = self.inv_freq
211
- else:
212
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
213
- t /= self.scaling_factor
214
- inv_freq = self.inv_freq
215
- freqs = torch.outer(t, inv_freq)
216
-
217
- if self.scale is None:
218
- self._cos_cached = torch.cos(freqs).to(dtype)
219
- self._sin_cached = torch.sin(freqs).to(dtype)
220
- else:
221
- power = (
222
- torch.arange(
223
- seqlen, dtype=self.scale.dtype, device=self.scale.device
224
- )
225
- - seqlen // 2
226
- ) / self.scale_base
227
- scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
228
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
229
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
230
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
231
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
232
-
233
- def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
234
- """Apply rotary embeddings to queries and keys.
235
-
236
- Args:
237
- q: Query tensor of shape (batch, seqlen, nheads, headdim)
238
- k: Key tensor of shape (batch, seqlen, nheads, headdim)
239
-
240
- Returns:
241
- Tuple of rotated query and key tensors
242
- """
243
- self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
244
- assert self._cos_cached is not None
245
- assert self._sin_cached is not None
246
- if self.scale is None:
247
- return (
248
- apply_rotary_emb_torch(
249
- q,
250
- self._cos_cached,
251
- self._sin_cached,
252
- self.interleaved,
253
- True, # inplace=True
254
- ),
255
- apply_rotary_emb_torch(
256
- k,
257
- self._cos_cached,
258
- self._sin_cached,
259
- self.interleaved,
260
- True, # inplace=True
261
- ),
262
- ) # type: ignore
263
- else:
264
- assert False
265
-
266
-
267
- ### Feedforward Network Components
268
- def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
269
- """Compute corrected dimension for SwiGLU."""
270
- return int(((expansion_ratio * d_model) + 255) // 256 * 256)
271
-
272
-
273
- class SwiGLU(nn.Module):
274
- """SwiGLU activation function."""
275
- def __init__(self):
276
- super(SwiGLU, self).__init__()
277
-
278
- def forward(self, x: torch.Tensor) -> torch.Tensor:
279
- x1, x2 = x.chunk(2, dim=-1)
280
- return F.silu(x1) * x2
281
-
282
-
283
- def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
284
- """Create SwiGLU feedforward network with layer normalization."""
285
- return nn.Sequential(
286
- nn.LayerNorm(d_model),
287
- nn.Linear(
288
- d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
289
- ),
290
- SwiGLU(),
291
- nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
292
- )
293
-
294
-
295
- ### Attention
296
- class MultiHeadAttention(nn.Module):
297
- """Multi-head attention with rotary embeddings.
298
-
299
- Args:
300
- d_model: Model dimension
301
- n_heads: Number of attention heads
302
- """
303
- def __init__(
304
- self,
305
- d_model: int,
306
- n_heads: int,
307
- attn_backend: str = "sdpa",
308
- ):
309
- super().__init__()
310
- self.d_model = d_model
311
- self.n_heads = n_heads
312
- self.d_head = self.d_model // self.n_heads
313
- self.attn_backend = attn_backend
314
- self._warned_flex_fallback = False
315
- self.layernorm_qkv = nn.Sequential(
316
- nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
317
- )
318
- self.out_proj = nn.Linear(d_model, d_model, bias=False)
319
- self.q_ln = nn.LayerNorm(d_model, bias=False)
320
- self.k_ln = nn.LayerNorm(d_model, bias=False)
321
- self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads)
322
- self.rotary = RotaryEmbedding(d_model // n_heads)
323
-
324
- def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
325
- """Apply rotary embeddings to query and key."""
326
- q = q.unflatten(-1, (self.n_heads, self.d_head))
327
- k = k.unflatten(-1, (self.n_heads, self.d_head))
328
- q, k = self.rotary(q, k)
329
- q = q.flatten(-2, -1)
330
- k = k.flatten(-2, -1)
331
- return q, k
332
-
333
- def forward(
334
- self,
335
- x: torch.Tensor,
336
- attention_mask: Optional[torch.Tensor] = None,
337
- flex_block_mask: Optional[object] = None,
338
- output_attentions: bool = False,
339
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
340
- """
341
- Args:
342
- x: Input tensor
343
- attention_mask: Optional attention mask
344
- output_attentions: Whether to return attention weights
345
-
346
- Returns:
347
- Output tensor after self attention, and optionally attention weights
348
- """
349
- attn_weights = None
350
- qkv_BLD3 = self.layernorm_qkv(x)
351
- query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
352
- query_BLD, key_BLD = (
353
- self.q_ln(query_BLD).to(query_BLD.dtype),
354
- self.k_ln(key_BLD).to(query_BLD.dtype),
355
- )
356
- query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
357
- query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
358
- scale = 1 / math.sqrt(self.d_head)
359
-
360
- if output_attentions: # Manual attention computation
361
- b, h, l, _ = query_BHLD.shape
362
- attn_bias = torch.zeros(b, h, l, l, dtype=query_BLD.dtype, device=query_BLD.device)
363
- if attention_mask is not None:
364
- attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
365
- attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
366
- attn_weights += attn_bias
367
- attn_weights = F.softmax(attn_weights, dim=-1)
368
- context_BHLD = torch.matmul(attn_weights, value_BHLD)
369
- else:
370
- sdpa_mask = None
371
- if attention_mask is not None:
372
- sdpa_mask = torch.zeros_like(attention_mask, dtype=query_BHLD.dtype)
373
- sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
374
- use_flex = (
375
- self.attn_backend == "flex"
376
- and (attention_mask is None or flex_block_mask is not None)
377
- )
378
- if use_flex:
379
- try:
380
- context_BHLD = flex_attention(
381
- query_BHLD,
382
- key_BHLD,
383
- value_BHLD,
384
- block_mask=flex_block_mask,
385
- scale=scale,
386
- )
387
- except Exception as exc:
388
- if not self._warned_flex_fallback:
389
- warnings.warn(
390
- f"Flex attention failed in ESM++ attention; falling back to SDPA. Error: {exc}",
391
- RuntimeWarning,
392
- )
393
- self._warned_flex_fallback = True
394
- context_BHLD = F.scaled_dot_product_attention(
395
- query_BHLD,
396
- key_BHLD,
397
- value_BHLD,
398
- attn_mask=sdpa_mask,
399
- scale=scale,
400
- )
401
- else:
402
- context_BHLD = F.scaled_dot_product_attention(
403
- query_BHLD,
404
- key_BHLD,
405
- value_BHLD,
406
- attn_mask=sdpa_mask,
407
- scale=scale,
408
- )
409
-
410
- context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
411
- output = self.out_proj(context_BLD)
412
- return output, attn_weights
413
-
414
-
415
- ### Regression Head
416
- def RegressionHead(d_model: int, output_dim: int, hidden_dim: Optional[int] = None) -> nn.Module:
417
- """Create a regression head with optional hidden dimension.
418
-
419
- Args:
420
- d_model: Input dimension
421
- output_dim: Output dimension
422
- hidden_dim: Optional hidden dimension (defaults to d_model)
423
- """
424
- hidden_dim = hidden_dim if hidden_dim is not None else d_model
425
- return nn.Sequential(
426
- nn.Linear(d_model, hidden_dim),
427
- nn.GELU(),
428
- nn.LayerNorm(hidden_dim),
429
- nn.Linear(hidden_dim, output_dim),
430
- )
431
-
432
-
433
- ### Transformer Block
434
- class UnifiedTransformerBlock(nn.Module):
435
- """Transformer block with attention and feedforward layers.
436
-
437
- Args:
438
- d_model: Model dimension
439
- n_heads: Number of attention heads
440
- residue_scaling_factor: Factor for scaling residual connections
441
- expansion_ratio: Expansion ratio for feedforward network
442
- """
443
- def __init__(
444
- self,
445
- d_model: int,
446
- n_heads: int,
447
- residue_scaling_factor: float = 1,
448
- expansion_ratio: float = 8 / 3,
449
- dropout: float = 0.0,
450
- attn_backend: str = "sdpa",
451
- ):
452
- super().__init__()
453
- self.attn = MultiHeadAttention(
454
- d_model=d_model,
455
- n_heads=n_heads,
456
- attn_backend=attn_backend,
457
- )
458
- self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
459
- self.scaling_factor = residue_scaling_factor
460
- self.dropout = nn.Dropout(dropout)
461
-
462
- def forward(
463
- self,
464
- x: torch.Tensor,
465
- attention_mask: Optional[torch.Tensor] = None,
466
- flex_block_mask: Optional[object] = None,
467
- output_attentions: bool = False,
468
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
469
- """
470
- Args:
471
- x: Input tensor
472
- attention_mask: Optional attention mask
473
- output_attentions: Whether to return attention weights
474
-
475
- Returns:
476
- Output tensor after transformer block, and optionally attention weights
477
- """
478
- attn_output, attn_weights = self.attn(
479
- x,
480
- attention_mask,
481
- flex_block_mask,
482
- output_attentions,
483
- )
484
- x = x + self.dropout(attn_output) / self.scaling_factor
485
- x = x + self.dropout(self.ffn(x)) / self.scaling_factor
486
- return x, attn_weights
487
-
488
-
489
- ### Model Outputs
490
- @dataclass
491
- class TransformerOutput(ModelOutput):
492
- """Output type for transformer encoder."""
493
- last_hidden_state: Optional[torch.Tensor] = None
494
- hidden_states: Optional[Tuple[torch.Tensor]] = None
495
- attentions: Optional[Tuple[torch.Tensor]] = None
496
-
497
-
498
- @dataclass
499
- class ESMplusplusOutput(ModelOutput):
500
- """Output type for ESM++ models."""
501
- loss: Optional[torch.Tensor] = None
502
- logits: Optional[torch.Tensor] = None
503
- last_hidden_state: Optional[torch.Tensor] = None
504
- hidden_states: Optional[Tuple[torch.Tensor]] = None
505
- attentions: Optional[Tuple[torch.Tensor]] = None
506
-
507
-
508
- ### Transformer Stack
509
- class TransformerStack(nn.Module):
510
- """Stack of transformer blocks.
511
-
512
- Args:
513
- d_model: Model dimension
514
- n_heads: Number of attention heads
515
- n_layers: Number of transformer layers
516
- dropout: Dropout rate
517
- """
518
- def __init__(
519
- self,
520
- d_model: int,
521
- n_heads: int,
522
- n_layers: int,
523
- dropout: float = 0.0,
524
- attn_backend: str = "sdpa",
525
- ):
526
- super().__init__()
527
- self.attn_backend = attn_backend
528
- self.blocks = nn.ModuleList(
529
- [
530
- UnifiedTransformerBlock(
531
- d_model,
532
- n_heads,
533
- residue_scaling_factor=math.sqrt(n_layers / 36),
534
- dropout=dropout,
535
- attn_backend=attn_backend,
536
- )
537
- for i in range(n_layers)
538
- ]
539
- )
540
- self.norm = nn.LayerNorm(d_model, bias=False)
541
- self.gradient_checkpointing = False
542
-
543
- def forward(
544
- self,
545
- x: torch.Tensor,
546
- attention_mask: Optional[torch.Tensor] = None,
547
- output_hidden_states: bool = False,
548
- output_attentions: bool = False,
549
- ) -> TransformerOutput:
550
- """
551
- Args:
552
- x: Input tensor
553
- attention_mask: Optional attention mask
554
- output_hidden_states: Whether to return all hidden states
555
- output_attentions: Whether to return attention weights
556
-
557
- Returns:
558
- TransformerOutput containing last hidden state and optionally all hidden states and attention weights
559
- """
560
- hidden_states = () if output_hidden_states else None
561
- attentions = () if output_attentions else None
562
-
563
- if attention_mask is not None:
564
- assert attention_mask.ndim == 2, f"Expected 2D token attention mask, got shape {attention_mask.shape}."
565
- token_attention_mask = attention_mask.bool()
566
- pairwise_attention_mask = token_attention_mask.unsqueeze(-1) & token_attention_mask.unsqueeze(-2)
567
- attention_mask = pairwise_attention_mask.unsqueeze(1)
568
- if self.attn_backend == "flex" and not output_attentions:
569
- flex_block_mask = _create_pad_block_mask(token_attention_mask)
570
- else:
571
- flex_block_mask = None
572
- else:
573
- flex_block_mask = None
574
-
575
- for block in self.blocks:
576
- if self.gradient_checkpointing and self.training:
577
- x, attn_weights = self._gradient_checkpointing_func(
578
- block.__call__,
579
- x,
580
- attention_mask,
581
- flex_block_mask,
582
- output_attentions,
583
- )
584
- else:
585
- x, attn_weights = block(x, attention_mask, flex_block_mask, output_attentions)
586
-
587
- if attentions is not None:
588
- attentions += (attn_weights,)
589
-
590
- if output_hidden_states:
591
- assert hidden_states is not None
592
- hidden_states += (x,)
593
-
594
- return TransformerOutput(
595
- last_hidden_state=self.norm(x),
596
- hidden_states=hidden_states,
597
- attentions=attentions
598
- )
599
-
600
-
601
- class PreTrainedESMplusplusModel(PreTrainedModel):
602
- """
603
- init weights for ESM++ models
604
- """
605
- config_class = ESMplusplusConfig
606
- base_model_prefix = "esm++"
607
- supports_gradient_checkpointing = True
608
- all_tied_weights_keys = {}
609
-
610
- def _init_weights(self, module):
611
- """Initialize the weights"""
612
- if isinstance(module, nn.Linear):
613
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
614
- if module.bias is not None:
615
- module.bias.data.zero_()
616
- elif isinstance(module, nn.Embedding):
617
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
618
- if module.padding_idx is not None:
619
- module.weight.data[module.padding_idx].zero_()
620
- elif isinstance(module, nn.LayerNorm):
621
- if module.bias is not None:
622
- module.bias.data.zero_()
623
- module.weight.data.fill_(1.0)
624
-
625
- @classmethod
626
- def from_pretrained_esm(cls, model_name: str):
627
- """Load a pretrained ESM++ model."""
628
- if '300' in model_name:
629
- return ESMplusplus_300M()
630
- elif '600' in model_name:
631
- return ESMplusplus_600M()
632
- else:
633
- raise ValueError(f"Invalid model name: {model_name}")
634
-
635
-
636
- ### ESM++ Models
637
- class ESMplusplusModel(PreTrainedESMplusplusModel, EmbeddingMixin):
638
- """
639
- ESM++ model. transformer model with no heads
640
- """
641
- config_class = ESMplusplusConfig
642
- def __init__(self, config: ESMplusplusConfig, **kwargs):
643
- PreTrainedESMplusplusModel.__init__(self, config, **kwargs)
644
- self.config = config
645
- self.vocab_size = config.vocab_size
646
- self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
647
- self.transformer = TransformerStack(
648
- d_model=config.hidden_size,
649
- n_heads=config.num_attention_heads,
650
- n_layers=config.num_hidden_layers,
651
- dropout=config.dropout,
652
- attn_backend=config.attn_backend,
653
- )
654
- self.tokenizer = EsmSequenceTokenizer()
655
- self.init_weights()
656
-
657
- def get_input_embeddings(self):
658
- return self.embed
659
-
660
- def set_input_embeddings(self, value):
661
- self.embed = value
662
-
663
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
664
- x = self.embed(input_ids)
665
- return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
666
-
667
- def forward(
668
- self,
669
- input_ids: Optional[torch.Tensor] = None,
670
- attention_mask: Optional[torch.Tensor] = None,
671
- inputs_embeds: Optional[torch.Tensor] = None,
672
- output_attentions: Optional[bool] = None,
673
- output_hidden_states: Optional[bool] = None,
674
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
675
- **kwargs,
676
- ) -> TransformerOutput:
677
- """Forward pass for masked language modeling.
678
-
679
- Args:
680
- input_ids: Input token IDs
681
- attention_mask: Attention mask
682
- inputs_embeds: Optional precomputed embeddings
683
- output_hidden_states: Whether to return all hidden states
684
- output_attentions: Whether to return attention weights
685
-
686
- Returns:
687
- TransformerOutput containing last hidden state and optionally all hidden states and attention weights
688
- """
689
- if inputs_embeds is None:
690
- x = self.embed(input_ids)
691
- else:
692
- x = inputs_embeds
693
- return self.transformer(x, attention_mask, output_hidden_states, output_attentions)
694
-
695
-
696
- class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
697
- """
698
- ESM++ model for masked language modeling.
699
- Implements the base ESM++ architecture with a masked language modeling head.
700
- """
701
- config_class = ESMplusplusConfig
702
- def __init__(self, config: ESMplusplusConfig, **kwargs):
703
- PreTrainedESMplusplusModel.__init__(self, config, **kwargs)
704
- self.config = config
705
- self.vocab_size = config.vocab_size
706
- self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
707
- self.transformer = TransformerStack(
708
- d_model=config.hidden_size,
709
- n_heads=config.num_attention_heads,
710
- n_layers=config.num_hidden_layers,
711
- dropout=config.dropout,
712
- attn_backend=config.attn_backend,
713
- )
714
- self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
715
- self.ce_loss = nn.CrossEntropyLoss()
716
- self.tokenizer = EsmSequenceTokenizer()
717
- self.init_weights()
718
-
719
- def get_input_embeddings(self):
720
- return self.embed
721
-
722
- def set_input_embeddings(self, value):
723
- self.embed = value
724
-
725
- def get_output_embeddings(self):
726
- return self.sequence_head[-1]
727
-
728
- def set_output_embeddings(self, new_embeddings):
729
- self.sequence_head[-1] = new_embeddings
730
-
731
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
732
- x = self.embed(input_ids)
733
- return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
734
-
735
- def forward(
736
- self,
737
- input_ids: Optional[torch.Tensor] = None,
738
- attention_mask: Optional[torch.Tensor] = None,
739
- inputs_embeds: Optional[torch.Tensor] = None,
740
- labels: Optional[torch.Tensor] = None,
741
- output_attentions: Optional[bool] = None,
742
- output_hidden_states: Optional[bool] = None,
743
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
744
- **kwargs,
745
- ) -> ESMplusplusOutput:
746
- """Forward pass for masked language modeling.
747
-
748
- Args:
749
- input_ids: Input token IDs
750
- attention_mask: Attention mask
751
- inputs_embeds: Optional precomputed embeddings
752
- labels: Optional labels for masked tokens
753
- output_hidden_states: Whether to return all hidden states
754
- output_attentions: Whether to return attention weights
755
-
756
- Returns:
757
- ESMplusplusOutput containing loss, logits, hidden states and attention weights
758
- """
759
- if inputs_embeds is None:
760
- x = self.embed(input_ids)
761
- else:
762
- x = inputs_embeds
763
- output = self.transformer(x, attention_mask, output_hidden_states, output_attentions)
764
- x = output.last_hidden_state
765
- logits = self.sequence_head(x)
766
- loss = None
767
- if labels is not None:
768
- loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
769
- return ESMplusplusOutput(
770
- loss=loss,
771
- logits=logits,
772
- last_hidden_state=x,
773
- hidden_states=output.hidden_states,
774
- attentions=output.attentions,
775
- )
776
-
777
-
778
- class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixin):
779
- """
780
- ESM++ model for sequence classification.
781
- Extends the base ESM++ model with a classification head.
782
- """
783
- def __init__(self, config: ESMplusplusConfig, **kwargs):
784
- ESMplusplusForMaskedLM.__init__(self, config, **kwargs)
785
- self.config = config
786
- self.num_labels = config.num_labels
787
- self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
788
- # Large intermediate projections help with sequence classification tasks (*4)
789
- self.mse = nn.MSELoss()
790
- self.ce = nn.CrossEntropyLoss()
791
- self.bce = nn.BCEWithLogitsLoss()
792
- # if kwargs has pooling_types, use them, otherwise use ['cls', 'mean']
793
- if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0:
794
- pooling_types = kwargs['pooling_types']
795
- else:
796
- pooling_types = ['cls', 'mean']
797
- self.pooler = Pooler(pooling_types)
798
- self.init_weights()
799
-
800
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
801
- x = self.embed(input_ids)
802
- return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
803
-
804
- def forward(
805
- self,
806
- input_ids: Optional[torch.Tensor] = None,
807
- attention_mask: Optional[torch.Tensor] = None,
808
- inputs_embeds: Optional[torch.Tensor] = None,
809
- labels: Optional[torch.Tensor] = None,
810
- output_attentions: Optional[bool] = None,
811
- output_hidden_states: Optional[bool] = None,
812
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
813
- **kwargs,
814
- ) -> ESMplusplusOutput:
815
- """Forward pass for sequence classification.
816
-
817
- Args:
818
- input_ids: Input token IDs
819
- attention_mask: Attention mask
820
- inputs_embeds: Optional precomputed embeddings
821
- labels: Optional labels for classification
822
- output_hidden_states: Whether to return all hidden states
823
- output_attentions: Whether to return attention weights
824
-
825
- Returns:
826
- ESMplusplusOutput containing loss, logits, and hidden states
827
- """
828
- output = super().forward(
829
- input_ids=input_ids,
830
- attention_mask=attention_mask,
831
- inputs_embeds=inputs_embeds,
832
- labels=None,
833
- output_attentions=output_attentions,
834
- output_hidden_states=output_hidden_states
835
- )
836
- x = output.last_hidden_state
837
- features = self.pooler(x, attention_mask)
838
- logits = self.classifier(features)
839
- loss = None
840
- if labels is not None:
841
- labels = labels.to(logits.device)
842
- if self.config.problem_type is None:
843
- if self.num_labels == 1:
844
- self.config.problem_type = "regression"
845
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
846
- self.config.problem_type = "single_label_classification"
847
- else:
848
- self.config.problem_type = "multi_label_classification"
849
-
850
- if self.config.problem_type == "regression":
851
- if self.num_labels == 1:
852
- loss = self.mse(logits.flatten(), labels.flatten())
853
- else:
854
- loss = self.mse(logits, labels)
855
- elif self.config.problem_type == "single_label_classification":
856
- loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
857
- elif self.config.problem_type == "multi_label_classification":
858
- loss = self.bce(logits, labels)
859
-
860
- return ESMplusplusOutput(
861
- loss=loss,
862
- logits=logits,
863
- last_hidden_state=x,
864
- hidden_states=output.hidden_states,
865
- attentions=output.attentions,
866
- )
867
-
868
-
869
- class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM, EmbeddingMixin):
870
- """
871
- ESM++ model for token classification.
872
- Extends the base ESM++ model with a token classification head.
873
- """
874
- def __init__(self, config: ESMplusplusConfig, **kwargs):
875
- ESMplusplusForMaskedLM.__init__(self, config, **kwargs)
876
- self.config = config
877
- self.num_labels = config.num_labels
878
- self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
879
- # Large intermediate projections help with sequence classification tasks (*4)
880
- self.loss_fct = nn.CrossEntropyLoss()
881
- self.init_weights()
882
-
883
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
884
- x = self.embed(input_ids)
885
- return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
886
-
887
- def forward(
888
- self,
889
- input_ids: Optional[torch.Tensor] = None,
890
- attention_mask: Optional[torch.Tensor] = None,
891
- inputs_embeds: Optional[torch.Tensor] = None,
892
- labels: Optional[torch.Tensor] = None,
893
- output_attentions: Optional[bool] = None,
894
- output_hidden_states: Optional[bool] = None,
895
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
896
- **kwargs,
897
- ) -> ESMplusplusOutput:
898
- """Forward pass for token classification.
899
-
900
- Args:
901
- input_ids: Input token IDs
902
- attention_mask: Attention mask
903
- inputs_embeds: Optional precomputed embeddings
904
- labels: Optional labels for token classification
905
- output_hidden_states: Whether to return all hidden states
906
- output_attentions: Whether to return attention weights
907
-
908
- Returns:
909
- ESMplusplusOutput containing loss, logits, and hidden states
910
- """
911
- output = super().forward(
912
- input_ids=input_ids,
913
- attention_mask=attention_mask,
914
- inputs_embeds=inputs_embeds,
915
- labels=None,
916
- output_attentions=output_attentions,
917
- output_hidden_states=output_hidden_states
918
- )
919
- x = output.last_hidden_state
920
- logits = self.classifier(x)
921
- loss = None
922
- if labels is not None:
923
- loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
924
- return ESMplusplusOutput(
925
- loss=loss,
926
- logits=logits,
927
- last_hidden_state=x,
928
- hidden_states=output.hidden_states,
929
- attentions=output.attentions,
930
- )
931
-
932
-
933
- ### Loading from EvolutionaryScale
934
- _ESMC_CHECKPOINT_SPECS = {
935
- "esmc-300": {
936
- "repo_id": "EvolutionaryScale/esmc-300m-2024-12",
937
- "weights_relpath": "data/weights/esmc_300m_2024_12_v0.pth",
938
- "hidden_size": 960,
939
- "num_attention_heads": 15,
940
- "num_hidden_layers": 30,
941
- },
942
- "esmc-600": {
943
- "repo_id": "EvolutionaryScale/esmc-600m-2024-12",
944
- "weights_relpath": "data/weights/esmc_600m_2024_12_v0.pth",
945
- "hidden_size": 1152,
946
- "num_attention_heads": 18,
947
- "num_hidden_layers": 36,
948
- },
949
- }
950
-
951
-
952
- def _resolve_esmc_checkpoint_key(model: str) -> str:
953
- if "esmc-300" in model:
954
- return "esmc-300"
955
- if "esmc-600" in model:
956
- return "esmc-600"
957
- raise ValueError(f"{model=} is an invalid ESMC model name.")
958
-
959
-
960
- @staticmethod
961
- @cache
962
- def data_root(model: str):
963
- if "INFRA_PROVIDER" in os.environ:
964
- return Path("")
965
- key = _resolve_esmc_checkpoint_key(model)
966
- return Path(snapshot_download(repo_id=_ESMC_CHECKPOINT_SPECS[key]["repo_id"]))
967
-
968
-
969
- def get_esmc_checkpoint_path(model: str) -> Path:
970
- key = _resolve_esmc_checkpoint_key(model)
971
- return data_root(key) / _ESMC_CHECKPOINT_SPECS[key]["weights_relpath"]
972
-
973
-
974
- def _load_esmc_checkpoint_model(
975
- config: ESMplusplusConfig,
976
- model: str,
977
- device: torch.device | str = "cpu",
978
- ) -> ESMplusplusForMaskedLM:
979
- key = _resolve_esmc_checkpoint_key(model)
980
- spec = _ESMC_CHECKPOINT_SPECS[key]
981
- assert config.hidden_size == spec["hidden_size"], (
982
- f"ESMC loader expected hidden_size={spec['hidden_size']} for {key}, "
983
- f"but got {config.hidden_size}."
984
- )
985
- assert config.num_attention_heads == spec["num_attention_heads"], (
986
- f"ESMC loader expected num_attention_heads={spec['num_attention_heads']} for {key}, "
987
- f"but got {config.num_attention_heads}."
988
- )
989
- assert config.num_hidden_layers == spec["num_hidden_layers"], (
990
- f"ESMC loader expected num_hidden_layers={spec['num_hidden_layers']} for {key}, "
991
- f"but got {config.num_hidden_layers}."
992
- )
993
- with torch.device(device):
994
- model_obj = ESMplusplusForMaskedLM(config)
995
- state_dict = torch.load(get_esmc_checkpoint_path(key), map_location=device)
996
- model_obj.load_state_dict(state_dict)
997
- return model_obj
998
-
999
-
1000
- def ESMplusplus_300M(device: torch.device | str = "cpu"):
1001
- config = ESMplusplusConfig(
1002
- hidden_size=960,
1003
- num_attention_heads=15,
1004
- num_hidden_layers=30,
1005
- )
1006
- return _load_esmc_checkpoint_model(config=config, model="esmc-300", device=device)
1007
-
1008
-
1009
- def ESMplusplus_600M(device: torch.device | str = "cpu"):
1010
- config = ESMplusplusConfig(
1011
- hidden_size=1152,
1012
- num_attention_heads=18,
1013
- num_hidden_layers=36,
1014
- )
1015
- return _load_esmc_checkpoint_model(config=config, model="esmc-600", device=device)
1016
-
1017
-
1018
- ### Tokenization
1019
- SEQUENCE_VOCAB = [
1020
- "<cls>", "<pad>", "<eos>", "<unk>",
1021
- "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
1022
- "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
1023
- "O", ".", "-", "|",
1024
- "<mask>",
1025
- ]
1026
-
1027
- class EsmSequenceTokenizer(PreTrainedTokenizerFast):
1028
- model_input_names = ["input_ids", "attention_mask"]
1029
-
1030
- def __init__(
1031
- self,
1032
- unk_token="<unk>",
1033
- cls_token="<cls>",
1034
- pad_token="<pad>",
1035
- mask_token="<mask>",
1036
- eos_token="<eos>",
1037
- chain_break_token="|",
1038
- **kwargs,
1039
- ):
1040
- all_tokens = SEQUENCE_VOCAB
1041
- token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
1042
-
1043
- # a character-level tokenizer is the same as BPE with no token merges
1044
- bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
1045
- tokenizer = Tokenizer(bpe)
1046
- special_tokens = [
1047
- cls_token,
1048
- pad_token,
1049
- mask_token,
1050
- eos_token,
1051
- chain_break_token,
1052
- ]
1053
- self.cb_token = chain_break_token
1054
- additional_special_tokens = [chain_break_token]
1055
-
1056
- tokenizer.add_special_tokens(special_tokens)
1057
-
1058
- # This is where we configure the automatic addition of special tokens when we call
1059
- # tokenizer(text, add_special_tokens=True). Note that you can also configure how two
1060
- # sequences are merged if you want.
1061
- tokenizer.post_processor = TemplateProcessing( # type: ignore
1062
- single="<cls> $A <eos>",
1063
- pair="<cls>:0 $A:0 <eos>:0 $B:1 <eos>:1",
1064
- special_tokens=[
1065
- ("<cls>", tokenizer.token_to_id("<cls>")),
1066
- ("<eos>", tokenizer.token_to_id("<eos>")),
1067
- ],
1068
- )
1069
- super().__init__(
1070
- tokenizer_object=tokenizer,
1071
- unk_token=unk_token,
1072
- cls_token=cls_token,
1073
- pad_token=pad_token,
1074
- mask_token=mask_token,
1075
- eos_token=eos_token,
1076
- additional_special_tokens=additional_special_tokens,
1077
- **kwargs,
1078
- )
1079
-
1080
- # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
1081
- @property
1082
- def bos_token(self):
1083
- return self.cls_token
1084
-
1085
- @property
1086
- def bos_token_id(self):
1087
- return self.cls_token_id
1088
-
1089
- @property
1090
- def chain_break_token(self):
1091
- return self.cb_token
1092
-
1093
- @property
1094
- def chain_break_token_id(self):
1095
- return self.convert_tokens_to_ids(self.chain_break_token)
1096
-
1097
- @property
1098
- def all_token_ids(self):
1099
- return list(range(self.vocab_size))
1100
-
1101
- @property
1102
- def special_token_ids(self):
1103
- return self.all_special_ids
1104
-
1105
-
1106
- if __name__ == "__main__":
1107
- # Set device to CPU for testing
1108
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1109
- print(f"Using device: {device}")
1110
-
1111
- # Test tokenizer
1112
- tokenizer = EsmSequenceTokenizer()
1113
- sample_sequence = "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG"
1114
- encoding = tokenizer(sample_sequence, return_tensors="pt")
1115
- print(f"Input sequence length: {len(sample_sequence)}")
1116
- print(f"Tokenized sequence: {encoding['input_ids'].shape}")
1117
-
1118
- # Prepare inputs
1119
- input_ids = encoding['input_ids'].to(device)
1120
- attention_mask = encoding['attention_mask'].to(device)
1121
-
1122
- # Test base model with smaller config for quick testing
1123
- print("\n=== Testing ESMplusplus Base Model ===")
1124
- base_config = ESMplusplusConfig(
1125
- hidden_size=384,
1126
- num_attention_heads=6,
1127
- num_hidden_layers=4
1128
- )
1129
- base_model = ESMplusplusModel(base_config).to(device)
1130
-
1131
- with torch.no_grad():
1132
- outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
1133
-
1134
- print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1135
-
1136
- # Test embedding functionality
1137
- print("\nTesting embedding functionality:")
1138
- with torch.no_grad():
1139
- embeddings = base_model._embed(input_ids, attention_mask)
1140
- print(f"Embedding shape: {embeddings.shape}")
1141
-
1142
- # Test masked language modeling
1143
- print("\n=== Testing ESMplusplus For Masked LM ===")
1144
- mlm_model = ESMplusplusForMaskedLM(base_config).to(device)
1145
-
1146
- with torch.no_grad():
1147
- outputs = mlm_model(input_ids=input_ids, attention_mask=attention_mask)
1148
-
1149
- print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1150
- print(f"Logits shape: {outputs.logits.shape}")
1151
-
1152
- # Test sequence classification model
1153
- print("\n=== Testing Sequence Classification Model ===")
1154
- classification_model = ESMplusplusForSequenceClassification(base_config).to(device)
1155
-
1156
- with torch.no_grad():
1157
- outputs = classification_model(input_ids=input_ids, attention_mask=attention_mask)
1158
-
1159
- print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1160
- print(f"Logits shape: {outputs.logits.shape}")
1161
-
1162
- # Test token classification model
1163
- print("\n=== Testing Token Classification Model ===")
1164
- token_model = ESMplusplusForTokenClassification(base_config).to(device)
1165
-
1166
- with torch.no_grad():
1167
- outputs = token_model(input_ids=input_ids, attention_mask=attention_mask)
1168
-
1169
- print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1170
- print(f"Logits shape: {outputs.logits.shape}")
1171
-
1172
- # Test embedding dataset functionality with a mini dataset
1173
- print("\n=== Testing Embed Dataset Functionality ===")
1174
- mini_dataset = [sample_sequence, sample_sequence[:50], sample_sequence[:30]]
1175
- print(f"Creating embeddings for {len(mini_dataset)} sequences")
1176
-
1177
- # Only run this if save path doesn't exist to avoid overwriting
1178
- if not os.path.exists("test_embeddings.pth"):
1179
- embeddings = mlm_model.embed_dataset(
1180
- sequences=mini_dataset,
1181
- tokenizer=tokenizer,
1182
- batch_size=2,
1183
- max_len=100,
1184
- full_embeddings=False,
1185
- pooling_types=['mean'],
1186
- save_path="test_embeddings.pth"
1187
- )
1188
- if embeddings:
1189
- print(f"Embedding dictionary size: {len(embeddings)}")
1190
- for seq, emb in embeddings.items():
1191
- print(f"Sequence length: {len(seq)}, Embedding shape: {emb.shape}")
1192
- break
1193
- else:
1194
- print("Skipping embedding test as test_embeddings.pth already exists")
1195
-
1196
- print("\nAll tests completed successfully!")
1197
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ESM++ model implementation.
3
+
4
+ ESM++ is a faithful implementation of ESMC that allows for batching and standard Huggingface compatibility
5
+ The ESM Python package is not required
6
+
7
+ Modified from https://github.com/evolutionaryscale/esm
8
+ License: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement
9
+ """
10
+
11
+ import entrypoint_setup
12
+ import math
13
+ import os
14
+ import warnings
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from dataclasses import dataclass
19
+ from functools import cache, partial
20
+ from pathlib import Path
21
+ from typing import Optional, Tuple, Union, List
22
+ from einops import rearrange, repeat
23
+ from huggingface_hub import snapshot_download
24
+ from tokenizers import Tokenizer
25
+ from tokenizers.models import BPE
26
+ from tokenizers.processors import TemplateProcessing
27
+ from torch.nn.attention.flex_attention import create_block_mask
28
+ from torch.nn.attention.flex_attention import flex_attention
29
+ from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
30
+ from transformers.modeling_outputs import ModelOutput
31
+
32
+ try:
33
+ # when used from AutoModel, these are in the same directory
34
+ from .embedding_mixin import EmbeddingMixin, Pooler
35
+ except:
36
+ # when running from our repo, these are in the base directory
37
+ from embedding_mixin import EmbeddingMixin, Pooler
38
+
39
+
40
+ def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
41
+ token_valid = attention_mask_2d.bool()
42
+ batch_size, seq_len = token_valid.shape
43
+
44
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
45
+ return token_valid[batch_idx, q_idx] & token_valid[batch_idx, kv_idx]
46
+
47
+ return create_block_mask(
48
+ mask_mod,
49
+ batch_size,
50
+ 1,
51
+ seq_len,
52
+ seq_len,
53
+ device=attention_mask_2d.device,
54
+ )
55
+
56
+
57
+ class ESMplusplusConfig(PretrainedConfig):
58
+ """Configuration class for ESM++ model.
59
+
60
+ Args:
61
+ vocab_size: Size of the vocabulary
62
+ hidden_size: Dimension of hidden layers
63
+ num_attention_heads: Number of attention heads
64
+ num_hidden_layers: Number of transformer layers
65
+ num_labels: Number of output labels for classification
66
+ problem_type: Type of problem - regression, single/multi label classification
67
+ """
68
+ model_type = "ESMplusplus"
69
+ def __init__(
70
+ self,
71
+ vocab_size: int = 64,
72
+ hidden_size: int = 960,
73
+ num_attention_heads: int = 15,
74
+ num_hidden_layers: int = 30,
75
+ num_labels: int = 2,
76
+ problem_type: str | None = None,
77
+ dropout: float = 0.0,
78
+ initializer_range: float = 0.02,
79
+ attn_backend: str = "sdpa",
80
+ **kwargs,
81
+ ):
82
+ super().__init__(**kwargs)
83
+ self.vocab_size = vocab_size
84
+ self.hidden_size = hidden_size
85
+ self.num_attention_heads = num_attention_heads
86
+ self.num_hidden_layers = num_hidden_layers
87
+ self.num_labels = num_labels
88
+ self.problem_type = problem_type
89
+ self.dropout = dropout
90
+ self.initializer_range = initializer_range
91
+ self.tie_word_embeddings = False
92
+ self.attn_backend = attn_backend
93
+
94
+
95
+ ### Rotary Embeddings
96
+ def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
97
+ """Rotates half the hidden dims of the input."""
98
+ if not interleaved:
99
+ x1, x2 = x.chunk(2, dim=-1)
100
+ return torch.cat((-x2, x1), dim=-1)
101
+ else:
102
+ x1, x2 = x[..., ::2], x[..., 1::2]
103
+ return rearrange(
104
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
105
+ )
106
+
107
+
108
+ def apply_rotary_emb_torch(
109
+ x: torch.Tensor,
110
+ cos: torch.Tensor,
111
+ sin: torch.Tensor,
112
+ interleaved: bool = False,
113
+ _inplace: bool = False,
114
+ ) -> torch.Tensor:
115
+ """Apply rotary embeddings to input based on cos and sin."""
116
+ ro_dim = cos.shape[-1] * 2
117
+ assert ro_dim <= x.shape[-1]
118
+ seqlen = x.size(1)
119
+ cos = cos[:seqlen]
120
+ sin = sin[:seqlen]
121
+ cos = repeat(cos, "s d -> s 1 (2 d)")
122
+ sin = repeat(sin, "s d -> s 1 (2 d)")
123
+ return torch.cat(
124
+ [
125
+ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
126
+ x[..., ro_dim:],
127
+ ],
128
+ dim=-1,
129
+ )
130
+
131
+
132
+ class RotaryEmbedding(torch.nn.Module):
133
+ """Rotary position embeddings.
134
+
135
+ Based on the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding"
136
+
137
+ Args:
138
+ dim: Dimension of the embedding
139
+ base: Base for computing angular frequencies
140
+ interleaved: Whether to use interleaved rotations
141
+ scale_base: Base for scaling
142
+ scaling_factor: Factor for scaling positions
143
+ pos_idx_in_fp32: Whether to compute position indices in fp32
144
+ device: Computation device
145
+ """
146
+ def __init__(
147
+ self,
148
+ dim: int,
149
+ base: float = 10000.0,
150
+ interleaved: bool = False,
151
+ scale_base: Optional[float] = None,
152
+ scaling_factor: float = 1.0,
153
+ pos_idx_in_fp32: bool = True,
154
+ device: Optional[torch.device] = None,
155
+ ):
156
+ super().__init__()
157
+ self.dim = dim
158
+ self.base = float(base)
159
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
160
+ self.interleaved = interleaved
161
+ self.scale_base = scale_base
162
+ self.scaling_factor = scaling_factor
163
+ self.device = device
164
+
165
+ self._seq_len_cached = 0
166
+ self._cos_cached = None
167
+ self._sin_cached = None
168
+ self._cos_k_cached = None
169
+ self._sin_k_cached = None
170
+ self._inv_freq_compute_device: Optional[torch.device] = None
171
+ self.reset_parameters()
172
+
173
+ def reset_parameters(self):
174
+ """Reset the parameters of the embedding."""
175
+ if "inv_freq" in self._buffers and isinstance(self._buffers["inv_freq"], torch.Tensor):
176
+ buffer_device = self._buffers["inv_freq"].device
177
+ else:
178
+ buffer_device = self.device
179
+ inv_freq = self._compute_inv_freq(buffer_device)
180
+ self._inv_freq_compute_device = inv_freq.device
181
+ self._seq_len_cached = 0
182
+ self._cos_cached = None
183
+ self._sin_cached = None
184
+ self._cos_k_cached = None
185
+ self._sin_k_cached = None
186
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
187
+ arange = torch.arange(0, self.dim, 2, device=buffer_device, dtype=torch.float32)
188
+ scale = (
189
+ (arange + 0.4 * self.dim) / (1.4 * self.dim)
190
+ if self.scale_base is not None
191
+ else None
192
+ )
193
+ self.register_buffer("scale", scale)
194
+
195
+ def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
196
+ """Compute inverse frequency bands."""
197
+ return 1 / (
198
+ self.base
199
+ ** (
200
+ torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
201
+ / self.dim
202
+ )
203
+ )
204
+
205
+ def _update_cos_sin_cache(self, seqlen: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
206
+ """Update the cached cosine and sine values."""
207
+ if (
208
+ seqlen > self._seq_len_cached
209
+ or self._cos_cached is None
210
+ or self._cos_cached.device != device
211
+ or self._cos_cached.dtype != dtype
212
+ or (self.training and self._cos_cached.is_inference())
213
+ ):
214
+ self._seq_len_cached = seqlen
215
+ if self.pos_idx_in_fp32:
216
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
217
+ t /= self.scaling_factor
218
+ if self.inv_freq.dtype != torch.float32:
219
+ inv_freq = self.inv_freq.to(torch.float32)
220
+ else:
221
+ inv_freq = self.inv_freq
222
+ else:
223
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
224
+ t /= self.scaling_factor
225
+ inv_freq = self.inv_freq
226
+ freqs = torch.outer(t, inv_freq)
227
+
228
+ if self.scale is None:
229
+ self._cos_cached = torch.cos(freqs).to(dtype)
230
+ self._sin_cached = torch.sin(freqs).to(dtype)
231
+ else:
232
+ power = (
233
+ torch.arange(
234
+ seqlen, dtype=self.scale.dtype, device=self.scale.device
235
+ )
236
+ - seqlen // 2
237
+ ) / self.scale_base
238
+ scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
239
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
240
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
241
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
242
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
243
+
244
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
245
+ """Apply rotary embeddings to queries and keys.
246
+
247
+ Args:
248
+ q: Query tensor of shape (batch, seqlen, nheads, headdim)
249
+ k: Key tensor of shape (batch, seqlen, nheads, headdim)
250
+
251
+ Returns:
252
+ Tuple of rotated query and key tensors
253
+ """
254
+ assert self._inv_freq_compute_device is not None, "Rotary inv_freq compute device should be set after initialization."
255
+ if self._inv_freq_compute_device != q.device:
256
+ self.reset_parameters()
257
+ self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
258
+ assert self._cos_cached is not None
259
+ assert self._sin_cached is not None
260
+ if self.scale is None:
261
+ return (
262
+ apply_rotary_emb_torch(
263
+ q,
264
+ self._cos_cached,
265
+ self._sin_cached,
266
+ self.interleaved,
267
+ True, # inplace=True
268
+ ),
269
+ apply_rotary_emb_torch(
270
+ k,
271
+ self._cos_cached,
272
+ self._sin_cached,
273
+ self.interleaved,
274
+ True, # inplace=True
275
+ ),
276
+ ) # type: ignore
277
+ else:
278
+ assert False
279
+
280
+
281
+ ### Feedforward Network Components
282
+ def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
283
+ """Compute corrected dimension for SwiGLU."""
284
+ return int(((expansion_ratio * d_model) + 255) // 256 * 256)
285
+
286
+
287
+ class SwiGLU(nn.Module):
288
+ """SwiGLU activation function."""
289
+ def __init__(self):
290
+ super(SwiGLU, self).__init__()
291
+
292
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
293
+ x1, x2 = x.chunk(2, dim=-1)
294
+ return F.silu(x1) * x2
295
+
296
+
297
+ def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
298
+ """Create SwiGLU feedforward network with layer normalization."""
299
+ return nn.Sequential(
300
+ nn.LayerNorm(d_model),
301
+ nn.Linear(
302
+ d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
303
+ ),
304
+ SwiGLU(),
305
+ nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
306
+ )
307
+
308
+
309
+ ### Attention
310
+ class MultiHeadAttention(nn.Module):
311
+ """Multi-head attention with rotary embeddings.
312
+
313
+ Args:
314
+ d_model: Model dimension
315
+ n_heads: Number of attention heads
316
+ """
317
+ def __init__(
318
+ self,
319
+ d_model: int,
320
+ n_heads: int,
321
+ attn_backend: str = "sdpa",
322
+ ):
323
+ super().__init__()
324
+ self.d_model = d_model
325
+ self.n_heads = n_heads
326
+ self.d_head = self.d_model // self.n_heads
327
+ self.attn_backend = attn_backend
328
+ self._warned_flex_fallback = False
329
+ self.layernorm_qkv = nn.Sequential(
330
+ nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
331
+ )
332
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
333
+ self.q_ln = nn.LayerNorm(d_model, bias=False)
334
+ self.k_ln = nn.LayerNorm(d_model, bias=False)
335
+ self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads)
336
+ self.rotary = RotaryEmbedding(d_model // n_heads)
337
+
338
+ def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
339
+ """Apply rotary embeddings to query and key."""
340
+ q = q.unflatten(-1, (self.n_heads, self.d_head))
341
+ k = k.unflatten(-1, (self.n_heads, self.d_head))
342
+ q, k = self.rotary(q, k)
343
+ q = q.flatten(-2, -1)
344
+ k = k.flatten(-2, -1)
345
+ return q, k
346
+
347
+ def forward(
348
+ self,
349
+ x: torch.Tensor,
350
+ attention_mask: Optional[torch.Tensor] = None,
351
+ flex_block_mask: Optional[object] = None,
352
+ output_attentions: bool = False,
353
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
354
+ """
355
+ Args:
356
+ x: Input tensor
357
+ attention_mask: Optional attention mask
358
+ output_attentions: Whether to return attention weights
359
+
360
+ Returns:
361
+ Output tensor after self attention, and optionally attention weights
362
+ """
363
+ attn_weights = None
364
+ qkv_BLD3 = self.layernorm_qkv(x)
365
+ query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
366
+ query_BLD, key_BLD = (
367
+ self.q_ln(query_BLD).to(query_BLD.dtype),
368
+ self.k_ln(key_BLD).to(query_BLD.dtype),
369
+ )
370
+ query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
371
+ query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
372
+ scale = 1 / math.sqrt(self.d_head)
373
+
374
+ if output_attentions: # Manual attention computation
375
+ b, h, l, _ = query_BHLD.shape
376
+ attn_bias = torch.zeros(b, h, l, l, dtype=query_BLD.dtype, device=query_BLD.device)
377
+ if attention_mask is not None:
378
+ attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
379
+ attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
380
+ attn_weights += attn_bias
381
+ attn_weights = F.softmax(attn_weights, dim=-1)
382
+ context_BHLD = torch.matmul(attn_weights, value_BHLD)
383
+ else:
384
+ sdpa_mask = None
385
+ if attention_mask is not None:
386
+ sdpa_mask = torch.zeros_like(attention_mask, dtype=query_BHLD.dtype)
387
+ sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
388
+ flex_dtype_supported = query_BHLD.dtype in (torch.float16, torch.bfloat16)
389
+ use_flex = (
390
+ self.attn_backend == "flex"
391
+ and flex_dtype_supported
392
+ and (attention_mask is None or flex_block_mask is not None)
393
+ )
394
+ if self.attn_backend == "flex" and not flex_dtype_supported and not self._warned_flex_fallback:
395
+ warnings.warn(
396
+ "Flex attention backend requested in float32; falling back to SDPA for strict numerical parity.",
397
+ RuntimeWarning,
398
+ )
399
+ self._warned_flex_fallback = True
400
+ if use_flex:
401
+ try:
402
+ context_BHLD = flex_attention(
403
+ query_BHLD,
404
+ key_BHLD,
405
+ value_BHLD,
406
+ block_mask=flex_block_mask,
407
+ scale=scale,
408
+ )
409
+ except Exception as exc:
410
+ if not self._warned_flex_fallback:
411
+ warnings.warn(
412
+ f"Flex attention failed in ESM++ attention; falling back to SDPA. Error: {exc}",
413
+ RuntimeWarning,
414
+ )
415
+ self._warned_flex_fallback = True
416
+ context_BHLD = F.scaled_dot_product_attention(
417
+ query_BHLD,
418
+ key_BHLD,
419
+ value_BHLD,
420
+ attn_mask=sdpa_mask,
421
+ scale=scale,
422
+ )
423
+ else:
424
+ context_BHLD = F.scaled_dot_product_attention(
425
+ query_BHLD,
426
+ key_BHLD,
427
+ value_BHLD,
428
+ attn_mask=sdpa_mask,
429
+ scale=scale,
430
+ )
431
+
432
+ context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
433
+ output = self.out_proj(context_BLD)
434
+ return output, attn_weights
435
+
436
+
437
+ ### Regression Head
438
+ def RegressionHead(d_model: int, output_dim: int, hidden_dim: Optional[int] = None) -> nn.Module:
439
+ """Create a regression head with optional hidden dimension.
440
+
441
+ Args:
442
+ d_model: Input dimension
443
+ output_dim: Output dimension
444
+ hidden_dim: Optional hidden dimension (defaults to d_model)
445
+ """
446
+ hidden_dim = hidden_dim if hidden_dim is not None else d_model
447
+ return nn.Sequential(
448
+ nn.Linear(d_model, hidden_dim),
449
+ nn.GELU(),
450
+ nn.LayerNorm(hidden_dim),
451
+ nn.Linear(hidden_dim, output_dim),
452
+ )
453
+
454
+
455
+ ### Transformer Block
456
+ class UnifiedTransformerBlock(nn.Module):
457
+ """Transformer block with attention and feedforward layers.
458
+
459
+ Args:
460
+ d_model: Model dimension
461
+ n_heads: Number of attention heads
462
+ residue_scaling_factor: Factor for scaling residual connections
463
+ expansion_ratio: Expansion ratio for feedforward network
464
+ """
465
+ def __init__(
466
+ self,
467
+ d_model: int,
468
+ n_heads: int,
469
+ residue_scaling_factor: float = 1,
470
+ expansion_ratio: float = 8 / 3,
471
+ dropout: float = 0.0,
472
+ attn_backend: str = "sdpa",
473
+ ):
474
+ super().__init__()
475
+ self.attn = MultiHeadAttention(
476
+ d_model=d_model,
477
+ n_heads=n_heads,
478
+ attn_backend=attn_backend,
479
+ )
480
+ self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
481
+ self.scaling_factor = residue_scaling_factor
482
+ self.dropout = nn.Dropout(dropout)
483
+
484
+ def forward(
485
+ self,
486
+ x: torch.Tensor,
487
+ attention_mask: Optional[torch.Tensor] = None,
488
+ flex_block_mask: Optional[object] = None,
489
+ output_attentions: bool = False,
490
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
491
+ """
492
+ Args:
493
+ x: Input tensor
494
+ attention_mask: Optional attention mask
495
+ output_attentions: Whether to return attention weights
496
+
497
+ Returns:
498
+ Output tensor after transformer block, and optionally attention weights
499
+ """
500
+ attn_output, attn_weights = self.attn(
501
+ x,
502
+ attention_mask,
503
+ flex_block_mask,
504
+ output_attentions,
505
+ )
506
+ x = x + self.dropout(attn_output) / self.scaling_factor
507
+ x = x + self.dropout(self.ffn(x)) / self.scaling_factor
508
+ return x, attn_weights
509
+
510
+
511
+ ### Model Outputs
512
+ @dataclass
513
+ class TransformerOutput(ModelOutput):
514
+ """Output type for transformer encoder."""
515
+ last_hidden_state: Optional[torch.Tensor] = None
516
+ hidden_states: Optional[Tuple[torch.Tensor]] = None
517
+ attentions: Optional[Tuple[torch.Tensor]] = None
518
+
519
+
520
+ @dataclass
521
+ class ESMplusplusOutput(ModelOutput):
522
+ """Output type for ESM++ models."""
523
+ loss: Optional[torch.Tensor] = None
524
+ logits: Optional[torch.Tensor] = None
525
+ last_hidden_state: Optional[torch.Tensor] = None
526
+ hidden_states: Optional[Tuple[torch.Tensor]] = None
527
+ attentions: Optional[Tuple[torch.Tensor]] = None
528
+
529
+
530
+ ### Transformer Stack
531
+ class TransformerStack(nn.Module):
532
+ """Stack of transformer blocks.
533
+
534
+ Args:
535
+ d_model: Model dimension
536
+ n_heads: Number of attention heads
537
+ n_layers: Number of transformer layers
538
+ dropout: Dropout rate
539
+ """
540
+ def __init__(
541
+ self,
542
+ d_model: int,
543
+ n_heads: int,
544
+ n_layers: int,
545
+ dropout: float = 0.0,
546
+ attn_backend: str = "sdpa",
547
+ ):
548
+ super().__init__()
549
+ self.attn_backend = attn_backend
550
+ self.blocks = nn.ModuleList(
551
+ [
552
+ UnifiedTransformerBlock(
553
+ d_model,
554
+ n_heads,
555
+ residue_scaling_factor=math.sqrt(n_layers / 36),
556
+ dropout=dropout,
557
+ attn_backend=attn_backend,
558
+ )
559
+ for i in range(n_layers)
560
+ ]
561
+ )
562
+ self.norm = nn.LayerNorm(d_model, bias=False)
563
+ self.gradient_checkpointing = False
564
+
565
+ def forward(
566
+ self,
567
+ x: torch.Tensor,
568
+ attention_mask: Optional[torch.Tensor] = None,
569
+ output_hidden_states: bool = False,
570
+ output_attentions: bool = False,
571
+ ) -> TransformerOutput:
572
+ """
573
+ Args:
574
+ x: Input tensor
575
+ attention_mask: Optional attention mask
576
+ output_hidden_states: Whether to return all hidden states
577
+ output_attentions: Whether to return attention weights
578
+
579
+ Returns:
580
+ TransformerOutput containing last hidden state and optionally all hidden states and attention weights
581
+ """
582
+ hidden_states = () if output_hidden_states else None
583
+ attentions = () if output_attentions else None
584
+
585
+ if attention_mask is not None:
586
+ assert attention_mask.ndim == 2, f"Expected 2D token attention mask, got shape {attention_mask.shape}."
587
+ token_attention_mask = attention_mask.bool()
588
+ pairwise_attention_mask = token_attention_mask.unsqueeze(-1) & token_attention_mask.unsqueeze(-2)
589
+ attention_mask = pairwise_attention_mask.unsqueeze(1)
590
+ if self.attn_backend == "flex" and not output_attentions:
591
+ flex_block_mask = _create_pad_block_mask(token_attention_mask)
592
+ else:
593
+ flex_block_mask = None
594
+ else:
595
+ flex_block_mask = None
596
+
597
+ for block in self.blocks:
598
+ if self.gradient_checkpointing and self.training:
599
+ x, attn_weights = self._gradient_checkpointing_func(
600
+ block.__call__,
601
+ x,
602
+ attention_mask,
603
+ flex_block_mask,
604
+ output_attentions,
605
+ )
606
+ else:
607
+ x, attn_weights = block(x, attention_mask, flex_block_mask, output_attentions)
608
+
609
+ if attentions is not None:
610
+ attentions += (attn_weights,)
611
+
612
+ if output_hidden_states:
613
+ assert hidden_states is not None
614
+ hidden_states += (x,)
615
+
616
+ return TransformerOutput(
617
+ last_hidden_state=self.norm(x),
618
+ hidden_states=hidden_states,
619
+ attentions=attentions
620
+ )
621
+
622
+
623
+ class PreTrainedESMplusplusModel(PreTrainedModel):
624
+ """
625
+ init weights for ESM++ models
626
+ """
627
+ config_class = ESMplusplusConfig
628
+ base_model_prefix = "esm++"
629
+ supports_gradient_checkpointing = True
630
+ all_tied_weights_keys = {}
631
+
632
+ def _init_weights(self, module):
633
+ """Initialize the weights"""
634
+ # HF from_pretrained marks loaded parameters with `_is_hf_initialized`.
635
+ # Skip this module if any local parameter is already marked as loaded.
636
+ for parameter in module.parameters(recurse=False):
637
+ if "_is_hf_initialized" in parameter.__dict__ and parameter.__dict__["_is_hf_initialized"]:
638
+ return
639
+
640
+ if isinstance(module, nn.Linear):
641
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
642
+ if module.bias is not None:
643
+ nn.init.zeros_(module.bias)
644
+ elif isinstance(module, nn.Embedding):
645
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
646
+ if module.padding_idx is not None:
647
+ with torch.no_grad():
648
+ module.weight[module.padding_idx].zero_()
649
+ elif isinstance(module, nn.LayerNorm):
650
+ if module.bias is not None:
651
+ nn.init.zeros_(module.bias)
652
+ nn.init.ones_(module.weight)
653
+
654
+ def _reset_rotary_embeddings(self):
655
+ """Refresh non-persistent rotary buffers after checkpoint loading."""
656
+ for module in self.modules():
657
+ if isinstance(module, RotaryEmbedding):
658
+ module.reset_parameters()
659
+
660
+ @classmethod
661
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
662
+ output_loading_info = bool(kwargs["output_loading_info"]) if "output_loading_info" in kwargs else False
663
+ loaded = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
664
+ if output_loading_info:
665
+ model, loading_info = loaded
666
+ model._reset_rotary_embeddings()
667
+ return model, loading_info
668
+ loaded._reset_rotary_embeddings()
669
+ return loaded
670
+
671
+ @classmethod
672
+ def from_pretrained_esm(cls, model_name: str):
673
+ """Load a pretrained ESM++ model."""
674
+ if '300' in model_name:
675
+ return ESMplusplus_300M()
676
+ elif '600' in model_name:
677
+ return ESMplusplus_600M()
678
+ else:
679
+ raise ValueError(f"Invalid model name: {model_name}")
680
+
681
+
682
+ ### ESM++ Models
683
+ class ESMplusplusModel(PreTrainedESMplusplusModel, EmbeddingMixin):
684
+ """
685
+ ESM++ model. transformer model with no heads
686
+ """
687
+ config_class = ESMplusplusConfig
688
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
689
+ PreTrainedESMplusplusModel.__init__(self, config, **kwargs)
690
+ self.config = config
691
+ self.vocab_size = config.vocab_size
692
+ self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
693
+ self.transformer = TransformerStack(
694
+ d_model=config.hidden_size,
695
+ n_heads=config.num_attention_heads,
696
+ n_layers=config.num_hidden_layers,
697
+ dropout=config.dropout,
698
+ attn_backend=config.attn_backend,
699
+ )
700
+ self.tokenizer = EsmSequenceTokenizer()
701
+ self.init_weights()
702
+
703
+ def get_input_embeddings(self):
704
+ return self.embed
705
+
706
+ def set_input_embeddings(self, value):
707
+ self.embed = value
708
+
709
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
710
+ x = self.embed(input_ids)
711
+ return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
712
+
713
+ def forward(
714
+ self,
715
+ input_ids: Optional[torch.Tensor] = None,
716
+ attention_mask: Optional[torch.Tensor] = None,
717
+ inputs_embeds: Optional[torch.Tensor] = None,
718
+ output_attentions: Optional[bool] = None,
719
+ output_hidden_states: Optional[bool] = None,
720
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
721
+ **kwargs,
722
+ ) -> TransformerOutput:
723
+ """Forward pass for masked language modeling.
724
+
725
+ Args:
726
+ input_ids: Input token IDs
727
+ attention_mask: Attention mask
728
+ inputs_embeds: Optional precomputed embeddings
729
+ output_hidden_states: Whether to return all hidden states
730
+ output_attentions: Whether to return attention weights
731
+
732
+ Returns:
733
+ TransformerOutput containing last hidden state and optionally all hidden states and attention weights
734
+ """
735
+ if inputs_embeds is None:
736
+ x = self.embed(input_ids)
737
+ else:
738
+ x = inputs_embeds
739
+ return self.transformer(x, attention_mask, output_hidden_states, output_attentions)
740
+
741
+
742
+ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
743
+ """
744
+ ESM++ model for masked language modeling.
745
+ Implements the base ESM++ architecture with a masked language modeling head.
746
+ """
747
+ config_class = ESMplusplusConfig
748
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
749
+ PreTrainedESMplusplusModel.__init__(self, config, **kwargs)
750
+ self.config = config
751
+ self.vocab_size = config.vocab_size
752
+ self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
753
+ self.transformer = TransformerStack(
754
+ d_model=config.hidden_size,
755
+ n_heads=config.num_attention_heads,
756
+ n_layers=config.num_hidden_layers,
757
+ dropout=config.dropout,
758
+ attn_backend=config.attn_backend,
759
+ )
760
+ self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
761
+ self.ce_loss = nn.CrossEntropyLoss()
762
+ self.tokenizer = EsmSequenceTokenizer()
763
+ self.init_weights()
764
+
765
+ def get_input_embeddings(self):
766
+ return self.embed
767
+
768
+ def set_input_embeddings(self, value):
769
+ self.embed = value
770
+
771
+ def get_output_embeddings(self):
772
+ return self.sequence_head[-1]
773
+
774
+ def set_output_embeddings(self, new_embeddings):
775
+ self.sequence_head[-1] = new_embeddings
776
+
777
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
778
+ x = self.embed(input_ids)
779
+ return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
780
+
781
+ def forward(
782
+ self,
783
+ input_ids: Optional[torch.Tensor] = None,
784
+ attention_mask: Optional[torch.Tensor] = None,
785
+ inputs_embeds: Optional[torch.Tensor] = None,
786
+ labels: Optional[torch.Tensor] = None,
787
+ output_attentions: Optional[bool] = None,
788
+ output_hidden_states: Optional[bool] = None,
789
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
790
+ **kwargs,
791
+ ) -> ESMplusplusOutput:
792
+ """Forward pass for masked language modeling.
793
+
794
+ Args:
795
+ input_ids: Input token IDs
796
+ attention_mask: Attention mask
797
+ inputs_embeds: Optional precomputed embeddings
798
+ labels: Optional labels for masked tokens
799
+ output_hidden_states: Whether to return all hidden states
800
+ output_attentions: Whether to return attention weights
801
+
802
+ Returns:
803
+ ESMplusplusOutput containing loss, logits, hidden states and attention weights
804
+ """
805
+ if inputs_embeds is None:
806
+ x = self.embed(input_ids)
807
+ else:
808
+ x = inputs_embeds
809
+ output = self.transformer(x, attention_mask, output_hidden_states, output_attentions)
810
+ x = output.last_hidden_state
811
+ logits = self.sequence_head(x)
812
+ loss = None
813
+ if labels is not None:
814
+ loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
815
+ return ESMplusplusOutput(
816
+ loss=loss,
817
+ logits=logits,
818
+ last_hidden_state=x,
819
+ hidden_states=output.hidden_states,
820
+ attentions=output.attentions,
821
+ )
822
+
823
+
824
+ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixin):
825
+ """
826
+ ESM++ model for sequence classification.
827
+ Extends the base ESM++ model with a classification head.
828
+ """
829
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
830
+ ESMplusplusForMaskedLM.__init__(self, config, **kwargs)
831
+ self.config = config
832
+ self.num_labels = config.num_labels
833
+ self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
834
+ # Large intermediate projections help with sequence classification tasks (*4)
835
+ self.mse = nn.MSELoss()
836
+ self.ce = nn.CrossEntropyLoss()
837
+ self.bce = nn.BCEWithLogitsLoss()
838
+ # if kwargs has pooling_types, use them, otherwise use ['cls', 'mean']
839
+ if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0:
840
+ pooling_types = kwargs['pooling_types']
841
+ else:
842
+ pooling_types = ['cls', 'mean']
843
+ self.pooler = Pooler(pooling_types)
844
+ self.init_weights()
845
+
846
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
847
+ x = self.embed(input_ids)
848
+ return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
849
+
850
+ def forward(
851
+ self,
852
+ input_ids: Optional[torch.Tensor] = None,
853
+ attention_mask: Optional[torch.Tensor] = None,
854
+ inputs_embeds: Optional[torch.Tensor] = None,
855
+ labels: Optional[torch.Tensor] = None,
856
+ output_attentions: Optional[bool] = None,
857
+ output_hidden_states: Optional[bool] = None,
858
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
859
+ **kwargs,
860
+ ) -> ESMplusplusOutput:
861
+ """Forward pass for sequence classification.
862
+
863
+ Args:
864
+ input_ids: Input token IDs
865
+ attention_mask: Attention mask
866
+ inputs_embeds: Optional precomputed embeddings
867
+ labels: Optional labels for classification
868
+ output_hidden_states: Whether to return all hidden states
869
+ output_attentions: Whether to return attention weights
870
+
871
+ Returns:
872
+ ESMplusplusOutput containing loss, logits, and hidden states
873
+ """
874
+ output = super().forward(
875
+ input_ids=input_ids,
876
+ attention_mask=attention_mask,
877
+ inputs_embeds=inputs_embeds,
878
+ labels=None,
879
+ output_attentions=output_attentions,
880
+ output_hidden_states=output_hidden_states
881
+ )
882
+ x = output.last_hidden_state
883
+ features = self.pooler(x, attention_mask)
884
+ logits = self.classifier(features)
885
+ loss = None
886
+ if labels is not None:
887
+ labels = labels.to(logits.device)
888
+ if self.config.problem_type is None:
889
+ if self.num_labels == 1:
890
+ self.config.problem_type = "regression"
891
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
892
+ self.config.problem_type = "single_label_classification"
893
+ else:
894
+ self.config.problem_type = "multi_label_classification"
895
+
896
+ if self.config.problem_type == "regression":
897
+ if self.num_labels == 1:
898
+ loss = self.mse(logits.flatten(), labels.flatten())
899
+ else:
900
+ loss = self.mse(logits, labels)
901
+ elif self.config.problem_type == "single_label_classification":
902
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
903
+ elif self.config.problem_type == "multi_label_classification":
904
+ loss = self.bce(logits, labels)
905
+
906
+ return ESMplusplusOutput(
907
+ loss=loss,
908
+ logits=logits,
909
+ last_hidden_state=x,
910
+ hidden_states=output.hidden_states,
911
+ attentions=output.attentions,
912
+ )
913
+
914
+
915
+ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM, EmbeddingMixin):
916
+ """
917
+ ESM++ model for token classification.
918
+ Extends the base ESM++ model with a token classification head.
919
+ """
920
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
921
+ ESMplusplusForMaskedLM.__init__(self, config, **kwargs)
922
+ self.config = config
923
+ self.num_labels = config.num_labels
924
+ self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
925
+ # Large intermediate projections help with sequence classification tasks (*4)
926
+ self.loss_fct = nn.CrossEntropyLoss()
927
+ self.init_weights()
928
+
929
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
930
+ x = self.embed(input_ids)
931
+ return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
932
+
933
+ def forward(
934
+ self,
935
+ input_ids: Optional[torch.Tensor] = None,
936
+ attention_mask: Optional[torch.Tensor] = None,
937
+ inputs_embeds: Optional[torch.Tensor] = None,
938
+ labels: Optional[torch.Tensor] = None,
939
+ output_attentions: Optional[bool] = None,
940
+ output_hidden_states: Optional[bool] = None,
941
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
942
+ **kwargs,
943
+ ) -> ESMplusplusOutput:
944
+ """Forward pass for token classification.
945
+
946
+ Args:
947
+ input_ids: Input token IDs
948
+ attention_mask: Attention mask
949
+ inputs_embeds: Optional precomputed embeddings
950
+ labels: Optional labels for token classification
951
+ output_hidden_states: Whether to return all hidden states
952
+ output_attentions: Whether to return attention weights
953
+
954
+ Returns:
955
+ ESMplusplusOutput containing loss, logits, and hidden states
956
+ """
957
+ output = super().forward(
958
+ input_ids=input_ids,
959
+ attention_mask=attention_mask,
960
+ inputs_embeds=inputs_embeds,
961
+ labels=None,
962
+ output_attentions=output_attentions,
963
+ output_hidden_states=output_hidden_states
964
+ )
965
+ x = output.last_hidden_state
966
+ logits = self.classifier(x)
967
+ loss = None
968
+ if labels is not None:
969
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
970
+ return ESMplusplusOutput(
971
+ loss=loss,
972
+ logits=logits,
973
+ last_hidden_state=x,
974
+ hidden_states=output.hidden_states,
975
+ attentions=output.attentions,
976
+ )
977
+
978
+
979
+ ### Loading from EvolutionaryScale
980
+ _ESMC_CHECKPOINT_SPECS = {
981
+ "esmc-300": {
982
+ "repo_id": "EvolutionaryScale/esmc-300m-2024-12",
983
+ "weights_relpath": "data/weights/esmc_300m_2024_12_v0.pth",
984
+ "hidden_size": 960,
985
+ "num_attention_heads": 15,
986
+ "num_hidden_layers": 30,
987
+ },
988
+ "esmc-600": {
989
+ "repo_id": "EvolutionaryScale/esmc-600m-2024-12",
990
+ "weights_relpath": "data/weights/esmc_600m_2024_12_v0.pth",
991
+ "hidden_size": 1152,
992
+ "num_attention_heads": 18,
993
+ "num_hidden_layers": 36,
994
+ },
995
+ }
996
+
997
+
998
+ def _resolve_esmc_checkpoint_key(model: str) -> str:
999
+ if "esmc-300" in model:
1000
+ return "esmc-300"
1001
+ if "esmc-600" in model:
1002
+ return "esmc-600"
1003
+ raise ValueError(f"{model=} is an invalid ESMC model name.")
1004
+
1005
+
1006
+ @staticmethod
1007
+ @cache
1008
+ def data_root(model: str):
1009
+ if "INFRA_PROVIDER" in os.environ:
1010
+ return Path("")
1011
+ key = _resolve_esmc_checkpoint_key(model)
1012
+ return Path(snapshot_download(repo_id=_ESMC_CHECKPOINT_SPECS[key]["repo_id"]))
1013
+
1014
+
1015
+ def get_esmc_checkpoint_path(model: str) -> Path:
1016
+ key = _resolve_esmc_checkpoint_key(model)
1017
+ return data_root(key) / _ESMC_CHECKPOINT_SPECS[key]["weights_relpath"]
1018
+
1019
+
1020
+ def _load_esmc_checkpoint_model(
1021
+ config: ESMplusplusConfig,
1022
+ model: str,
1023
+ device: torch.device | str = "cpu",
1024
+ ) -> ESMplusplusForMaskedLM:
1025
+ key = _resolve_esmc_checkpoint_key(model)
1026
+ spec = _ESMC_CHECKPOINT_SPECS[key]
1027
+ assert config.hidden_size == spec["hidden_size"], (
1028
+ f"ESMC loader expected hidden_size={spec['hidden_size']} for {key}, "
1029
+ f"but got {config.hidden_size}."
1030
+ )
1031
+ assert config.num_attention_heads == spec["num_attention_heads"], (
1032
+ f"ESMC loader expected num_attention_heads={spec['num_attention_heads']} for {key}, "
1033
+ f"but got {config.num_attention_heads}."
1034
+ )
1035
+ assert config.num_hidden_layers == spec["num_hidden_layers"], (
1036
+ f"ESMC loader expected num_hidden_layers={spec['num_hidden_layers']} for {key}, "
1037
+ f"but got {config.num_hidden_layers}."
1038
+ )
1039
+ with torch.device(device):
1040
+ model_obj = ESMplusplusForMaskedLM(config)
1041
+ state_dict = torch.load(get_esmc_checkpoint_path(key), map_location=device)
1042
+ model_obj.load_state_dict(state_dict)
1043
+ return model_obj
1044
+
1045
+
1046
+ def ESMplusplus_300M(device: torch.device | str = "cpu"):
1047
+ config = ESMplusplusConfig(
1048
+ hidden_size=960,
1049
+ num_attention_heads=15,
1050
+ num_hidden_layers=30,
1051
+ )
1052
+ return _load_esmc_checkpoint_model(config=config, model="esmc-300", device=device)
1053
+
1054
+
1055
+ def ESMplusplus_600M(device: torch.device | str = "cpu"):
1056
+ config = ESMplusplusConfig(
1057
+ hidden_size=1152,
1058
+ num_attention_heads=18,
1059
+ num_hidden_layers=36,
1060
+ )
1061
+ return _load_esmc_checkpoint_model(config=config, model="esmc-600", device=device)
1062
+
1063
+
1064
+ ### Tokenization
1065
+ SEQUENCE_VOCAB = [
1066
+ "<cls>", "<pad>", "<eos>", "<unk>",
1067
+ "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
1068
+ "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
1069
+ "O", ".", "-", "|",
1070
+ "<mask>",
1071
+ ]
1072
+
1073
+ class EsmSequenceTokenizer(PreTrainedTokenizerFast):
1074
+ model_input_names = ["input_ids", "attention_mask"]
1075
+
1076
+ def __init__(
1077
+ self,
1078
+ unk_token="<unk>",
1079
+ cls_token="<cls>",
1080
+ pad_token="<pad>",
1081
+ mask_token="<mask>",
1082
+ eos_token="<eos>",
1083
+ chain_break_token="|",
1084
+ **kwargs,
1085
+ ):
1086
+ all_tokens = SEQUENCE_VOCAB
1087
+ token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
1088
+
1089
+ # a character-level tokenizer is the same as BPE with no token merges
1090
+ bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
1091
+ tokenizer = Tokenizer(bpe)
1092
+ special_tokens = [
1093
+ cls_token,
1094
+ pad_token,
1095
+ mask_token,
1096
+ eos_token,
1097
+ chain_break_token,
1098
+ ]
1099
+ self.cb_token = chain_break_token
1100
+ additional_special_tokens = [chain_break_token]
1101
+
1102
+ tokenizer.add_special_tokens(special_tokens)
1103
+
1104
+ # This is where we configure the automatic addition of special tokens when we call
1105
+ # tokenizer(text, add_special_tokens=True). Note that you can also configure how two
1106
+ # sequences are merged if you want.
1107
+ tokenizer.post_processor = TemplateProcessing( # type: ignore
1108
+ single="<cls> $A <eos>",
1109
+ pair="<cls>:0 $A:0 <eos>:0 $B:1 <eos>:1",
1110
+ special_tokens=[
1111
+ ("<cls>", tokenizer.token_to_id("<cls>")),
1112
+ ("<eos>", tokenizer.token_to_id("<eos>")),
1113
+ ],
1114
+ )
1115
+ super().__init__(
1116
+ tokenizer_object=tokenizer,
1117
+ unk_token=unk_token,
1118
+ cls_token=cls_token,
1119
+ pad_token=pad_token,
1120
+ mask_token=mask_token,
1121
+ eos_token=eos_token,
1122
+ additional_special_tokens=additional_special_tokens,
1123
+ **kwargs,
1124
+ )
1125
+
1126
+ # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
1127
+ @property
1128
+ def bos_token(self):
1129
+ return self.cls_token
1130
+
1131
+ @property
1132
+ def bos_token_id(self):
1133
+ return self.cls_token_id
1134
+
1135
+ @property
1136
+ def chain_break_token(self):
1137
+ return self.cb_token
1138
+
1139
+ @property
1140
+ def chain_break_token_id(self):
1141
+ return self.convert_tokens_to_ids(self.chain_break_token)
1142
+
1143
+ @property
1144
+ def all_token_ids(self):
1145
+ return list(range(self.vocab_size))
1146
+
1147
+ @property
1148
+ def special_token_ids(self):
1149
+ return self.all_special_ids
1150
+
1151
+
1152
+ if __name__ == "__main__":
1153
+ # Set device to CPU for testing
1154
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1155
+ print(f"Using device: {device}")
1156
+
1157
+ # Test tokenizer
1158
+ tokenizer = EsmSequenceTokenizer()
1159
+ sample_sequence = "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG"
1160
+ encoding = tokenizer(sample_sequence, return_tensors="pt")
1161
+ print(f"Input sequence length: {len(sample_sequence)}")
1162
+ print(f"Tokenized sequence: {encoding['input_ids'].shape}")
1163
+
1164
+ # Prepare inputs
1165
+ input_ids = encoding['input_ids'].to(device)
1166
+ attention_mask = encoding['attention_mask'].to(device)
1167
+
1168
+ # Test base model with smaller config for quick testing
1169
+ print("\n=== Testing ESMplusplus Base Model ===")
1170
+ base_config = ESMplusplusConfig(
1171
+ hidden_size=384,
1172
+ num_attention_heads=6,
1173
+ num_hidden_layers=4
1174
+ )
1175
+ base_model = ESMplusplusModel(base_config).to(device)
1176
+
1177
+ with torch.no_grad():
1178
+ outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
1179
+
1180
+ print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1181
+
1182
+ # Test embedding functionality
1183
+ print("\nTesting embedding functionality:")
1184
+ with torch.no_grad():
1185
+ embeddings = base_model._embed(input_ids, attention_mask)
1186
+ print(f"Embedding shape: {embeddings.shape}")
1187
+
1188
+ # Test masked language modeling
1189
+ print("\n=== Testing ESMplusplus For Masked LM ===")
1190
+ mlm_model = ESMplusplusForMaskedLM(base_config).to(device)
1191
+
1192
+ with torch.no_grad():
1193
+ outputs = mlm_model(input_ids=input_ids, attention_mask=attention_mask)
1194
+
1195
+ print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1196
+ print(f"Logits shape: {outputs.logits.shape}")
1197
+
1198
+ # Test sequence classification model
1199
+ print("\n=== Testing Sequence Classification Model ===")
1200
+ classification_model = ESMplusplusForSequenceClassification(base_config).to(device)
1201
+
1202
+ with torch.no_grad():
1203
+ outputs = classification_model(input_ids=input_ids, attention_mask=attention_mask)
1204
+
1205
+ print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1206
+ print(f"Logits shape: {outputs.logits.shape}")
1207
+
1208
+ # Test token classification model
1209
+ print("\n=== Testing Token Classification Model ===")
1210
+ token_model = ESMplusplusForTokenClassification(base_config).to(device)
1211
+
1212
+ with torch.no_grad():
1213
+ outputs = token_model(input_ids=input_ids, attention_mask=attention_mask)
1214
+
1215
+ print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1216
+ print(f"Logits shape: {outputs.logits.shape}")
1217
+
1218
+ # Test embedding dataset functionality with a mini dataset
1219
+ print("\n=== Testing Embed Dataset Functionality ===")
1220
+ mini_dataset = [sample_sequence, sample_sequence[:50], sample_sequence[:30]]
1221
+ print(f"Creating embeddings for {len(mini_dataset)} sequences")
1222
+
1223
+ # Only run this if save path doesn't exist to avoid overwriting
1224
+ if not os.path.exists("test_embeddings.pth"):
1225
+ embeddings = mlm_model.embed_dataset(
1226
+ sequences=mini_dataset,
1227
+ tokenizer=tokenizer,
1228
+ batch_size=2,
1229
+ max_len=100,
1230
+ full_embeddings=False,
1231
+ pooling_types=['mean'],
1232
+ save_path="test_embeddings.pth"
1233
+ )
1234
+ if embeddings:
1235
+ print(f"Embedding dictionary size: {len(embeddings)}")
1236
+ for seq, emb in embeddings.items():
1237
+ print(f"Sequence length: {len(seq)}, Embedding shape: {emb.shape}")
1238
+ break
1239
+ else:
1240
+ print("Skipping embedding test as test_embeddings.pth already exists")
1241
+
1242
+ print("\nAll tests completed successfully!")
1243
+