lhallee commited on
Commit
e2f35b7
·
verified ·
1 Parent(s): 737f3bf

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +1032 -1021
modeling_fastesm.py CHANGED
@@ -1,1021 +1,1032 @@
1
- import entrypoint_setup
2
- import torch
3
- import torch.nn as nn
4
- import warnings
5
- from torch.nn import functional as F
6
- from torch.nn.attention.flex_attention import create_block_mask
7
- from torch.nn.attention.flex_attention import flex_attention
8
- from typing import Optional, Tuple, Union, Dict, Any
9
- from einops import rearrange
10
- from dataclasses import dataclass
11
- from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
12
- from transformers.modeling_outputs import (
13
- ModelOutput,
14
- BaseModelOutputWithPastAndCrossAttentions,
15
- BaseModelOutputWithPoolingAndCrossAttentions,
16
- SequenceClassifierOutput,
17
- TokenClassifierOutput
18
- )
19
- from transformers.models.esm.modeling_esm import (
20
- EsmIntermediate,
21
- EsmOutput,
22
- EsmPooler,
23
- EsmLMHead,
24
- EsmSelfOutput,
25
- EsmClassificationHead,
26
- )
27
-
28
- try:
29
- # when used from AutoModel, these are in the same directory
30
- from .embedding_mixin import EmbeddingMixin, Pooler
31
- except:
32
- # when running from our repo, these are in the base directory
33
- from embedding_mixin import EmbeddingMixin, Pooler
34
-
35
-
36
- def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
37
- assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
38
- token_valid = attention_mask_2d.bool()
39
- batch_size, seq_len = token_valid.shape
40
-
41
- def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
42
- return token_valid[batch_idx, q_idx] & token_valid[batch_idx, kv_idx]
43
-
44
- return create_block_mask(
45
- mask_mod,
46
- batch_size,
47
- 1,
48
- seq_len,
49
- seq_len,
50
- device=attention_mask_2d.device,
51
- )
52
-
53
-
54
- @dataclass
55
- class EsmMaskedLMOutput(ModelOutput):
56
- loss: Optional[torch.Tensor] = None
57
- logits: Optional[torch.Tensor] = None
58
- last_hidden_state: Optional[torch.Tensor] = None
59
- hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
60
- attentions: Optional[Tuple[torch.Tensor, ...]] = None
61
-
62
-
63
- class FastEsmConfig(PretrainedConfig):
64
- model_type = "fast_esm"
65
- def __init__(
66
- self,
67
- vocab_size: int = None,
68
- mask_token_id: int = None,
69
- pad_token_id: int = None,
70
- hidden_size: int = 768,
71
- num_hidden_layers: int = 12,
72
- num_attention_heads: int = 12,
73
- intermediate_size: int = 3072,
74
- hidden_dropout_prob: float = 0.1,
75
- attention_probs_dropout_prob: float = 0.1,
76
- max_position_embeddings: int = 1026,
77
- initializer_range: float = 0.02,
78
- layer_norm_eps: float = 1e-12,
79
- position_embedding_type: str = "absolute",
80
- emb_layer_norm_before: bool = None,
81
- token_dropout: bool = True,
82
- attn_backend: str = "sdpa",
83
- **kwargs,
84
- ):
85
- super().__init__(
86
- pad_token_id=pad_token_id,
87
- mask_token_id=mask_token_id,
88
- **kwargs,
89
- )
90
-
91
- self.vocab_size = vocab_size
92
- self.hidden_size = hidden_size
93
- self.num_hidden_layers = num_hidden_layers
94
- self.num_attention_heads = num_attention_heads
95
- self.intermediate_size = intermediate_size
96
- self.hidden_dropout_prob = hidden_dropout_prob
97
- self.attention_probs_dropout_prob = attention_probs_dropout_prob
98
- self.max_position_embeddings = max_position_embeddings
99
- self.initializer_range = initializer_range
100
- self.layer_norm_eps = layer_norm_eps
101
- self.position_embedding_type = position_embedding_type
102
- self.emb_layer_norm_before = emb_layer_norm_before
103
- self.tie_word_embeddings = False
104
- self.token_dropout = token_dropout
105
- self.attn_backend = attn_backend
106
-
107
- def to_dict(self) -> Dict[str, Any]:
108
- """
109
- Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
110
-
111
- Returns:
112
- `Dict[str, any]`: Dictionar y of all the attributes that make up this configuration instance,
113
- """
114
- output = super().to_dict()
115
- return output
116
-
117
-
118
- def rotate_half(x: torch.Tensor) -> torch.Tensor:
119
- x1, x2 = x.chunk(2, dim=-1)
120
- return torch.cat((-x2, x1), dim=-1)
121
-
122
-
123
- def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
124
- cos = cos[:, :, : x.shape[-2], :]
125
- sin = sin[:, :, : x.shape[-2], :]
126
-
127
- return (x * cos) + (rotate_half(x) * sin)
128
-
129
-
130
- def symmetrize(x: torch.Tensor) -> torch.Tensor:
131
- "Make layer symmetric in final two dimensions, used for contact prediction."
132
- return x + x.transpose(-1, -2)
133
-
134
-
135
- def average_product_correct(x: torch.Tensor) -> torch.Tensor:
136
- "Perform average product correct, used for contact prediction."
137
- a1 = x.sum(-1, keepdims=True)
138
- a2 = x.sum(-2, keepdims=True)
139
- a12 = x.sum((-1, -2), keepdims=True)
140
-
141
- avg = a1 * a2
142
- avg.div_(a12) # in-place to reduce memory
143
- normalized = x - avg
144
- return normalized
145
-
146
-
147
- class EsmContactPredictionHead(nn.Module):
148
- """Performs symmetrization, apc, and computes a logistic regression on the output features"""
149
-
150
- def __init__(
151
- self,
152
- in_features: int,
153
- bias: bool = True,
154
- eos_idx: int = 2,
155
- ):
156
- super().__init__()
157
- self.in_features = in_features
158
- self.eos_idx = eos_idx
159
- self.regression = nn.Linear(in_features, 1, bias=bias)
160
- self.activation = nn.Sigmoid()
161
-
162
- def forward(self, input_ids: torch.Tensor, attentions: torch.Tensor) -> torch.Tensor:
163
- # remove eos token attentions
164
- eos_mask = input_ids.ne(self.eos_idx).to(attentions)
165
- eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
166
- attentions = attentions * eos_mask[:, None, None, :, :]
167
- attentions = attentions[..., :-1, :-1]
168
- # remove cls token attentions
169
- attentions = attentions[..., 1:, 1:]
170
- batch_size, layers, heads, seqlen, _ = attentions.size()
171
- attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
172
-
173
- # features: batch x channels x tokens x tokens (symmetric)
174
- attentions = attentions.to(
175
- self.regression.weight.device
176
- ) # attentions always float32, may need to convert to float16
177
- attentions = average_product_correct(symmetrize(attentions))
178
- attentions = attentions.permute(0, 2, 3, 1)
179
- return self.activation(self.regression(attentions).squeeze(3))
180
-
181
-
182
- class RotaryEmbedding(torch.nn.Module):
183
- """
184
- Rotary position embeddings based on those in
185
- [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
186
- matrices which depend on their relative positions.
187
- """
188
-
189
- def __init__(self, dim: int):
190
- super().__init__()
191
- # Generate and save the inverse frequency buffer (non trainable)
192
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
193
- inv_freq = inv_freq
194
- self.register_buffer("inv_freq", inv_freq)
195
-
196
- self._seq_len_cached = None
197
- self._cos_cached = None
198
- self._sin_cached = None
199
-
200
- def _update_cos_sin_tables(self, x: torch.Tensor, seq_dimension: int = 2) -> Tuple[torch.Tensor, torch.Tensor]:
201
- seq_len = x.shape[seq_dimension]
202
-
203
- # Reset the tables if the sequence length has changed,
204
- # or if we're on a new device (possibly due to tracing for instance)
205
- if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
206
- self._seq_len_cached = seq_len
207
- t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
208
- freqs = torch.outer(t, self.inv_freq)
209
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
210
-
211
- self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
212
- self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
213
-
214
- return self._cos_cached, self._sin_cached
215
-
216
- def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
217
- self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
218
-
219
- return (
220
- apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
221
- apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
222
- )
223
-
224
-
225
- class EsmEmbeddings(nn.Module):
226
- """
227
- Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
228
- """
229
-
230
- def __init__(self, config):
231
- super().__init__()
232
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
233
- if config.emb_layer_norm_before:
234
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
235
- else:
236
- self.layer_norm = None
237
- self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
238
- self.register_buffer(
239
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
240
- )
241
- self.token_dropout = config.token_dropout
242
- self.mask_token_id = config.mask_token_id
243
-
244
- def forward(
245
- self,
246
- input_ids: Optional[torch.Tensor] = None,
247
- attention_mask: Optional[torch.Tensor] = None,
248
- position_ids: Optional[torch.Tensor] = None,
249
- inputs_embeds: Optional[torch.Tensor] = None,
250
- past_key_values_length: Optional[int] = 0,
251
- ):
252
- if inputs_embeds is None:
253
- inputs_embeds = self.word_embeddings(input_ids)
254
-
255
- embeddings = inputs_embeds
256
-
257
- if attention_mask is None:
258
- attention_mask = torch.ones_like(input_ids)
259
-
260
- if self.token_dropout:
261
- embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0)
262
- mask_ratio_train = 0.15 * 0.8
263
- src_lengths = attention_mask.sum(-1)
264
- mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
265
- embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
266
- embeddings.dtype
267
- )
268
-
269
- if self.layer_norm is not None:
270
- embeddings = self.layer_norm(embeddings)
271
- if attention_mask is not None:
272
- embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
273
- return embeddings
274
-
275
- def create_position_ids_from_inputs_embeds(self, inputs_embeds):
276
- """
277
- We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
278
-
279
- Args:
280
- inputs_embeds: torch.Tensor
281
-
282
- Returns: torch.Tensor
283
- """
284
- input_shape = inputs_embeds.size()[:-1]
285
- sequence_length = input_shape[1]
286
-
287
- position_ids = torch.arange(
288
- self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
289
- )
290
- return position_ids.unsqueeze(0).expand(input_shape)
291
-
292
-
293
- class EsmSelfAttention(nn.Module):
294
- def __init__(self, config, position_embedding_type: Optional[str] = None):
295
- super().__init__()
296
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
297
- raise ValueError(
298
- f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
299
- f"heads ({config.num_attention_heads})"
300
- )
301
-
302
- self.num_attention_heads = config.num_attention_heads
303
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
304
- self.all_head_size = self.num_attention_heads * self.attention_head_size
305
-
306
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
307
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
308
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
309
- self.scale = self.attention_head_size**-0.5
310
-
311
- self.dropout_prob = config.attention_probs_dropout_prob
312
- self.attn_backend = config.attn_backend
313
- self._warned_flex_fallback = False
314
- self.position_embedding_type = position_embedding_type or getattr(
315
- config, "position_embedding_type", "absolute"
316
- )
317
- self.rotary_embeddings = None
318
- if self.position_embedding_type == "rotary":
319
- self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
320
-
321
- def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
322
- return rearrange(x, 'b s (h d) -> b h s d', h=self.num_attention_heads)
323
-
324
- def forward(
325
- self,
326
- hidden_states: torch.Tensor,
327
- attention_mask: Optional[torch.Tensor] = None,
328
- flex_block_mask: Optional[object] = None,
329
- output_attentions: Optional[bool] = False,
330
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
331
- """Forward pass for self attention.
332
-
333
- Args:
334
- hidden_states: Input tensor
335
- attention_mask: Optional attention mask
336
- output_attentions: Whether to return attention weights
337
-
338
- Returns:
339
- Output tensor and optionally attention weights
340
- """
341
- query_layer = self.transpose_for_scores(self.query(hidden_states)) * self.scale
342
- key_layer = self.transpose_for_scores(self.key(hidden_states))
343
- value_layer = self.transpose_for_scores(self.value(hidden_states))
344
-
345
- if self.position_embedding_type == "rotary":
346
- query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
347
-
348
- if output_attentions:
349
- # Manual attention computation - apply scaling here
350
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
351
- if attention_mask is not None:
352
- attention_scores = attention_scores.masked_fill(attention_mask.logical_not(), float("-inf"))
353
- attention_probs = F.softmax(attention_scores, dim=-1)
354
- if self.dropout_prob > 0:
355
- attention_probs = F.dropout(attention_probs, p=self.dropout_prob, training=self.training)
356
- context_layer = torch.matmul(attention_probs, value_layer)
357
- context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
358
- return context_layer, attention_probs
359
- else:
360
- sdpa_mask = None
361
- if attention_mask is not None:
362
- sdpa_mask = torch.zeros_like(attention_mask, dtype=query_layer.dtype)
363
- sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
364
- use_flex = (
365
- self.attn_backend == "flex"
366
- and (attention_mask is None or flex_block_mask is not None)
367
- )
368
- if use_flex:
369
- try:
370
- context_layer = flex_attention(
371
- query_layer,
372
- key_layer,
373
- value_layer,
374
- block_mask=flex_block_mask,
375
- scale=1.0,
376
- )
377
- except Exception as exc:
378
- if not self._warned_flex_fallback:
379
- warnings.warn(
380
- f"Flex attention failed in FastESM attention; falling back to SDPA. Error: {exc}",
381
- RuntimeWarning,
382
- )
383
- self._warned_flex_fallback = True
384
- context_layer = F.scaled_dot_product_attention(
385
- query_layer,
386
- key_layer,
387
- value_layer,
388
- attn_mask=sdpa_mask,
389
- dropout_p=self.dropout_prob,
390
- scale=1.0,
391
- )
392
- else:
393
- context_layer = F.scaled_dot_product_attention(
394
- query_layer,
395
- key_layer,
396
- value_layer,
397
- attn_mask=sdpa_mask,
398
- dropout_p=self.dropout_prob,
399
- scale=1.0
400
- )
401
- context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
402
- return context_layer
403
-
404
-
405
- class EsmAttention(nn.Module):
406
- def __init__(self, config):
407
- super().__init__()
408
- self.self = EsmSelfAttention(config)
409
- self.output = EsmSelfOutput(config)
410
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
411
-
412
- def forward(
413
- self,
414
- hidden_states: torch.Tensor,
415
- attention_mask: Optional[torch.Tensor] = None,
416
- flex_block_mask: Optional[object] = None,
417
- output_attentions: Optional[bool] = False,
418
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
419
- """Forward pass for attention layer.
420
-
421
- Args:
422
- hidden_states: Input tensor
423
- attention_mask: Optional attention mask
424
- output_attentions: Whether to return attention weights
425
-
426
- Returns:
427
- Output tensor and optionally attention weights
428
- """
429
- hidden_states_ln = self.LayerNorm(hidden_states)
430
- self_outputs = self.self(
431
- hidden_states_ln,
432
- attention_mask,
433
- flex_block_mask,
434
- output_attentions,
435
- )
436
- if output_attentions:
437
- attention_output, attention_weights = self_outputs
438
- attention_output = self.output(attention_output, hidden_states)
439
- return attention_output, attention_weights
440
- else:
441
- attention_output = self_outputs
442
- return self.output(attention_output, hidden_states)
443
-
444
-
445
- class EsmLayer(nn.Module):
446
- def __init__(self, config):
447
- super().__init__()
448
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
449
- self.seq_len_dim = 1
450
- self.attention = EsmAttention(config)
451
- self.intermediate = EsmIntermediate(config)
452
- self.output = EsmOutput(config)
453
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
454
-
455
- def forward(
456
- self,
457
- hidden_states: torch.Tensor,
458
- attention_mask: Optional[torch.Tensor] = None,
459
- flex_block_mask: Optional[object] = None,
460
- output_attentions: Optional[bool] = False,
461
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
462
- """Forward pass for transformer layer.
463
-
464
- Args:
465
- hidden_states: Input tensor
466
- attention_mask: Optional attention mask
467
- output_attentions: Whether to return attention weights
468
-
469
- Returns:
470
- Output tensor and optionally attention weights
471
- """
472
- attention_outputs = self.attention(
473
- hidden_states,
474
- attention_mask,
475
- flex_block_mask,
476
- output_attentions,
477
- )
478
- if output_attentions:
479
- attention_output, attention_weights = attention_outputs
480
- else:
481
- attention_output = attention_outputs
482
- attention_weights = None
483
-
484
- layer_output = self.feed_forward_chunk(attention_output)
485
-
486
- if output_attentions:
487
- return layer_output, attention_weights
488
- return layer_output
489
-
490
- def feed_forward_chunk(self, attention_output):
491
- attention_output_ln = self.LayerNorm(attention_output)
492
- intermediate_output = self.intermediate(attention_output_ln)
493
- layer_output = self.output(intermediate_output, attention_output)
494
- return layer_output
495
-
496
-
497
- class EsmEncoder(nn.Module):
498
- def __init__(self, config):
499
- super().__init__()
500
- self.config = config
501
- self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
502
- self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
503
- self.gradient_checkpointing = False
504
-
505
- def forward(
506
- self,
507
- hidden_states: torch.Tensor,
508
- attention_mask: Optional[torch.Tensor] = None,
509
- flex_block_mask: Optional[object] = None,
510
- output_hidden_states: Optional[bool] = False,
511
- output_attentions: Optional[bool] = False,
512
- ) -> BaseModelOutputWithPastAndCrossAttentions:
513
- """Forward pass for transformer encoder.
514
-
515
- Args:
516
- hidden_states: Input tensor
517
- attention_mask: Optional attention mask
518
- output_hidden_states: Whether to return all hidden states
519
- output_attentions: Whether to return attention weights
520
-
521
- Returns:
522
- BaseModelOutputWithPastAndCrossAttentions containing model outputs
523
- """
524
- all_hidden_states = () if output_hidden_states else None
525
- all_attentions = () if output_attentions else None
526
-
527
- for layer_module in self.layer:
528
- if output_hidden_states:
529
- all_hidden_states = all_hidden_states + (hidden_states,)
530
-
531
- if self.gradient_checkpointing and self.training:
532
- layer_outputs = self._gradient_checkpointing_func(
533
- layer_module.__call__,
534
- hidden_states,
535
- attention_mask,
536
- flex_block_mask,
537
- output_attentions,
538
- )
539
- else:
540
- layer_outputs = layer_module(
541
- hidden_states,
542
- attention_mask,
543
- flex_block_mask,
544
- output_attentions,
545
- )
546
-
547
- if output_attentions:
548
- hidden_states, attention_weights = layer_outputs
549
- all_attentions = all_attentions + (attention_weights,)
550
- else:
551
- hidden_states = layer_outputs
552
-
553
- if self.emb_layer_norm_after:
554
- hidden_states = self.emb_layer_norm_after(hidden_states)
555
-
556
- if output_hidden_states:
557
- all_hidden_states = all_hidden_states + (hidden_states,)
558
-
559
- return BaseModelOutputWithPastAndCrossAttentions(
560
- last_hidden_state=hidden_states,
561
- hidden_states=all_hidden_states,
562
- attentions=all_attentions,
563
- )
564
-
565
-
566
- class FastEsmPreTrainedModel(PreTrainedModel):
567
- """
568
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
569
- models.
570
- """
571
- config_class = FastEsmConfig
572
- base_model_prefix = "fastesm"
573
- supports_gradient_checkpointing = True
574
- tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
575
- all_tied_weights_keys = {}
576
-
577
- def _init_weights(self, module):
578
- """Initialize the weights"""
579
- if isinstance(module, nn.Linear):
580
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
581
- if module.bias is not None:
582
- module.bias.data.zero_()
583
- elif isinstance(module, nn.Embedding):
584
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
585
- if module.padding_idx is not None:
586
- module.weight.data[module.padding_idx].zero_()
587
- elif isinstance(module, nn.LayerNorm):
588
- if module.bias is not None:
589
- module.bias.data.zero_()
590
- module.weight.data.fill_(1.0)
591
-
592
- def get_input_embeddings(self) -> nn.Module:
593
- try:
594
- return self.embeddings.word_embeddings
595
- except AttributeError:
596
- return self.esm.embeddings.word_embeddings
597
-
598
-
599
- class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
600
- def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
601
- FastEsmPreTrainedModel.__init__(self, config, **kwargs)
602
- self.config = config
603
- self.embeddings = EsmEmbeddings(config)
604
- self.encoder = EsmEncoder(config)
605
- self.contact_head = EsmContactPredictionHead(
606
- in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
607
- )
608
- # Initialize weights and apply final processing
609
- self.post_init()
610
-
611
- def get_input_embeddings(self):
612
- return self.embeddings.word_embeddings
613
-
614
- def set_input_embeddings(self, value):
615
- self.embeddings.word_embeddings = value
616
-
617
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
618
- token_embedding_output = self.embeddings(input_ids, attention_mask=attention_mask)
619
- batch_size, seq_length = input_ids.shape
620
- if attention_mask is not None:
621
- extended_attention_mask = attention_mask[:, None, None, :].expand(
622
- batch_size, 1, seq_length, seq_length
623
- ).bool()
624
- else:
625
- extended_attention_mask = None
626
- encoder_outputs = self.encoder(
627
- token_embedding_output,
628
- attention_mask=extended_attention_mask,
629
- output_hidden_states=False,
630
- output_attentions=False,
631
- )
632
- return encoder_outputs.last_hidden_state
633
-
634
- def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
635
- attns = self(input_ids, attention_mask=attention_mask, output_attentions=True).attentions
636
- attns = torch.stack(attns, dim=1)
637
- attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
638
- attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
639
- return self.contact_head(input_ids, attns)
640
-
641
- def forward(
642
- self,
643
- input_ids: Optional[torch.Tensor] = None,
644
- attention_mask: Optional[torch.Tensor] = None,
645
- position_ids: Optional[torch.Tensor] = None,
646
- inputs_embeds: Optional[torch.Tensor] = None,
647
- output_attentions: Optional[bool] = None,
648
- output_hidden_states: Optional[bool] = None,
649
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
650
- ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
651
- """Forward pass for base model.
652
-
653
- Args:
654
- input_ids: Input token IDs
655
- attention_mask: Optional attention mask
656
- position_ids: Optional position IDs
657
- inputs_embeds: Optional input embeddings
658
- output_hidden_states: Whether to return all hidden states
659
- output_attentions: Whether to return attention weights
660
-
661
- Returns:
662
- Model outputs including hidden states and optionally attention weights
663
- """
664
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
665
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
666
-
667
- if input_ids is not None and inputs_embeds is not None:
668
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
669
- elif input_ids is not None:
670
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
671
- input_shape = input_ids.size()
672
- elif inputs_embeds is not None:
673
- input_shape = inputs_embeds.size()[:-1]
674
- else:
675
- raise ValueError("You have to specify either input_ids or inputs_embeds")
676
-
677
- batch_size, seq_length = input_shape
678
- token_embedding_output = self.embeddings(
679
- input_ids=input_ids,
680
- position_ids=position_ids,
681
- attention_mask=attention_mask,
682
- inputs_embeds=inputs_embeds,
683
- )
684
-
685
- if attention_mask is not None:
686
- extended_attention_mask = attention_mask[:, None, None, :].expand(
687
- batch_size, 1, seq_length, seq_length
688
- ).bool()
689
- else:
690
- extended_attention_mask = None
691
-
692
- encoder_outputs = self.encoder(
693
- token_embedding_output,
694
- attention_mask=extended_attention_mask,
695
- output_hidden_states=output_hidden_states,
696
- output_attentions=output_attentions,
697
- )
698
- sequence_output = encoder_outputs.last_hidden_state
699
-
700
- return BaseModelOutputWithPoolingAndCrossAttentions(
701
- last_hidden_state=sequence_output,
702
- hidden_states=encoder_outputs.hidden_states,
703
- attentions=encoder_outputs.attentions,
704
- )
705
-
706
-
707
- class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
708
- def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
709
- FastEsmPreTrainedModel.__init__(self, config, **kwargs)
710
- self.config = config
711
- self.esm = FAST_ESM_ENCODER(config)
712
- self.pooler = EsmPooler(config) if add_pooling_layer else None
713
- # Initialize weights and apply final processing
714
- self.post_init()
715
-
716
- def get_input_embeddings(self):
717
- return self.embeddings.word_embeddings
718
-
719
- def set_input_embeddings(self, value):
720
- self.embeddings.word_embeddings = value
721
-
722
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
723
- return self.esm._embed(input_ids, attention_mask)
724
-
725
- def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
726
- return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
727
-
728
- def forward(
729
- self,
730
- input_ids: Optional[torch.Tensor] = None,
731
- attention_mask: Optional[torch.Tensor] = None,
732
- position_ids: Optional[torch.Tensor] = None,
733
- inputs_embeds: Optional[torch.Tensor] = None,
734
- output_attentions: Optional[bool] = None,
735
- output_hidden_states: Optional[bool] = None,
736
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
737
- **kwargs,
738
- ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
739
- """Forward pass for base model.
740
-
741
- Args:
742
- input_ids: Input token IDs
743
- attention_mask: Optional attention mask
744
- position_ids: Optional position IDs
745
- inputs_embeds: Optional input embeddings
746
- output_hidden_states: Whether to return all hidden states
747
- output_attentions: Whether to return attention weights
748
-
749
- Returns:
750
- Model outputs including hidden states and optionally attention weights
751
- """
752
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
753
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
754
-
755
- if input_ids is not None and inputs_embeds is not None:
756
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
757
- elif input_ids is not None:
758
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
759
- input_shape = input_ids.size()
760
- elif inputs_embeds is not None:
761
- input_shape = inputs_embeds.size()[:-1]
762
- else:
763
- raise ValueError("You have to specify either input_ids or inputs_embeds")
764
-
765
- outputs = self.esm(
766
- input_ids,
767
- attention_mask=attention_mask,
768
- position_ids=position_ids,
769
- inputs_embeds=inputs_embeds,
770
- output_hidden_states=output_hidden_states,
771
- output_attentions=output_attentions,
772
- )
773
- sequence_output = outputs.last_hidden_state
774
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
775
-
776
- return BaseModelOutputWithPoolingAndCrossAttentions(
777
- last_hidden_state=sequence_output,
778
- pooler_output=pooled_output,
779
- hidden_states=outputs.hidden_states,
780
- attentions=outputs.attentions,
781
- )
782
-
783
-
784
- class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
785
- def __init__(self, config, **kwargs):
786
- FastEsmPreTrainedModel.__init__(self, config, **kwargs)
787
- self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
788
- self.lm_head = EsmLMHead(config)
789
- self.loss_fct = nn.CrossEntropyLoss()
790
- self.init_weights()
791
-
792
- def get_output_embeddings(self):
793
- return self.lm_head.decoder
794
-
795
- def set_output_embeddings(self, new_embeddings):
796
- self.lm_head.decoder = new_embeddings
797
-
798
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
799
- return self.esm._embed(input_ids, attention_mask)
800
-
801
- def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
802
- return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
803
-
804
- def forward(
805
- self,
806
- input_ids: Optional[torch.Tensor] = None,
807
- attention_mask: Optional[torch.Tensor] = None,
808
- position_ids: Optional[torch.Tensor] = None,
809
- inputs_embeds: Optional[torch.Tensor] = None,
810
- labels: Optional[torch.Tensor] = None,
811
- output_attentions: Optional[bool] = None,
812
- output_hidden_states: Optional[bool] = None,
813
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
814
- **kwargs,
815
- ) -> Union[Tuple, EsmMaskedLMOutput]:
816
- outputs = self.esm(
817
- input_ids,
818
- attention_mask=attention_mask,
819
- position_ids=position_ids,
820
- inputs_embeds=inputs_embeds,
821
- output_hidden_states=output_hidden_states,
822
- output_attentions=output_attentions,
823
- )
824
- sequence_output = outputs.last_hidden_state
825
- prediction_scores = self.lm_head(sequence_output)
826
-
827
- loss = None
828
- if labels is not None:
829
- labels = labels.to(prediction_scores.device)
830
- loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
831
-
832
- return EsmMaskedLMOutput(
833
- loss=loss,
834
- logits=prediction_scores,
835
- last_hidden_state=sequence_output,
836
- hidden_states=outputs.hidden_states,
837
- attentions=outputs.attentions,
838
- )
839
-
840
-
841
- class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
842
- def __init__(self, config, **kwargs):
843
- FastEsmPreTrainedModel.__init__(self, config, **kwargs)
844
- self.num_labels = config.num_labels
845
- self.config = config
846
- self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
847
- self.classifier = EsmClassificationHead(config)
848
- self.mse = nn.MSELoss()
849
- self.ce = nn.CrossEntropyLoss()
850
- self.bce = nn.BCEWithLogitsLoss()
851
- self.init_weights()
852
-
853
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
854
- return self.esm._embed(input_ids, attention_mask)
855
-
856
- def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
857
- return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
858
-
859
- def forward(
860
- self,
861
- input_ids: Optional[torch.Tensor] = None,
862
- attention_mask: Optional[torch.Tensor] = None,
863
- position_ids: Optional[torch.Tensor] = None,
864
- inputs_embeds: Optional[torch.Tensor] = None,
865
- labels: Optional[torch.Tensor] = None,
866
- output_attentions: Optional[bool] = None,
867
- output_hidden_states: Optional[bool] = None,
868
- return_dict: Optional[bool] = None,
869
- **kwargs,
870
- ) -> Union[Tuple, SequenceClassifierOutput]:
871
- outputs = self.esm(
872
- input_ids,
873
- attention_mask=attention_mask,
874
- position_ids=position_ids,
875
- inputs_embeds=inputs_embeds,
876
- output_attentions=output_attentions,
877
- output_hidden_states=output_hidden_states,
878
- )
879
- sequence_output = outputs.last_hidden_state
880
- logits = self.classifier(sequence_output)
881
-
882
- loss = None
883
- if labels is not None:
884
- labels = labels.to(logits.device)
885
- if self.config.problem_type is None:
886
- if self.num_labels == 1:
887
- self.config.problem_type = "regression"
888
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
889
- self.config.problem_type = "single_label_classification"
890
- else:
891
- self.config.problem_type = "multi_label_classification"
892
-
893
- if self.config.problem_type == "regression":
894
- if self.num_labels == 1:
895
- loss = self.mse(logits.squeeze(), labels.squeeze())
896
- else:
897
- loss = self.mse(logits, labels)
898
- elif self.config.problem_type == "single_label_classification":
899
- loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
900
- elif self.config.problem_type == "multi_label_classification":
901
- loss = self.bce(logits, labels)
902
-
903
- return SequenceClassifierOutput(
904
- loss=loss,
905
- logits=logits,
906
- hidden_states=outputs.hidden_states,
907
- attentions=outputs.attentions,
908
- )
909
-
910
-
911
- class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin):
912
- def __init__(self, config, **kwargs):
913
- FastEsmPreTrainedModel.__init__(self, config, **kwargs)
914
- self.num_labels = config.num_labels
915
- self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
916
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
917
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
918
- self.loss_fct = nn.CrossEntropyLoss()
919
- self.init_weights()
920
-
921
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
922
- return self.esm._embed(input_ids, attention_mask)
923
-
924
- def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
925
- return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
926
-
927
- def forward(
928
- self,
929
- input_ids: Optional[torch.Tensor] = None,
930
- attention_mask: Optional[torch.Tensor] = None,
931
- position_ids: Optional[torch.Tensor] = None,
932
- inputs_embeds: Optional[torch.Tensor] = None,
933
- labels: Optional[torch.Tensor] = None,
934
- output_attentions: Optional[bool] = None,
935
- output_hidden_states: Optional[bool] = None,
936
- return_dict: Optional[bool] = None,
937
- **kwargs,
938
- ) -> Union[Tuple, TokenClassifierOutput]:
939
- outputs = self.esm(
940
- input_ids,
941
- attention_mask=attention_mask,
942
- position_ids=position_ids,
943
- inputs_embeds=inputs_embeds,
944
- output_attentions=output_attentions,
945
- output_hidden_states=output_hidden_states,
946
- )
947
- sequence_output = outputs.last_hidden_state
948
- sequence_output = self.dropout(sequence_output)
949
- logits = self.classifier(sequence_output)
950
-
951
- loss = None
952
- if labels is not None:
953
- labels = labels.to(logits.device)
954
- loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
955
-
956
- return TokenClassifierOutput(
957
- loss=loss,
958
- logits=logits,
959
- hidden_states=outputs.hidden_states,
960
- attentions=outputs.attentions,
961
- )
962
-
963
-
964
- if __name__ == "__main__":
965
- """
966
- Test the hidden state differences between the FastEsmModel and the HF EsmModel.
967
- In full precision, the differences are very very small, but nonzero due to floating point issues with F.scaled_dot_product_attention.
968
- In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
969
- """
970
- import random
971
- from transformers import EsmForMaskedLM as TransformersEsmModel, EsmTokenizer
972
-
973
- model_paths = [
974
- "facebook/esm2_t6_8M_UR50D",
975
- "facebook/esm2_t12_35M_UR50D",
976
- #"facebook/esm2_t30_150M_UR50D",
977
- #"facebook/esm2_t33_650M_UR50D",
978
- ]
979
- canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
980
- length = 64
981
- seq_count = 100
982
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
983
- tolerances = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
984
-
985
- def generate_random_sequence(length: int) -> str:
986
- return 'M' + "".join(random.choices(canonical_amino_acids, k=length))
987
-
988
- print("Percentage of hidden states that are within the tolerance:")
989
- for model_path in model_paths:
990
- print(f"Testing {model_path}...")
991
- tokenizer = EsmTokenizer.from_pretrained(model_path)
992
- config = FastEsmConfig.from_pretrained(model_path)
993
- fast_model = FastEsmForMaskedLM(config).from_pretrained(model_path).to(device)
994
- print('fast model')
995
- print(fast_model)
996
- model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
997
- print('transformers model')
998
- print(model)
999
-
1000
- counts = [0] * len(tolerances)
1001
- for _ in range(seq_count):
1002
- example_seq = generate_random_sequence(length)
1003
- fast_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
1004
- fast_output = fast_model(fast_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
1005
-
1006
- model_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
1007
- model_output = model(model_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
1008
-
1009
- for i, atol in enumerate(tolerances):
1010
- if torch.allclose(fast_output, model_output, atol=atol):
1011
- counts[i] += 1
1012
-
1013
- print(f"{model_path}:")
1014
- for i, atol in enumerate(tolerances):
1015
- print(f" tolerance={atol}: {counts[i] / seq_count * 100}%")
1016
-
1017
- model.cpu()
1018
- fast_model.cpu()
1019
- del model
1020
- del fast_model
1021
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import entrypoint_setup
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from typing import Optional, Tuple, Union, Dict, Any
6
+ from einops import rearrange
7
+ from dataclasses import dataclass
8
+ from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
9
+ from transformers.modeling_outputs import (
10
+ ModelOutput,
11
+ BaseModelOutputWithPastAndCrossAttentions,
12
+ BaseModelOutputWithPoolingAndCrossAttentions,
13
+ SequenceClassifierOutput,
14
+ TokenClassifierOutput
15
+ )
16
+ from transformers.models.esm.modeling_esm import (
17
+ EsmIntermediate,
18
+ EsmOutput,
19
+ EsmPooler,
20
+ EsmLMHead,
21
+ EsmSelfOutput,
22
+ EsmClassificationHead,
23
+ )
24
+ try:
25
+ from torch.nn.attention.flex_attention import create_block_mask
26
+ from torch.nn.attention.flex_attention import flex_attention
27
+ except ImportError:
28
+ create_block_mask = None
29
+ flex_attention = None
30
+
31
+ try:
32
+ # when used from AutoModel, these are in the same directory
33
+ from .embedding_mixin import EmbeddingMixin, Pooler
34
+ except:
35
+ # when running from our repo, these are in the base directory
36
+ from embedding_mixin import EmbeddingMixin, Pooler
37
+
38
+
39
+ def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
40
+ assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
41
+ token_valid = attention_mask_2d.bool()
42
+ batch_size, seq_len = token_valid.shape
43
+
44
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
45
+ return token_valid[batch_idx, q_idx] & token_valid[batch_idx, kv_idx]
46
+
47
+ return create_block_mask(
48
+ mask_mod,
49
+ batch_size,
50
+ 1,
51
+ seq_len,
52
+ seq_len,
53
+ device=attention_mask_2d.device,
54
+ )
55
+
56
+
57
+ @dataclass
58
+ class EsmMaskedLMOutput(ModelOutput):
59
+ loss: Optional[torch.Tensor] = None
60
+ logits: Optional[torch.Tensor] = None
61
+ last_hidden_state: Optional[torch.Tensor] = None
62
+ hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
63
+ attentions: Optional[Tuple[torch.Tensor, ...]] = None
64
+
65
+
66
+ class FastEsmConfig(PretrainedConfig):
67
+ model_type = "fast_esm"
68
+ def __init__(
69
+ self,
70
+ vocab_size: int = None,
71
+ mask_token_id: int = None,
72
+ pad_token_id: int = None,
73
+ hidden_size: int = 768,
74
+ num_hidden_layers: int = 12,
75
+ num_attention_heads: int = 12,
76
+ intermediate_size: int = 3072,
77
+ hidden_dropout_prob: float = 0.1,
78
+ attention_probs_dropout_prob: float = 0.1,
79
+ max_position_embeddings: int = 1026,
80
+ initializer_range: float = 0.02,
81
+ layer_norm_eps: float = 1e-12,
82
+ position_embedding_type: str = "absolute",
83
+ emb_layer_norm_before: bool = None,
84
+ token_dropout: bool = True,
85
+ attn_backend: str = "sdpa",
86
+ **kwargs,
87
+ ):
88
+ super().__init__(
89
+ pad_token_id=pad_token_id,
90
+ mask_token_id=mask_token_id,
91
+ **kwargs,
92
+ )
93
+
94
+ self.vocab_size = vocab_size
95
+ self.hidden_size = hidden_size
96
+ self.num_hidden_layers = num_hidden_layers
97
+ self.num_attention_heads = num_attention_heads
98
+ self.intermediate_size = intermediate_size
99
+ self.hidden_dropout_prob = hidden_dropout_prob
100
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
101
+ self.max_position_embeddings = max_position_embeddings
102
+ self.initializer_range = initializer_range
103
+ self.layer_norm_eps = layer_norm_eps
104
+ self.position_embedding_type = position_embedding_type
105
+ self.emb_layer_norm_before = emb_layer_norm_before
106
+ self.tie_word_embeddings = False
107
+ self.token_dropout = token_dropout
108
+ self.attn_backend = attn_backend
109
+
110
+ def to_dict(self) -> Dict[str, Any]:
111
+ """
112
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
113
+
114
+ Returns:
115
+ `Dict[str, any]`: Dictionar y of all the attributes that make up this configuration instance,
116
+ """
117
+ output = super().to_dict()
118
+ return output
119
+
120
+
121
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
122
+ x1, x2 = x.chunk(2, dim=-1)
123
+ return torch.cat((-x2, x1), dim=-1)
124
+
125
+
126
+ def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
127
+ cos = cos[:, :, : x.shape[-2], :]
128
+ sin = sin[:, :, : x.shape[-2], :]
129
+
130
+ return (x * cos) + (rotate_half(x) * sin)
131
+
132
+
133
+ def symmetrize(x: torch.Tensor) -> torch.Tensor:
134
+ "Make layer symmetric in final two dimensions, used for contact prediction."
135
+ return x + x.transpose(-1, -2)
136
+
137
+
138
+ def average_product_correct(x: torch.Tensor) -> torch.Tensor:
139
+ "Perform average product correct, used for contact prediction."
140
+ a1 = x.sum(-1, keepdims=True)
141
+ a2 = x.sum(-2, keepdims=True)
142
+ a12 = x.sum((-1, -2), keepdims=True)
143
+
144
+ avg = a1 * a2
145
+ avg.div_(a12) # in-place to reduce memory
146
+ normalized = x - avg
147
+ return normalized
148
+
149
+
150
+ class EsmContactPredictionHead(nn.Module):
151
+ """Performs symmetrization, apc, and computes a logistic regression on the output features"""
152
+
153
+ def __init__(
154
+ self,
155
+ in_features: int,
156
+ bias: bool = True,
157
+ eos_idx: int = 2,
158
+ ):
159
+ super().__init__()
160
+ self.in_features = in_features
161
+ self.eos_idx = eos_idx
162
+ self.regression = nn.Linear(in_features, 1, bias=bias)
163
+ self.activation = nn.Sigmoid()
164
+
165
+ def forward(self, input_ids: torch.Tensor, attentions: torch.Tensor) -> torch.Tensor:
166
+ # remove eos token attentions
167
+ eos_mask = input_ids.ne(self.eos_idx).to(attentions)
168
+ eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
169
+ attentions = attentions * eos_mask[:, None, None, :, :]
170
+ attentions = attentions[..., :-1, :-1]
171
+ # remove cls token attentions
172
+ attentions = attentions[..., 1:, 1:]
173
+ batch_size, layers, heads, seqlen, _ = attentions.size()
174
+ attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
175
+
176
+ # features: batch x channels x tokens x tokens (symmetric)
177
+ attentions = attentions.to(
178
+ self.regression.weight.device
179
+ ) # attentions always float32, may need to convert to float16
180
+ attentions = average_product_correct(symmetrize(attentions))
181
+ attentions = attentions.permute(0, 2, 3, 1)
182
+ return self.activation(self.regression(attentions).squeeze(3))
183
+
184
+
185
+ class RotaryEmbedding(torch.nn.Module):
186
+ """
187
+ Rotary position embeddings based on those in
188
+ [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
189
+ matrices which depend on their relative positions.
190
+ """
191
+
192
+ def __init__(self, dim: int):
193
+ super().__init__()
194
+ # Generate and save the inverse frequency buffer (non trainable)
195
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
196
+ inv_freq = inv_freq
197
+ self.register_buffer("inv_freq", inv_freq)
198
+
199
+ self._seq_len_cached = None
200
+ self._cos_cached = None
201
+ self._sin_cached = None
202
+
203
+ def _update_cos_sin_tables(self, x: torch.Tensor, seq_dimension: int = 2) -> Tuple[torch.Tensor, torch.Tensor]:
204
+ seq_len = x.shape[seq_dimension]
205
+
206
+ # Reset the tables if the sequence length has changed,
207
+ # or if we're on a new device (possibly due to tracing for instance)
208
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
209
+ self._seq_len_cached = seq_len
210
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
211
+ freqs = torch.outer(t, self.inv_freq)
212
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
213
+
214
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
215
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
216
+
217
+ return self._cos_cached, self._sin_cached
218
+
219
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
220
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
221
+
222
+ return (
223
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
224
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
225
+ )
226
+
227
+
228
+ class EsmEmbeddings(nn.Module):
229
+ """
230
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
231
+ """
232
+
233
+ def __init__(self, config):
234
+ super().__init__()
235
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
236
+ if config.emb_layer_norm_before:
237
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
238
+ else:
239
+ self.layer_norm = None
240
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
241
+ self.register_buffer(
242
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
243
+ )
244
+ self.token_dropout = config.token_dropout
245
+ self.mask_token_id = config.mask_token_id
246
+
247
+ def forward(
248
+ self,
249
+ input_ids: Optional[torch.Tensor] = None,
250
+ attention_mask: Optional[torch.Tensor] = None,
251
+ position_ids: Optional[torch.Tensor] = None,
252
+ inputs_embeds: Optional[torch.Tensor] = None,
253
+ past_key_values_length: Optional[int] = 0,
254
+ ):
255
+ if inputs_embeds is None:
256
+ inputs_embeds = self.word_embeddings(input_ids)
257
+
258
+ embeddings = inputs_embeds
259
+
260
+ if attention_mask is None:
261
+ attention_mask = torch.ones_like(input_ids)
262
+
263
+ if self.token_dropout:
264
+ embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0)
265
+ mask_ratio_train = 0.15 * 0.8
266
+ src_lengths = attention_mask.sum(-1)
267
+ mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
268
+ embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
269
+ embeddings.dtype
270
+ )
271
+
272
+ if self.layer_norm is not None:
273
+ embeddings = self.layer_norm(embeddings)
274
+ if attention_mask is not None:
275
+ embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
276
+ return embeddings
277
+
278
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
279
+ """
280
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
281
+
282
+ Args:
283
+ inputs_embeds: torch.Tensor
284
+
285
+ Returns: torch.Tensor
286
+ """
287
+ input_shape = inputs_embeds.size()[:-1]
288
+ sequence_length = input_shape[1]
289
+
290
+ position_ids = torch.arange(
291
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
292
+ )
293
+ return position_ids.unsqueeze(0).expand(input_shape)
294
+
295
+
296
+ class EsmSelfAttention(nn.Module):
297
+ def __init__(self, config, position_embedding_type: Optional[str] = None):
298
+ super().__init__()
299
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
300
+ raise ValueError(
301
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
302
+ f"heads ({config.num_attention_heads})"
303
+ )
304
+
305
+ self.num_attention_heads = config.num_attention_heads
306
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
307
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
308
+
309
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
310
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
311
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
312
+ self.scale = self.attention_head_size**-0.5
313
+
314
+ self.dropout_prob = config.attention_probs_dropout_prob
315
+ self.attn_backend = config.attn_backend
316
+ self.position_embedding_type = position_embedding_type or getattr(
317
+ config, "position_embedding_type", "absolute"
318
+ )
319
+ self.rotary_embeddings = None
320
+ if self.position_embedding_type == "rotary":
321
+ self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
322
+
323
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
324
+ return rearrange(x, 'b s (h d) -> b h s d', h=self.num_attention_heads)
325
+
326
+ def forward(
327
+ self,
328
+ hidden_states: torch.Tensor,
329
+ attention_mask: Optional[torch.Tensor] = None,
330
+ flex_block_mask: Optional[object] = None,
331
+ output_attentions: Optional[bool] = False,
332
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
333
+ """Forward pass for self attention.
334
+
335
+ Args:
336
+ hidden_states: Input tensor
337
+ attention_mask: Optional attention mask
338
+ output_attentions: Whether to return attention weights
339
+
340
+ Returns:
341
+ Output tensor and optionally attention weights
342
+ """
343
+ query_layer = self.transpose_for_scores(self.query(hidden_states)) * self.scale
344
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
345
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
346
+
347
+ if self.position_embedding_type == "rotary":
348
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
349
+
350
+ if output_attentions:
351
+ # Manual attention computation - apply scaling here
352
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
353
+ if attention_mask is not None:
354
+ attention_scores = attention_scores.masked_fill(attention_mask.logical_not(), float("-inf"))
355
+ attention_probs = F.softmax(attention_scores, dim=-1)
356
+ if self.dropout_prob > 0:
357
+ attention_probs = F.dropout(attention_probs, p=self.dropout_prob, training=self.training)
358
+ context_layer = torch.matmul(attention_probs, value_layer)
359
+ context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
360
+ return context_layer, attention_probs
361
+ else:
362
+ sdpa_mask = None
363
+ if attention_mask is not None:
364
+ sdpa_mask = torch.zeros_like(attention_mask, dtype=query_layer.dtype)
365
+ sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
366
+ flex_dtype_supported = query_layer.dtype in (torch.float16, torch.bfloat16)
367
+ use_flex = (
368
+ self.attn_backend == "flex"
369
+ and flex_attention is not None
370
+ and flex_dtype_supported
371
+ and (attention_mask is None or flex_block_mask is not None)
372
+ )
373
+ if use_flex:
374
+ try:
375
+ context_layer = flex_attention(
376
+ query_layer,
377
+ key_layer,
378
+ value_layer,
379
+ block_mask=flex_block_mask,
380
+ scale=1.0,
381
+ )
382
+ except Exception:
383
+ context_layer = F.scaled_dot_product_attention(
384
+ query_layer,
385
+ key_layer,
386
+ value_layer,
387
+ attn_mask=sdpa_mask,
388
+ dropout_p=self.dropout_prob,
389
+ scale=1.0,
390
+ )
391
+ else:
392
+ context_layer = F.scaled_dot_product_attention(
393
+ query_layer,
394
+ key_layer,
395
+ value_layer,
396
+ attn_mask=sdpa_mask,
397
+ dropout_p=self.dropout_prob,
398
+ scale=1.0
399
+ )
400
+ context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
401
+ return context_layer
402
+
403
+
404
+ class EsmAttention(nn.Module):
405
+ def __init__(self, config):
406
+ super().__init__()
407
+ self.self = EsmSelfAttention(config)
408
+ self.output = EsmSelfOutput(config)
409
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
410
+
411
+ def forward(
412
+ self,
413
+ hidden_states: torch.Tensor,
414
+ attention_mask: Optional[torch.Tensor] = None,
415
+ flex_block_mask: Optional[object] = None,
416
+ output_attentions: Optional[bool] = False,
417
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
418
+ """Forward pass for attention layer.
419
+
420
+ Args:
421
+ hidden_states: Input tensor
422
+ attention_mask: Optional attention mask
423
+ output_attentions: Whether to return attention weights
424
+
425
+ Returns:
426
+ Output tensor and optionally attention weights
427
+ """
428
+ hidden_states_ln = self.LayerNorm(hidden_states)
429
+ self_outputs = self.self(
430
+ hidden_states_ln,
431
+ attention_mask,
432
+ flex_block_mask,
433
+ output_attentions,
434
+ )
435
+ if output_attentions:
436
+ attention_output, attention_weights = self_outputs
437
+ attention_output = self.output(attention_output, hidden_states)
438
+ return attention_output, attention_weights
439
+ else:
440
+ attention_output = self_outputs
441
+ return self.output(attention_output, hidden_states)
442
+
443
+
444
+ class EsmLayer(nn.Module):
445
+ def __init__(self, config):
446
+ super().__init__()
447
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
448
+ self.seq_len_dim = 1
449
+ self.attention = EsmAttention(config)
450
+ self.intermediate = EsmIntermediate(config)
451
+ self.output = EsmOutput(config)
452
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
453
+
454
+ def forward(
455
+ self,
456
+ hidden_states: torch.Tensor,
457
+ attention_mask: Optional[torch.Tensor] = None,
458
+ flex_block_mask: Optional[object] = None,
459
+ output_attentions: Optional[bool] = False,
460
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
461
+ """Forward pass for transformer layer.
462
+
463
+ Args:
464
+ hidden_states: Input tensor
465
+ attention_mask: Optional attention mask
466
+ output_attentions: Whether to return attention weights
467
+
468
+ Returns:
469
+ Output tensor and optionally attention weights
470
+ """
471
+ attention_outputs = self.attention(
472
+ hidden_states,
473
+ attention_mask,
474
+ flex_block_mask,
475
+ output_attentions,
476
+ )
477
+ if output_attentions:
478
+ attention_output, attention_weights = attention_outputs
479
+ else:
480
+ attention_output = attention_outputs
481
+ attention_weights = None
482
+
483
+ layer_output = self.feed_forward_chunk(attention_output)
484
+
485
+ if output_attentions:
486
+ return layer_output, attention_weights
487
+ return layer_output
488
+
489
+ def feed_forward_chunk(self, attention_output):
490
+ attention_output_ln = self.LayerNorm(attention_output)
491
+ intermediate_output = self.intermediate(attention_output_ln)
492
+ layer_output = self.output(intermediate_output, attention_output)
493
+ return layer_output
494
+
495
+
496
+ class EsmEncoder(nn.Module):
497
+ def __init__(self, config):
498
+ super().__init__()
499
+ self.config = config
500
+ self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
501
+ self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
502
+ self.gradient_checkpointing = False
503
+
504
+ def forward(
505
+ self,
506
+ hidden_states: torch.Tensor,
507
+ attention_mask: Optional[torch.Tensor] = None,
508
+ flex_block_mask: Optional[object] = None,
509
+ output_hidden_states: Optional[bool] = False,
510
+ output_attentions: Optional[bool] = False,
511
+ ) -> BaseModelOutputWithPastAndCrossAttentions:
512
+ """Forward pass for transformer encoder.
513
+
514
+ Args:
515
+ hidden_states: Input tensor
516
+ attention_mask: Optional attention mask
517
+ output_hidden_states: Whether to return all hidden states
518
+ output_attentions: Whether to return attention weights
519
+
520
+ Returns:
521
+ BaseModelOutputWithPastAndCrossAttentions containing model outputs
522
+ """
523
+ all_hidden_states = () if output_hidden_states else None
524
+ all_attentions = () if output_attentions else None
525
+
526
+ for layer_module in self.layer:
527
+ if output_hidden_states:
528
+ all_hidden_states = all_hidden_states + (hidden_states,)
529
+
530
+ if self.gradient_checkpointing and self.training:
531
+ layer_outputs = self._gradient_checkpointing_func(
532
+ layer_module.__call__,
533
+ hidden_states,
534
+ attention_mask,
535
+ flex_block_mask,
536
+ output_attentions,
537
+ )
538
+ else:
539
+ layer_outputs = layer_module(
540
+ hidden_states,
541
+ attention_mask,
542
+ flex_block_mask,
543
+ output_attentions,
544
+ )
545
+
546
+ if output_attentions:
547
+ hidden_states, attention_weights = layer_outputs
548
+ all_attentions = all_attentions + (attention_weights,)
549
+ else:
550
+ hidden_states = layer_outputs
551
+
552
+ if self.emb_layer_norm_after:
553
+ hidden_states = self.emb_layer_norm_after(hidden_states)
554
+
555
+ if output_hidden_states:
556
+ all_hidden_states = all_hidden_states + (hidden_states,)
557
+
558
+ return BaseModelOutputWithPastAndCrossAttentions(
559
+ last_hidden_state=hidden_states,
560
+ hidden_states=all_hidden_states,
561
+ attentions=all_attentions,
562
+ )
563
+
564
+
565
+ class FastEsmPreTrainedModel(PreTrainedModel):
566
+ """
567
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
568
+ models.
569
+ """
570
+ config_class = FastEsmConfig
571
+ base_model_prefix = "fastesm"
572
+ supports_gradient_checkpointing = True
573
+ tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
574
+ all_tied_weights_keys = {}
575
+
576
+ def _init_weights(self, module):
577
+ """Initialize the weights"""
578
+ if isinstance(module, nn.Linear):
579
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
580
+ if module.bias is not None:
581
+ module.bias.data.zero_()
582
+ elif isinstance(module, nn.Embedding):
583
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
584
+ if module.padding_idx is not None:
585
+ module.weight.data[module.padding_idx].zero_()
586
+ elif isinstance(module, nn.LayerNorm):
587
+ if module.bias is not None:
588
+ module.bias.data.zero_()
589
+ module.weight.data.fill_(1.0)
590
+
591
+ def get_input_embeddings(self) -> nn.Module:
592
+ try:
593
+ return self.embeddings.word_embeddings
594
+ except AttributeError:
595
+ return self.esm.embeddings.word_embeddings
596
+
597
+
598
+ class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
599
+ def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
600
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
601
+ self.config = config
602
+ self.embeddings = EsmEmbeddings(config)
603
+ self.encoder = EsmEncoder(config)
604
+ self.contact_head = EsmContactPredictionHead(
605
+ in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
606
+ )
607
+ # Initialize weights and apply final processing
608
+ self.post_init()
609
+
610
+ def get_input_embeddings(self):
611
+ return self.embeddings.word_embeddings
612
+
613
+ def set_input_embeddings(self, value):
614
+ self.embeddings.word_embeddings = value
615
+
616
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
617
+ token_embedding_output = self.embeddings(input_ids, attention_mask=attention_mask)
618
+ batch_size, seq_length = input_ids.shape
619
+ flex_block_mask = None
620
+ if attention_mask is not None:
621
+ extended_attention_mask = attention_mask[:, None, None, :].expand(
622
+ batch_size, 1, seq_length, seq_length
623
+ ).bool()
624
+ if self.config.attn_backend == "flex" and create_block_mask is not None:
625
+ flex_block_mask = _create_pad_block_mask(attention_mask.bool())
626
+ else:
627
+ extended_attention_mask = None
628
+ encoder_outputs = self.encoder(
629
+ token_embedding_output,
630
+ attention_mask=extended_attention_mask,
631
+ flex_block_mask=flex_block_mask,
632
+ output_hidden_states=False,
633
+ output_attentions=False,
634
+ )
635
+ return encoder_outputs.last_hidden_state
636
+
637
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
638
+ attns = self(input_ids, attention_mask=attention_mask, output_attentions=True).attentions
639
+ attns = torch.stack(attns, dim=1)
640
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
641
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
642
+ return self.contact_head(input_ids, attns)
643
+
644
+ def forward(
645
+ self,
646
+ input_ids: Optional[torch.Tensor] = None,
647
+ attention_mask: Optional[torch.Tensor] = None,
648
+ position_ids: Optional[torch.Tensor] = None,
649
+ inputs_embeds: Optional[torch.Tensor] = None,
650
+ output_attentions: Optional[bool] = None,
651
+ output_hidden_states: Optional[bool] = None,
652
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
653
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
654
+ """Forward pass for base model.
655
+
656
+ Args:
657
+ input_ids: Input token IDs
658
+ attention_mask: Optional attention mask
659
+ position_ids: Optional position IDs
660
+ inputs_embeds: Optional input embeddings
661
+ output_hidden_states: Whether to return all hidden states
662
+ output_attentions: Whether to return attention weights
663
+
664
+ Returns:
665
+ Model outputs including hidden states and optionally attention weights
666
+ """
667
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
668
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
669
+
670
+ if input_ids is not None and inputs_embeds is not None:
671
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
672
+ elif input_ids is not None:
673
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
674
+ input_shape = input_ids.size()
675
+ elif inputs_embeds is not None:
676
+ input_shape = inputs_embeds.size()[:-1]
677
+ else:
678
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
679
+
680
+ batch_size, seq_length = input_shape
681
+ token_embedding_output = self.embeddings(
682
+ input_ids=input_ids,
683
+ position_ids=position_ids,
684
+ attention_mask=attention_mask,
685
+ inputs_embeds=inputs_embeds,
686
+ )
687
+
688
+ flex_block_mask = None
689
+ if attention_mask is not None:
690
+ extended_attention_mask = attention_mask[:, None, None, :].expand(
691
+ batch_size, 1, seq_length, seq_length
692
+ ).bool()
693
+ if (
694
+ self.config.attn_backend == "flex"
695
+ and not output_attentions
696
+ and create_block_mask is not None
697
+ ):
698
+ flex_block_mask = _create_pad_block_mask(attention_mask.bool())
699
+ else:
700
+ extended_attention_mask = None
701
+
702
+ encoder_outputs = self.encoder(
703
+ token_embedding_output,
704
+ attention_mask=extended_attention_mask,
705
+ flex_block_mask=flex_block_mask,
706
+ output_hidden_states=output_hidden_states,
707
+ output_attentions=output_attentions,
708
+ )
709
+ sequence_output = encoder_outputs.last_hidden_state
710
+
711
+ return BaseModelOutputWithPoolingAndCrossAttentions(
712
+ last_hidden_state=sequence_output,
713
+ hidden_states=encoder_outputs.hidden_states,
714
+ attentions=encoder_outputs.attentions,
715
+ )
716
+
717
+
718
+ class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
719
+ def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
720
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
721
+ self.config = config
722
+ self.esm = FAST_ESM_ENCODER(config)
723
+ self.pooler = EsmPooler(config) if add_pooling_layer else None
724
+ # Initialize weights and apply final processing
725
+ self.post_init()
726
+
727
+ def get_input_embeddings(self):
728
+ return self.embeddings.word_embeddings
729
+
730
+ def set_input_embeddings(self, value):
731
+ self.embeddings.word_embeddings = value
732
+
733
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
734
+ return self.esm._embed(input_ids, attention_mask)
735
+
736
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
737
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
738
+
739
+ def forward(
740
+ self,
741
+ input_ids: Optional[torch.Tensor] = None,
742
+ attention_mask: Optional[torch.Tensor] = None,
743
+ position_ids: Optional[torch.Tensor] = None,
744
+ inputs_embeds: Optional[torch.Tensor] = None,
745
+ output_attentions: Optional[bool] = None,
746
+ output_hidden_states: Optional[bool] = None,
747
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
748
+ **kwargs,
749
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
750
+ """Forward pass for base model.
751
+
752
+ Args:
753
+ input_ids: Input token IDs
754
+ attention_mask: Optional attention mask
755
+ position_ids: Optional position IDs
756
+ inputs_embeds: Optional input embeddings
757
+ output_hidden_states: Whether to return all hidden states
758
+ output_attentions: Whether to return attention weights
759
+
760
+ Returns:
761
+ Model outputs including hidden states and optionally attention weights
762
+ """
763
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
764
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
765
+
766
+ if input_ids is not None and inputs_embeds is not None:
767
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
768
+ elif input_ids is not None:
769
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
770
+ input_shape = input_ids.size()
771
+ elif inputs_embeds is not None:
772
+ input_shape = inputs_embeds.size()[:-1]
773
+ else:
774
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
775
+
776
+ outputs = self.esm(
777
+ input_ids,
778
+ attention_mask=attention_mask,
779
+ position_ids=position_ids,
780
+ inputs_embeds=inputs_embeds,
781
+ output_hidden_states=output_hidden_states,
782
+ output_attentions=output_attentions,
783
+ )
784
+ sequence_output = outputs.last_hidden_state
785
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
786
+
787
+ return BaseModelOutputWithPoolingAndCrossAttentions(
788
+ last_hidden_state=sequence_output,
789
+ pooler_output=pooled_output,
790
+ hidden_states=outputs.hidden_states,
791
+ attentions=outputs.attentions,
792
+ )
793
+
794
+
795
+ class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
796
+ def __init__(self, config, **kwargs):
797
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
798
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
799
+ self.lm_head = EsmLMHead(config)
800
+ self.loss_fct = nn.CrossEntropyLoss()
801
+ self.init_weights()
802
+
803
+ def get_output_embeddings(self):
804
+ return self.lm_head.decoder
805
+
806
+ def set_output_embeddings(self, new_embeddings):
807
+ self.lm_head.decoder = new_embeddings
808
+
809
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
810
+ return self.esm._embed(input_ids, attention_mask)
811
+
812
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
813
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
814
+
815
+ def forward(
816
+ self,
817
+ input_ids: Optional[torch.Tensor] = None,
818
+ attention_mask: Optional[torch.Tensor] = None,
819
+ position_ids: Optional[torch.Tensor] = None,
820
+ inputs_embeds: Optional[torch.Tensor] = None,
821
+ labels: Optional[torch.Tensor] = None,
822
+ output_attentions: Optional[bool] = None,
823
+ output_hidden_states: Optional[bool] = None,
824
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
825
+ **kwargs,
826
+ ) -> Union[Tuple, EsmMaskedLMOutput]:
827
+ outputs = self.esm(
828
+ input_ids,
829
+ attention_mask=attention_mask,
830
+ position_ids=position_ids,
831
+ inputs_embeds=inputs_embeds,
832
+ output_hidden_states=output_hidden_states,
833
+ output_attentions=output_attentions,
834
+ )
835
+ sequence_output = outputs.last_hidden_state
836
+ prediction_scores = self.lm_head(sequence_output)
837
+
838
+ loss = None
839
+ if labels is not None:
840
+ labels = labels.to(prediction_scores.device)
841
+ loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
842
+
843
+ return EsmMaskedLMOutput(
844
+ loss=loss,
845
+ logits=prediction_scores,
846
+ last_hidden_state=sequence_output,
847
+ hidden_states=outputs.hidden_states,
848
+ attentions=outputs.attentions,
849
+ )
850
+
851
+
852
+ class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
853
+ def __init__(self, config, **kwargs):
854
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
855
+ self.num_labels = config.num_labels
856
+ self.config = config
857
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
858
+ self.classifier = EsmClassificationHead(config)
859
+ self.mse = nn.MSELoss()
860
+ self.ce = nn.CrossEntropyLoss()
861
+ self.bce = nn.BCEWithLogitsLoss()
862
+ self.init_weights()
863
+
864
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
865
+ return self.esm._embed(input_ids, attention_mask)
866
+
867
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
868
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
869
+
870
+ def forward(
871
+ self,
872
+ input_ids: Optional[torch.Tensor] = None,
873
+ attention_mask: Optional[torch.Tensor] = None,
874
+ position_ids: Optional[torch.Tensor] = None,
875
+ inputs_embeds: Optional[torch.Tensor] = None,
876
+ labels: Optional[torch.Tensor] = None,
877
+ output_attentions: Optional[bool] = None,
878
+ output_hidden_states: Optional[bool] = None,
879
+ return_dict: Optional[bool] = None,
880
+ **kwargs,
881
+ ) -> Union[Tuple, SequenceClassifierOutput]:
882
+ outputs = self.esm(
883
+ input_ids,
884
+ attention_mask=attention_mask,
885
+ position_ids=position_ids,
886
+ inputs_embeds=inputs_embeds,
887
+ output_attentions=output_attentions,
888
+ output_hidden_states=output_hidden_states,
889
+ )
890
+ sequence_output = outputs.last_hidden_state
891
+ logits = self.classifier(sequence_output)
892
+
893
+ loss = None
894
+ if labels is not None:
895
+ labels = labels.to(logits.device)
896
+ if self.config.problem_type is None:
897
+ if self.num_labels == 1:
898
+ self.config.problem_type = "regression"
899
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
900
+ self.config.problem_type = "single_label_classification"
901
+ else:
902
+ self.config.problem_type = "multi_label_classification"
903
+
904
+ if self.config.problem_type == "regression":
905
+ if self.num_labels == 1:
906
+ loss = self.mse(logits.squeeze(), labels.squeeze())
907
+ else:
908
+ loss = self.mse(logits, labels)
909
+ elif self.config.problem_type == "single_label_classification":
910
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
911
+ elif self.config.problem_type == "multi_label_classification":
912
+ loss = self.bce(logits, labels)
913
+
914
+ return SequenceClassifierOutput(
915
+ loss=loss,
916
+ logits=logits,
917
+ hidden_states=outputs.hidden_states,
918
+ attentions=outputs.attentions,
919
+ )
920
+
921
+
922
+ class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin):
923
+ def __init__(self, config, **kwargs):
924
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
925
+ self.num_labels = config.num_labels
926
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
927
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
928
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
929
+ self.loss_fct = nn.CrossEntropyLoss()
930
+ self.init_weights()
931
+
932
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
933
+ return self.esm._embed(input_ids, attention_mask)
934
+
935
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
936
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
937
+
938
+ def forward(
939
+ self,
940
+ input_ids: Optional[torch.Tensor] = None,
941
+ attention_mask: Optional[torch.Tensor] = None,
942
+ position_ids: Optional[torch.Tensor] = None,
943
+ inputs_embeds: Optional[torch.Tensor] = None,
944
+ labels: Optional[torch.Tensor] = None,
945
+ output_attentions: Optional[bool] = None,
946
+ output_hidden_states: Optional[bool] = None,
947
+ return_dict: Optional[bool] = None,
948
+ **kwargs,
949
+ ) -> Union[Tuple, TokenClassifierOutput]:
950
+ outputs = self.esm(
951
+ input_ids,
952
+ attention_mask=attention_mask,
953
+ position_ids=position_ids,
954
+ inputs_embeds=inputs_embeds,
955
+ output_attentions=output_attentions,
956
+ output_hidden_states=output_hidden_states,
957
+ )
958
+ sequence_output = outputs.last_hidden_state
959
+ sequence_output = self.dropout(sequence_output)
960
+ logits = self.classifier(sequence_output)
961
+
962
+ loss = None
963
+ if labels is not None:
964
+ labels = labels.to(logits.device)
965
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
966
+
967
+ return TokenClassifierOutput(
968
+ loss=loss,
969
+ logits=logits,
970
+ hidden_states=outputs.hidden_states,
971
+ attentions=outputs.attentions,
972
+ )
973
+
974
+
975
+ if __name__ == "__main__":
976
+ """
977
+ Test the hidden state differences between the FastEsmModel and the HF EsmModel.
978
+ In full precision, the differences are very very small, but nonzero due to floating point issues with F.scaled_dot_product_attention.
979
+ In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
980
+ """
981
+ import random
982
+ from transformers import EsmForMaskedLM as TransformersEsmModel, EsmTokenizer
983
+
984
+ model_paths = [
985
+ "facebook/esm2_t6_8M_UR50D",
986
+ "facebook/esm2_t12_35M_UR50D",
987
+ #"facebook/esm2_t30_150M_UR50D",
988
+ #"facebook/esm2_t33_650M_UR50D",
989
+ ]
990
+ canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
991
+ length = 64
992
+ seq_count = 100
993
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
994
+ tolerances = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
995
+
996
+ def generate_random_sequence(length: int) -> str:
997
+ return 'M' + "".join(random.choices(canonical_amino_acids, k=length))
998
+
999
+ print("Percentage of hidden states that are within the tolerance:")
1000
+ for model_path in model_paths:
1001
+ print(f"Testing {model_path}...")
1002
+ tokenizer = EsmTokenizer.from_pretrained(model_path)
1003
+ config = FastEsmConfig.from_pretrained(model_path)
1004
+ fast_model = FastEsmForMaskedLM(config).from_pretrained(model_path).to(device)
1005
+ print('fast model')
1006
+ print(fast_model)
1007
+ model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
1008
+ print('transformers model')
1009
+ print(model)
1010
+
1011
+ counts = [0] * len(tolerances)
1012
+ for _ in range(seq_count):
1013
+ example_seq = generate_random_sequence(length)
1014
+ fast_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
1015
+ fast_output = fast_model(fast_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
1016
+
1017
+ model_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
1018
+ model_output = model(model_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
1019
+
1020
+ for i, atol in enumerate(tolerances):
1021
+ if torch.allclose(fast_output, model_output, atol=atol):
1022
+ counts[i] += 1
1023
+
1024
+ print(f"{model_path}:")
1025
+ for i, atol in enumerate(tolerances):
1026
+ print(f" tolerance={atol}: {counts[i] / seq_count * 100}%")
1027
+
1028
+ model.cpu()
1029
+ fast_model.cpu()
1030
+ del model
1031
+ del fast_model
1032
+ torch.cuda.empty_cache()