lhallee commited on
Commit
9dfda2a
·
verified ·
1 Parent(s): de44871

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +1197 -1159
modeling_esm_plusplus.py CHANGED
@@ -1,1159 +1,1197 @@
1
- """
2
- ESM++ model implementation.
3
-
4
- ESM++ is a faithful implementation of ESMC that allows for batching and standard Huggingface compatibility
5
- The ESM Python package is not required
6
-
7
- Modified from https://github.com/evolutionaryscale/esm
8
- License: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement
9
- """
10
-
11
- import entrypoint_setup
12
- import math
13
- import os
14
- import warnings
15
- import torch
16
- import torch.nn as nn
17
- import torch.nn.functional as F
18
- from dataclasses import dataclass
19
- from functools import cache, partial
20
- from pathlib import Path
21
- from typing import Optional, Tuple, Union, List
22
- from einops import rearrange, repeat
23
- from huggingface_hub import snapshot_download
24
- from tokenizers import Tokenizer
25
- from tokenizers.models import BPE
26
- from tokenizers.processors import TemplateProcessing
27
- from torch.nn.attention.flex_attention import create_block_mask
28
- from torch.nn.attention.flex_attention import flex_attention
29
- from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
30
- from transformers.modeling_outputs import ModelOutput
31
-
32
- try:
33
- # when used from AutoModel, these are in the same directory
34
- from .embedding_mixin import EmbeddingMixin, Pooler
35
- except:
36
- # when running from our repo, these are in the base directory
37
- from embedding_mixin import EmbeddingMixin, Pooler
38
-
39
-
40
- def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
41
- token_valid = attention_mask_2d.bool()
42
- batch_size, seq_len = token_valid.shape
43
-
44
- def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
45
- return token_valid[batch_idx, q_idx] & token_valid[batch_idx, kv_idx]
46
-
47
- return create_block_mask(
48
- mask_mod,
49
- batch_size,
50
- 1,
51
- seq_len,
52
- seq_len,
53
- device=attention_mask_2d.device,
54
- )
55
-
56
-
57
- class ESMplusplusConfig(PretrainedConfig):
58
- """Configuration class for ESM++ model.
59
-
60
- Args:
61
- vocab_size: Size of the vocabulary
62
- hidden_size: Dimension of hidden layers
63
- num_attention_heads: Number of attention heads
64
- num_hidden_layers: Number of transformer layers
65
- num_labels: Number of output labels for classification
66
- problem_type: Type of problem - regression, single/multi label classification
67
- """
68
- model_type = "ESMplusplus"
69
- def __init__(
70
- self,
71
- vocab_size: int = 64,
72
- hidden_size: int = 960,
73
- num_attention_heads: int = 15,
74
- num_hidden_layers: int = 30,
75
- num_labels: int = 2,
76
- problem_type: str | None = None,
77
- dropout: float = 0.0,
78
- initializer_range: float = 0.02,
79
- attn_backend: str = "sdpa",
80
- **kwargs,
81
- ):
82
- super().__init__(**kwargs)
83
- self.vocab_size = vocab_size
84
- self.hidden_size = hidden_size
85
- self.num_attention_heads = num_attention_heads
86
- self.num_hidden_layers = num_hidden_layers
87
- self.num_labels = num_labels
88
- self.problem_type = problem_type
89
- self.dropout = dropout
90
- self.initializer_range = initializer_range
91
- self.tie_word_embeddings = False
92
- self.attn_backend = attn_backend
93
-
94
-
95
- ### Rotary Embeddings
96
- def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
97
- """Rotates half the hidden dims of the input."""
98
- if not interleaved:
99
- x1, x2 = x.chunk(2, dim=-1)
100
- return torch.cat((-x2, x1), dim=-1)
101
- else:
102
- x1, x2 = x[..., ::2], x[..., 1::2]
103
- return rearrange(
104
- torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
105
- )
106
-
107
-
108
- def apply_rotary_emb_torch(
109
- x: torch.Tensor,
110
- cos: torch.Tensor,
111
- sin: torch.Tensor,
112
- interleaved: bool = False,
113
- _inplace: bool = False,
114
- ) -> torch.Tensor:
115
- """Apply rotary embeddings to input based on cos and sin."""
116
- ro_dim = cos.shape[-1] * 2
117
- assert ro_dim <= x.shape[-1]
118
- seqlen = x.size(1)
119
- cos = cos[:seqlen]
120
- sin = sin[:seqlen]
121
- cos = repeat(cos, "s d -> s 1 (2 d)")
122
- sin = repeat(sin, "s d -> s 1 (2 d)")
123
- return torch.cat(
124
- [
125
- x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
126
- x[..., ro_dim:],
127
- ],
128
- dim=-1,
129
- )
130
-
131
-
132
- class RotaryEmbedding(torch.nn.Module):
133
- """Rotary position embeddings.
134
-
135
- Based on the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding"
136
-
137
- Args:
138
- dim: Dimension of the embedding
139
- base: Base for computing angular frequencies
140
- interleaved: Whether to use interleaved rotations
141
- scale_base: Base for scaling
142
- scaling_factor: Factor for scaling positions
143
- pos_idx_in_fp32: Whether to compute position indices in fp32
144
- device: Computation device
145
- """
146
- def __init__(
147
- self,
148
- dim: int,
149
- base: float = 10000.0,
150
- interleaved: bool = False,
151
- scale_base: Optional[float] = None,
152
- scaling_factor: float = 1.0,
153
- pos_idx_in_fp32: bool = True,
154
- device: Optional[torch.device] = None,
155
- ):
156
- super().__init__()
157
- self.dim = dim
158
- self.base = float(base)
159
- self.pos_idx_in_fp32 = pos_idx_in_fp32
160
- self.interleaved = interleaved
161
- self.scale_base = scale_base
162
- self.scaling_factor = scaling_factor
163
- self.device = device
164
-
165
- self._seq_len_cached = 0
166
- self._cos_cached = None
167
- self._sin_cached = None
168
- self._cos_k_cached = None
169
- self._sin_k_cached = None
170
- self.reset_parameters()
171
-
172
- def reset_parameters(self):
173
- """Reset the parameters of the embedding."""
174
- inv_freq = self._compute_inv_freq(self.device)
175
- self.register_buffer("inv_freq", inv_freq, persistent=False)
176
- arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32)
177
- scale = (
178
- (arange + 0.4 * self.dim) / (1.4 * self.dim)
179
- if self.scale_base is not None
180
- else None
181
- )
182
- self.register_buffer("scale", scale)
183
-
184
- def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
185
- """Compute inverse frequency bands."""
186
- return 1 / (
187
- self.base
188
- ** (
189
- torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
190
- / self.dim
191
- )
192
- )
193
-
194
- def _update_cos_sin_cache(self, seqlen: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
195
- """Update the cached cosine and sine values."""
196
- if (
197
- seqlen > self._seq_len_cached
198
- or self._cos_cached is None
199
- or self._cos_cached.device != device
200
- or self._cos_cached.dtype != dtype
201
- or (self.training and self._cos_cached.is_inference())
202
- ):
203
- self._seq_len_cached = seqlen
204
- if self.pos_idx_in_fp32:
205
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
206
- t /= self.scaling_factor
207
- if self.inv_freq.dtype != torch.float32:
208
- inv_freq = self.inv_freq.to(torch.float32)
209
- else:
210
- inv_freq = self.inv_freq
211
- else:
212
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
213
- t /= self.scaling_factor
214
- inv_freq = self.inv_freq
215
- freqs = torch.outer(t, inv_freq)
216
-
217
- if self.scale is None:
218
- self._cos_cached = torch.cos(freqs).to(dtype)
219
- self._sin_cached = torch.sin(freqs).to(dtype)
220
- else:
221
- power = (
222
- torch.arange(
223
- seqlen, dtype=self.scale.dtype, device=self.scale.device
224
- )
225
- - seqlen // 2
226
- ) / self.scale_base
227
- scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
228
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
229
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
230
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
231
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
232
-
233
- def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
234
- """Apply rotary embeddings to queries and keys.
235
-
236
- Args:
237
- q: Query tensor of shape (batch, seqlen, nheads, headdim)
238
- k: Key tensor of shape (batch, seqlen, nheads, headdim)
239
-
240
- Returns:
241
- Tuple of rotated query and key tensors
242
- """
243
- self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
244
- assert self._cos_cached is not None
245
- assert self._sin_cached is not None
246
- if self.scale is None:
247
- return (
248
- apply_rotary_emb_torch(
249
- q,
250
- self._cos_cached,
251
- self._sin_cached,
252
- self.interleaved,
253
- True, # inplace=True
254
- ),
255
- apply_rotary_emb_torch(
256
- k,
257
- self._cos_cached,
258
- self._sin_cached,
259
- self.interleaved,
260
- True, # inplace=True
261
- ),
262
- ) # type: ignore
263
- else:
264
- assert False
265
-
266
-
267
- ### Feedforward Network Components
268
- def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
269
- """Compute corrected dimension for SwiGLU."""
270
- return int(((expansion_ratio * d_model) + 255) // 256 * 256)
271
-
272
-
273
- class SwiGLU(nn.Module):
274
- """SwiGLU activation function."""
275
- def __init__(self):
276
- super(SwiGLU, self).__init__()
277
-
278
- def forward(self, x: torch.Tensor) -> torch.Tensor:
279
- x1, x2 = x.chunk(2, dim=-1)
280
- return F.silu(x1) * x2
281
-
282
-
283
- def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
284
- """Create SwiGLU feedforward network with layer normalization."""
285
- return nn.Sequential(
286
- nn.LayerNorm(d_model),
287
- nn.Linear(
288
- d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
289
- ),
290
- SwiGLU(),
291
- nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
292
- )
293
-
294
-
295
- ### Attention
296
- class MultiHeadAttention(nn.Module):
297
- """Multi-head attention with rotary embeddings.
298
-
299
- Args:
300
- d_model: Model dimension
301
- n_heads: Number of attention heads
302
- """
303
- def __init__(
304
- self,
305
- d_model: int,
306
- n_heads: int,
307
- attn_backend: str = "sdpa",
308
- ):
309
- super().__init__()
310
- self.d_model = d_model
311
- self.n_heads = n_heads
312
- self.d_head = self.d_model // self.n_heads
313
- self.attn_backend = attn_backend
314
- self._warned_flex_fallback = False
315
- self.layernorm_qkv = nn.Sequential(
316
- nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
317
- )
318
- self.out_proj = nn.Linear(d_model, d_model, bias=False)
319
- self.q_ln = nn.LayerNorm(d_model, bias=False)
320
- self.k_ln = nn.LayerNorm(d_model, bias=False)
321
- self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads)
322
- self.rotary = RotaryEmbedding(d_model // n_heads)
323
-
324
- def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
325
- """Apply rotary embeddings to query and key."""
326
- q = q.unflatten(-1, (self.n_heads, self.d_head))
327
- k = k.unflatten(-1, (self.n_heads, self.d_head))
328
- q, k = self.rotary(q, k)
329
- q = q.flatten(-2, -1)
330
- k = k.flatten(-2, -1)
331
- return q, k
332
-
333
- def forward(
334
- self,
335
- x: torch.Tensor,
336
- attention_mask: Optional[torch.Tensor] = None,
337
- flex_block_mask: Optional[object] = None,
338
- output_attentions: bool = False,
339
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
340
- """
341
- Args:
342
- x: Input tensor
343
- attention_mask: Optional attention mask
344
- output_attentions: Whether to return attention weights
345
-
346
- Returns:
347
- Output tensor after self attention, and optionally attention weights
348
- """
349
- attn_weights = None
350
- qkv_BLD3 = self.layernorm_qkv(x)
351
- query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
352
- query_BLD, key_BLD = (
353
- self.q_ln(query_BLD).to(query_BLD.dtype),
354
- self.k_ln(key_BLD).to(query_BLD.dtype),
355
- )
356
- query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
357
- query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
358
- scale = 1 / math.sqrt(self.d_head)
359
-
360
- if output_attentions: # Manual attention computation
361
- b, h, l, _ = query_BHLD.shape
362
- attn_bias = torch.zeros(b, h, l, l, dtype=query_BLD.dtype, device=query_BLD.device)
363
- if attention_mask is not None:
364
- attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
365
- attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
366
- attn_weights += attn_bias
367
- attn_weights = F.softmax(attn_weights, dim=-1)
368
- context_BHLD = torch.matmul(attn_weights, value_BHLD)
369
- else:
370
- sdpa_mask = None
371
- if attention_mask is not None:
372
- sdpa_mask = torch.zeros_like(attention_mask, dtype=query_BHLD.dtype)
373
- sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
374
- use_flex = (
375
- self.attn_backend == "flex"
376
- and (attention_mask is None or flex_block_mask is not None)
377
- )
378
- if use_flex:
379
- try:
380
- context_BHLD = flex_attention(
381
- query_BHLD,
382
- key_BHLD,
383
- value_BHLD,
384
- block_mask=flex_block_mask,
385
- scale=scale,
386
- )
387
- except Exception as exc:
388
- if not self._warned_flex_fallback:
389
- warnings.warn(
390
- f"Flex attention failed in ESM++ attention; falling back to SDPA. Error: {exc}",
391
- RuntimeWarning,
392
- )
393
- self._warned_flex_fallback = True
394
- context_BHLD = F.scaled_dot_product_attention(
395
- query_BHLD,
396
- key_BHLD,
397
- value_BHLD,
398
- attn_mask=sdpa_mask,
399
- scale=scale,
400
- )
401
- else:
402
- context_BHLD = F.scaled_dot_product_attention(
403
- query_BHLD,
404
- key_BHLD,
405
- value_BHLD,
406
- attn_mask=sdpa_mask,
407
- scale=scale,
408
- )
409
-
410
- context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
411
- output = self.out_proj(context_BLD)
412
- return output, attn_weights
413
-
414
-
415
- ### Regression Head
416
- def RegressionHead(d_model: int, output_dim: int, hidden_dim: Optional[int] = None) -> nn.Module:
417
- """Create a regression head with optional hidden dimension.
418
-
419
- Args:
420
- d_model: Input dimension
421
- output_dim: Output dimension
422
- hidden_dim: Optional hidden dimension (defaults to d_model)
423
- """
424
- hidden_dim = hidden_dim if hidden_dim is not None else d_model
425
- return nn.Sequential(
426
- nn.Linear(d_model, hidden_dim),
427
- nn.GELU(),
428
- nn.LayerNorm(hidden_dim),
429
- nn.Linear(hidden_dim, output_dim),
430
- )
431
-
432
-
433
- ### Transformer Block
434
- class UnifiedTransformerBlock(nn.Module):
435
- """Transformer block with attention and feedforward layers.
436
-
437
- Args:
438
- d_model: Model dimension
439
- n_heads: Number of attention heads
440
- residue_scaling_factor: Factor for scaling residual connections
441
- expansion_ratio: Expansion ratio for feedforward network
442
- """
443
- def __init__(
444
- self,
445
- d_model: int,
446
- n_heads: int,
447
- residue_scaling_factor: float = 1,
448
- expansion_ratio: float = 8 / 3,
449
- dropout: float = 0.0,
450
- attn_backend: str = "sdpa",
451
- ):
452
- super().__init__()
453
- self.attn = MultiHeadAttention(
454
- d_model=d_model,
455
- n_heads=n_heads,
456
- attn_backend=attn_backend,
457
- )
458
- self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
459
- self.scaling_factor = residue_scaling_factor
460
- self.dropout = nn.Dropout(dropout)
461
-
462
- def forward(
463
- self,
464
- x: torch.Tensor,
465
- attention_mask: Optional[torch.Tensor] = None,
466
- flex_block_mask: Optional[object] = None,
467
- output_attentions: bool = False,
468
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
469
- """
470
- Args:
471
- x: Input tensor
472
- attention_mask: Optional attention mask
473
- output_attentions: Whether to return attention weights
474
-
475
- Returns:
476
- Output tensor after transformer block, and optionally attention weights
477
- """
478
- attn_output, attn_weights = self.attn(
479
- x,
480
- attention_mask,
481
- flex_block_mask,
482
- output_attentions,
483
- )
484
- x = x + self.dropout(attn_output) / self.scaling_factor
485
- x = x + self.dropout(self.ffn(x)) / self.scaling_factor
486
- return x, attn_weights
487
-
488
-
489
- ### Model Outputs
490
- @dataclass
491
- class TransformerOutput(ModelOutput):
492
- """Output type for transformer encoder."""
493
- last_hidden_state: Optional[torch.Tensor] = None
494
- hidden_states: Optional[Tuple[torch.Tensor]] = None
495
- attentions: Optional[Tuple[torch.Tensor]] = None
496
-
497
-
498
- @dataclass
499
- class ESMplusplusOutput(ModelOutput):
500
- """Output type for ESM++ models."""
501
- loss: Optional[torch.Tensor] = None
502
- logits: Optional[torch.Tensor] = None
503
- last_hidden_state: Optional[torch.Tensor] = None
504
- hidden_states: Optional[Tuple[torch.Tensor]] = None
505
- attentions: Optional[Tuple[torch.Tensor]] = None
506
-
507
-
508
- ### Transformer Stack
509
- class TransformerStack(nn.Module):
510
- """Stack of transformer blocks.
511
-
512
- Args:
513
- d_model: Model dimension
514
- n_heads: Number of attention heads
515
- n_layers: Number of transformer layers
516
- dropout: Dropout rate
517
- """
518
- def __init__(
519
- self,
520
- d_model: int,
521
- n_heads: int,
522
- n_layers: int,
523
- dropout: float = 0.0,
524
- attn_backend: str = "sdpa",
525
- ):
526
- super().__init__()
527
- self.attn_backend = attn_backend
528
- self.blocks = nn.ModuleList(
529
- [
530
- UnifiedTransformerBlock(
531
- d_model,
532
- n_heads,
533
- residue_scaling_factor=math.sqrt(n_layers / 36),
534
- dropout=dropout,
535
- attn_backend=attn_backend,
536
- )
537
- for i in range(n_layers)
538
- ]
539
- )
540
- self.norm = nn.LayerNorm(d_model, bias=False)
541
- self.gradient_checkpointing = False
542
-
543
- def forward(
544
- self,
545
- x: torch.Tensor,
546
- attention_mask: Optional[torch.Tensor] = None,
547
- output_hidden_states: bool = False,
548
- output_attentions: bool = False,
549
- ) -> TransformerOutput:
550
- """
551
- Args:
552
- x: Input tensor
553
- attention_mask: Optional attention mask
554
- output_hidden_states: Whether to return all hidden states
555
- output_attentions: Whether to return attention weights
556
-
557
- Returns:
558
- TransformerOutput containing last hidden state and optionally all hidden states and attention weights
559
- """
560
- batch_size, seq_len, _ = x.shape
561
- hidden_states = () if output_hidden_states else None
562
- attentions = () if output_attentions else None
563
-
564
- if attention_mask is not None:
565
- attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
566
- if self.attn_backend == "flex" and not output_attentions:
567
- token_attention_mask = attention_mask[:, 0, 0, :]
568
- flex_block_mask = _create_pad_block_mask(token_attention_mask)
569
- else:
570
- flex_block_mask = None
571
- else:
572
- flex_block_mask = None
573
-
574
- for block in self.blocks:
575
- if self.gradient_checkpointing and self.training:
576
- x, attn_weights = self._gradient_checkpointing_func(
577
- block.__call__,
578
- x,
579
- attention_mask,
580
- flex_block_mask,
581
- output_attentions,
582
- )
583
- else:
584
- x, attn_weights = block(x, attention_mask, flex_block_mask, output_attentions)
585
-
586
- if attentions is not None:
587
- attentions += (attn_weights,)
588
-
589
- if output_hidden_states:
590
- assert hidden_states is not None
591
- hidden_states += (x,)
592
-
593
- return TransformerOutput(
594
- last_hidden_state=self.norm(x),
595
- hidden_states=hidden_states,
596
- attentions=attentions
597
- )
598
-
599
-
600
- class PreTrainedESMplusplusModel(PreTrainedModel):
601
- """
602
- init weights for ESM++ models
603
- """
604
- config_class = ESMplusplusConfig
605
- base_model_prefix = "esm++"
606
- supports_gradient_checkpointing = True
607
- all_tied_weights_keys = {}
608
-
609
- def _init_weights(self, module):
610
- """Initialize the weights"""
611
- if isinstance(module, nn.Linear):
612
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
613
- if module.bias is not None:
614
- module.bias.data.zero_()
615
- elif isinstance(module, nn.Embedding):
616
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
617
- if module.padding_idx is not None:
618
- module.weight.data[module.padding_idx].zero_()
619
- elif isinstance(module, nn.LayerNorm):
620
- if module.bias is not None:
621
- module.bias.data.zero_()
622
- module.weight.data.fill_(1.0)
623
-
624
- @classmethod
625
- def from_pretrained_esm(cls, model_name: str):
626
- """Load a pretrained ESM++ model."""
627
- if '300' in model_name:
628
- return ESMplusplus_300M()
629
- elif '600' in model_name:
630
- return ESMplusplus_600M()
631
- else:
632
- raise ValueError(f"Invalid model name: {model_name}")
633
-
634
-
635
- ### ESM++ Models
636
- class ESMplusplusModel(PreTrainedESMplusplusModel, EmbeddingMixin):
637
- """
638
- ESM++ model. transformer model with no heads
639
- """
640
- config_class = ESMplusplusConfig
641
- def __init__(self, config: ESMplusplusConfig, **kwargs):
642
- PreTrainedESMplusplusModel.__init__(self, config, **kwargs)
643
- self.config = config
644
- self.vocab_size = config.vocab_size
645
- self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
646
- self.transformer = TransformerStack(
647
- d_model=config.hidden_size,
648
- n_heads=config.num_attention_heads,
649
- n_layers=config.num_hidden_layers,
650
- dropout=config.dropout,
651
- attn_backend=config.attn_backend,
652
- )
653
- self.tokenizer = EsmSequenceTokenizer()
654
- self.init_weights()
655
-
656
- def get_input_embeddings(self):
657
- return self.embed
658
-
659
- def set_input_embeddings(self, value):
660
- self.embed = value
661
-
662
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
663
- x = self.embed(input_ids)
664
- return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
665
-
666
- def forward(
667
- self,
668
- input_ids: Optional[torch.Tensor] = None,
669
- attention_mask: Optional[torch.Tensor] = None,
670
- inputs_embeds: Optional[torch.Tensor] = None,
671
- output_attentions: Optional[bool] = None,
672
- output_hidden_states: Optional[bool] = None,
673
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
674
- **kwargs,
675
- ) -> TransformerOutput:
676
- """Forward pass for masked language modeling.
677
-
678
- Args:
679
- input_ids: Input token IDs
680
- attention_mask: Attention mask
681
- inputs_embeds: Optional precomputed embeddings
682
- output_hidden_states: Whether to return all hidden states
683
- output_attentions: Whether to return attention weights
684
-
685
- Returns:
686
- TransformerOutput containing last hidden state and optionally all hidden states and attention weights
687
- """
688
- if inputs_embeds is None:
689
- x = self.embed(input_ids)
690
- else:
691
- x = inputs_embeds
692
- return self.transformer(x, attention_mask, output_hidden_states, output_attentions)
693
-
694
-
695
- class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
696
- """
697
- ESM++ model for masked language modeling.
698
- Implements the base ESM++ architecture with a masked language modeling head.
699
- """
700
- config_class = ESMplusplusConfig
701
- def __init__(self, config: ESMplusplusConfig, **kwargs):
702
- PreTrainedESMplusplusModel.__init__(self, config, **kwargs)
703
- self.config = config
704
- self.vocab_size = config.vocab_size
705
- self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
706
- self.transformer = TransformerStack(
707
- d_model=config.hidden_size,
708
- n_heads=config.num_attention_heads,
709
- n_layers=config.num_hidden_layers,
710
- dropout=config.dropout,
711
- attn_backend=config.attn_backend,
712
- )
713
- self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
714
- self.ce_loss = nn.CrossEntropyLoss()
715
- self.tokenizer = EsmSequenceTokenizer()
716
- self.init_weights()
717
-
718
- def get_input_embeddings(self):
719
- return self.embed
720
-
721
- def set_input_embeddings(self, value):
722
- self.embed = value
723
-
724
- def get_output_embeddings(self):
725
- return self.sequence_head[-1]
726
-
727
- def set_output_embeddings(self, new_embeddings):
728
- self.sequence_head[-1] = new_embeddings
729
-
730
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
731
- x = self.embed(input_ids)
732
- return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
733
-
734
- def forward(
735
- self,
736
- input_ids: Optional[torch.Tensor] = None,
737
- attention_mask: Optional[torch.Tensor] = None,
738
- inputs_embeds: Optional[torch.Tensor] = None,
739
- labels: Optional[torch.Tensor] = None,
740
- output_attentions: Optional[bool] = None,
741
- output_hidden_states: Optional[bool] = None,
742
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
743
- **kwargs,
744
- ) -> ESMplusplusOutput:
745
- """Forward pass for masked language modeling.
746
-
747
- Args:
748
- input_ids: Input token IDs
749
- attention_mask: Attention mask
750
- inputs_embeds: Optional precomputed embeddings
751
- labels: Optional labels for masked tokens
752
- output_hidden_states: Whether to return all hidden states
753
- output_attentions: Whether to return attention weights
754
-
755
- Returns:
756
- ESMplusplusOutput containing loss, logits, hidden states and attention weights
757
- """
758
- if inputs_embeds is None:
759
- x = self.embed(input_ids)
760
- else:
761
- x = inputs_embeds
762
- output = self.transformer(x, attention_mask, output_hidden_states, output_attentions)
763
- x = output.last_hidden_state
764
- logits = self.sequence_head(x)
765
- loss = None
766
- if labels is not None:
767
- loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
768
- return ESMplusplusOutput(
769
- loss=loss,
770
- logits=logits,
771
- last_hidden_state=x,
772
- hidden_states=output.hidden_states,
773
- attentions=output.attentions,
774
- )
775
-
776
-
777
- class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixin):
778
- """
779
- ESM++ model for sequence classification.
780
- Extends the base ESM++ model with a classification head.
781
- """
782
- def __init__(self, config: ESMplusplusConfig, **kwargs):
783
- ESMplusplusForMaskedLM.__init__(self, config, **kwargs)
784
- self.config = config
785
- self.num_labels = config.num_labels
786
- self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
787
- # Large intermediate projections help with sequence classification tasks (*4)
788
- self.mse = nn.MSELoss()
789
- self.ce = nn.CrossEntropyLoss()
790
- self.bce = nn.BCEWithLogitsLoss()
791
- # if kwargs has pooling_types, use them, otherwise use ['cls', 'mean']
792
- if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0:
793
- pooling_types = kwargs['pooling_types']
794
- else:
795
- pooling_types = ['cls', 'mean']
796
- self.pooler = Pooler(pooling_types)
797
- self.init_weights()
798
-
799
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
800
- x = self.embed(input_ids)
801
- return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
802
-
803
- def forward(
804
- self,
805
- input_ids: Optional[torch.Tensor] = None,
806
- attention_mask: Optional[torch.Tensor] = None,
807
- inputs_embeds: Optional[torch.Tensor] = None,
808
- labels: Optional[torch.Tensor] = None,
809
- output_attentions: Optional[bool] = None,
810
- output_hidden_states: Optional[bool] = None,
811
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
812
- **kwargs,
813
- ) -> ESMplusplusOutput:
814
- """Forward pass for sequence classification.
815
-
816
- Args:
817
- input_ids: Input token IDs
818
- attention_mask: Attention mask
819
- inputs_embeds: Optional precomputed embeddings
820
- labels: Optional labels for classification
821
- output_hidden_states: Whether to return all hidden states
822
- output_attentions: Whether to return attention weights
823
-
824
- Returns:
825
- ESMplusplusOutput containing loss, logits, and hidden states
826
- """
827
- output = super().forward(
828
- input_ids=input_ids,
829
- attention_mask=attention_mask,
830
- inputs_embeds=inputs_embeds,
831
- labels=None,
832
- output_attentions=output_attentions,
833
- output_hidden_states=output_hidden_states
834
- )
835
- x = output.last_hidden_state
836
- features = self.pooler(x, attention_mask)
837
- logits = self.classifier(features)
838
- loss = None
839
- if labels is not None:
840
- labels = labels.to(logits.device)
841
- if self.config.problem_type is None:
842
- if self.num_labels == 1:
843
- self.config.problem_type = "regression"
844
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
845
- self.config.problem_type = "single_label_classification"
846
- else:
847
- self.config.problem_type = "multi_label_classification"
848
-
849
- if self.config.problem_type == "regression":
850
- if self.num_labels == 1:
851
- loss = self.mse(logits.flatten(), labels.flatten())
852
- else:
853
- loss = self.mse(logits, labels)
854
- elif self.config.problem_type == "single_label_classification":
855
- loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
856
- elif self.config.problem_type == "multi_label_classification":
857
- loss = self.bce(logits, labels)
858
-
859
- return ESMplusplusOutput(
860
- loss=loss,
861
- logits=logits,
862
- last_hidden_state=x,
863
- hidden_states=output.hidden_states,
864
- attentions=output.attentions,
865
- )
866
-
867
-
868
- class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM, EmbeddingMixin):
869
- """
870
- ESM++ model for token classification.
871
- Extends the base ESM++ model with a token classification head.
872
- """
873
- def __init__(self, config: ESMplusplusConfig, **kwargs):
874
- ESMplusplusForMaskedLM.__init__(self, config, **kwargs)
875
- self.config = config
876
- self.num_labels = config.num_labels
877
- self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
878
- # Large intermediate projections help with sequence classification tasks (*4)
879
- self.loss_fct = nn.CrossEntropyLoss()
880
- self.init_weights()
881
-
882
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
883
- x = self.embed(input_ids)
884
- return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
885
-
886
- def forward(
887
- self,
888
- input_ids: Optional[torch.Tensor] = None,
889
- attention_mask: Optional[torch.Tensor] = None,
890
- inputs_embeds: Optional[torch.Tensor] = None,
891
- labels: Optional[torch.Tensor] = None,
892
- output_attentions: Optional[bool] = None,
893
- output_hidden_states: Optional[bool] = None,
894
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
895
- **kwargs,
896
- ) -> ESMplusplusOutput:
897
- """Forward pass for token classification.
898
-
899
- Args:
900
- input_ids: Input token IDs
901
- attention_mask: Attention mask
902
- inputs_embeds: Optional precomputed embeddings
903
- labels: Optional labels for token classification
904
- output_hidden_states: Whether to return all hidden states
905
- output_attentions: Whether to return attention weights
906
-
907
- Returns:
908
- ESMplusplusOutput containing loss, logits, and hidden states
909
- """
910
- output = super().forward(
911
- input_ids=input_ids,
912
- attention_mask=attention_mask,
913
- inputs_embeds=inputs_embeds,
914
- labels=None,
915
- output_attentions=output_attentions,
916
- output_hidden_states=output_hidden_states
917
- )
918
- x = output.last_hidden_state
919
- logits = self.classifier(x)
920
- loss = None
921
- if labels is not None:
922
- loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
923
- return ESMplusplusOutput(
924
- loss=loss,
925
- logits=logits,
926
- last_hidden_state=x,
927
- hidden_states=output.hidden_states,
928
- attentions=output.attentions,
929
- )
930
-
931
-
932
- ### Loading from EvolutionaryScale
933
- @staticmethod
934
- @cache
935
- def data_root(model: str):
936
- if "INFRA_PROVIDER" in os.environ:
937
- return Path("")
938
- # Try to download from hugginface if it doesn't exist
939
- if model.startswith("esmc-300"):
940
- path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-300m-2024-12"))
941
- elif model.startswith("esmc-600"):
942
- path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-600m-2024-12"))
943
- else:
944
- raise ValueError(f"{model=} is an invalid model name.")
945
- return path
946
-
947
-
948
- def ESMplusplus_300M(device: torch.device | str = "cpu"):
949
- with torch.device(device):
950
- config = ESMplusplusConfig(
951
- hidden_size=960,
952
- num_attention_heads=15,
953
- num_hidden_layers=30,
954
- )
955
- model = ESMplusplusForMaskedLM(config)
956
- state_dict = torch.load(
957
- data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
958
- map_location=device,
959
- )
960
- model.load_state_dict(state_dict)
961
- return model
962
-
963
-
964
- def ESMplusplus_600M(device: torch.device | str = "cpu"):
965
- with torch.device(device):
966
- config = ESMplusplusConfig(
967
- hidden_size=1152,
968
- num_attention_heads=18,
969
- num_hidden_layers=36,
970
- )
971
- model = ESMplusplusForMaskedLM(config)
972
- state_dict = torch.load(
973
- data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
974
- map_location=device,
975
- )
976
- model.load_state_dict(state_dict)
977
- return model
978
-
979
-
980
- ### Tokenization
981
- SEQUENCE_VOCAB = [
982
- "<cls>", "<pad>", "<eos>", "<unk>",
983
- "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
984
- "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
985
- "O", ".", "-", "|",
986
- "<mask>",
987
- ]
988
-
989
- class EsmSequenceTokenizer(PreTrainedTokenizerFast):
990
- model_input_names = ["input_ids", "attention_mask"]
991
-
992
- def __init__(
993
- self,
994
- unk_token="<unk>",
995
- cls_token="<cls>",
996
- pad_token="<pad>",
997
- mask_token="<mask>",
998
- eos_token="<eos>",
999
- chain_break_token="|",
1000
- **kwargs,
1001
- ):
1002
- all_tokens = SEQUENCE_VOCAB
1003
- token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
1004
-
1005
- # a character-level tokenizer is the same as BPE with no token merges
1006
- bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
1007
- tokenizer = Tokenizer(bpe)
1008
- special_tokens = [
1009
- cls_token,
1010
- pad_token,
1011
- mask_token,
1012
- eos_token,
1013
- chain_break_token,
1014
- ]
1015
- self.cb_token = chain_break_token
1016
- additional_special_tokens = [chain_break_token]
1017
-
1018
- tokenizer.add_special_tokens(special_tokens)
1019
-
1020
- # This is where we configure the automatic addition of special tokens when we call
1021
- # tokenizer(text, add_special_tokens=True). Note that you can also configure how two
1022
- # sequences are merged if you want.
1023
- tokenizer.post_processor = TemplateProcessing( # type: ignore
1024
- single="<cls> $A <eos>",
1025
- pair="<cls>:0 $A:0 <eos>:0 $B:1 <eos>:1",
1026
- special_tokens=[
1027
- ("<cls>", tokenizer.token_to_id("<cls>")),
1028
- ("<eos>", tokenizer.token_to_id("<eos>")),
1029
- ],
1030
- )
1031
- super().__init__(
1032
- tokenizer_object=tokenizer,
1033
- unk_token=unk_token,
1034
- cls_token=cls_token,
1035
- pad_token=pad_token,
1036
- mask_token=mask_token,
1037
- eos_token=eos_token,
1038
- additional_special_tokens=additional_special_tokens,
1039
- **kwargs,
1040
- )
1041
-
1042
- # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
1043
- @property
1044
- def bos_token(self):
1045
- return self.cls_token
1046
-
1047
- @property
1048
- def bos_token_id(self):
1049
- return self.cls_token_id
1050
-
1051
- @property
1052
- def chain_break_token(self):
1053
- return self.cb_token
1054
-
1055
- @property
1056
- def chain_break_token_id(self):
1057
- return self.convert_tokens_to_ids(self.chain_break_token)
1058
-
1059
- @property
1060
- def all_token_ids(self):
1061
- return list(range(self.vocab_size))
1062
-
1063
- @property
1064
- def special_token_ids(self):
1065
- return self.all_special_ids
1066
-
1067
-
1068
- if __name__ == "__main__":
1069
- # Set device to CPU for testing
1070
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1071
- print(f"Using device: {device}")
1072
-
1073
- # Test tokenizer
1074
- tokenizer = EsmSequenceTokenizer()
1075
- sample_sequence = "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG"
1076
- encoding = tokenizer(sample_sequence, return_tensors="pt")
1077
- print(f"Input sequence length: {len(sample_sequence)}")
1078
- print(f"Tokenized sequence: {encoding['input_ids'].shape}")
1079
-
1080
- # Prepare inputs
1081
- input_ids = encoding['input_ids'].to(device)
1082
- attention_mask = encoding['attention_mask'].to(device)
1083
-
1084
- # Test base model with smaller config for quick testing
1085
- print("\n=== Testing ESMplusplus Base Model ===")
1086
- base_config = ESMplusplusConfig(
1087
- hidden_size=384,
1088
- num_attention_heads=6,
1089
- num_hidden_layers=4
1090
- )
1091
- base_model = ESMplusplusModel(base_config).to(device)
1092
-
1093
- with torch.no_grad():
1094
- outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
1095
-
1096
- print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1097
-
1098
- # Test embedding functionality
1099
- print("\nTesting embedding functionality:")
1100
- with torch.no_grad():
1101
- embeddings = base_model._embed(input_ids, attention_mask)
1102
- print(f"Embedding shape: {embeddings.shape}")
1103
-
1104
- # Test masked language modeling
1105
- print("\n=== Testing ESMplusplus For Masked LM ===")
1106
- mlm_model = ESMplusplusForMaskedLM(base_config).to(device)
1107
-
1108
- with torch.no_grad():
1109
- outputs = mlm_model(input_ids=input_ids, attention_mask=attention_mask)
1110
-
1111
- print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1112
- print(f"Logits shape: {outputs.logits.shape}")
1113
-
1114
- # Test sequence classification model
1115
- print("\n=== Testing Sequence Classification Model ===")
1116
- classification_model = ESMplusplusForSequenceClassification(base_config).to(device)
1117
-
1118
- with torch.no_grad():
1119
- outputs = classification_model(input_ids=input_ids, attention_mask=attention_mask)
1120
-
1121
- print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1122
- print(f"Logits shape: {outputs.logits.shape}")
1123
-
1124
- # Test token classification model
1125
- print("\n=== Testing Token Classification Model ===")
1126
- token_model = ESMplusplusForTokenClassification(base_config).to(device)
1127
-
1128
- with torch.no_grad():
1129
- outputs = token_model(input_ids=input_ids, attention_mask=attention_mask)
1130
-
1131
- print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1132
- print(f"Logits shape: {outputs.logits.shape}")
1133
-
1134
- # Test embedding dataset functionality with a mini dataset
1135
- print("\n=== Testing Embed Dataset Functionality ===")
1136
- mini_dataset = [sample_sequence, sample_sequence[:50], sample_sequence[:30]]
1137
- print(f"Creating embeddings for {len(mini_dataset)} sequences")
1138
-
1139
- # Only run this if save path doesn't exist to avoid overwriting
1140
- if not os.path.exists("test_embeddings.pth"):
1141
- embeddings = mlm_model.embed_dataset(
1142
- sequences=mini_dataset,
1143
- tokenizer=tokenizer,
1144
- batch_size=2,
1145
- max_len=100,
1146
- full_embeddings=False,
1147
- pooling_types=['mean'],
1148
- save_path="test_embeddings.pth"
1149
- )
1150
- if embeddings:
1151
- print(f"Embedding dictionary size: {len(embeddings)}")
1152
- for seq, emb in embeddings.items():
1153
- print(f"Sequence length: {len(seq)}, Embedding shape: {emb.shape}")
1154
- break
1155
- else:
1156
- print("Skipping embedding test as test_embeddings.pth already exists")
1157
-
1158
- print("\nAll tests completed successfully!")
1159
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ESM++ model implementation.
3
+
4
+ ESM++ is a faithful implementation of ESMC that allows for batching and standard Huggingface compatibility
5
+ The ESM Python package is not required
6
+
7
+ Modified from https://github.com/evolutionaryscale/esm
8
+ License: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement
9
+ """
10
+
11
+ import entrypoint_setup
12
+ import math
13
+ import os
14
+ import warnings
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from dataclasses import dataclass
19
+ from functools import cache, partial
20
+ from pathlib import Path
21
+ from typing import Optional, Tuple, Union, List
22
+ from einops import rearrange, repeat
23
+ from huggingface_hub import snapshot_download
24
+ from tokenizers import Tokenizer
25
+ from tokenizers.models import BPE
26
+ from tokenizers.processors import TemplateProcessing
27
+ from torch.nn.attention.flex_attention import create_block_mask
28
+ from torch.nn.attention.flex_attention import flex_attention
29
+ from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
30
+ from transformers.modeling_outputs import ModelOutput
31
+
32
+ try:
33
+ # when used from AutoModel, these are in the same directory
34
+ from .embedding_mixin import EmbeddingMixin, Pooler
35
+ except:
36
+ # when running from our repo, these are in the base directory
37
+ from embedding_mixin import EmbeddingMixin, Pooler
38
+
39
+
40
+ def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
41
+ token_valid = attention_mask_2d.bool()
42
+ batch_size, seq_len = token_valid.shape
43
+
44
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
45
+ return token_valid[batch_idx, q_idx] & token_valid[batch_idx, kv_idx]
46
+
47
+ return create_block_mask(
48
+ mask_mod,
49
+ batch_size,
50
+ 1,
51
+ seq_len,
52
+ seq_len,
53
+ device=attention_mask_2d.device,
54
+ )
55
+
56
+
57
+ class ESMplusplusConfig(PretrainedConfig):
58
+ """Configuration class for ESM++ model.
59
+
60
+ Args:
61
+ vocab_size: Size of the vocabulary
62
+ hidden_size: Dimension of hidden layers
63
+ num_attention_heads: Number of attention heads
64
+ num_hidden_layers: Number of transformer layers
65
+ num_labels: Number of output labels for classification
66
+ problem_type: Type of problem - regression, single/multi label classification
67
+ """
68
+ model_type = "ESMplusplus"
69
+ def __init__(
70
+ self,
71
+ vocab_size: int = 64,
72
+ hidden_size: int = 960,
73
+ num_attention_heads: int = 15,
74
+ num_hidden_layers: int = 30,
75
+ num_labels: int = 2,
76
+ problem_type: str | None = None,
77
+ dropout: float = 0.0,
78
+ initializer_range: float = 0.02,
79
+ attn_backend: str = "sdpa",
80
+ **kwargs,
81
+ ):
82
+ super().__init__(**kwargs)
83
+ self.vocab_size = vocab_size
84
+ self.hidden_size = hidden_size
85
+ self.num_attention_heads = num_attention_heads
86
+ self.num_hidden_layers = num_hidden_layers
87
+ self.num_labels = num_labels
88
+ self.problem_type = problem_type
89
+ self.dropout = dropout
90
+ self.initializer_range = initializer_range
91
+ self.tie_word_embeddings = False
92
+ self.attn_backend = attn_backend
93
+
94
+
95
+ ### Rotary Embeddings
96
+ def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
97
+ """Rotates half the hidden dims of the input."""
98
+ if not interleaved:
99
+ x1, x2 = x.chunk(2, dim=-1)
100
+ return torch.cat((-x2, x1), dim=-1)
101
+ else:
102
+ x1, x2 = x[..., ::2], x[..., 1::2]
103
+ return rearrange(
104
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
105
+ )
106
+
107
+
108
+ def apply_rotary_emb_torch(
109
+ x: torch.Tensor,
110
+ cos: torch.Tensor,
111
+ sin: torch.Tensor,
112
+ interleaved: bool = False,
113
+ _inplace: bool = False,
114
+ ) -> torch.Tensor:
115
+ """Apply rotary embeddings to input based on cos and sin."""
116
+ ro_dim = cos.shape[-1] * 2
117
+ assert ro_dim <= x.shape[-1]
118
+ seqlen = x.size(1)
119
+ cos = cos[:seqlen]
120
+ sin = sin[:seqlen]
121
+ cos = repeat(cos, "s d -> s 1 (2 d)")
122
+ sin = repeat(sin, "s d -> s 1 (2 d)")
123
+ return torch.cat(
124
+ [
125
+ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
126
+ x[..., ro_dim:],
127
+ ],
128
+ dim=-1,
129
+ )
130
+
131
+
132
+ class RotaryEmbedding(torch.nn.Module):
133
+ """Rotary position embeddings.
134
+
135
+ Based on the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding"
136
+
137
+ Args:
138
+ dim: Dimension of the embedding
139
+ base: Base for computing angular frequencies
140
+ interleaved: Whether to use interleaved rotations
141
+ scale_base: Base for scaling
142
+ scaling_factor: Factor for scaling positions
143
+ pos_idx_in_fp32: Whether to compute position indices in fp32
144
+ device: Computation device
145
+ """
146
+ def __init__(
147
+ self,
148
+ dim: int,
149
+ base: float = 10000.0,
150
+ interleaved: bool = False,
151
+ scale_base: Optional[float] = None,
152
+ scaling_factor: float = 1.0,
153
+ pos_idx_in_fp32: bool = True,
154
+ device: Optional[torch.device] = None,
155
+ ):
156
+ super().__init__()
157
+ self.dim = dim
158
+ self.base = float(base)
159
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
160
+ self.interleaved = interleaved
161
+ self.scale_base = scale_base
162
+ self.scaling_factor = scaling_factor
163
+ self.device = device
164
+
165
+ self._seq_len_cached = 0
166
+ self._cos_cached = None
167
+ self._sin_cached = None
168
+ self._cos_k_cached = None
169
+ self._sin_k_cached = None
170
+ self.reset_parameters()
171
+
172
+ def reset_parameters(self):
173
+ """Reset the parameters of the embedding."""
174
+ inv_freq = self._compute_inv_freq(self.device)
175
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
176
+ arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32)
177
+ scale = (
178
+ (arange + 0.4 * self.dim) / (1.4 * self.dim)
179
+ if self.scale_base is not None
180
+ else None
181
+ )
182
+ self.register_buffer("scale", scale)
183
+
184
+ def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
185
+ """Compute inverse frequency bands."""
186
+ return 1 / (
187
+ self.base
188
+ ** (
189
+ torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
190
+ / self.dim
191
+ )
192
+ )
193
+
194
+ def _update_cos_sin_cache(self, seqlen: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
195
+ """Update the cached cosine and sine values."""
196
+ if (
197
+ seqlen > self._seq_len_cached
198
+ or self._cos_cached is None
199
+ or self._cos_cached.device != device
200
+ or self._cos_cached.dtype != dtype
201
+ or (self.training and self._cos_cached.is_inference())
202
+ ):
203
+ self._seq_len_cached = seqlen
204
+ if self.pos_idx_in_fp32:
205
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
206
+ t /= self.scaling_factor
207
+ if self.inv_freq.dtype != torch.float32:
208
+ inv_freq = self.inv_freq.to(torch.float32)
209
+ else:
210
+ inv_freq = self.inv_freq
211
+ else:
212
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
213
+ t /= self.scaling_factor
214
+ inv_freq = self.inv_freq
215
+ freqs = torch.outer(t, inv_freq)
216
+
217
+ if self.scale is None:
218
+ self._cos_cached = torch.cos(freqs).to(dtype)
219
+ self._sin_cached = torch.sin(freqs).to(dtype)
220
+ else:
221
+ power = (
222
+ torch.arange(
223
+ seqlen, dtype=self.scale.dtype, device=self.scale.device
224
+ )
225
+ - seqlen // 2
226
+ ) / self.scale_base
227
+ scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
228
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
229
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
230
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
231
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
232
+
233
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
234
+ """Apply rotary embeddings to queries and keys.
235
+
236
+ Args:
237
+ q: Query tensor of shape (batch, seqlen, nheads, headdim)
238
+ k: Key tensor of shape (batch, seqlen, nheads, headdim)
239
+
240
+ Returns:
241
+ Tuple of rotated query and key tensors
242
+ """
243
+ self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
244
+ assert self._cos_cached is not None
245
+ assert self._sin_cached is not None
246
+ if self.scale is None:
247
+ return (
248
+ apply_rotary_emb_torch(
249
+ q,
250
+ self._cos_cached,
251
+ self._sin_cached,
252
+ self.interleaved,
253
+ True, # inplace=True
254
+ ),
255
+ apply_rotary_emb_torch(
256
+ k,
257
+ self._cos_cached,
258
+ self._sin_cached,
259
+ self.interleaved,
260
+ True, # inplace=True
261
+ ),
262
+ ) # type: ignore
263
+ else:
264
+ assert False
265
+
266
+
267
+ ### Feedforward Network Components
268
+ def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
269
+ """Compute corrected dimension for SwiGLU."""
270
+ return int(((expansion_ratio * d_model) + 255) // 256 * 256)
271
+
272
+
273
+ class SwiGLU(nn.Module):
274
+ """SwiGLU activation function."""
275
+ def __init__(self):
276
+ super(SwiGLU, self).__init__()
277
+
278
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
279
+ x1, x2 = x.chunk(2, dim=-1)
280
+ return F.silu(x1) * x2
281
+
282
+
283
+ def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
284
+ """Create SwiGLU feedforward network with layer normalization."""
285
+ return nn.Sequential(
286
+ nn.LayerNorm(d_model),
287
+ nn.Linear(
288
+ d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
289
+ ),
290
+ SwiGLU(),
291
+ nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
292
+ )
293
+
294
+
295
+ ### Attention
296
+ class MultiHeadAttention(nn.Module):
297
+ """Multi-head attention with rotary embeddings.
298
+
299
+ Args:
300
+ d_model: Model dimension
301
+ n_heads: Number of attention heads
302
+ """
303
+ def __init__(
304
+ self,
305
+ d_model: int,
306
+ n_heads: int,
307
+ attn_backend: str = "sdpa",
308
+ ):
309
+ super().__init__()
310
+ self.d_model = d_model
311
+ self.n_heads = n_heads
312
+ self.d_head = self.d_model // self.n_heads
313
+ self.attn_backend = attn_backend
314
+ self._warned_flex_fallback = False
315
+ self.layernorm_qkv = nn.Sequential(
316
+ nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
317
+ )
318
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
319
+ self.q_ln = nn.LayerNorm(d_model, bias=False)
320
+ self.k_ln = nn.LayerNorm(d_model, bias=False)
321
+ self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads)
322
+ self.rotary = RotaryEmbedding(d_model // n_heads)
323
+
324
+ def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
325
+ """Apply rotary embeddings to query and key."""
326
+ q = q.unflatten(-1, (self.n_heads, self.d_head))
327
+ k = k.unflatten(-1, (self.n_heads, self.d_head))
328
+ q, k = self.rotary(q, k)
329
+ q = q.flatten(-2, -1)
330
+ k = k.flatten(-2, -1)
331
+ return q, k
332
+
333
+ def forward(
334
+ self,
335
+ x: torch.Tensor,
336
+ attention_mask: Optional[torch.Tensor] = None,
337
+ flex_block_mask: Optional[object] = None,
338
+ output_attentions: bool = False,
339
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
340
+ """
341
+ Args:
342
+ x: Input tensor
343
+ attention_mask: Optional attention mask
344
+ output_attentions: Whether to return attention weights
345
+
346
+ Returns:
347
+ Output tensor after self attention, and optionally attention weights
348
+ """
349
+ attn_weights = None
350
+ qkv_BLD3 = self.layernorm_qkv(x)
351
+ query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
352
+ query_BLD, key_BLD = (
353
+ self.q_ln(query_BLD).to(query_BLD.dtype),
354
+ self.k_ln(key_BLD).to(query_BLD.dtype),
355
+ )
356
+ query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
357
+ query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
358
+ scale = 1 / math.sqrt(self.d_head)
359
+
360
+ if output_attentions: # Manual attention computation
361
+ b, h, l, _ = query_BHLD.shape
362
+ attn_bias = torch.zeros(b, h, l, l, dtype=query_BLD.dtype, device=query_BLD.device)
363
+ if attention_mask is not None:
364
+ attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
365
+ attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
366
+ attn_weights += attn_bias
367
+ attn_weights = F.softmax(attn_weights, dim=-1)
368
+ context_BHLD = torch.matmul(attn_weights, value_BHLD)
369
+ else:
370
+ sdpa_mask = None
371
+ if attention_mask is not None:
372
+ sdpa_mask = torch.zeros_like(attention_mask, dtype=query_BHLD.dtype)
373
+ sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
374
+ use_flex = (
375
+ self.attn_backend == "flex"
376
+ and (attention_mask is None or flex_block_mask is not None)
377
+ )
378
+ if use_flex:
379
+ try:
380
+ context_BHLD = flex_attention(
381
+ query_BHLD,
382
+ key_BHLD,
383
+ value_BHLD,
384
+ block_mask=flex_block_mask,
385
+ scale=scale,
386
+ )
387
+ except Exception as exc:
388
+ if not self._warned_flex_fallback:
389
+ warnings.warn(
390
+ f"Flex attention failed in ESM++ attention; falling back to SDPA. Error: {exc}",
391
+ RuntimeWarning,
392
+ )
393
+ self._warned_flex_fallback = True
394
+ context_BHLD = F.scaled_dot_product_attention(
395
+ query_BHLD,
396
+ key_BHLD,
397
+ value_BHLD,
398
+ attn_mask=sdpa_mask,
399
+ scale=scale,
400
+ )
401
+ else:
402
+ context_BHLD = F.scaled_dot_product_attention(
403
+ query_BHLD,
404
+ key_BHLD,
405
+ value_BHLD,
406
+ attn_mask=sdpa_mask,
407
+ scale=scale,
408
+ )
409
+
410
+ context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
411
+ output = self.out_proj(context_BLD)
412
+ return output, attn_weights
413
+
414
+
415
+ ### Regression Head
416
+ def RegressionHead(d_model: int, output_dim: int, hidden_dim: Optional[int] = None) -> nn.Module:
417
+ """Create a regression head with optional hidden dimension.
418
+
419
+ Args:
420
+ d_model: Input dimension
421
+ output_dim: Output dimension
422
+ hidden_dim: Optional hidden dimension (defaults to d_model)
423
+ """
424
+ hidden_dim = hidden_dim if hidden_dim is not None else d_model
425
+ return nn.Sequential(
426
+ nn.Linear(d_model, hidden_dim),
427
+ nn.GELU(),
428
+ nn.LayerNorm(hidden_dim),
429
+ nn.Linear(hidden_dim, output_dim),
430
+ )
431
+
432
+
433
+ ### Transformer Block
434
+ class UnifiedTransformerBlock(nn.Module):
435
+ """Transformer block with attention and feedforward layers.
436
+
437
+ Args:
438
+ d_model: Model dimension
439
+ n_heads: Number of attention heads
440
+ residue_scaling_factor: Factor for scaling residual connections
441
+ expansion_ratio: Expansion ratio for feedforward network
442
+ """
443
+ def __init__(
444
+ self,
445
+ d_model: int,
446
+ n_heads: int,
447
+ residue_scaling_factor: float = 1,
448
+ expansion_ratio: float = 8 / 3,
449
+ dropout: float = 0.0,
450
+ attn_backend: str = "sdpa",
451
+ ):
452
+ super().__init__()
453
+ self.attn = MultiHeadAttention(
454
+ d_model=d_model,
455
+ n_heads=n_heads,
456
+ attn_backend=attn_backend,
457
+ )
458
+ self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
459
+ self.scaling_factor = residue_scaling_factor
460
+ self.dropout = nn.Dropout(dropout)
461
+
462
+ def forward(
463
+ self,
464
+ x: torch.Tensor,
465
+ attention_mask: Optional[torch.Tensor] = None,
466
+ flex_block_mask: Optional[object] = None,
467
+ output_attentions: bool = False,
468
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
469
+ """
470
+ Args:
471
+ x: Input tensor
472
+ attention_mask: Optional attention mask
473
+ output_attentions: Whether to return attention weights
474
+
475
+ Returns:
476
+ Output tensor after transformer block, and optionally attention weights
477
+ """
478
+ attn_output, attn_weights = self.attn(
479
+ x,
480
+ attention_mask,
481
+ flex_block_mask,
482
+ output_attentions,
483
+ )
484
+ x = x + self.dropout(attn_output) / self.scaling_factor
485
+ x = x + self.dropout(self.ffn(x)) / self.scaling_factor
486
+ return x, attn_weights
487
+
488
+
489
+ ### Model Outputs
490
+ @dataclass
491
+ class TransformerOutput(ModelOutput):
492
+ """Output type for transformer encoder."""
493
+ last_hidden_state: Optional[torch.Tensor] = None
494
+ hidden_states: Optional[Tuple[torch.Tensor]] = None
495
+ attentions: Optional[Tuple[torch.Tensor]] = None
496
+
497
+
498
+ @dataclass
499
+ class ESMplusplusOutput(ModelOutput):
500
+ """Output type for ESM++ models."""
501
+ loss: Optional[torch.Tensor] = None
502
+ logits: Optional[torch.Tensor] = None
503
+ last_hidden_state: Optional[torch.Tensor] = None
504
+ hidden_states: Optional[Tuple[torch.Tensor]] = None
505
+ attentions: Optional[Tuple[torch.Tensor]] = None
506
+
507
+
508
+ ### Transformer Stack
509
+ class TransformerStack(nn.Module):
510
+ """Stack of transformer blocks.
511
+
512
+ Args:
513
+ d_model: Model dimension
514
+ n_heads: Number of attention heads
515
+ n_layers: Number of transformer layers
516
+ dropout: Dropout rate
517
+ """
518
+ def __init__(
519
+ self,
520
+ d_model: int,
521
+ n_heads: int,
522
+ n_layers: int,
523
+ dropout: float = 0.0,
524
+ attn_backend: str = "sdpa",
525
+ ):
526
+ super().__init__()
527
+ self.attn_backend = attn_backend
528
+ self.blocks = nn.ModuleList(
529
+ [
530
+ UnifiedTransformerBlock(
531
+ d_model,
532
+ n_heads,
533
+ residue_scaling_factor=math.sqrt(n_layers / 36),
534
+ dropout=dropout,
535
+ attn_backend=attn_backend,
536
+ )
537
+ for i in range(n_layers)
538
+ ]
539
+ )
540
+ self.norm = nn.LayerNorm(d_model, bias=False)
541
+ self.gradient_checkpointing = False
542
+
543
+ def forward(
544
+ self,
545
+ x: torch.Tensor,
546
+ attention_mask: Optional[torch.Tensor] = None,
547
+ output_hidden_states: bool = False,
548
+ output_attentions: bool = False,
549
+ ) -> TransformerOutput:
550
+ """
551
+ Args:
552
+ x: Input tensor
553
+ attention_mask: Optional attention mask
554
+ output_hidden_states: Whether to return all hidden states
555
+ output_attentions: Whether to return attention weights
556
+
557
+ Returns:
558
+ TransformerOutput containing last hidden state and optionally all hidden states and attention weights
559
+ """
560
+ hidden_states = () if output_hidden_states else None
561
+ attentions = () if output_attentions else None
562
+
563
+ if attention_mask is not None:
564
+ assert attention_mask.ndim == 2, f"Expected 2D token attention mask, got shape {attention_mask.shape}."
565
+ token_attention_mask = attention_mask.bool()
566
+ pairwise_attention_mask = token_attention_mask.unsqueeze(-1) & token_attention_mask.unsqueeze(-2)
567
+ attention_mask = pairwise_attention_mask.unsqueeze(1)
568
+ if self.attn_backend == "flex" and not output_attentions:
569
+ flex_block_mask = _create_pad_block_mask(token_attention_mask)
570
+ else:
571
+ flex_block_mask = None
572
+ else:
573
+ flex_block_mask = None
574
+
575
+ for block in self.blocks:
576
+ if self.gradient_checkpointing and self.training:
577
+ x, attn_weights = self._gradient_checkpointing_func(
578
+ block.__call__,
579
+ x,
580
+ attention_mask,
581
+ flex_block_mask,
582
+ output_attentions,
583
+ )
584
+ else:
585
+ x, attn_weights = block(x, attention_mask, flex_block_mask, output_attentions)
586
+
587
+ if attentions is not None:
588
+ attentions += (attn_weights,)
589
+
590
+ if output_hidden_states:
591
+ assert hidden_states is not None
592
+ hidden_states += (x,)
593
+
594
+ return TransformerOutput(
595
+ last_hidden_state=self.norm(x),
596
+ hidden_states=hidden_states,
597
+ attentions=attentions
598
+ )
599
+
600
+
601
+ class PreTrainedESMplusplusModel(PreTrainedModel):
602
+ """
603
+ init weights for ESM++ models
604
+ """
605
+ config_class = ESMplusplusConfig
606
+ base_model_prefix = "esm++"
607
+ supports_gradient_checkpointing = True
608
+ all_tied_weights_keys = {}
609
+
610
+ def _init_weights(self, module):
611
+ """Initialize the weights"""
612
+ if isinstance(module, nn.Linear):
613
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
614
+ if module.bias is not None:
615
+ module.bias.data.zero_()
616
+ elif isinstance(module, nn.Embedding):
617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
618
+ if module.padding_idx is not None:
619
+ module.weight.data[module.padding_idx].zero_()
620
+ elif isinstance(module, nn.LayerNorm):
621
+ if module.bias is not None:
622
+ module.bias.data.zero_()
623
+ module.weight.data.fill_(1.0)
624
+
625
+ @classmethod
626
+ def from_pretrained_esm(cls, model_name: str):
627
+ """Load a pretrained ESM++ model."""
628
+ if '300' in model_name:
629
+ return ESMplusplus_300M()
630
+ elif '600' in model_name:
631
+ return ESMplusplus_600M()
632
+ else:
633
+ raise ValueError(f"Invalid model name: {model_name}")
634
+
635
+
636
+ ### ESM++ Models
637
+ class ESMplusplusModel(PreTrainedESMplusplusModel, EmbeddingMixin):
638
+ """
639
+ ESM++ model. transformer model with no heads
640
+ """
641
+ config_class = ESMplusplusConfig
642
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
643
+ PreTrainedESMplusplusModel.__init__(self, config, **kwargs)
644
+ self.config = config
645
+ self.vocab_size = config.vocab_size
646
+ self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
647
+ self.transformer = TransformerStack(
648
+ d_model=config.hidden_size,
649
+ n_heads=config.num_attention_heads,
650
+ n_layers=config.num_hidden_layers,
651
+ dropout=config.dropout,
652
+ attn_backend=config.attn_backend,
653
+ )
654
+ self.tokenizer = EsmSequenceTokenizer()
655
+ self.init_weights()
656
+
657
+ def get_input_embeddings(self):
658
+ return self.embed
659
+
660
+ def set_input_embeddings(self, value):
661
+ self.embed = value
662
+
663
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
664
+ x = self.embed(input_ids)
665
+ return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
666
+
667
+ def forward(
668
+ self,
669
+ input_ids: Optional[torch.Tensor] = None,
670
+ attention_mask: Optional[torch.Tensor] = None,
671
+ inputs_embeds: Optional[torch.Tensor] = None,
672
+ output_attentions: Optional[bool] = None,
673
+ output_hidden_states: Optional[bool] = None,
674
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
675
+ **kwargs,
676
+ ) -> TransformerOutput:
677
+ """Forward pass for masked language modeling.
678
+
679
+ Args:
680
+ input_ids: Input token IDs
681
+ attention_mask: Attention mask
682
+ inputs_embeds: Optional precomputed embeddings
683
+ output_hidden_states: Whether to return all hidden states
684
+ output_attentions: Whether to return attention weights
685
+
686
+ Returns:
687
+ TransformerOutput containing last hidden state and optionally all hidden states and attention weights
688
+ """
689
+ if inputs_embeds is None:
690
+ x = self.embed(input_ids)
691
+ else:
692
+ x = inputs_embeds
693
+ return self.transformer(x, attention_mask, output_hidden_states, output_attentions)
694
+
695
+
696
+ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
697
+ """
698
+ ESM++ model for masked language modeling.
699
+ Implements the base ESM++ architecture with a masked language modeling head.
700
+ """
701
+ config_class = ESMplusplusConfig
702
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
703
+ PreTrainedESMplusplusModel.__init__(self, config, **kwargs)
704
+ self.config = config
705
+ self.vocab_size = config.vocab_size
706
+ self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
707
+ self.transformer = TransformerStack(
708
+ d_model=config.hidden_size,
709
+ n_heads=config.num_attention_heads,
710
+ n_layers=config.num_hidden_layers,
711
+ dropout=config.dropout,
712
+ attn_backend=config.attn_backend,
713
+ )
714
+ self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
715
+ self.ce_loss = nn.CrossEntropyLoss()
716
+ self.tokenizer = EsmSequenceTokenizer()
717
+ self.init_weights()
718
+
719
+ def get_input_embeddings(self):
720
+ return self.embed
721
+
722
+ def set_input_embeddings(self, value):
723
+ self.embed = value
724
+
725
+ def get_output_embeddings(self):
726
+ return self.sequence_head[-1]
727
+
728
+ def set_output_embeddings(self, new_embeddings):
729
+ self.sequence_head[-1] = new_embeddings
730
+
731
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
732
+ x = self.embed(input_ids)
733
+ return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
734
+
735
+ def forward(
736
+ self,
737
+ input_ids: Optional[torch.Tensor] = None,
738
+ attention_mask: Optional[torch.Tensor] = None,
739
+ inputs_embeds: Optional[torch.Tensor] = None,
740
+ labels: Optional[torch.Tensor] = None,
741
+ output_attentions: Optional[bool] = None,
742
+ output_hidden_states: Optional[bool] = None,
743
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
744
+ **kwargs,
745
+ ) -> ESMplusplusOutput:
746
+ """Forward pass for masked language modeling.
747
+
748
+ Args:
749
+ input_ids: Input token IDs
750
+ attention_mask: Attention mask
751
+ inputs_embeds: Optional precomputed embeddings
752
+ labels: Optional labels for masked tokens
753
+ output_hidden_states: Whether to return all hidden states
754
+ output_attentions: Whether to return attention weights
755
+
756
+ Returns:
757
+ ESMplusplusOutput containing loss, logits, hidden states and attention weights
758
+ """
759
+ if inputs_embeds is None:
760
+ x = self.embed(input_ids)
761
+ else:
762
+ x = inputs_embeds
763
+ output = self.transformer(x, attention_mask, output_hidden_states, output_attentions)
764
+ x = output.last_hidden_state
765
+ logits = self.sequence_head(x)
766
+ loss = None
767
+ if labels is not None:
768
+ loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
769
+ return ESMplusplusOutput(
770
+ loss=loss,
771
+ logits=logits,
772
+ last_hidden_state=x,
773
+ hidden_states=output.hidden_states,
774
+ attentions=output.attentions,
775
+ )
776
+
777
+
778
+ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixin):
779
+ """
780
+ ESM++ model for sequence classification.
781
+ Extends the base ESM++ model with a classification head.
782
+ """
783
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
784
+ ESMplusplusForMaskedLM.__init__(self, config, **kwargs)
785
+ self.config = config
786
+ self.num_labels = config.num_labels
787
+ self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
788
+ # Large intermediate projections help with sequence classification tasks (*4)
789
+ self.mse = nn.MSELoss()
790
+ self.ce = nn.CrossEntropyLoss()
791
+ self.bce = nn.BCEWithLogitsLoss()
792
+ # if kwargs has pooling_types, use them, otherwise use ['cls', 'mean']
793
+ if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0:
794
+ pooling_types = kwargs['pooling_types']
795
+ else:
796
+ pooling_types = ['cls', 'mean']
797
+ self.pooler = Pooler(pooling_types)
798
+ self.init_weights()
799
+
800
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
801
+ x = self.embed(input_ids)
802
+ return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
803
+
804
+ def forward(
805
+ self,
806
+ input_ids: Optional[torch.Tensor] = None,
807
+ attention_mask: Optional[torch.Tensor] = None,
808
+ inputs_embeds: Optional[torch.Tensor] = None,
809
+ labels: Optional[torch.Tensor] = None,
810
+ output_attentions: Optional[bool] = None,
811
+ output_hidden_states: Optional[bool] = None,
812
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
813
+ **kwargs,
814
+ ) -> ESMplusplusOutput:
815
+ """Forward pass for sequence classification.
816
+
817
+ Args:
818
+ input_ids: Input token IDs
819
+ attention_mask: Attention mask
820
+ inputs_embeds: Optional precomputed embeddings
821
+ labels: Optional labels for classification
822
+ output_hidden_states: Whether to return all hidden states
823
+ output_attentions: Whether to return attention weights
824
+
825
+ Returns:
826
+ ESMplusplusOutput containing loss, logits, and hidden states
827
+ """
828
+ output = super().forward(
829
+ input_ids=input_ids,
830
+ attention_mask=attention_mask,
831
+ inputs_embeds=inputs_embeds,
832
+ labels=None,
833
+ output_attentions=output_attentions,
834
+ output_hidden_states=output_hidden_states
835
+ )
836
+ x = output.last_hidden_state
837
+ features = self.pooler(x, attention_mask)
838
+ logits = self.classifier(features)
839
+ loss = None
840
+ if labels is not None:
841
+ labels = labels.to(logits.device)
842
+ if self.config.problem_type is None:
843
+ if self.num_labels == 1:
844
+ self.config.problem_type = "regression"
845
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
846
+ self.config.problem_type = "single_label_classification"
847
+ else:
848
+ self.config.problem_type = "multi_label_classification"
849
+
850
+ if self.config.problem_type == "regression":
851
+ if self.num_labels == 1:
852
+ loss = self.mse(logits.flatten(), labels.flatten())
853
+ else:
854
+ loss = self.mse(logits, labels)
855
+ elif self.config.problem_type == "single_label_classification":
856
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
857
+ elif self.config.problem_type == "multi_label_classification":
858
+ loss = self.bce(logits, labels)
859
+
860
+ return ESMplusplusOutput(
861
+ loss=loss,
862
+ logits=logits,
863
+ last_hidden_state=x,
864
+ hidden_states=output.hidden_states,
865
+ attentions=output.attentions,
866
+ )
867
+
868
+
869
+ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM, EmbeddingMixin):
870
+ """
871
+ ESM++ model for token classification.
872
+ Extends the base ESM++ model with a token classification head.
873
+ """
874
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
875
+ ESMplusplusForMaskedLM.__init__(self, config, **kwargs)
876
+ self.config = config
877
+ self.num_labels = config.num_labels
878
+ self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
879
+ # Large intermediate projections help with sequence classification tasks (*4)
880
+ self.loss_fct = nn.CrossEntropyLoss()
881
+ self.init_weights()
882
+
883
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
884
+ x = self.embed(input_ids)
885
+ return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
886
+
887
+ def forward(
888
+ self,
889
+ input_ids: Optional[torch.Tensor] = None,
890
+ attention_mask: Optional[torch.Tensor] = None,
891
+ inputs_embeds: Optional[torch.Tensor] = None,
892
+ labels: Optional[torch.Tensor] = None,
893
+ output_attentions: Optional[bool] = None,
894
+ output_hidden_states: Optional[bool] = None,
895
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
896
+ **kwargs,
897
+ ) -> ESMplusplusOutput:
898
+ """Forward pass for token classification.
899
+
900
+ Args:
901
+ input_ids: Input token IDs
902
+ attention_mask: Attention mask
903
+ inputs_embeds: Optional precomputed embeddings
904
+ labels: Optional labels for token classification
905
+ output_hidden_states: Whether to return all hidden states
906
+ output_attentions: Whether to return attention weights
907
+
908
+ Returns:
909
+ ESMplusplusOutput containing loss, logits, and hidden states
910
+ """
911
+ output = super().forward(
912
+ input_ids=input_ids,
913
+ attention_mask=attention_mask,
914
+ inputs_embeds=inputs_embeds,
915
+ labels=None,
916
+ output_attentions=output_attentions,
917
+ output_hidden_states=output_hidden_states
918
+ )
919
+ x = output.last_hidden_state
920
+ logits = self.classifier(x)
921
+ loss = None
922
+ if labels is not None:
923
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
924
+ return ESMplusplusOutput(
925
+ loss=loss,
926
+ logits=logits,
927
+ last_hidden_state=x,
928
+ hidden_states=output.hidden_states,
929
+ attentions=output.attentions,
930
+ )
931
+
932
+
933
+ ### Loading from EvolutionaryScale
934
+ _ESMC_CHECKPOINT_SPECS = {
935
+ "esmc-300": {
936
+ "repo_id": "EvolutionaryScale/esmc-300m-2024-12",
937
+ "weights_relpath": "data/weights/esmc_300m_2024_12_v0.pth",
938
+ "hidden_size": 960,
939
+ "num_attention_heads": 15,
940
+ "num_hidden_layers": 30,
941
+ },
942
+ "esmc-600": {
943
+ "repo_id": "EvolutionaryScale/esmc-600m-2024-12",
944
+ "weights_relpath": "data/weights/esmc_600m_2024_12_v0.pth",
945
+ "hidden_size": 1152,
946
+ "num_attention_heads": 18,
947
+ "num_hidden_layers": 36,
948
+ },
949
+ }
950
+
951
+
952
+ def _resolve_esmc_checkpoint_key(model: str) -> str:
953
+ if "esmc-300" in model:
954
+ return "esmc-300"
955
+ if "esmc-600" in model:
956
+ return "esmc-600"
957
+ raise ValueError(f"{model=} is an invalid ESMC model name.")
958
+
959
+
960
+ @staticmethod
961
+ @cache
962
+ def data_root(model: str):
963
+ if "INFRA_PROVIDER" in os.environ:
964
+ return Path("")
965
+ key = _resolve_esmc_checkpoint_key(model)
966
+ return Path(snapshot_download(repo_id=_ESMC_CHECKPOINT_SPECS[key]["repo_id"]))
967
+
968
+
969
+ def get_esmc_checkpoint_path(model: str) -> Path:
970
+ key = _resolve_esmc_checkpoint_key(model)
971
+ return data_root(key) / _ESMC_CHECKPOINT_SPECS[key]["weights_relpath"]
972
+
973
+
974
+ def _load_esmc_checkpoint_model(
975
+ config: ESMplusplusConfig,
976
+ model: str,
977
+ device: torch.device | str = "cpu",
978
+ ) -> ESMplusplusForMaskedLM:
979
+ key = _resolve_esmc_checkpoint_key(model)
980
+ spec = _ESMC_CHECKPOINT_SPECS[key]
981
+ assert config.hidden_size == spec["hidden_size"], (
982
+ f"ESMC loader expected hidden_size={spec['hidden_size']} for {key}, "
983
+ f"but got {config.hidden_size}."
984
+ )
985
+ assert config.num_attention_heads == spec["num_attention_heads"], (
986
+ f"ESMC loader expected num_attention_heads={spec['num_attention_heads']} for {key}, "
987
+ f"but got {config.num_attention_heads}."
988
+ )
989
+ assert config.num_hidden_layers == spec["num_hidden_layers"], (
990
+ f"ESMC loader expected num_hidden_layers={spec['num_hidden_layers']} for {key}, "
991
+ f"but got {config.num_hidden_layers}."
992
+ )
993
+ with torch.device(device):
994
+ model_obj = ESMplusplusForMaskedLM(config)
995
+ state_dict = torch.load(get_esmc_checkpoint_path(key), map_location=device)
996
+ model_obj.load_state_dict(state_dict)
997
+ return model_obj
998
+
999
+
1000
+ def ESMplusplus_300M(device: torch.device | str = "cpu"):
1001
+ config = ESMplusplusConfig(
1002
+ hidden_size=960,
1003
+ num_attention_heads=15,
1004
+ num_hidden_layers=30,
1005
+ )
1006
+ return _load_esmc_checkpoint_model(config=config, model="esmc-300", device=device)
1007
+
1008
+
1009
+ def ESMplusplus_600M(device: torch.device | str = "cpu"):
1010
+ config = ESMplusplusConfig(
1011
+ hidden_size=1152,
1012
+ num_attention_heads=18,
1013
+ num_hidden_layers=36,
1014
+ )
1015
+ return _load_esmc_checkpoint_model(config=config, model="esmc-600", device=device)
1016
+
1017
+
1018
+ ### Tokenization
1019
+ SEQUENCE_VOCAB = [
1020
+ "<cls>", "<pad>", "<eos>", "<unk>",
1021
+ "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
1022
+ "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
1023
+ "O", ".", "-", "|",
1024
+ "<mask>",
1025
+ ]
1026
+
1027
+ class EsmSequenceTokenizer(PreTrainedTokenizerFast):
1028
+ model_input_names = ["input_ids", "attention_mask"]
1029
+
1030
+ def __init__(
1031
+ self,
1032
+ unk_token="<unk>",
1033
+ cls_token="<cls>",
1034
+ pad_token="<pad>",
1035
+ mask_token="<mask>",
1036
+ eos_token="<eos>",
1037
+ chain_break_token="|",
1038
+ **kwargs,
1039
+ ):
1040
+ all_tokens = SEQUENCE_VOCAB
1041
+ token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
1042
+
1043
+ # a character-level tokenizer is the same as BPE with no token merges
1044
+ bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
1045
+ tokenizer = Tokenizer(bpe)
1046
+ special_tokens = [
1047
+ cls_token,
1048
+ pad_token,
1049
+ mask_token,
1050
+ eos_token,
1051
+ chain_break_token,
1052
+ ]
1053
+ self.cb_token = chain_break_token
1054
+ additional_special_tokens = [chain_break_token]
1055
+
1056
+ tokenizer.add_special_tokens(special_tokens)
1057
+
1058
+ # This is where we configure the automatic addition of special tokens when we call
1059
+ # tokenizer(text, add_special_tokens=True). Note that you can also configure how two
1060
+ # sequences are merged if you want.
1061
+ tokenizer.post_processor = TemplateProcessing( # type: ignore
1062
+ single="<cls> $A <eos>",
1063
+ pair="<cls>:0 $A:0 <eos>:0 $B:1 <eos>:1",
1064
+ special_tokens=[
1065
+ ("<cls>", tokenizer.token_to_id("<cls>")),
1066
+ ("<eos>", tokenizer.token_to_id("<eos>")),
1067
+ ],
1068
+ )
1069
+ super().__init__(
1070
+ tokenizer_object=tokenizer,
1071
+ unk_token=unk_token,
1072
+ cls_token=cls_token,
1073
+ pad_token=pad_token,
1074
+ mask_token=mask_token,
1075
+ eos_token=eos_token,
1076
+ additional_special_tokens=additional_special_tokens,
1077
+ **kwargs,
1078
+ )
1079
+
1080
+ # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
1081
+ @property
1082
+ def bos_token(self):
1083
+ return self.cls_token
1084
+
1085
+ @property
1086
+ def bos_token_id(self):
1087
+ return self.cls_token_id
1088
+
1089
+ @property
1090
+ def chain_break_token(self):
1091
+ return self.cb_token
1092
+
1093
+ @property
1094
+ def chain_break_token_id(self):
1095
+ return self.convert_tokens_to_ids(self.chain_break_token)
1096
+
1097
+ @property
1098
+ def all_token_ids(self):
1099
+ return list(range(self.vocab_size))
1100
+
1101
+ @property
1102
+ def special_token_ids(self):
1103
+ return self.all_special_ids
1104
+
1105
+
1106
+ if __name__ == "__main__":
1107
+ # Set device to CPU for testing
1108
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1109
+ print(f"Using device: {device}")
1110
+
1111
+ # Test tokenizer
1112
+ tokenizer = EsmSequenceTokenizer()
1113
+ sample_sequence = "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG"
1114
+ encoding = tokenizer(sample_sequence, return_tensors="pt")
1115
+ print(f"Input sequence length: {len(sample_sequence)}")
1116
+ print(f"Tokenized sequence: {encoding['input_ids'].shape}")
1117
+
1118
+ # Prepare inputs
1119
+ input_ids = encoding['input_ids'].to(device)
1120
+ attention_mask = encoding['attention_mask'].to(device)
1121
+
1122
+ # Test base model with smaller config for quick testing
1123
+ print("\n=== Testing ESMplusplus Base Model ===")
1124
+ base_config = ESMplusplusConfig(
1125
+ hidden_size=384,
1126
+ num_attention_heads=6,
1127
+ num_hidden_layers=4
1128
+ )
1129
+ base_model = ESMplusplusModel(base_config).to(device)
1130
+
1131
+ with torch.no_grad():
1132
+ outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
1133
+
1134
+ print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1135
+
1136
+ # Test embedding functionality
1137
+ print("\nTesting embedding functionality:")
1138
+ with torch.no_grad():
1139
+ embeddings = base_model._embed(input_ids, attention_mask)
1140
+ print(f"Embedding shape: {embeddings.shape}")
1141
+
1142
+ # Test masked language modeling
1143
+ print("\n=== Testing ESMplusplus For Masked LM ===")
1144
+ mlm_model = ESMplusplusForMaskedLM(base_config).to(device)
1145
+
1146
+ with torch.no_grad():
1147
+ outputs = mlm_model(input_ids=input_ids, attention_mask=attention_mask)
1148
+
1149
+ print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1150
+ print(f"Logits shape: {outputs.logits.shape}")
1151
+
1152
+ # Test sequence classification model
1153
+ print("\n=== Testing Sequence Classification Model ===")
1154
+ classification_model = ESMplusplusForSequenceClassification(base_config).to(device)
1155
+
1156
+ with torch.no_grad():
1157
+ outputs = classification_model(input_ids=input_ids, attention_mask=attention_mask)
1158
+
1159
+ print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1160
+ print(f"Logits shape: {outputs.logits.shape}")
1161
+
1162
+ # Test token classification model
1163
+ print("\n=== Testing Token Classification Model ===")
1164
+ token_model = ESMplusplusForTokenClassification(base_config).to(device)
1165
+
1166
+ with torch.no_grad():
1167
+ outputs = token_model(input_ids=input_ids, attention_mask=attention_mask)
1168
+
1169
+ print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1170
+ print(f"Logits shape: {outputs.logits.shape}")
1171
+
1172
+ # Test embedding dataset functionality with a mini dataset
1173
+ print("\n=== Testing Embed Dataset Functionality ===")
1174
+ mini_dataset = [sample_sequence, sample_sequence[:50], sample_sequence[:30]]
1175
+ print(f"Creating embeddings for {len(mini_dataset)} sequences")
1176
+
1177
+ # Only run this if save path doesn't exist to avoid overwriting
1178
+ if not os.path.exists("test_embeddings.pth"):
1179
+ embeddings = mlm_model.embed_dataset(
1180
+ sequences=mini_dataset,
1181
+ tokenizer=tokenizer,
1182
+ batch_size=2,
1183
+ max_len=100,
1184
+ full_embeddings=False,
1185
+ pooling_types=['mean'],
1186
+ save_path="test_embeddings.pth"
1187
+ )
1188
+ if embeddings:
1189
+ print(f"Embedding dictionary size: {len(embeddings)}")
1190
+ for seq, emb in embeddings.items():
1191
+ print(f"Sequence length: {len(seq)}, Embedding shape: {emb.shape}")
1192
+ break
1193
+ else:
1194
+ print("Skipping embedding test as test_embeddings.pth already exists")
1195
+
1196
+ print("\nAll tests completed successfully!")
1197
+