lhallee commited on
Commit
d1e8c88
·
verified ·
1 Parent(s): 34ca892

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +1038 -1032
modeling_fastesm.py CHANGED
@@ -1,1032 +1,1038 @@
1
- import entrypoint_setup
2
- import torch
3
- import torch.nn as nn
4
- from torch.nn import functional as F
5
- from typing import Optional, Tuple, Union, Dict, Any
6
- from einops import rearrange
7
- from dataclasses import dataclass
8
- from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
9
- from transformers.modeling_outputs import (
10
- ModelOutput,
11
- BaseModelOutputWithPastAndCrossAttentions,
12
- BaseModelOutputWithPoolingAndCrossAttentions,
13
- SequenceClassifierOutput,
14
- TokenClassifierOutput
15
- )
16
- from transformers.models.esm.modeling_esm import (
17
- EsmIntermediate,
18
- EsmOutput,
19
- EsmPooler,
20
- EsmLMHead,
21
- EsmSelfOutput,
22
- EsmClassificationHead,
23
- )
24
- try:
25
- from torch.nn.attention.flex_attention import create_block_mask
26
- from torch.nn.attention.flex_attention import flex_attention
27
- except ImportError:
28
- create_block_mask = None
29
- flex_attention = None
30
-
31
- try:
32
- # when used from AutoModel, these are in the same directory
33
- from .embedding_mixin import EmbeddingMixin, Pooler
34
- except:
35
- # when running from our repo, these are in the base directory
36
- from embedding_mixin import EmbeddingMixin, Pooler
37
-
38
-
39
- def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
40
- assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
41
- token_valid = attention_mask_2d.bool()
42
- batch_size, seq_len = token_valid.shape
43
-
44
- def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
45
- return token_valid[batch_idx, q_idx] & token_valid[batch_idx, kv_idx]
46
-
47
- return create_block_mask(
48
- mask_mod,
49
- batch_size,
50
- 1,
51
- seq_len,
52
- seq_len,
53
- device=attention_mask_2d.device,
54
- )
55
-
56
-
57
- @dataclass
58
- class EsmMaskedLMOutput(ModelOutput):
59
- loss: Optional[torch.Tensor] = None
60
- logits: Optional[torch.Tensor] = None
61
- last_hidden_state: Optional[torch.Tensor] = None
62
- hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
63
- attentions: Optional[Tuple[torch.Tensor, ...]] = None
64
-
65
-
66
- class FastEsmConfig(PretrainedConfig):
67
- model_type = "fast_esm"
68
- def __init__(
69
- self,
70
- vocab_size: int = None,
71
- mask_token_id: int = None,
72
- pad_token_id: int = None,
73
- hidden_size: int = 768,
74
- num_hidden_layers: int = 12,
75
- num_attention_heads: int = 12,
76
- intermediate_size: int = 3072,
77
- hidden_dropout_prob: float = 0.1,
78
- attention_probs_dropout_prob: float = 0.1,
79
- max_position_embeddings: int = 1026,
80
- initializer_range: float = 0.02,
81
- layer_norm_eps: float = 1e-12,
82
- position_embedding_type: str = "absolute",
83
- emb_layer_norm_before: bool = None,
84
- token_dropout: bool = True,
85
- attn_backend: str = "sdpa",
86
- **kwargs,
87
- ):
88
- super().__init__(
89
- pad_token_id=pad_token_id,
90
- mask_token_id=mask_token_id,
91
- **kwargs,
92
- )
93
-
94
- self.vocab_size = vocab_size
95
- self.hidden_size = hidden_size
96
- self.num_hidden_layers = num_hidden_layers
97
- self.num_attention_heads = num_attention_heads
98
- self.intermediate_size = intermediate_size
99
- self.hidden_dropout_prob = hidden_dropout_prob
100
- self.attention_probs_dropout_prob = attention_probs_dropout_prob
101
- self.max_position_embeddings = max_position_embeddings
102
- self.initializer_range = initializer_range
103
- self.layer_norm_eps = layer_norm_eps
104
- self.position_embedding_type = position_embedding_type
105
- self.emb_layer_norm_before = emb_layer_norm_before
106
- self.tie_word_embeddings = False
107
- self.token_dropout = token_dropout
108
- self.attn_backend = attn_backend
109
-
110
- def to_dict(self) -> Dict[str, Any]:
111
- """
112
- Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
113
-
114
- Returns:
115
- `Dict[str, any]`: Dictionar y of all the attributes that make up this configuration instance,
116
- """
117
- output = super().to_dict()
118
- return output
119
-
120
-
121
- def rotate_half(x: torch.Tensor) -> torch.Tensor:
122
- x1, x2 = x.chunk(2, dim=-1)
123
- return torch.cat((-x2, x1), dim=-1)
124
-
125
-
126
- def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
127
- cos = cos[:, :, : x.shape[-2], :]
128
- sin = sin[:, :, : x.shape[-2], :]
129
-
130
- return (x * cos) + (rotate_half(x) * sin)
131
-
132
-
133
- def symmetrize(x: torch.Tensor) -> torch.Tensor:
134
- "Make layer symmetric in final two dimensions, used for contact prediction."
135
- return x + x.transpose(-1, -2)
136
-
137
-
138
- def average_product_correct(x: torch.Tensor) -> torch.Tensor:
139
- "Perform average product correct, used for contact prediction."
140
- a1 = x.sum(-1, keepdims=True)
141
- a2 = x.sum(-2, keepdims=True)
142
- a12 = x.sum((-1, -2), keepdims=True)
143
-
144
- avg = a1 * a2
145
- avg.div_(a12) # in-place to reduce memory
146
- normalized = x - avg
147
- return normalized
148
-
149
-
150
- class EsmContactPredictionHead(nn.Module):
151
- """Performs symmetrization, apc, and computes a logistic regression on the output features"""
152
-
153
- def __init__(
154
- self,
155
- in_features: int,
156
- bias: bool = True,
157
- eos_idx: int = 2,
158
- ):
159
- super().__init__()
160
- self.in_features = in_features
161
- self.eos_idx = eos_idx
162
- self.regression = nn.Linear(in_features, 1, bias=bias)
163
- self.activation = nn.Sigmoid()
164
-
165
- def forward(self, input_ids: torch.Tensor, attentions: torch.Tensor) -> torch.Tensor:
166
- # remove eos token attentions
167
- eos_mask = input_ids.ne(self.eos_idx).to(attentions)
168
- eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
169
- attentions = attentions * eos_mask[:, None, None, :, :]
170
- attentions = attentions[..., :-1, :-1]
171
- # remove cls token attentions
172
- attentions = attentions[..., 1:, 1:]
173
- batch_size, layers, heads, seqlen, _ = attentions.size()
174
- attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
175
-
176
- # features: batch x channels x tokens x tokens (symmetric)
177
- attentions = attentions.to(
178
- self.regression.weight.device
179
- ) # attentions always float32, may need to convert to float16
180
- attentions = average_product_correct(symmetrize(attentions))
181
- attentions = attentions.permute(0, 2, 3, 1)
182
- return self.activation(self.regression(attentions).squeeze(3))
183
-
184
-
185
- class RotaryEmbedding(torch.nn.Module):
186
- """
187
- Rotary position embeddings based on those in
188
- [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
189
- matrices which depend on their relative positions.
190
- """
191
-
192
- def __init__(self, dim: int):
193
- super().__init__()
194
- # Generate and save the inverse frequency buffer (non trainable)
195
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
196
- inv_freq = inv_freq
197
- self.register_buffer("inv_freq", inv_freq)
198
-
199
- self._seq_len_cached = None
200
- self._cos_cached = None
201
- self._sin_cached = None
202
-
203
- def _update_cos_sin_tables(self, x: torch.Tensor, seq_dimension: int = 2) -> Tuple[torch.Tensor, torch.Tensor]:
204
- seq_len = x.shape[seq_dimension]
205
-
206
- # Reset the tables if the sequence length has changed,
207
- # or if we're on a new device (possibly due to tracing for instance)
208
- if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
209
- self._seq_len_cached = seq_len
210
- t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
211
- freqs = torch.outer(t, self.inv_freq)
212
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
213
-
214
- self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
215
- self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
216
-
217
- return self._cos_cached, self._sin_cached
218
-
219
- def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
220
- self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
221
-
222
- return (
223
- apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
224
- apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
225
- )
226
-
227
-
228
- class EsmEmbeddings(nn.Module):
229
- """
230
- Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
231
- """
232
-
233
- def __init__(self, config):
234
- super().__init__()
235
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
236
- if config.emb_layer_norm_before:
237
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
238
- else:
239
- self.layer_norm = None
240
- self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
241
- self.register_buffer(
242
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
243
- )
244
- self.token_dropout = config.token_dropout
245
- self.mask_token_id = config.mask_token_id
246
-
247
- def forward(
248
- self,
249
- input_ids: Optional[torch.Tensor] = None,
250
- attention_mask: Optional[torch.Tensor] = None,
251
- position_ids: Optional[torch.Tensor] = None,
252
- inputs_embeds: Optional[torch.Tensor] = None,
253
- past_key_values_length: Optional[int] = 0,
254
- ):
255
- if inputs_embeds is None:
256
- inputs_embeds = self.word_embeddings(input_ids)
257
-
258
- embeddings = inputs_embeds
259
-
260
- if attention_mask is None:
261
- attention_mask = torch.ones_like(input_ids)
262
-
263
- if self.token_dropout:
264
- embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0)
265
- mask_ratio_train = 0.15 * 0.8
266
- src_lengths = attention_mask.sum(-1)
267
- mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
268
- embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
269
- embeddings.dtype
270
- )
271
-
272
- if self.layer_norm is not None:
273
- embeddings = self.layer_norm(embeddings)
274
- if attention_mask is not None:
275
- embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
276
- return embeddings
277
-
278
- def create_position_ids_from_inputs_embeds(self, inputs_embeds):
279
- """
280
- We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
281
-
282
- Args:
283
- inputs_embeds: torch.Tensor
284
-
285
- Returns: torch.Tensor
286
- """
287
- input_shape = inputs_embeds.size()[:-1]
288
- sequence_length = input_shape[1]
289
-
290
- position_ids = torch.arange(
291
- self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
292
- )
293
- return position_ids.unsqueeze(0).expand(input_shape)
294
-
295
-
296
- class EsmSelfAttention(nn.Module):
297
- def __init__(self, config, position_embedding_type: Optional[str] = None):
298
- super().__init__()
299
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
300
- raise ValueError(
301
- f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
302
- f"heads ({config.num_attention_heads})"
303
- )
304
-
305
- self.num_attention_heads = config.num_attention_heads
306
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
307
- self.all_head_size = self.num_attention_heads * self.attention_head_size
308
-
309
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
310
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
311
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
312
- self.scale = self.attention_head_size**-0.5
313
-
314
- self.dropout_prob = config.attention_probs_dropout_prob
315
- self.attn_backend = config.attn_backend
316
- self.position_embedding_type = position_embedding_type or getattr(
317
- config, "position_embedding_type", "absolute"
318
- )
319
- self.rotary_embeddings = None
320
- if self.position_embedding_type == "rotary":
321
- self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
322
-
323
- def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
324
- return rearrange(x, 'b s (h d) -> b h s d', h=self.num_attention_heads)
325
-
326
- def forward(
327
- self,
328
- hidden_states: torch.Tensor,
329
- attention_mask: Optional[torch.Tensor] = None,
330
- flex_block_mask: Optional[object] = None,
331
- output_attentions: Optional[bool] = False,
332
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
333
- """Forward pass for self attention.
334
-
335
- Args:
336
- hidden_states: Input tensor
337
- attention_mask: Optional attention mask
338
- output_attentions: Whether to return attention weights
339
-
340
- Returns:
341
- Output tensor and optionally attention weights
342
- """
343
- query_layer = self.transpose_for_scores(self.query(hidden_states)) * self.scale
344
- key_layer = self.transpose_for_scores(self.key(hidden_states))
345
- value_layer = self.transpose_for_scores(self.value(hidden_states))
346
-
347
- if self.position_embedding_type == "rotary":
348
- query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
349
-
350
- if output_attentions:
351
- # Manual attention computation - apply scaling here
352
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
353
- if attention_mask is not None:
354
- attention_scores = attention_scores.masked_fill(attention_mask.logical_not(), float("-inf"))
355
- attention_probs = F.softmax(attention_scores, dim=-1)
356
- if self.dropout_prob > 0:
357
- attention_probs = F.dropout(attention_probs, p=self.dropout_prob, training=self.training)
358
- context_layer = torch.matmul(attention_probs, value_layer)
359
- context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
360
- return context_layer, attention_probs
361
- else:
362
- sdpa_mask = None
363
- if attention_mask is not None:
364
- sdpa_mask = torch.zeros_like(attention_mask, dtype=query_layer.dtype)
365
- sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
366
- flex_dtype_supported = query_layer.dtype in (torch.float16, torch.bfloat16)
367
- use_flex = (
368
- self.attn_backend == "flex"
369
- and flex_attention is not None
370
- and flex_dtype_supported
371
- and (attention_mask is None or flex_block_mask is not None)
372
- )
373
- if use_flex:
374
- try:
375
- context_layer = flex_attention(
376
- query_layer,
377
- key_layer,
378
- value_layer,
379
- block_mask=flex_block_mask,
380
- scale=1.0,
381
- )
382
- except Exception:
383
- context_layer = F.scaled_dot_product_attention(
384
- query_layer,
385
- key_layer,
386
- value_layer,
387
- attn_mask=sdpa_mask,
388
- dropout_p=self.dropout_prob,
389
- scale=1.0,
390
- )
391
- else:
392
- context_layer = F.scaled_dot_product_attention(
393
- query_layer,
394
- key_layer,
395
- value_layer,
396
- attn_mask=sdpa_mask,
397
- dropout_p=self.dropout_prob,
398
- scale=1.0
399
- )
400
- context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
401
- return context_layer
402
-
403
-
404
- class EsmAttention(nn.Module):
405
- def __init__(self, config):
406
- super().__init__()
407
- self.self = EsmSelfAttention(config)
408
- self.output = EsmSelfOutput(config)
409
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
410
-
411
- def forward(
412
- self,
413
- hidden_states: torch.Tensor,
414
- attention_mask: Optional[torch.Tensor] = None,
415
- flex_block_mask: Optional[object] = None,
416
- output_attentions: Optional[bool] = False,
417
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
418
- """Forward pass for attention layer.
419
-
420
- Args:
421
- hidden_states: Input tensor
422
- attention_mask: Optional attention mask
423
- output_attentions: Whether to return attention weights
424
-
425
- Returns:
426
- Output tensor and optionally attention weights
427
- """
428
- hidden_states_ln = self.LayerNorm(hidden_states)
429
- self_outputs = self.self(
430
- hidden_states_ln,
431
- attention_mask,
432
- flex_block_mask,
433
- output_attentions,
434
- )
435
- if output_attentions:
436
- attention_output, attention_weights = self_outputs
437
- attention_output = self.output(attention_output, hidden_states)
438
- return attention_output, attention_weights
439
- else:
440
- attention_output = self_outputs
441
- return self.output(attention_output, hidden_states)
442
-
443
-
444
- class EsmLayer(nn.Module):
445
- def __init__(self, config):
446
- super().__init__()
447
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
448
- self.seq_len_dim = 1
449
- self.attention = EsmAttention(config)
450
- self.intermediate = EsmIntermediate(config)
451
- self.output = EsmOutput(config)
452
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
453
-
454
- def forward(
455
- self,
456
- hidden_states: torch.Tensor,
457
- attention_mask: Optional[torch.Tensor] = None,
458
- flex_block_mask: Optional[object] = None,
459
- output_attentions: Optional[bool] = False,
460
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
461
- """Forward pass for transformer layer.
462
-
463
- Args:
464
- hidden_states: Input tensor
465
- attention_mask: Optional attention mask
466
- output_attentions: Whether to return attention weights
467
-
468
- Returns:
469
- Output tensor and optionally attention weights
470
- """
471
- attention_outputs = self.attention(
472
- hidden_states,
473
- attention_mask,
474
- flex_block_mask,
475
- output_attentions,
476
- )
477
- if output_attentions:
478
- attention_output, attention_weights = attention_outputs
479
- else:
480
- attention_output = attention_outputs
481
- attention_weights = None
482
-
483
- layer_output = self.feed_forward_chunk(attention_output)
484
-
485
- if output_attentions:
486
- return layer_output, attention_weights
487
- return layer_output
488
-
489
- def feed_forward_chunk(self, attention_output):
490
- attention_output_ln = self.LayerNorm(attention_output)
491
- intermediate_output = self.intermediate(attention_output_ln)
492
- layer_output = self.output(intermediate_output, attention_output)
493
- return layer_output
494
-
495
-
496
- class EsmEncoder(nn.Module):
497
- def __init__(self, config):
498
- super().__init__()
499
- self.config = config
500
- self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
501
- self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
502
- self.gradient_checkpointing = False
503
-
504
- def forward(
505
- self,
506
- hidden_states: torch.Tensor,
507
- attention_mask: Optional[torch.Tensor] = None,
508
- flex_block_mask: Optional[object] = None,
509
- output_hidden_states: Optional[bool] = False,
510
- output_attentions: Optional[bool] = False,
511
- ) -> BaseModelOutputWithPastAndCrossAttentions:
512
- """Forward pass for transformer encoder.
513
-
514
- Args:
515
- hidden_states: Input tensor
516
- attention_mask: Optional attention mask
517
- output_hidden_states: Whether to return all hidden states
518
- output_attentions: Whether to return attention weights
519
-
520
- Returns:
521
- BaseModelOutputWithPastAndCrossAttentions containing model outputs
522
- """
523
- all_hidden_states = () if output_hidden_states else None
524
- all_attentions = () if output_attentions else None
525
-
526
- for layer_module in self.layer:
527
- if output_hidden_states:
528
- all_hidden_states = all_hidden_states + (hidden_states,)
529
-
530
- if self.gradient_checkpointing and self.training:
531
- layer_outputs = self._gradient_checkpointing_func(
532
- layer_module.__call__,
533
- hidden_states,
534
- attention_mask,
535
- flex_block_mask,
536
- output_attentions,
537
- )
538
- else:
539
- layer_outputs = layer_module(
540
- hidden_states,
541
- attention_mask,
542
- flex_block_mask,
543
- output_attentions,
544
- )
545
-
546
- if output_attentions:
547
- hidden_states, attention_weights = layer_outputs
548
- all_attentions = all_attentions + (attention_weights,)
549
- else:
550
- hidden_states = layer_outputs
551
-
552
- if self.emb_layer_norm_after:
553
- hidden_states = self.emb_layer_norm_after(hidden_states)
554
-
555
- if output_hidden_states:
556
- all_hidden_states = all_hidden_states + (hidden_states,)
557
-
558
- return BaseModelOutputWithPastAndCrossAttentions(
559
- last_hidden_state=hidden_states,
560
- hidden_states=all_hidden_states,
561
- attentions=all_attentions,
562
- )
563
-
564
-
565
- class FastEsmPreTrainedModel(PreTrainedModel):
566
- """
567
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
568
- models.
569
- """
570
- config_class = FastEsmConfig
571
- base_model_prefix = "fastesm"
572
- supports_gradient_checkpointing = True
573
- tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
574
- all_tied_weights_keys = {}
575
-
576
- def _init_weights(self, module):
577
- """Initialize the weights"""
578
- if isinstance(module, nn.Linear):
579
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
580
- if module.bias is not None:
581
- module.bias.data.zero_()
582
- elif isinstance(module, nn.Embedding):
583
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
584
- if module.padding_idx is not None:
585
- module.weight.data[module.padding_idx].zero_()
586
- elif isinstance(module, nn.LayerNorm):
587
- if module.bias is not None:
588
- module.bias.data.zero_()
589
- module.weight.data.fill_(1.0)
590
-
591
- def get_input_embeddings(self) -> nn.Module:
592
- try:
593
- return self.embeddings.word_embeddings
594
- except AttributeError:
595
- return self.esm.embeddings.word_embeddings
596
-
597
-
598
- class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
599
- def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
600
- FastEsmPreTrainedModel.__init__(self, config, **kwargs)
601
- self.config = config
602
- self.embeddings = EsmEmbeddings(config)
603
- self.encoder = EsmEncoder(config)
604
- self.contact_head = EsmContactPredictionHead(
605
- in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
606
- )
607
- # Initialize weights and apply final processing
608
- self.post_init()
609
-
610
- def get_input_embeddings(self):
611
- return self.embeddings.word_embeddings
612
-
613
- def set_input_embeddings(self, value):
614
- self.embeddings.word_embeddings = value
615
-
616
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
617
- token_embedding_output = self.embeddings(input_ids, attention_mask=attention_mask)
618
- batch_size, seq_length = input_ids.shape
619
- flex_block_mask = None
620
- if attention_mask is not None:
621
- extended_attention_mask = attention_mask[:, None, None, :].expand(
622
- batch_size, 1, seq_length, seq_length
623
- ).bool()
624
- if self.config.attn_backend == "flex" and create_block_mask is not None:
625
- flex_block_mask = _create_pad_block_mask(attention_mask.bool())
626
- else:
627
- extended_attention_mask = None
628
- encoder_outputs = self.encoder(
629
- token_embedding_output,
630
- attention_mask=extended_attention_mask,
631
- flex_block_mask=flex_block_mask,
632
- output_hidden_states=False,
633
- output_attentions=False,
634
- )
635
- return encoder_outputs.last_hidden_state
636
-
637
- def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
638
- attns = self(input_ids, attention_mask=attention_mask, output_attentions=True).attentions
639
- attns = torch.stack(attns, dim=1)
640
- attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
641
- attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
642
- return self.contact_head(input_ids, attns)
643
-
644
- def forward(
645
- self,
646
- input_ids: Optional[torch.Tensor] = None,
647
- attention_mask: Optional[torch.Tensor] = None,
648
- position_ids: Optional[torch.Tensor] = None,
649
- inputs_embeds: Optional[torch.Tensor] = None,
650
- output_attentions: Optional[bool] = None,
651
- output_hidden_states: Optional[bool] = None,
652
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
653
- ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
654
- """Forward pass for base model.
655
-
656
- Args:
657
- input_ids: Input token IDs
658
- attention_mask: Optional attention mask
659
- position_ids: Optional position IDs
660
- inputs_embeds: Optional input embeddings
661
- output_hidden_states: Whether to return all hidden states
662
- output_attentions: Whether to return attention weights
663
-
664
- Returns:
665
- Model outputs including hidden states and optionally attention weights
666
- """
667
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
668
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
669
-
670
- if input_ids is not None and inputs_embeds is not None:
671
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
672
- elif input_ids is not None:
673
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
674
- input_shape = input_ids.size()
675
- elif inputs_embeds is not None:
676
- input_shape = inputs_embeds.size()[:-1]
677
- else:
678
- raise ValueError("You have to specify either input_ids or inputs_embeds")
679
-
680
- batch_size, seq_length = input_shape
681
- token_embedding_output = self.embeddings(
682
- input_ids=input_ids,
683
- position_ids=position_ids,
684
- attention_mask=attention_mask,
685
- inputs_embeds=inputs_embeds,
686
- )
687
-
688
- flex_block_mask = None
689
- if attention_mask is not None:
690
- extended_attention_mask = attention_mask[:, None, None, :].expand(
691
- batch_size, 1, seq_length, seq_length
692
- ).bool()
693
- if (
694
- self.config.attn_backend == "flex"
695
- and not output_attentions
696
- and create_block_mask is not None
697
- ):
698
- flex_block_mask = _create_pad_block_mask(attention_mask.bool())
699
- else:
700
- extended_attention_mask = None
701
-
702
- encoder_outputs = self.encoder(
703
- token_embedding_output,
704
- attention_mask=extended_attention_mask,
705
- flex_block_mask=flex_block_mask,
706
- output_hidden_states=output_hidden_states,
707
- output_attentions=output_attentions,
708
- )
709
- sequence_output = encoder_outputs.last_hidden_state
710
-
711
- return BaseModelOutputWithPoolingAndCrossAttentions(
712
- last_hidden_state=sequence_output,
713
- hidden_states=encoder_outputs.hidden_states,
714
- attentions=encoder_outputs.attentions,
715
- )
716
-
717
-
718
- class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
719
- def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
720
- FastEsmPreTrainedModel.__init__(self, config, **kwargs)
721
- self.config = config
722
- self.esm = FAST_ESM_ENCODER(config)
723
- self.pooler = EsmPooler(config) if add_pooling_layer else None
724
- # Initialize weights and apply final processing
725
- self.post_init()
726
-
727
- def get_input_embeddings(self):
728
- return self.embeddings.word_embeddings
729
-
730
- def set_input_embeddings(self, value):
731
- self.embeddings.word_embeddings = value
732
-
733
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
734
- return self.esm._embed(input_ids, attention_mask)
735
-
736
- def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
737
- return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
738
-
739
- def forward(
740
- self,
741
- input_ids: Optional[torch.Tensor] = None,
742
- attention_mask: Optional[torch.Tensor] = None,
743
- position_ids: Optional[torch.Tensor] = None,
744
- inputs_embeds: Optional[torch.Tensor] = None,
745
- output_attentions: Optional[bool] = None,
746
- output_hidden_states: Optional[bool] = None,
747
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
748
- **kwargs,
749
- ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
750
- """Forward pass for base model.
751
-
752
- Args:
753
- input_ids: Input token IDs
754
- attention_mask: Optional attention mask
755
- position_ids: Optional position IDs
756
- inputs_embeds: Optional input embeddings
757
- output_hidden_states: Whether to return all hidden states
758
- output_attentions: Whether to return attention weights
759
-
760
- Returns:
761
- Model outputs including hidden states and optionally attention weights
762
- """
763
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
764
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
765
-
766
- if input_ids is not None and inputs_embeds is not None:
767
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
768
- elif input_ids is not None:
769
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
770
- input_shape = input_ids.size()
771
- elif inputs_embeds is not None:
772
- input_shape = inputs_embeds.size()[:-1]
773
- else:
774
- raise ValueError("You have to specify either input_ids or inputs_embeds")
775
-
776
- outputs = self.esm(
777
- input_ids,
778
- attention_mask=attention_mask,
779
- position_ids=position_ids,
780
- inputs_embeds=inputs_embeds,
781
- output_hidden_states=output_hidden_states,
782
- output_attentions=output_attentions,
783
- )
784
- sequence_output = outputs.last_hidden_state
785
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
786
-
787
- return BaseModelOutputWithPoolingAndCrossAttentions(
788
- last_hidden_state=sequence_output,
789
- pooler_output=pooled_output,
790
- hidden_states=outputs.hidden_states,
791
- attentions=outputs.attentions,
792
- )
793
-
794
-
795
- class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
796
- def __init__(self, config, **kwargs):
797
- FastEsmPreTrainedModel.__init__(self, config, **kwargs)
798
- self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
799
- self.lm_head = EsmLMHead(config)
800
- self.loss_fct = nn.CrossEntropyLoss()
801
- self.init_weights()
802
-
803
- def get_output_embeddings(self):
804
- return self.lm_head.decoder
805
-
806
- def set_output_embeddings(self, new_embeddings):
807
- self.lm_head.decoder = new_embeddings
808
-
809
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
810
- return self.esm._embed(input_ids, attention_mask)
811
-
812
- def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
813
- return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
814
-
815
- def forward(
816
- self,
817
- input_ids: Optional[torch.Tensor] = None,
818
- attention_mask: Optional[torch.Tensor] = None,
819
- position_ids: Optional[torch.Tensor] = None,
820
- inputs_embeds: Optional[torch.Tensor] = None,
821
- labels: Optional[torch.Tensor] = None,
822
- output_attentions: Optional[bool] = None,
823
- output_hidden_states: Optional[bool] = None,
824
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
825
- **kwargs,
826
- ) -> Union[Tuple, EsmMaskedLMOutput]:
827
- outputs = self.esm(
828
- input_ids,
829
- attention_mask=attention_mask,
830
- position_ids=position_ids,
831
- inputs_embeds=inputs_embeds,
832
- output_hidden_states=output_hidden_states,
833
- output_attentions=output_attentions,
834
- )
835
- sequence_output = outputs.last_hidden_state
836
- prediction_scores = self.lm_head(sequence_output)
837
-
838
- loss = None
839
- if labels is not None:
840
- labels = labels.to(prediction_scores.device)
841
- loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
842
-
843
- return EsmMaskedLMOutput(
844
- loss=loss,
845
- logits=prediction_scores,
846
- last_hidden_state=sequence_output,
847
- hidden_states=outputs.hidden_states,
848
- attentions=outputs.attentions,
849
- )
850
-
851
-
852
- class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
853
- def __init__(self, config, **kwargs):
854
- FastEsmPreTrainedModel.__init__(self, config, **kwargs)
855
- self.num_labels = config.num_labels
856
- self.config = config
857
- self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
858
- self.classifier = EsmClassificationHead(config)
859
- self.mse = nn.MSELoss()
860
- self.ce = nn.CrossEntropyLoss()
861
- self.bce = nn.BCEWithLogitsLoss()
862
- self.init_weights()
863
-
864
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
865
- return self.esm._embed(input_ids, attention_mask)
866
-
867
- def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
868
- return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
869
-
870
- def forward(
871
- self,
872
- input_ids: Optional[torch.Tensor] = None,
873
- attention_mask: Optional[torch.Tensor] = None,
874
- position_ids: Optional[torch.Tensor] = None,
875
- inputs_embeds: Optional[torch.Tensor] = None,
876
- labels: Optional[torch.Tensor] = None,
877
- output_attentions: Optional[bool] = None,
878
- output_hidden_states: Optional[bool] = None,
879
- return_dict: Optional[bool] = None,
880
- **kwargs,
881
- ) -> Union[Tuple, SequenceClassifierOutput]:
882
- outputs = self.esm(
883
- input_ids,
884
- attention_mask=attention_mask,
885
- position_ids=position_ids,
886
- inputs_embeds=inputs_embeds,
887
- output_attentions=output_attentions,
888
- output_hidden_states=output_hidden_states,
889
- )
890
- sequence_output = outputs.last_hidden_state
891
- logits = self.classifier(sequence_output)
892
-
893
- loss = None
894
- if labels is not None:
895
- labels = labels.to(logits.device)
896
- if self.config.problem_type is None:
897
- if self.num_labels == 1:
898
- self.config.problem_type = "regression"
899
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
900
- self.config.problem_type = "single_label_classification"
901
- else:
902
- self.config.problem_type = "multi_label_classification"
903
-
904
- if self.config.problem_type == "regression":
905
- if self.num_labels == 1:
906
- loss = self.mse(logits.squeeze(), labels.squeeze())
907
- else:
908
- loss = self.mse(logits, labels)
909
- elif self.config.problem_type == "single_label_classification":
910
- loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
911
- elif self.config.problem_type == "multi_label_classification":
912
- loss = self.bce(logits, labels)
913
-
914
- return SequenceClassifierOutput(
915
- loss=loss,
916
- logits=logits,
917
- hidden_states=outputs.hidden_states,
918
- attentions=outputs.attentions,
919
- )
920
-
921
-
922
- class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin):
923
- def __init__(self, config, **kwargs):
924
- FastEsmPreTrainedModel.__init__(self, config, **kwargs)
925
- self.num_labels = config.num_labels
926
- self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
927
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
928
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
929
- self.loss_fct = nn.CrossEntropyLoss()
930
- self.init_weights()
931
-
932
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
933
- return self.esm._embed(input_ids, attention_mask)
934
-
935
- def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
936
- return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
937
-
938
- def forward(
939
- self,
940
- input_ids: Optional[torch.Tensor] = None,
941
- attention_mask: Optional[torch.Tensor] = None,
942
- position_ids: Optional[torch.Tensor] = None,
943
- inputs_embeds: Optional[torch.Tensor] = None,
944
- labels: Optional[torch.Tensor] = None,
945
- output_attentions: Optional[bool] = None,
946
- output_hidden_states: Optional[bool] = None,
947
- return_dict: Optional[bool] = None,
948
- **kwargs,
949
- ) -> Union[Tuple, TokenClassifierOutput]:
950
- outputs = self.esm(
951
- input_ids,
952
- attention_mask=attention_mask,
953
- position_ids=position_ids,
954
- inputs_embeds=inputs_embeds,
955
- output_attentions=output_attentions,
956
- output_hidden_states=output_hidden_states,
957
- )
958
- sequence_output = outputs.last_hidden_state
959
- sequence_output = self.dropout(sequence_output)
960
- logits = self.classifier(sequence_output)
961
-
962
- loss = None
963
- if labels is not None:
964
- labels = labels.to(logits.device)
965
- loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
966
-
967
- return TokenClassifierOutput(
968
- loss=loss,
969
- logits=logits,
970
- hidden_states=outputs.hidden_states,
971
- attentions=outputs.attentions,
972
- )
973
-
974
-
975
- if __name__ == "__main__":
976
- """
977
- Test the hidden state differences between the FastEsmModel and the HF EsmModel.
978
- In full precision, the differences are very very small, but nonzero due to floating point issues with F.scaled_dot_product_attention.
979
- In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
980
- """
981
- import random
982
- from transformers import EsmForMaskedLM as TransformersEsmModel, EsmTokenizer
983
-
984
- model_paths = [
985
- "facebook/esm2_t6_8M_UR50D",
986
- "facebook/esm2_t12_35M_UR50D",
987
- #"facebook/esm2_t30_150M_UR50D",
988
- #"facebook/esm2_t33_650M_UR50D",
989
- ]
990
- canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
991
- length = 64
992
- seq_count = 100
993
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
994
- tolerances = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
995
-
996
- def generate_random_sequence(length: int) -> str:
997
- return 'M' + "".join(random.choices(canonical_amino_acids, k=length))
998
-
999
- print("Percentage of hidden states that are within the tolerance:")
1000
- for model_path in model_paths:
1001
- print(f"Testing {model_path}...")
1002
- tokenizer = EsmTokenizer.from_pretrained(model_path)
1003
- config = FastEsmConfig.from_pretrained(model_path)
1004
- fast_model = FastEsmForMaskedLM(config).from_pretrained(model_path).to(device)
1005
- print('fast model')
1006
- print(fast_model)
1007
- model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
1008
- print('transformers model')
1009
- print(model)
1010
-
1011
- counts = [0] * len(tolerances)
1012
- for _ in range(seq_count):
1013
- example_seq = generate_random_sequence(length)
1014
- fast_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
1015
- fast_output = fast_model(fast_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
1016
-
1017
- model_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
1018
- model_output = model(model_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
1019
-
1020
- for i, atol in enumerate(tolerances):
1021
- if torch.allclose(fast_output, model_output, atol=atol):
1022
- counts[i] += 1
1023
-
1024
- print(f"{model_path}:")
1025
- for i, atol in enumerate(tolerances):
1026
- print(f" tolerance={atol}: {counts[i] / seq_count * 100}%")
1027
-
1028
- model.cpu()
1029
- fast_model.cpu()
1030
- del model
1031
- del fast_model
1032
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
1
+ import entrypoint_setup
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from typing import Optional, Tuple, Union, Dict, Any
6
+ from einops import rearrange
7
+ from dataclasses import dataclass
8
+ from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
9
+ from transformers.modeling_outputs import (
10
+ ModelOutput,
11
+ BaseModelOutputWithPastAndCrossAttentions,
12
+ BaseModelOutputWithPoolingAndCrossAttentions,
13
+ SequenceClassifierOutput,
14
+ TokenClassifierOutput
15
+ )
16
+ from transformers.models.esm.modeling_esm import (
17
+ EsmIntermediate,
18
+ EsmOutput,
19
+ EsmPooler,
20
+ EsmLMHead,
21
+ EsmSelfOutput,
22
+ EsmClassificationHead,
23
+ )
24
+ try:
25
+ from torch.nn.attention.flex_attention import create_block_mask
26
+ from torch.nn.attention.flex_attention import flex_attention
27
+ except ImportError:
28
+ create_block_mask = None
29
+ flex_attention = None
30
+
31
+ try:
32
+ # when used from AutoModel, these are in the same directory
33
+ from .embedding_mixin import EmbeddingMixin
34
+ except:
35
+ try:
36
+ # whem importing as a submodule, embedding mixin is in the FastPLMs directory
37
+ from ..embedding_mixin import EmbeddingMixin
38
+ except:
39
+ # when running from our repo, these are in the base directory
40
+ from embedding_mixin import EmbeddingMixin
41
+
42
+
43
+ def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
44
+ assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
45
+ token_valid = attention_mask_2d.bool()
46
+ batch_size, seq_len = token_valid.shape
47
+
48
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
49
+ return token_valid[batch_idx, q_idx] & token_valid[batch_idx, kv_idx]
50
+
51
+ return create_block_mask(
52
+ mask_mod,
53
+ batch_size,
54
+ 1,
55
+ seq_len,
56
+ seq_len,
57
+ device=attention_mask_2d.device,
58
+ )
59
+
60
+
61
+ @dataclass
62
+ class EsmMaskedLMOutput(ModelOutput):
63
+ loss: Optional[torch.Tensor] = None
64
+ logits: Optional[torch.Tensor] = None
65
+ last_hidden_state: Optional[torch.Tensor] = None
66
+ hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
67
+ attentions: Optional[Tuple[torch.Tensor, ...]] = None
68
+
69
+
70
+ class FastEsmConfig(PretrainedConfig):
71
+ model_type = "fast_esm"
72
+ def __init__(
73
+ self,
74
+ vocab_size: int = None,
75
+ mask_token_id: int = None,
76
+ pad_token_id: int = None,
77
+ hidden_size: int = 768,
78
+ num_hidden_layers: int = 12,
79
+ num_attention_heads: int = 12,
80
+ intermediate_size: int = 3072,
81
+ hidden_dropout_prob: float = 0.1,
82
+ attention_probs_dropout_prob: float = 0.1,
83
+ max_position_embeddings: int = 1026,
84
+ initializer_range: float = 0.02,
85
+ layer_norm_eps: float = 1e-12,
86
+ position_embedding_type: str = "absolute",
87
+ emb_layer_norm_before: bool = None,
88
+ token_dropout: bool = True,
89
+ attn_backend: str = "sdpa",
90
+ **kwargs,
91
+ ):
92
+ super().__init__(
93
+ pad_token_id=pad_token_id,
94
+ mask_token_id=mask_token_id,
95
+ **kwargs,
96
+ )
97
+
98
+ self.vocab_size = vocab_size
99
+ self.hidden_size = hidden_size
100
+ self.num_hidden_layers = num_hidden_layers
101
+ self.num_attention_heads = num_attention_heads
102
+ self.intermediate_size = intermediate_size
103
+ self.hidden_dropout_prob = hidden_dropout_prob
104
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
105
+ self.max_position_embeddings = max_position_embeddings
106
+ self.initializer_range = initializer_range
107
+ self.layer_norm_eps = layer_norm_eps
108
+ self.position_embedding_type = position_embedding_type
109
+ self.emb_layer_norm_before = emb_layer_norm_before
110
+ self.tie_word_embeddings = False
111
+ self.token_dropout = token_dropout
112
+ self.attn_backend = attn_backend
113
+
114
+ def to_dict(self) -> Dict[str, Any]:
115
+ """
116
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
117
+
118
+ Returns:
119
+ `Dict[str, any]`: Dictionar y of all the attributes that make up this configuration instance,
120
+ """
121
+ output = super().to_dict()
122
+ return output
123
+
124
+
125
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
126
+ x1, x2 = x.chunk(2, dim=-1)
127
+ return torch.cat((-x2, x1), dim=-1)
128
+
129
+
130
+ def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
131
+ cos = cos[:, :, : x.shape[-2], :]
132
+ sin = sin[:, :, : x.shape[-2], :]
133
+
134
+ return (x * cos) + (rotate_half(x) * sin)
135
+
136
+
137
+ def symmetrize(x: torch.Tensor) -> torch.Tensor:
138
+ "Make layer symmetric in final two dimensions, used for contact prediction."
139
+ return x + x.transpose(-1, -2)
140
+
141
+
142
+ def average_product_correct(x: torch.Tensor) -> torch.Tensor:
143
+ "Perform average product correct, used for contact prediction."
144
+ a1 = x.sum(-1, keepdims=True)
145
+ a2 = x.sum(-2, keepdims=True)
146
+ a12 = x.sum((-1, -2), keepdims=True)
147
+
148
+ avg = a1 * a2
149
+ avg.div_(a12) # in-place to reduce memory
150
+ normalized = x - avg
151
+ return normalized
152
+
153
+
154
+ class EsmContactPredictionHead(nn.Module):
155
+ """Performs symmetrization, apc, and computes a logistic regression on the output features"""
156
+
157
+ def __init__(
158
+ self,
159
+ in_features: int,
160
+ bias: bool = True,
161
+ eos_idx: int = 2,
162
+ ):
163
+ super().__init__()
164
+ self.in_features = in_features
165
+ self.eos_idx = eos_idx
166
+ self.regression = nn.Linear(in_features, 1, bias=bias)
167
+ self.activation = nn.Sigmoid()
168
+
169
+ def forward(self, input_ids: torch.Tensor, attentions: torch.Tensor) -> torch.Tensor:
170
+ # remove eos token attentions
171
+ eos_mask = input_ids.ne(self.eos_idx).to(attentions)
172
+ eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
173
+ attentions = attentions * eos_mask[:, None, None, :, :]
174
+ attentions = attentions[..., :-1, :-1]
175
+ # remove cls token attentions
176
+ attentions = attentions[..., 1:, 1:]
177
+ batch_size, layers, heads, seqlen, _ = attentions.size()
178
+ attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
179
+
180
+ # features: batch x channels x tokens x tokens (symmetric)
181
+ attentions = attentions.to(
182
+ self.regression.weight.device
183
+ ) # attentions always float32, may need to convert to float16
184
+ attentions = average_product_correct(symmetrize(attentions))
185
+ attentions = attentions.permute(0, 2, 3, 1)
186
+ return self.activation(self.regression(attentions).squeeze(3))
187
+
188
+
189
+ class RotaryEmbedding(torch.nn.Module):
190
+ """
191
+ Rotary position embeddings based on those in
192
+ [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
193
+ matrices which depend on their relative positions.
194
+ """
195
+
196
+ def __init__(self, dim: int):
197
+ super().__init__()
198
+ # Generate and save the inverse frequency buffer (non trainable)
199
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
200
+ inv_freq = inv_freq
201
+ self.register_buffer("inv_freq", inv_freq)
202
+
203
+ self._seq_len_cached = None
204
+ self._cos_cached = None
205
+ self._sin_cached = None
206
+
207
+ def _update_cos_sin_tables(self, x: torch.Tensor, seq_dimension: int = 2) -> Tuple[torch.Tensor, torch.Tensor]:
208
+ seq_len = x.shape[seq_dimension]
209
+
210
+ # Reset the tables if the sequence length has changed,
211
+ # or if we're on a new device (possibly due to tracing for instance)
212
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
213
+ self._seq_len_cached = seq_len
214
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
215
+ freqs = torch.outer(t, self.inv_freq)
216
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
217
+
218
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
219
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
220
+
221
+ return self._cos_cached, self._sin_cached
222
+
223
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
224
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
225
+
226
+ return (
227
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
228
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
229
+ )
230
+
231
+
232
+ class EsmEmbeddings(nn.Module):
233
+ """
234
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
235
+ """
236
+
237
+ def __init__(self, config):
238
+ super().__init__()
239
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
240
+ if config.emb_layer_norm_before:
241
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
242
+ else:
243
+ self.layer_norm = None
244
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
245
+ self.register_buffer(
246
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
247
+ )
248
+ self.token_dropout = config.token_dropout
249
+ self.mask_token_id = config.mask_token_id
250
+
251
+ def forward(
252
+ self,
253
+ input_ids: Optional[torch.Tensor] = None,
254
+ attention_mask: Optional[torch.Tensor] = None,
255
+ position_ids: Optional[torch.Tensor] = None,
256
+ inputs_embeds: Optional[torch.Tensor] = None,
257
+ past_key_values_length: Optional[int] = 0,
258
+ ):
259
+ if inputs_embeds is None:
260
+ inputs_embeds = self.word_embeddings(input_ids)
261
+
262
+ embeddings = inputs_embeds
263
+
264
+ if attention_mask is None:
265
+ attention_mask = torch.ones_like(input_ids)
266
+
267
+ if self.token_dropout:
268
+ embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0)
269
+ mask_ratio_train = 0.15 * 0.8
270
+ src_lengths = attention_mask.sum(-1)
271
+ mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
272
+ embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
273
+ embeddings.dtype
274
+ )
275
+
276
+ if self.layer_norm is not None:
277
+ embeddings = self.layer_norm(embeddings)
278
+ if attention_mask is not None:
279
+ embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
280
+ return embeddings
281
+
282
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
283
+ """
284
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
285
+
286
+ Args:
287
+ inputs_embeds: torch.Tensor
288
+
289
+ Returns: torch.Tensor
290
+ """
291
+ input_shape = inputs_embeds.size()[:-1]
292
+ sequence_length = input_shape[1]
293
+
294
+ position_ids = torch.arange(
295
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
296
+ )
297
+ return position_ids.unsqueeze(0).expand(input_shape)
298
+
299
+
300
+ class EsmSelfAttention(nn.Module):
301
+ def __init__(self, config, position_embedding_type: Optional[str] = None):
302
+ super().__init__()
303
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
304
+ raise ValueError(
305
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
306
+ f"heads ({config.num_attention_heads})"
307
+ )
308
+
309
+ self.num_attention_heads = config.num_attention_heads
310
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
311
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
312
+
313
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
314
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
315
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
316
+ self.scale = self.attention_head_size**-0.5
317
+
318
+ self.dropout_prob = config.attention_probs_dropout_prob
319
+ self.attn_backend = config.attn_backend
320
+ self.position_embedding_type = position_embedding_type or getattr(
321
+ config, "position_embedding_type", "absolute"
322
+ )
323
+ self.rotary_embeddings = None
324
+ if self.position_embedding_type == "rotary":
325
+ self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
326
+
327
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
328
+ return rearrange(x, 'b s (h d) -> b h s d', h=self.num_attention_heads)
329
+
330
+ def forward(
331
+ self,
332
+ hidden_states: torch.Tensor,
333
+ attention_mask: Optional[torch.Tensor] = None,
334
+ flex_block_mask: Optional[object] = None,
335
+ output_attentions: Optional[bool] = False,
336
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
337
+ """Forward pass for self attention.
338
+
339
+ Args:
340
+ hidden_states: Input tensor
341
+ attention_mask: Optional attention mask
342
+ output_attentions: Whether to return attention weights
343
+
344
+ Returns:
345
+ Output tensor and optionally attention weights
346
+ """
347
+ query_layer = self.transpose_for_scores(self.query(hidden_states)) * self.scale
348
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
349
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
350
+
351
+ if self.position_embedding_type == "rotary":
352
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
353
+
354
+ if output_attentions:
355
+ # Manual attention computation - apply scaling here
356
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
357
+ if attention_mask is not None:
358
+ attention_scores = attention_scores.masked_fill(attention_mask.logical_not(), float("-inf"))
359
+ attention_probs = F.softmax(attention_scores, dim=-1)
360
+ if self.dropout_prob > 0:
361
+ attention_probs = F.dropout(attention_probs, p=self.dropout_prob, training=self.training)
362
+ context_layer = torch.matmul(attention_probs, value_layer)
363
+ context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
364
+ return context_layer, attention_probs
365
+ else:
366
+ if self.attn_backend == "flex":
367
+ assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
368
+ assert query_layer.dtype in (torch.float16, torch.bfloat16), (
369
+ f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}."
370
+ )
371
+ if attention_mask is not None:
372
+ assert flex_block_mask is not None, (
373
+ "Flex attention backend requires a block mask when attention_mask is provided."
374
+ )
375
+ context_layer = flex_attention(
376
+ query_layer,
377
+ key_layer,
378
+ value_layer,
379
+ block_mask=flex_block_mask,
380
+ scale=1.0,
381
+ )
382
+ else:
383
+ sdpa_mask = None
384
+ if attention_mask is not None:
385
+ sdpa_mask = torch.zeros_like(attention_mask, dtype=query_layer.dtype)
386
+ sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
387
+ context_layer = F.scaled_dot_product_attention(
388
+ query_layer,
389
+ key_layer,
390
+ value_layer,
391
+ attn_mask=sdpa_mask,
392
+ dropout_p=self.dropout_prob if self.training else 0.0,
393
+ scale=1.0
394
+ )
395
+ context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
396
+ return context_layer
397
+
398
+
399
+ class EsmAttention(nn.Module):
400
+ def __init__(self, config):
401
+ super().__init__()
402
+ self.self = EsmSelfAttention(config)
403
+ self.output = EsmSelfOutput(config)
404
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
405
+
406
+ def forward(
407
+ self,
408
+ hidden_states: torch.Tensor,
409
+ attention_mask: Optional[torch.Tensor] = None,
410
+ flex_block_mask: Optional[object] = None,
411
+ output_attentions: Optional[bool] = False,
412
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
413
+ """Forward pass for attention layer.
414
+
415
+ Args:
416
+ hidden_states: Input tensor
417
+ attention_mask: Optional attention mask
418
+ output_attentions: Whether to return attention weights
419
+
420
+ Returns:
421
+ Output tensor and optionally attention weights
422
+ """
423
+ hidden_states_ln = self.LayerNorm(hidden_states)
424
+ self_outputs = self.self(
425
+ hidden_states_ln,
426
+ attention_mask,
427
+ flex_block_mask,
428
+ output_attentions,
429
+ )
430
+ if output_attentions:
431
+ attention_output, attention_weights = self_outputs
432
+ attention_output = self.output(attention_output, hidden_states)
433
+ return attention_output, attention_weights
434
+ else:
435
+ attention_output = self_outputs
436
+ return self.output(attention_output, hidden_states)
437
+
438
+
439
+ class EsmLayer(nn.Module):
440
+ def __init__(self, config):
441
+ super().__init__()
442
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
443
+ self.seq_len_dim = 1
444
+ self.attention = EsmAttention(config)
445
+ self.intermediate = EsmIntermediate(config)
446
+ self.output = EsmOutput(config)
447
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
448
+
449
+ def forward(
450
+ self,
451
+ hidden_states: torch.Tensor,
452
+ attention_mask: Optional[torch.Tensor] = None,
453
+ flex_block_mask: Optional[object] = None,
454
+ output_attentions: Optional[bool] = False,
455
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
456
+ """Forward pass for transformer layer.
457
+
458
+ Args:
459
+ hidden_states: Input tensor
460
+ attention_mask: Optional attention mask
461
+ output_attentions: Whether to return attention weights
462
+
463
+ Returns:
464
+ Output tensor and optionally attention weights
465
+ """
466
+ attention_outputs = self.attention(
467
+ hidden_states,
468
+ attention_mask,
469
+ flex_block_mask,
470
+ output_attentions,
471
+ )
472
+ if output_attentions:
473
+ attention_output, attention_weights = attention_outputs
474
+ else:
475
+ attention_output = attention_outputs
476
+ attention_weights = None
477
+
478
+ layer_output = self.feed_forward_chunk(attention_output)
479
+
480
+ if output_attentions:
481
+ return layer_output, attention_weights
482
+ return layer_output
483
+
484
+ def feed_forward_chunk(self, attention_output):
485
+ attention_output_ln = self.LayerNorm(attention_output)
486
+ intermediate_output = self.intermediate(attention_output_ln)
487
+ layer_output = self.output(intermediate_output, attention_output)
488
+ return layer_output
489
+
490
+
491
+ class EsmEncoder(nn.Module):
492
+ def __init__(self, config):
493
+ super().__init__()
494
+ self.config = config
495
+ self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
496
+ self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
497
+ self.gradient_checkpointing = False
498
+
499
+ def forward(
500
+ self,
501
+ hidden_states: torch.Tensor,
502
+ attention_mask: Optional[torch.Tensor] = None,
503
+ flex_block_mask: Optional[object] = None,
504
+ output_hidden_states: Optional[bool] = False,
505
+ output_attentions: Optional[bool] = False,
506
+ ) -> BaseModelOutputWithPastAndCrossAttentions:
507
+ """Forward pass for transformer encoder.
508
+
509
+ Args:
510
+ hidden_states: Input tensor
511
+ attention_mask: Optional attention mask
512
+ output_hidden_states: Whether to return all hidden states
513
+ output_attentions: Whether to return attention weights
514
+
515
+ Returns:
516
+ BaseModelOutputWithPastAndCrossAttentions containing model outputs
517
+ """
518
+ all_hidden_states = () if output_hidden_states else None
519
+ all_attentions = () if output_attentions else None
520
+
521
+ for layer_module in self.layer:
522
+ if output_hidden_states:
523
+ all_hidden_states = all_hidden_states + (hidden_states,)
524
+
525
+ if self.gradient_checkpointing and self.training:
526
+ layer_outputs = self._gradient_checkpointing_func(
527
+ layer_module.__call__,
528
+ hidden_states,
529
+ attention_mask,
530
+ flex_block_mask,
531
+ output_attentions,
532
+ )
533
+ else:
534
+ layer_outputs = layer_module(
535
+ hidden_states,
536
+ attention_mask,
537
+ flex_block_mask,
538
+ output_attentions,
539
+ )
540
+
541
+ if output_attentions:
542
+ hidden_states, attention_weights = layer_outputs
543
+ all_attentions = all_attentions + (attention_weights,)
544
+ else:
545
+ hidden_states = layer_outputs
546
+
547
+ if self.emb_layer_norm_after:
548
+ hidden_states = self.emb_layer_norm_after(hidden_states)
549
+
550
+ if output_hidden_states:
551
+ all_hidden_states = all_hidden_states + (hidden_states,)
552
+
553
+ return BaseModelOutputWithPastAndCrossAttentions(
554
+ last_hidden_state=hidden_states,
555
+ hidden_states=all_hidden_states,
556
+ attentions=all_attentions,
557
+ )
558
+
559
+
560
+ class FastEsmPreTrainedModel(PreTrainedModel):
561
+ """
562
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
563
+ models.
564
+ """
565
+ config_class = FastEsmConfig
566
+ base_model_prefix = "fastesm"
567
+ supports_gradient_checkpointing = True
568
+ tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
569
+ all_tied_weights_keys = {}
570
+
571
+ def _init_weights(self, module):
572
+ """Initialize the weights"""
573
+ if isinstance(module, nn.Linear):
574
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
575
+ if module.bias is not None:
576
+ module.bias.data.zero_()
577
+ elif isinstance(module, nn.Embedding):
578
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
579
+ if module.padding_idx is not None:
580
+ module.weight.data[module.padding_idx].zero_()
581
+ elif isinstance(module, nn.LayerNorm):
582
+ if module.bias is not None:
583
+ module.bias.data.zero_()
584
+ module.weight.data.fill_(1.0)
585
+
586
+ def get_input_embeddings(self) -> nn.Module:
587
+ try:
588
+ return self.embeddings.word_embeddings
589
+ except AttributeError:
590
+ return self.esm.embeddings.word_embeddings
591
+
592
+
593
+ class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
594
+ def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
595
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
596
+ self.config = config
597
+ self.embeddings = EsmEmbeddings(config)
598
+ self.encoder = EsmEncoder(config)
599
+ self.contact_head = EsmContactPredictionHead(
600
+ in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
601
+ )
602
+ # Initialize weights and apply final processing
603
+ self.post_init()
604
+
605
+ def get_input_embeddings(self):
606
+ return self.embeddings.word_embeddings
607
+
608
+ def set_input_embeddings(self, value):
609
+ self.embeddings.word_embeddings = value
610
+
611
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
612
+ token_embedding_output = self.embeddings(input_ids, attention_mask=attention_mask)
613
+ batch_size, seq_length = input_ids.shape
614
+ flex_block_mask = None
615
+ if attention_mask is not None:
616
+ token_attention_mask = attention_mask.bool()
617
+ if self.config.attn_backend == "flex":
618
+ assert create_block_mask is not None, (
619
+ "Flex attention backend requested but torch.create_block_mask is unavailable."
620
+ )
621
+ flex_block_mask = _create_pad_block_mask(token_attention_mask)
622
+ extended_attention_mask = None
623
+ else:
624
+ extended_attention_mask = token_attention_mask[:, None, None, :].expand(
625
+ batch_size, 1, seq_length, seq_length
626
+ )
627
+ else:
628
+ extended_attention_mask = None
629
+ encoder_outputs = self.encoder(
630
+ token_embedding_output,
631
+ attention_mask=extended_attention_mask,
632
+ flex_block_mask=flex_block_mask,
633
+ output_hidden_states=False,
634
+ output_attentions=False,
635
+ )
636
+ return encoder_outputs.last_hidden_state
637
+
638
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
639
+ attns = self(input_ids, attention_mask=attention_mask, output_attentions=True).attentions
640
+ attns = torch.stack(attns, dim=1)
641
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
642
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
643
+ return self.contact_head(input_ids, attns)
644
+
645
+ def forward(
646
+ self,
647
+ input_ids: Optional[torch.Tensor] = None,
648
+ attention_mask: Optional[torch.Tensor] = None,
649
+ position_ids: Optional[torch.Tensor] = None,
650
+ inputs_embeds: Optional[torch.Tensor] = None,
651
+ output_attentions: Optional[bool] = None,
652
+ output_hidden_states: Optional[bool] = None,
653
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
654
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
655
+ """Forward pass for base model.
656
+
657
+ Args:
658
+ input_ids: Input token IDs
659
+ attention_mask: Optional attention mask
660
+ position_ids: Optional position IDs
661
+ inputs_embeds: Optional input embeddings
662
+ output_hidden_states: Whether to return all hidden states
663
+ output_attentions: Whether to return attention weights
664
+
665
+ Returns:
666
+ Model outputs including hidden states and optionally attention weights
667
+ """
668
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
669
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
670
+
671
+ if input_ids is not None and inputs_embeds is not None:
672
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
673
+ elif input_ids is not None:
674
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
675
+ input_shape = input_ids.size()
676
+ elif inputs_embeds is not None:
677
+ input_shape = inputs_embeds.size()[:-1]
678
+ else:
679
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
680
+
681
+ batch_size, seq_length = input_shape
682
+ token_embedding_output = self.embeddings(
683
+ input_ids=input_ids,
684
+ position_ids=position_ids,
685
+ attention_mask=attention_mask,
686
+ inputs_embeds=inputs_embeds,
687
+ )
688
+
689
+ flex_block_mask = None
690
+ if attention_mask is not None:
691
+ token_attention_mask = attention_mask.bool()
692
+ if (
693
+ self.config.attn_backend == "flex"
694
+ and not output_attentions
695
+ ):
696
+ assert create_block_mask is not None, (
697
+ "Flex attention backend requested but torch.create_block_mask is unavailable."
698
+ )
699
+ flex_block_mask = _create_pad_block_mask(token_attention_mask)
700
+ extended_attention_mask = None
701
+ else:
702
+ extended_attention_mask = token_attention_mask[:, None, None, :].expand(
703
+ batch_size, 1, seq_length, seq_length
704
+ )
705
+ else:
706
+ extended_attention_mask = None
707
+
708
+ encoder_outputs = self.encoder(
709
+ token_embedding_output,
710
+ attention_mask=extended_attention_mask,
711
+ flex_block_mask=flex_block_mask,
712
+ output_hidden_states=output_hidden_states,
713
+ output_attentions=output_attentions,
714
+ )
715
+ sequence_output = encoder_outputs.last_hidden_state
716
+
717
+ return BaseModelOutputWithPoolingAndCrossAttentions(
718
+ last_hidden_state=sequence_output,
719
+ hidden_states=encoder_outputs.hidden_states,
720
+ attentions=encoder_outputs.attentions,
721
+ )
722
+
723
+
724
+ class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
725
+ def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
726
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
727
+ self.config = config
728
+ self.esm = FAST_ESM_ENCODER(config)
729
+ self.pooler = EsmPooler(config) if add_pooling_layer else None
730
+ # Initialize weights and apply final processing
731
+ self.post_init()
732
+
733
+ def get_input_embeddings(self):
734
+ return self.embeddings.word_embeddings
735
+
736
+ def set_input_embeddings(self, value):
737
+ self.embeddings.word_embeddings = value
738
+
739
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
740
+ return self.esm._embed(input_ids, attention_mask)
741
+
742
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
743
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
744
+
745
+ def forward(
746
+ self,
747
+ input_ids: Optional[torch.Tensor] = None,
748
+ attention_mask: Optional[torch.Tensor] = None,
749
+ position_ids: Optional[torch.Tensor] = None,
750
+ inputs_embeds: Optional[torch.Tensor] = None,
751
+ output_attentions: Optional[bool] = None,
752
+ output_hidden_states: Optional[bool] = None,
753
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
754
+ **kwargs,
755
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
756
+ """Forward pass for base model.
757
+
758
+ Args:
759
+ input_ids: Input token IDs
760
+ attention_mask: Optional attention mask
761
+ position_ids: Optional position IDs
762
+ inputs_embeds: Optional input embeddings
763
+ output_hidden_states: Whether to return all hidden states
764
+ output_attentions: Whether to return attention weights
765
+
766
+ Returns:
767
+ Model outputs including hidden states and optionally attention weights
768
+ """
769
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
770
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
771
+
772
+ if input_ids is not None and inputs_embeds is not None:
773
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
774
+ elif input_ids is not None:
775
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
776
+ input_shape = input_ids.size()
777
+ elif inputs_embeds is not None:
778
+ input_shape = inputs_embeds.size()[:-1]
779
+ else:
780
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
781
+
782
+ outputs = self.esm(
783
+ input_ids,
784
+ attention_mask=attention_mask,
785
+ position_ids=position_ids,
786
+ inputs_embeds=inputs_embeds,
787
+ output_hidden_states=output_hidden_states,
788
+ output_attentions=output_attentions,
789
+ )
790
+ sequence_output = outputs.last_hidden_state
791
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
792
+
793
+ return BaseModelOutputWithPoolingAndCrossAttentions(
794
+ last_hidden_state=sequence_output,
795
+ pooler_output=pooled_output,
796
+ hidden_states=outputs.hidden_states,
797
+ attentions=outputs.attentions,
798
+ )
799
+
800
+
801
+ class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
802
+ def __init__(self, config, **kwargs):
803
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
804
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
805
+ self.lm_head = EsmLMHead(config)
806
+ self.loss_fct = nn.CrossEntropyLoss()
807
+ self.init_weights()
808
+
809
+ def get_output_embeddings(self):
810
+ return self.lm_head.decoder
811
+
812
+ def set_output_embeddings(self, new_embeddings):
813
+ self.lm_head.decoder = new_embeddings
814
+
815
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
816
+ return self.esm._embed(input_ids, attention_mask)
817
+
818
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
819
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
820
+
821
+ def forward(
822
+ self,
823
+ input_ids: Optional[torch.Tensor] = None,
824
+ attention_mask: Optional[torch.Tensor] = None,
825
+ position_ids: Optional[torch.Tensor] = None,
826
+ inputs_embeds: Optional[torch.Tensor] = None,
827
+ labels: Optional[torch.Tensor] = None,
828
+ output_attentions: Optional[bool] = None,
829
+ output_hidden_states: Optional[bool] = None,
830
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
831
+ **kwargs,
832
+ ) -> Union[Tuple, EsmMaskedLMOutput]:
833
+ outputs = self.esm(
834
+ input_ids,
835
+ attention_mask=attention_mask,
836
+ position_ids=position_ids,
837
+ inputs_embeds=inputs_embeds,
838
+ output_hidden_states=output_hidden_states,
839
+ output_attentions=output_attentions,
840
+ )
841
+ sequence_output = outputs.last_hidden_state
842
+ prediction_scores = self.lm_head(sequence_output)
843
+
844
+ loss = None
845
+ if labels is not None:
846
+ labels = labels.to(prediction_scores.device)
847
+ loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
848
+
849
+ return EsmMaskedLMOutput(
850
+ loss=loss,
851
+ logits=prediction_scores,
852
+ last_hidden_state=sequence_output,
853
+ hidden_states=outputs.hidden_states,
854
+ attentions=outputs.attentions,
855
+ )
856
+
857
+
858
+ class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
859
+ def __init__(self, config, **kwargs):
860
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
861
+ self.num_labels = config.num_labels
862
+ self.config = config
863
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
864
+ self.classifier = EsmClassificationHead(config)
865
+ self.mse = nn.MSELoss()
866
+ self.ce = nn.CrossEntropyLoss()
867
+ self.bce = nn.BCEWithLogitsLoss()
868
+ self.init_weights()
869
+
870
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
871
+ return self.esm._embed(input_ids, attention_mask)
872
+
873
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
874
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
875
+
876
+ def forward(
877
+ self,
878
+ input_ids: Optional[torch.Tensor] = None,
879
+ attention_mask: Optional[torch.Tensor] = None,
880
+ position_ids: Optional[torch.Tensor] = None,
881
+ inputs_embeds: Optional[torch.Tensor] = None,
882
+ labels: Optional[torch.Tensor] = None,
883
+ output_attentions: Optional[bool] = None,
884
+ output_hidden_states: Optional[bool] = None,
885
+ return_dict: Optional[bool] = None,
886
+ **kwargs,
887
+ ) -> Union[Tuple, SequenceClassifierOutput]:
888
+ outputs = self.esm(
889
+ input_ids,
890
+ attention_mask=attention_mask,
891
+ position_ids=position_ids,
892
+ inputs_embeds=inputs_embeds,
893
+ output_attentions=output_attentions,
894
+ output_hidden_states=output_hidden_states,
895
+ )
896
+ sequence_output = outputs.last_hidden_state
897
+ logits = self.classifier(sequence_output)
898
+
899
+ loss = None
900
+ if labels is not None:
901
+ labels = labels.to(logits.device)
902
+ if self.config.problem_type is None:
903
+ if self.num_labels == 1:
904
+ self.config.problem_type = "regression"
905
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
906
+ self.config.problem_type = "single_label_classification"
907
+ else:
908
+ self.config.problem_type = "multi_label_classification"
909
+
910
+ if self.config.problem_type == "regression":
911
+ if self.num_labels == 1:
912
+ loss = self.mse(logits.squeeze(), labels.squeeze())
913
+ else:
914
+ loss = self.mse(logits, labels)
915
+ elif self.config.problem_type == "single_label_classification":
916
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
917
+ elif self.config.problem_type == "multi_label_classification":
918
+ loss = self.bce(logits, labels)
919
+
920
+ return SequenceClassifierOutput(
921
+ loss=loss,
922
+ logits=logits,
923
+ hidden_states=outputs.hidden_states,
924
+ attentions=outputs.attentions,
925
+ )
926
+
927
+
928
+ class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin):
929
+ def __init__(self, config, **kwargs):
930
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
931
+ self.num_labels = config.num_labels
932
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
933
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
934
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
935
+ self.loss_fct = nn.CrossEntropyLoss()
936
+ self.init_weights()
937
+
938
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
939
+ return self.esm._embed(input_ids, attention_mask)
940
+
941
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
942
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
943
+
944
+ def forward(
945
+ self,
946
+ input_ids: Optional[torch.Tensor] = None,
947
+ attention_mask: Optional[torch.Tensor] = None,
948
+ position_ids: Optional[torch.Tensor] = None,
949
+ inputs_embeds: Optional[torch.Tensor] = None,
950
+ labels: Optional[torch.Tensor] = None,
951
+ output_attentions: Optional[bool] = None,
952
+ output_hidden_states: Optional[bool] = None,
953
+ return_dict: Optional[bool] = None,
954
+ **kwargs,
955
+ ) -> Union[Tuple, TokenClassifierOutput]:
956
+ outputs = self.esm(
957
+ input_ids,
958
+ attention_mask=attention_mask,
959
+ position_ids=position_ids,
960
+ inputs_embeds=inputs_embeds,
961
+ output_attentions=output_attentions,
962
+ output_hidden_states=output_hidden_states,
963
+ )
964
+ sequence_output = outputs.last_hidden_state
965
+ sequence_output = self.dropout(sequence_output)
966
+ logits = self.classifier(sequence_output)
967
+
968
+ loss = None
969
+ if labels is not None:
970
+ labels = labels.to(logits.device)
971
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
972
+
973
+ return TokenClassifierOutput(
974
+ loss=loss,
975
+ logits=logits,
976
+ hidden_states=outputs.hidden_states,
977
+ attentions=outputs.attentions,
978
+ )
979
+
980
+
981
+ if __name__ == "__main__":
982
+ """
983
+ Test the hidden state differences between the FastEsmModel and the HF EsmModel.
984
+ In full precision, the differences are very very small, but nonzero due to floating point issues with F.scaled_dot_product_attention.
985
+ In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
986
+ """
987
+ import random
988
+ from transformers import EsmForMaskedLM as TransformersEsmModel, EsmTokenizer
989
+
990
+ model_paths = [
991
+ "facebook/esm2_t6_8M_UR50D",
992
+ "facebook/esm2_t12_35M_UR50D",
993
+ #"facebook/esm2_t30_150M_UR50D",
994
+ #"facebook/esm2_t33_650M_UR50D",
995
+ ]
996
+ canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
997
+ length = 64
998
+ seq_count = 100
999
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1000
+ tolerances = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
1001
+
1002
+ def generate_random_sequence(length: int) -> str:
1003
+ return 'M' + "".join(random.choices(canonical_amino_acids, k=length))
1004
+
1005
+ print("Percentage of hidden states that are within the tolerance:")
1006
+ for model_path in model_paths:
1007
+ print(f"Testing {model_path}...")
1008
+ tokenizer = EsmTokenizer.from_pretrained(model_path)
1009
+ config = FastEsmConfig.from_pretrained(model_path)
1010
+ fast_model = FastEsmForMaskedLM(config).from_pretrained(model_path).to(device)
1011
+ print('fast model')
1012
+ print(fast_model)
1013
+ model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
1014
+ print('transformers model')
1015
+ print(model)
1016
+
1017
+ counts = [0] * len(tolerances)
1018
+ for _ in range(seq_count):
1019
+ example_seq = generate_random_sequence(length)
1020
+ fast_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
1021
+ fast_output = fast_model(fast_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
1022
+
1023
+ model_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
1024
+ model_output = model(model_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
1025
+
1026
+ for i, atol in enumerate(tolerances):
1027
+ if torch.allclose(fast_output, model_output, atol=atol):
1028
+ counts[i] += 1
1029
+
1030
+ print(f"{model_path}:")
1031
+ for i, atol in enumerate(tolerances):
1032
+ print(f" tolerance={atol}: {counts[i] / seq_count * 100}%")
1033
+
1034
+ model.cpu()
1035
+ fast_model.cpu()
1036
+ del model
1037
+ del fast_model
1038
+ torch.cuda.empty_cache()