lhallee commited on
Commit
2e8dabc
·
verified ·
1 Parent(s): d87df29

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +981 -981
modeling_fastesm.py CHANGED
@@ -1,981 +1,981 @@
1
- import entrypoint_setup
2
- import torch
3
- import torch.nn as nn
4
- from torch.nn import functional as F
5
- from typing import Optional, Tuple, Union, Dict, Any
6
- from einops import rearrange
7
- from dataclasses import dataclass
8
- from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
9
- from transformers.modeling_outputs import (
10
- ModelOutput,
11
- BaseModelOutputWithPastAndCrossAttentions,
12
- BaseModelOutputWithPoolingAndCrossAttentions,
13
- SequenceClassifierOutput,
14
- TokenClassifierOutput
15
- )
16
- from transformers.models.esm.modeling_esm import (
17
- EsmIntermediate,
18
- EsmOutput,
19
- EsmPooler,
20
- EsmLMHead,
21
- EsmSelfOutput,
22
- EsmClassificationHead,
23
- )
24
- try:
25
- from torch.nn.attention.flex_attention import create_block_mask
26
- from torch.nn.attention.flex_attention import flex_attention
27
- except ImportError:
28
- create_block_mask = None
29
- flex_attention = None
30
-
31
- try:
32
- from .embedding_mixin import EmbeddingMixin
33
- except ImportError:
34
- try:
35
- from ..embedding_mixin import EmbeddingMixin
36
- except ImportError:
37
- from embedding_mixin import EmbeddingMixin
38
-
39
-
40
- def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
41
- assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
42
- token_valid = attention_mask_2d.bool()
43
- batch_size, seq_len = token_valid.shape
44
-
45
- def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
46
- return token_valid[batch_idx, q_idx] & token_valid[batch_idx, kv_idx]
47
-
48
- return create_block_mask(
49
- mask_mod,
50
- batch_size,
51
- 1,
52
- seq_len,
53
- seq_len,
54
- device=attention_mask_2d.device,
55
- )
56
-
57
-
58
- @dataclass
59
- class EsmMaskedLMOutput(ModelOutput):
60
- loss: Optional[torch.Tensor] = None
61
- logits: Optional[torch.Tensor] = None
62
- last_hidden_state: Optional[torch.Tensor] = None
63
- hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
64
- attentions: Optional[Tuple[torch.Tensor, ...]] = None
65
-
66
-
67
- class FastEsmConfig(PretrainedConfig):
68
- model_type = "fast_esm"
69
- def __init__(
70
- self,
71
- vocab_size: int = None,
72
- mask_token_id: int = None,
73
- pad_token_id: int = None,
74
- hidden_size: int = 768,
75
- num_hidden_layers: int = 12,
76
- num_attention_heads: int = 12,
77
- intermediate_size: int = 3072,
78
- hidden_dropout_prob: float = 0.1,
79
- attention_probs_dropout_prob: float = 0.1,
80
- max_position_embeddings: int = 1026,
81
- initializer_range: float = 0.02,
82
- layer_norm_eps: float = 1e-12,
83
- position_embedding_type: str = "absolute",
84
- emb_layer_norm_before: bool = None,
85
- token_dropout: bool = True,
86
- attn_backend: str = "sdpa",
87
- **kwargs,
88
- ):
89
- super().__init__(
90
- pad_token_id=pad_token_id,
91
- mask_token_id=mask_token_id,
92
- **kwargs,
93
- )
94
-
95
- self.vocab_size = vocab_size
96
- self.hidden_size = hidden_size
97
- self.num_hidden_layers = num_hidden_layers
98
- self.num_attention_heads = num_attention_heads
99
- self.intermediate_size = intermediate_size
100
- self.hidden_dropout_prob = hidden_dropout_prob
101
- self.attention_probs_dropout_prob = attention_probs_dropout_prob
102
- self.max_position_embeddings = max_position_embeddings
103
- self.initializer_range = initializer_range
104
- self.layer_norm_eps = layer_norm_eps
105
- self.position_embedding_type = position_embedding_type
106
- self.emb_layer_norm_before = emb_layer_norm_before
107
- self.tie_word_embeddings = False
108
- self.token_dropout = token_dropout
109
- self.attn_backend = attn_backend
110
-
111
- def to_dict(self) -> Dict[str, Any]:
112
- """
113
- Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
114
-
115
- Returns:
116
- `Dict[str, any]`: Dictionar y of all the attributes that make up this configuration instance,
117
- """
118
- output = super().to_dict()
119
- return output
120
-
121
-
122
- def rotate_half(x: torch.Tensor) -> torch.Tensor:
123
- x1, x2 = x.chunk(2, dim=-1)
124
- return torch.cat((-x2, x1), dim=-1)
125
-
126
-
127
- def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
128
- cos = cos[:, :, : x.shape[-2], :]
129
- sin = sin[:, :, : x.shape[-2], :]
130
-
131
- return (x * cos) + (rotate_half(x) * sin)
132
-
133
-
134
- def symmetrize(x: torch.Tensor) -> torch.Tensor:
135
- "Make layer symmetric in final two dimensions, used for contact prediction."
136
- return x + x.transpose(-1, -2)
137
-
138
-
139
- def average_product_correct(x: torch.Tensor) -> torch.Tensor:
140
- "Perform average product correct, used for contact prediction."
141
- a1 = x.sum(-1, keepdims=True)
142
- a2 = x.sum(-2, keepdims=True)
143
- a12 = x.sum((-1, -2), keepdims=True)
144
-
145
- avg = a1 * a2
146
- avg.div_(a12) # in-place to reduce memory
147
- normalized = x - avg
148
- return normalized
149
-
150
-
151
- class EsmContactPredictionHead(nn.Module):
152
- """Performs symmetrization, apc, and computes a logistic regression on the output features"""
153
-
154
- def __init__(
155
- self,
156
- in_features: int,
157
- bias: bool = True,
158
- eos_idx: int = 2,
159
- ):
160
- super().__init__()
161
- self.in_features = in_features
162
- self.eos_idx = eos_idx
163
- self.regression = nn.Linear(in_features, 1, bias=bias)
164
- self.activation = nn.Sigmoid()
165
-
166
- def forward(self, input_ids: torch.Tensor, attentions: torch.Tensor) -> torch.Tensor:
167
- # remove eos token attentions
168
- eos_mask = input_ids.ne(self.eos_idx).to(attentions)
169
- eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
170
- attentions = attentions * eos_mask[:, None, None, :, :]
171
- attentions = attentions[..., :-1, :-1]
172
- # remove cls token attentions
173
- attentions = attentions[..., 1:, 1:]
174
- batch_size, layers, heads, seqlen, _ = attentions.size()
175
- attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
176
-
177
- # features: batch x channels x tokens x tokens (symmetric)
178
- attentions = attentions.to(
179
- self.regression.weight.device
180
- ) # attentions always float32, may need to convert to float16
181
- attentions = average_product_correct(symmetrize(attentions))
182
- attentions = attentions.permute(0, 2, 3, 1)
183
- return self.activation(self.regression(attentions).squeeze(3))
184
-
185
-
186
- class RotaryEmbedding(torch.nn.Module):
187
- """
188
- Rotary position embeddings based on those in
189
- [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
190
- matrices which depend on their relative positions.
191
- """
192
-
193
- def __init__(self, dim: int):
194
- super().__init__()
195
- # Generate and save the inverse frequency buffer (non trainable)
196
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
197
- inv_freq = inv_freq
198
- self.register_buffer("inv_freq", inv_freq)
199
-
200
- self._seq_len_cached = None
201
- self._cos_cached = None
202
- self._sin_cached = None
203
-
204
- def _update_cos_sin_tables(self, x: torch.Tensor, seq_dimension: int = 2) -> Tuple[torch.Tensor, torch.Tensor]:
205
- seq_len = x.shape[seq_dimension]
206
-
207
- # Reset the tables if the sequence length has changed,
208
- # or if we're on a new device (possibly due to tracing for instance)
209
- if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
210
- self._seq_len_cached = seq_len
211
- t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
212
- freqs = torch.outer(t, self.inv_freq)
213
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
214
-
215
- self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
216
- self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
217
-
218
- return self._cos_cached, self._sin_cached
219
-
220
- def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
221
- self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
222
-
223
- return (
224
- apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
225
- apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
226
- )
227
-
228
-
229
- class EsmEmbeddings(nn.Module):
230
- """
231
- Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
232
- """
233
-
234
- def __init__(self, config):
235
- super().__init__()
236
- self.padding_idx = config.pad_token_id
237
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
238
- if config.emb_layer_norm_before:
239
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
240
- else:
241
- self.layer_norm = None
242
- self.position_embedding_type = config.position_embedding_type
243
- self.register_buffer(
244
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
245
- )
246
- self.token_dropout = config.token_dropout
247
- self.mask_token_id = config.mask_token_id
248
-
249
- def forward(
250
- self,
251
- input_ids: Optional[torch.Tensor] = None,
252
- attention_mask: Optional[torch.Tensor] = None,
253
- position_ids: Optional[torch.Tensor] = None,
254
- inputs_embeds: Optional[torch.Tensor] = None,
255
- past_key_values_length: Optional[int] = 0,
256
- ):
257
- if inputs_embeds is None:
258
- inputs_embeds = self.word_embeddings(input_ids)
259
-
260
- embeddings = inputs_embeds
261
-
262
- if attention_mask is None:
263
- attention_mask = torch.ones_like(input_ids)
264
-
265
- if self.token_dropout:
266
- embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0)
267
- mask_ratio_train = 0.15 * 0.8
268
- src_lengths = attention_mask.sum(-1)
269
- mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
270
- embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
271
- embeddings.dtype
272
- )
273
-
274
- if self.layer_norm is not None:
275
- embeddings = self.layer_norm(embeddings)
276
- if attention_mask is not None:
277
- embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
278
- return embeddings
279
-
280
- def create_position_ids_from_inputs_embeds(self, inputs_embeds):
281
- """
282
- We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
283
-
284
- Args:
285
- inputs_embeds: torch.Tensor
286
-
287
- Returns: torch.Tensor
288
- """
289
- input_shape = inputs_embeds.size()[:-1]
290
- sequence_length = input_shape[1]
291
-
292
- position_ids = torch.arange(
293
- self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
294
- )
295
- return position_ids.unsqueeze(0).expand(input_shape)
296
-
297
-
298
- class EsmSelfAttention(nn.Module):
299
- def __init__(self, config, position_embedding_type: Optional[str] = None):
300
- super().__init__()
301
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
302
- raise ValueError(
303
- f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
304
- f"heads ({config.num_attention_heads})"
305
- )
306
-
307
- self.num_attention_heads = config.num_attention_heads
308
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
309
- self.all_head_size = self.num_attention_heads * self.attention_head_size
310
-
311
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
312
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
313
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
314
- self.scale = self.attention_head_size**-0.5
315
-
316
- self.dropout_prob = config.attention_probs_dropout_prob
317
- self.attn_backend = config.attn_backend
318
- self.position_embedding_type = position_embedding_type or getattr(
319
- config, "position_embedding_type", "absolute"
320
- )
321
- self.rotary_embeddings = None
322
- if self.position_embedding_type == "rotary":
323
- self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
324
-
325
- def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
326
- return rearrange(x, 'b s (h d) -> b h s d', h=self.num_attention_heads)
327
-
328
- def forward(
329
- self,
330
- hidden_states: torch.Tensor,
331
- attention_mask: Optional[torch.Tensor] = None,
332
- flex_block_mask: Optional[object] = None,
333
- output_attentions: Optional[bool] = False,
334
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
335
- """Forward pass for self attention.
336
-
337
- Args:
338
- hidden_states: Input tensor
339
- attention_mask: Optional attention mask
340
- output_attentions: Whether to return attention weights
341
-
342
- Returns:
343
- Output tensor and optionally attention weights
344
- """
345
- query_layer = self.transpose_for_scores(self.query(hidden_states)) * self.scale
346
- key_layer = self.transpose_for_scores(self.key(hidden_states))
347
- value_layer = self.transpose_for_scores(self.value(hidden_states))
348
-
349
- if self.position_embedding_type == "rotary":
350
- query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
351
-
352
- if output_attentions:
353
- # Manual attention computation - apply scaling here
354
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
355
- if attention_mask is not None:
356
- attention_scores = attention_scores.masked_fill(attention_mask.logical_not(), float("-inf"))
357
- attention_probs = F.softmax(attention_scores, dim=-1)
358
- if self.dropout_prob > 0:
359
- attention_probs = F.dropout(attention_probs, p=self.dropout_prob, training=self.training)
360
- context_layer = torch.matmul(attention_probs, value_layer)
361
- context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
362
- return context_layer, attention_probs
363
- else:
364
- if self.attn_backend == "flex":
365
- assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
366
- assert query_layer.dtype in (torch.float16, torch.bfloat16), (
367
- f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}."
368
- )
369
- if attention_mask is not None:
370
- assert flex_block_mask is not None, (
371
- "Flex attention backend requires a block mask when attention_mask is provided."
372
- )
373
- context_layer = flex_attention(
374
- query_layer,
375
- key_layer,
376
- value_layer,
377
- block_mask=flex_block_mask,
378
- scale=1.0,
379
- )
380
- else:
381
- sdpa_mask = None
382
- if attention_mask is not None:
383
- sdpa_mask = torch.zeros_like(attention_mask, dtype=query_layer.dtype)
384
- sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
385
- context_layer = F.scaled_dot_product_attention(
386
- query_layer,
387
- key_layer,
388
- value_layer,
389
- attn_mask=sdpa_mask,
390
- dropout_p=self.dropout_prob if self.training else 0.0,
391
- scale=1.0
392
- )
393
- context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
394
- return context_layer
395
-
396
-
397
- class EsmAttention(nn.Module):
398
- def __init__(self, config):
399
- super().__init__()
400
- self.self = EsmSelfAttention(config)
401
- self.output = EsmSelfOutput(config)
402
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
403
-
404
- def forward(
405
- self,
406
- hidden_states: torch.Tensor,
407
- attention_mask: Optional[torch.Tensor] = None,
408
- flex_block_mask: Optional[object] = None,
409
- output_attentions: Optional[bool] = False,
410
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
411
- """Forward pass for attention layer.
412
-
413
- Args:
414
- hidden_states: Input tensor
415
- attention_mask: Optional attention mask
416
- output_attentions: Whether to return attention weights
417
-
418
- Returns:
419
- Output tensor and optionally attention weights
420
- """
421
- hidden_states_ln = self.LayerNorm(hidden_states)
422
- self_outputs = self.self(
423
- hidden_states_ln,
424
- attention_mask,
425
- flex_block_mask,
426
- output_attentions,
427
- )
428
- if output_attentions:
429
- attention_output, attention_weights = self_outputs
430
- attention_output = self.output(attention_output, hidden_states)
431
- return attention_output, attention_weights
432
- else:
433
- attention_output = self_outputs
434
- return self.output(attention_output, hidden_states)
435
-
436
-
437
- class EsmLayer(nn.Module):
438
- def __init__(self, config):
439
- super().__init__()
440
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
441
- self.seq_len_dim = 1
442
- self.attention = EsmAttention(config)
443
- self.intermediate = EsmIntermediate(config)
444
- self.output = EsmOutput(config)
445
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
446
-
447
- def forward(
448
- self,
449
- hidden_states: torch.Tensor,
450
- attention_mask: Optional[torch.Tensor] = None,
451
- flex_block_mask: Optional[object] = None,
452
- output_attentions: Optional[bool] = False,
453
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
454
- """Forward pass for transformer layer.
455
-
456
- Args:
457
- hidden_states: Input tensor
458
- attention_mask: Optional attention mask
459
- output_attentions: Whether to return attention weights
460
-
461
- Returns:
462
- Output tensor and optionally attention weights
463
- """
464
- attention_outputs = self.attention(
465
- hidden_states,
466
- attention_mask,
467
- flex_block_mask,
468
- output_attentions,
469
- )
470
- if output_attentions:
471
- attention_output, attention_weights = attention_outputs
472
- else:
473
- attention_output = attention_outputs
474
- attention_weights = None
475
-
476
- layer_output = self.feed_forward_chunk(attention_output)
477
-
478
- if output_attentions:
479
- return layer_output, attention_weights
480
- return layer_output
481
-
482
- def feed_forward_chunk(self, attention_output):
483
- attention_output_ln = self.LayerNorm(attention_output)
484
- intermediate_output = self.intermediate(attention_output_ln)
485
- layer_output = self.output(intermediate_output, attention_output)
486
- return layer_output
487
-
488
-
489
- class EsmEncoder(nn.Module):
490
- def __init__(self, config):
491
- super().__init__()
492
- self.config = config
493
- self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
494
- self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
495
- self.gradient_checkpointing = False
496
-
497
- def forward(
498
- self,
499
- hidden_states: torch.Tensor,
500
- attention_mask: Optional[torch.Tensor] = None,
501
- flex_block_mask: Optional[object] = None,
502
- output_hidden_states: Optional[bool] = False,
503
- output_attentions: Optional[bool] = False,
504
- ) -> BaseModelOutputWithPastAndCrossAttentions:
505
- """Forward pass for transformer encoder.
506
-
507
- Args:
508
- hidden_states: Input tensor
509
- attention_mask: Optional attention mask
510
- output_hidden_states: Whether to return all hidden states
511
- output_attentions: Whether to return attention weights
512
-
513
- Returns:
514
- BaseModelOutputWithPastAndCrossAttentions containing model outputs
515
- """
516
- all_hidden_states = () if output_hidden_states else None
517
- all_attentions = () if output_attentions else None
518
-
519
- for layer_module in self.layer:
520
- if output_hidden_states:
521
- all_hidden_states = all_hidden_states + (hidden_states,)
522
-
523
- if self.gradient_checkpointing and self.training:
524
- layer_outputs = self._gradient_checkpointing_func(
525
- layer_module.__call__,
526
- hidden_states,
527
- attention_mask,
528
- flex_block_mask,
529
- output_attentions,
530
- )
531
- else:
532
- layer_outputs = layer_module(
533
- hidden_states,
534
- attention_mask,
535
- flex_block_mask,
536
- output_attentions,
537
- )
538
-
539
- if output_attentions:
540
- hidden_states, attention_weights = layer_outputs
541
- all_attentions = all_attentions + (attention_weights,)
542
- else:
543
- hidden_states = layer_outputs
544
-
545
- if self.emb_layer_norm_after:
546
- hidden_states = self.emb_layer_norm_after(hidden_states)
547
-
548
- if output_hidden_states:
549
- all_hidden_states = all_hidden_states + (hidden_states,)
550
-
551
- return BaseModelOutputWithPastAndCrossAttentions(
552
- last_hidden_state=hidden_states,
553
- hidden_states=all_hidden_states,
554
- attentions=all_attentions,
555
- )
556
-
557
-
558
- class FastEsmPreTrainedModel(PreTrainedModel):
559
- """
560
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
561
- models.
562
- """
563
- config_class = FastEsmConfig
564
- base_model_prefix = "fastesm"
565
- supports_gradient_checkpointing = True
566
- tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
567
- all_tied_weights_keys = {}
568
-
569
- def _init_weights(self, module):
570
- """Initialize the weights"""
571
- if isinstance(module, nn.Linear):
572
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
573
- if module.bias is not None:
574
- module.bias.data.zero_()
575
- elif isinstance(module, nn.Embedding):
576
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
577
- if module.padding_idx is not None:
578
- module.weight.data[module.padding_idx].zero_()
579
- elif isinstance(module, nn.LayerNorm):
580
- if module.bias is not None:
581
- module.bias.data.zero_()
582
- module.weight.data.fill_(1.0)
583
-
584
-
585
-
586
- class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
587
- def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
588
- FastEsmPreTrainedModel.__init__(self, config, **kwargs)
589
- self.config = config
590
- self.embeddings = EsmEmbeddings(config)
591
- self.encoder = EsmEncoder(config)
592
- self.contact_head = EsmContactPredictionHead(
593
- in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
594
- )
595
- # Initialize weights and apply final processing
596
- self.post_init()
597
-
598
- def get_input_embeddings(self):
599
- return self.embeddings.word_embeddings
600
-
601
- def set_input_embeddings(self, value):
602
- self.embeddings.word_embeddings = value
603
-
604
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
605
- token_embedding_output = self.embeddings(input_ids, attention_mask=attention_mask)
606
- batch_size, seq_length = input_ids.shape
607
- flex_block_mask = None
608
- if attention_mask is not None:
609
- token_attention_mask = attention_mask.bool()
610
- if self.config.attn_backend == "flex":
611
- assert create_block_mask is not None, (
612
- "Flex attention backend requested but torch.create_block_mask is unavailable."
613
- )
614
- flex_block_mask = _create_pad_block_mask(token_attention_mask)
615
- extended_attention_mask = None
616
- else:
617
- extended_attention_mask = token_attention_mask[:, None, None, :].expand(
618
- batch_size, 1, seq_length, seq_length
619
- )
620
- else:
621
- extended_attention_mask = None
622
- encoder_outputs = self.encoder(
623
- token_embedding_output,
624
- attention_mask=extended_attention_mask,
625
- flex_block_mask=flex_block_mask,
626
- output_hidden_states=False,
627
- output_attentions=False,
628
- )
629
- return encoder_outputs.last_hidden_state
630
-
631
- def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
632
- attns = self(input_ids, attention_mask=attention_mask, output_attentions=True).attentions
633
- attns = torch.stack(attns, dim=1)
634
- attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
635
- attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
636
- return self.contact_head(input_ids, attns)
637
-
638
- def forward(
639
- self,
640
- input_ids: Optional[torch.Tensor] = None,
641
- attention_mask: Optional[torch.Tensor] = None,
642
- position_ids: Optional[torch.Tensor] = None,
643
- inputs_embeds: Optional[torch.Tensor] = None,
644
- output_attentions: Optional[bool] = None,
645
- output_hidden_states: Optional[bool] = None,
646
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
647
- ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
648
- """Forward pass for base model.
649
-
650
- Args:
651
- input_ids: Input token IDs
652
- attention_mask: Optional attention mask
653
- position_ids: Optional position IDs
654
- inputs_embeds: Optional input embeddings
655
- output_hidden_states: Whether to return all hidden states
656
- output_attentions: Whether to return attention weights
657
-
658
- Returns:
659
- Model outputs including hidden states and optionally attention weights
660
- """
661
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
662
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
663
-
664
- if input_ids is not None and inputs_embeds is not None:
665
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
666
- elif input_ids is not None:
667
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
668
- input_shape = input_ids.size()
669
- elif inputs_embeds is not None:
670
- input_shape = inputs_embeds.size()[:-1]
671
- else:
672
- raise ValueError("You have to specify either input_ids or inputs_embeds")
673
-
674
- batch_size, seq_length = input_shape
675
- token_embedding_output = self.embeddings(
676
- input_ids=input_ids,
677
- position_ids=position_ids,
678
- attention_mask=attention_mask,
679
- inputs_embeds=inputs_embeds,
680
- )
681
-
682
- flex_block_mask = None
683
- if attention_mask is not None:
684
- token_attention_mask = attention_mask.bool()
685
- if (
686
- self.config.attn_backend == "flex"
687
- and not output_attentions
688
- ):
689
- assert create_block_mask is not None, (
690
- "Flex attention backend requested but torch.create_block_mask is unavailable."
691
- )
692
- flex_block_mask = _create_pad_block_mask(token_attention_mask)
693
- extended_attention_mask = None
694
- else:
695
- extended_attention_mask = token_attention_mask[:, None, None, :].expand(
696
- batch_size, 1, seq_length, seq_length
697
- )
698
- else:
699
- extended_attention_mask = None
700
-
701
- encoder_outputs = self.encoder(
702
- token_embedding_output,
703
- attention_mask=extended_attention_mask,
704
- flex_block_mask=flex_block_mask,
705
- output_hidden_states=output_hidden_states,
706
- output_attentions=output_attentions,
707
- )
708
- sequence_output = encoder_outputs.last_hidden_state
709
-
710
- return BaseModelOutputWithPoolingAndCrossAttentions(
711
- last_hidden_state=sequence_output,
712
- hidden_states=encoder_outputs.hidden_states,
713
- attentions=encoder_outputs.attentions,
714
- )
715
-
716
-
717
- class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
718
- def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
719
- FastEsmPreTrainedModel.__init__(self, config, **kwargs)
720
- self.config = config
721
- self.esm = FAST_ESM_ENCODER(config)
722
- self.pooler = EsmPooler(config) if add_pooling_layer else None
723
- self.post_init()
724
-
725
- def get_input_embeddings(self):
726
- return self.esm.embeddings.word_embeddings
727
-
728
- def set_input_embeddings(self, value):
729
- self.esm.embeddings.word_embeddings = value
730
-
731
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
732
- return self.esm._embed(input_ids, attention_mask)
733
-
734
- def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
735
- return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
736
-
737
- def forward(
738
- self,
739
- input_ids: Optional[torch.Tensor] = None,
740
- attention_mask: Optional[torch.Tensor] = None,
741
- position_ids: Optional[torch.Tensor] = None,
742
- inputs_embeds: Optional[torch.Tensor] = None,
743
- output_attentions: Optional[bool] = None,
744
- output_hidden_states: Optional[bool] = None,
745
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
746
- **kwargs,
747
- ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
748
- """Forward pass for base model.
749
-
750
- Args:
751
- input_ids: Input token IDs
752
- attention_mask: Optional attention mask
753
- position_ids: Optional position IDs
754
- inputs_embeds: Optional input embeddings
755
- output_hidden_states: Whether to return all hidden states
756
- output_attentions: Whether to return attention weights
757
-
758
- Returns:
759
- Model outputs including hidden states and optionally attention weights
760
- """
761
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
762
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
763
-
764
- if input_ids is not None and inputs_embeds is not None:
765
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
766
- elif input_ids is not None:
767
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
768
- input_shape = input_ids.size()
769
- elif inputs_embeds is not None:
770
- input_shape = inputs_embeds.size()[:-1]
771
- else:
772
- raise ValueError("You have to specify either input_ids or inputs_embeds")
773
-
774
- outputs = self.esm(
775
- input_ids,
776
- attention_mask=attention_mask,
777
- position_ids=position_ids,
778
- inputs_embeds=inputs_embeds,
779
- output_hidden_states=output_hidden_states,
780
- output_attentions=output_attentions,
781
- )
782
- sequence_output = outputs.last_hidden_state
783
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
784
-
785
- return BaseModelOutputWithPoolingAndCrossAttentions(
786
- last_hidden_state=sequence_output,
787
- pooler_output=pooled_output,
788
- hidden_states=outputs.hidden_states,
789
- attentions=outputs.attentions,
790
- )
791
-
792
-
793
- class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
794
- def __init__(self, config, **kwargs):
795
- FastEsmPreTrainedModel.__init__(self, config, **kwargs)
796
- self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
797
- self.lm_head = EsmLMHead(config)
798
- self.loss_fct = nn.CrossEntropyLoss()
799
- self.init_weights()
800
-
801
- def get_input_embeddings(self):
802
- return self.esm.embeddings.word_embeddings
803
-
804
- def get_output_embeddings(self):
805
- return self.lm_head.decoder
806
-
807
- def set_output_embeddings(self, new_embeddings):
808
- self.lm_head.decoder = new_embeddings
809
-
810
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
811
- return self.esm._embed(input_ids, attention_mask)
812
-
813
- def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
814
- return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
815
-
816
- def forward(
817
- self,
818
- input_ids: Optional[torch.Tensor] = None,
819
- attention_mask: Optional[torch.Tensor] = None,
820
- position_ids: Optional[torch.Tensor] = None,
821
- inputs_embeds: Optional[torch.Tensor] = None,
822
- labels: Optional[torch.Tensor] = None,
823
- output_attentions: Optional[bool] = None,
824
- output_hidden_states: Optional[bool] = None,
825
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
826
- **kwargs,
827
- ) -> Union[Tuple, EsmMaskedLMOutput]:
828
- outputs = self.esm(
829
- input_ids,
830
- attention_mask=attention_mask,
831
- position_ids=position_ids,
832
- inputs_embeds=inputs_embeds,
833
- output_hidden_states=output_hidden_states,
834
- output_attentions=output_attentions,
835
- )
836
- sequence_output = outputs.last_hidden_state
837
- prediction_scores = self.lm_head(sequence_output)
838
-
839
- loss = None
840
- if labels is not None:
841
- labels = labels.to(prediction_scores.device)
842
- loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
843
-
844
- return EsmMaskedLMOutput(
845
- loss=loss,
846
- logits=prediction_scores,
847
- last_hidden_state=sequence_output,
848
- hidden_states=outputs.hidden_states,
849
- attentions=outputs.attentions,
850
- )
851
-
852
-
853
- class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
854
- def __init__(self, config, **kwargs):
855
- FastEsmPreTrainedModel.__init__(self, config, **kwargs)
856
- self.num_labels = config.num_labels
857
- self.config = config
858
- self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
859
- self.classifier = EsmClassificationHead(config)
860
- self.mse = nn.MSELoss()
861
- self.ce = nn.CrossEntropyLoss()
862
- self.bce = nn.BCEWithLogitsLoss()
863
- self.init_weights()
864
-
865
- def get_input_embeddings(self):
866
- return self.esm.embeddings.word_embeddings
867
-
868
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
869
- return self.esm._embed(input_ids, attention_mask)
870
-
871
- def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
872
- return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
873
-
874
- def forward(
875
- self,
876
- input_ids: Optional[torch.Tensor] = None,
877
- attention_mask: Optional[torch.Tensor] = None,
878
- position_ids: Optional[torch.Tensor] = None,
879
- inputs_embeds: Optional[torch.Tensor] = None,
880
- labels: Optional[torch.Tensor] = None,
881
- output_attentions: Optional[bool] = None,
882
- output_hidden_states: Optional[bool] = None,
883
- return_dict: Optional[bool] = None,
884
- **kwargs,
885
- ) -> Union[Tuple, SequenceClassifierOutput]:
886
- outputs = self.esm(
887
- input_ids,
888
- attention_mask=attention_mask,
889
- position_ids=position_ids,
890
- inputs_embeds=inputs_embeds,
891
- output_attentions=output_attentions,
892
- output_hidden_states=output_hidden_states,
893
- )
894
- sequence_output = outputs.last_hidden_state
895
- logits = self.classifier(sequence_output)
896
-
897
- loss = None
898
- if labels is not None:
899
- labels = labels.to(logits.device)
900
- if self.config.problem_type is None:
901
- if self.num_labels == 1:
902
- self.config.problem_type = "regression"
903
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
904
- self.config.problem_type = "single_label_classification"
905
- else:
906
- self.config.problem_type = "multi_label_classification"
907
-
908
- if self.config.problem_type == "regression":
909
- if self.num_labels == 1:
910
- loss = self.mse(logits.squeeze(), labels.squeeze())
911
- else:
912
- loss = self.mse(logits, labels)
913
- elif self.config.problem_type == "single_label_classification":
914
- loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
915
- elif self.config.problem_type == "multi_label_classification":
916
- loss = self.bce(logits, labels)
917
-
918
- return SequenceClassifierOutput(
919
- loss=loss,
920
- logits=logits,
921
- hidden_states=outputs.hidden_states,
922
- attentions=outputs.attentions,
923
- )
924
-
925
-
926
- class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin):
927
- def __init__(self, config, **kwargs):
928
- FastEsmPreTrainedModel.__init__(self, config, **kwargs)
929
- self.num_labels = config.num_labels
930
- self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
931
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
932
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
933
- self.loss_fct = nn.CrossEntropyLoss()
934
- self.init_weights()
935
-
936
- def get_input_embeddings(self):
937
- return self.esm.embeddings.word_embeddings
938
-
939
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
940
- return self.esm._embed(input_ids, attention_mask)
941
-
942
- def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
943
- return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
944
-
945
- def forward(
946
- self,
947
- input_ids: Optional[torch.Tensor] = None,
948
- attention_mask: Optional[torch.Tensor] = None,
949
- position_ids: Optional[torch.Tensor] = None,
950
- inputs_embeds: Optional[torch.Tensor] = None,
951
- labels: Optional[torch.Tensor] = None,
952
- output_attentions: Optional[bool] = None,
953
- output_hidden_states: Optional[bool] = None,
954
- return_dict: Optional[bool] = None,
955
- **kwargs,
956
- ) -> Union[Tuple, TokenClassifierOutput]:
957
- outputs = self.esm(
958
- input_ids,
959
- attention_mask=attention_mask,
960
- position_ids=position_ids,
961
- inputs_embeds=inputs_embeds,
962
- output_attentions=output_attentions,
963
- output_hidden_states=output_hidden_states,
964
- )
965
- sequence_output = outputs.last_hidden_state
966
- sequence_output = self.dropout(sequence_output)
967
- logits = self.classifier(sequence_output)
968
-
969
- loss = None
970
- if labels is not None:
971
- labels = labels.to(logits.device)
972
- loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
973
-
974
- return TokenClassifierOutput(
975
- loss=loss,
976
- logits=logits,
977
- hidden_states=outputs.hidden_states,
978
- attentions=outputs.attentions,
979
- )
980
-
981
-
 
1
+ import entrypoint_setup
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from typing import Optional, Tuple, Union, Dict, Any
6
+ from einops import rearrange
7
+ from dataclasses import dataclass
8
+ from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
9
+ from transformers.modeling_outputs import (
10
+ ModelOutput,
11
+ BaseModelOutputWithPastAndCrossAttentions,
12
+ BaseModelOutputWithPoolingAndCrossAttentions,
13
+ SequenceClassifierOutput,
14
+ TokenClassifierOutput
15
+ )
16
+ from transformers.models.esm.modeling_esm import (
17
+ EsmIntermediate,
18
+ EsmOutput,
19
+ EsmPooler,
20
+ EsmLMHead,
21
+ EsmSelfOutput,
22
+ EsmClassificationHead,
23
+ )
24
+ try:
25
+ from torch.nn.attention.flex_attention import create_block_mask
26
+ from torch.nn.attention.flex_attention import flex_attention
27
+ except ImportError:
28
+ create_block_mask = None
29
+ flex_attention = None
30
+
31
+ try:
32
+ from .embedding_mixin import EmbeddingMixin
33
+ except ImportError:
34
+ try:
35
+ from ..embedding_mixin import EmbeddingMixin
36
+ except ImportError:
37
+ from embedding_mixin import EmbeddingMixin
38
+
39
+
40
+ def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
41
+ assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
42
+ token_valid = attention_mask_2d.bool()
43
+ batch_size, seq_len = token_valid.shape
44
+
45
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
46
+ return token_valid[batch_idx, q_idx] & token_valid[batch_idx, kv_idx]
47
+
48
+ return create_block_mask(
49
+ mask_mod,
50
+ batch_size,
51
+ 1,
52
+ seq_len,
53
+ seq_len,
54
+ device=attention_mask_2d.device,
55
+ )
56
+
57
+
58
+ @dataclass
59
+ class EsmMaskedLMOutput(ModelOutput):
60
+ loss: Optional[torch.Tensor] = None
61
+ logits: Optional[torch.Tensor] = None
62
+ last_hidden_state: Optional[torch.Tensor] = None
63
+ hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
64
+ attentions: Optional[Tuple[torch.Tensor, ...]] = None
65
+
66
+
67
+ class FastEsmConfig(PretrainedConfig):
68
+ model_type = "fast_esm"
69
+ def __init__(
70
+ self,
71
+ vocab_size: int = None,
72
+ mask_token_id: int = None,
73
+ pad_token_id: int = None,
74
+ hidden_size: int = 768,
75
+ num_hidden_layers: int = 12,
76
+ num_attention_heads: int = 12,
77
+ intermediate_size: int = 3072,
78
+ hidden_dropout_prob: float = 0.1,
79
+ attention_probs_dropout_prob: float = 0.1,
80
+ max_position_embeddings: int = 1026,
81
+ initializer_range: float = 0.02,
82
+ layer_norm_eps: float = 1e-12,
83
+ position_embedding_type: str = "absolute",
84
+ emb_layer_norm_before: bool = None,
85
+ token_dropout: bool = True,
86
+ attn_backend: str = "sdpa",
87
+ **kwargs,
88
+ ):
89
+ super().__init__(
90
+ pad_token_id=pad_token_id,
91
+ mask_token_id=mask_token_id,
92
+ **kwargs,
93
+ )
94
+
95
+ self.vocab_size = vocab_size
96
+ self.hidden_size = hidden_size
97
+ self.num_hidden_layers = num_hidden_layers
98
+ self.num_attention_heads = num_attention_heads
99
+ self.intermediate_size = intermediate_size
100
+ self.hidden_dropout_prob = hidden_dropout_prob
101
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
102
+ self.max_position_embeddings = max_position_embeddings
103
+ self.initializer_range = initializer_range
104
+ self.layer_norm_eps = layer_norm_eps
105
+ self.position_embedding_type = position_embedding_type
106
+ self.emb_layer_norm_before = emb_layer_norm_before
107
+ self.tie_word_embeddings = False
108
+ self.token_dropout = token_dropout
109
+ self.attn_backend = attn_backend
110
+
111
+ def to_dict(self) -> Dict[str, Any]:
112
+ """
113
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
114
+
115
+ Returns:
116
+ `Dict[str, any]`: Dictionar y of all the attributes that make up this configuration instance,
117
+ """
118
+ output = super().to_dict()
119
+ return output
120
+
121
+
122
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
123
+ x1, x2 = x.chunk(2, dim=-1)
124
+ return torch.cat((-x2, x1), dim=-1)
125
+
126
+
127
+ def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
128
+ cos = cos[:, :, : x.shape[-2], :]
129
+ sin = sin[:, :, : x.shape[-2], :]
130
+
131
+ return (x * cos) + (rotate_half(x) * sin)
132
+
133
+
134
+ def symmetrize(x: torch.Tensor) -> torch.Tensor:
135
+ "Make layer symmetric in final two dimensions, used for contact prediction."
136
+ return x + x.transpose(-1, -2)
137
+
138
+
139
+ def average_product_correct(x: torch.Tensor) -> torch.Tensor:
140
+ "Perform average product correct, used for contact prediction."
141
+ a1 = x.sum(-1, keepdims=True)
142
+ a2 = x.sum(-2, keepdims=True)
143
+ a12 = x.sum((-1, -2), keepdims=True)
144
+
145
+ avg = a1 * a2
146
+ avg.div_(a12) # in-place to reduce memory
147
+ normalized = x - avg
148
+ return normalized
149
+
150
+
151
+ class EsmContactPredictionHead(nn.Module):
152
+ """Performs symmetrization, apc, and computes a logistic regression on the output features"""
153
+
154
+ def __init__(
155
+ self,
156
+ in_features: int,
157
+ bias: bool = True,
158
+ eos_idx: int = 2,
159
+ ):
160
+ super().__init__()
161
+ self.in_features = in_features
162
+ self.eos_idx = eos_idx
163
+ self.regression = nn.Linear(in_features, 1, bias=bias)
164
+ self.activation = nn.Sigmoid()
165
+
166
+ def forward(self, input_ids: torch.Tensor, attentions: torch.Tensor) -> torch.Tensor:
167
+ # remove eos token attentions
168
+ eos_mask = input_ids.ne(self.eos_idx).to(attentions)
169
+ eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
170
+ attentions = attentions * eos_mask[:, None, None, :, :]
171
+ attentions = attentions[..., :-1, :-1]
172
+ # remove cls token attentions
173
+ attentions = attentions[..., 1:, 1:]
174
+ batch_size, layers, heads, seqlen, _ = attentions.size()
175
+ attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
176
+
177
+ # features: batch x channels x tokens x tokens (symmetric)
178
+ attentions = attentions.to(
179
+ self.regression.weight.device
180
+ ) # attentions always float32, may need to convert to float16
181
+ attentions = average_product_correct(symmetrize(attentions))
182
+ attentions = attentions.permute(0, 2, 3, 1)
183
+ return self.activation(self.regression(attentions).squeeze(3))
184
+
185
+
186
+ class RotaryEmbedding(torch.nn.Module):
187
+ """
188
+ Rotary position embeddings based on those in
189
+ [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
190
+ matrices which depend on their relative positions.
191
+ """
192
+
193
+ def __init__(self, dim: int):
194
+ super().__init__()
195
+ # Generate and save the inverse frequency buffer (non trainable)
196
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
197
+ inv_freq = inv_freq
198
+ self.register_buffer("inv_freq", inv_freq)
199
+
200
+ self._seq_len_cached = None
201
+ self._cos_cached = None
202
+ self._sin_cached = None
203
+
204
+ def _update_cos_sin_tables(self, x: torch.Tensor, seq_dimension: int = 2) -> Tuple[torch.Tensor, torch.Tensor]:
205
+ seq_len = x.shape[seq_dimension]
206
+
207
+ # Reset the tables if the sequence length has changed,
208
+ # or if we're on a new device (possibly due to tracing for instance)
209
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
210
+ self._seq_len_cached = seq_len
211
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
212
+ freqs = torch.outer(t, self.inv_freq)
213
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
214
+
215
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
216
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
217
+
218
+ return self._cos_cached, self._sin_cached
219
+
220
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
221
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
222
+
223
+ return (
224
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
225
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
226
+ )
227
+
228
+
229
+ class EsmEmbeddings(nn.Module):
230
+ """
231
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
232
+ """
233
+
234
+ def __init__(self, config):
235
+ super().__init__()
236
+ self.padding_idx = config.pad_token_id
237
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
238
+ if config.emb_layer_norm_before:
239
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
240
+ else:
241
+ self.layer_norm = None
242
+ self.position_embedding_type = config.position_embedding_type
243
+ self.register_buffer(
244
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
245
+ )
246
+ self.token_dropout = config.token_dropout
247
+ self.mask_token_id = config.mask_token_id
248
+
249
+ def forward(
250
+ self,
251
+ input_ids: Optional[torch.Tensor] = None,
252
+ attention_mask: Optional[torch.Tensor] = None,
253
+ position_ids: Optional[torch.Tensor] = None,
254
+ inputs_embeds: Optional[torch.Tensor] = None,
255
+ past_key_values_length: Optional[int] = 0,
256
+ ):
257
+ if inputs_embeds is None:
258
+ inputs_embeds = self.word_embeddings(input_ids)
259
+
260
+ embeddings = inputs_embeds
261
+
262
+ if attention_mask is None:
263
+ attention_mask = torch.ones_like(input_ids)
264
+
265
+ if self.token_dropout:
266
+ embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0)
267
+ mask_ratio_train = 0.15 * 0.8
268
+ src_lengths = attention_mask.sum(-1)
269
+ mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
270
+ embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
271
+ embeddings.dtype
272
+ )
273
+
274
+ if self.layer_norm is not None:
275
+ embeddings = self.layer_norm(embeddings)
276
+ if attention_mask is not None:
277
+ embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
278
+ return embeddings
279
+
280
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
281
+ """
282
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
283
+
284
+ Args:
285
+ inputs_embeds: torch.Tensor
286
+
287
+ Returns: torch.Tensor
288
+ """
289
+ input_shape = inputs_embeds.size()[:-1]
290
+ sequence_length = input_shape[1]
291
+
292
+ position_ids = torch.arange(
293
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
294
+ )
295
+ return position_ids.unsqueeze(0).expand(input_shape)
296
+
297
+
298
+ class EsmSelfAttention(nn.Module):
299
+ def __init__(self, config, position_embedding_type: Optional[str] = None):
300
+ super().__init__()
301
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
302
+ raise ValueError(
303
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
304
+ f"heads ({config.num_attention_heads})"
305
+ )
306
+
307
+ self.num_attention_heads = config.num_attention_heads
308
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
309
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
310
+
311
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
312
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
313
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
314
+ self.scale = self.attention_head_size**-0.5
315
+
316
+ self.dropout_prob = config.attention_probs_dropout_prob
317
+ self.attn_backend = config.attn_backend
318
+ self.position_embedding_type = position_embedding_type or getattr(
319
+ config, "position_embedding_type", "absolute"
320
+ )
321
+ self.rotary_embeddings = None
322
+ if self.position_embedding_type == "rotary":
323
+ self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
324
+
325
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
326
+ return rearrange(x, 'b s (h d) -> b h s d', h=self.num_attention_heads)
327
+
328
+ def forward(
329
+ self,
330
+ hidden_states: torch.Tensor,
331
+ attention_mask: Optional[torch.Tensor] = None,
332
+ flex_block_mask: Optional[object] = None,
333
+ output_attentions: Optional[bool] = False,
334
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
335
+ """Forward pass for self attention.
336
+
337
+ Args:
338
+ hidden_states: Input tensor
339
+ attention_mask: Optional attention mask
340
+ output_attentions: Whether to return attention weights
341
+
342
+ Returns:
343
+ Output tensor and optionally attention weights
344
+ """
345
+ query_layer = self.transpose_for_scores(self.query(hidden_states)) * self.scale
346
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
347
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
348
+
349
+ if self.position_embedding_type == "rotary":
350
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
351
+
352
+ if output_attentions:
353
+ # Manual attention computation - apply scaling here
354
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
355
+ if attention_mask is not None:
356
+ attention_scores = attention_scores.masked_fill(attention_mask.logical_not(), float("-inf"))
357
+ attention_probs = F.softmax(attention_scores, dim=-1)
358
+ if self.dropout_prob > 0:
359
+ attention_probs = F.dropout(attention_probs, p=self.dropout_prob, training=self.training)
360
+ context_layer = torch.matmul(attention_probs, value_layer)
361
+ context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
362
+ return context_layer, attention_probs
363
+ else:
364
+ if self.attn_backend == "flex":
365
+ assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
366
+ assert query_layer.dtype in (torch.float16, torch.bfloat16), (
367
+ f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}."
368
+ )
369
+ if attention_mask is not None:
370
+ assert flex_block_mask is not None, (
371
+ "Flex attention backend requires a block mask when attention_mask is provided."
372
+ )
373
+ context_layer = flex_attention(
374
+ query_layer,
375
+ key_layer,
376
+ value_layer,
377
+ block_mask=flex_block_mask,
378
+ scale=1.0,
379
+ )
380
+ else:
381
+ sdpa_mask = None
382
+ if attention_mask is not None:
383
+ sdpa_mask = torch.zeros_like(attention_mask, dtype=query_layer.dtype)
384
+ sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
385
+ context_layer = F.scaled_dot_product_attention(
386
+ query_layer,
387
+ key_layer,
388
+ value_layer,
389
+ attn_mask=sdpa_mask,
390
+ dropout_p=self.dropout_prob if self.training else 0.0,
391
+ scale=1.0
392
+ )
393
+ context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
394
+ return context_layer
395
+
396
+
397
+ class EsmAttention(nn.Module):
398
+ def __init__(self, config):
399
+ super().__init__()
400
+ self.self = EsmSelfAttention(config)
401
+ self.output = EsmSelfOutput(config)
402
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
403
+
404
+ def forward(
405
+ self,
406
+ hidden_states: torch.Tensor,
407
+ attention_mask: Optional[torch.Tensor] = None,
408
+ flex_block_mask: Optional[object] = None,
409
+ output_attentions: Optional[bool] = False,
410
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
411
+ """Forward pass for attention layer.
412
+
413
+ Args:
414
+ hidden_states: Input tensor
415
+ attention_mask: Optional attention mask
416
+ output_attentions: Whether to return attention weights
417
+
418
+ Returns:
419
+ Output tensor and optionally attention weights
420
+ """
421
+ hidden_states_ln = self.LayerNorm(hidden_states)
422
+ self_outputs = self.self(
423
+ hidden_states_ln,
424
+ attention_mask,
425
+ flex_block_mask,
426
+ output_attentions,
427
+ )
428
+ if output_attentions:
429
+ attention_output, attention_weights = self_outputs
430
+ attention_output = self.output(attention_output, hidden_states)
431
+ return attention_output, attention_weights
432
+ else:
433
+ attention_output = self_outputs
434
+ return self.output(attention_output, hidden_states)
435
+
436
+
437
+ class EsmLayer(nn.Module):
438
+ def __init__(self, config):
439
+ super().__init__()
440
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
441
+ self.seq_len_dim = 1
442
+ self.attention = EsmAttention(config)
443
+ self.intermediate = EsmIntermediate(config)
444
+ self.output = EsmOutput(config)
445
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
446
+
447
+ def forward(
448
+ self,
449
+ hidden_states: torch.Tensor,
450
+ attention_mask: Optional[torch.Tensor] = None,
451
+ flex_block_mask: Optional[object] = None,
452
+ output_attentions: Optional[bool] = False,
453
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
454
+ """Forward pass for transformer layer.
455
+
456
+ Args:
457
+ hidden_states: Input tensor
458
+ attention_mask: Optional attention mask
459
+ output_attentions: Whether to return attention weights
460
+
461
+ Returns:
462
+ Output tensor and optionally attention weights
463
+ """
464
+ attention_outputs = self.attention(
465
+ hidden_states,
466
+ attention_mask,
467
+ flex_block_mask,
468
+ output_attentions,
469
+ )
470
+ if output_attentions:
471
+ attention_output, attention_weights = attention_outputs
472
+ else:
473
+ attention_output = attention_outputs
474
+ attention_weights = None
475
+
476
+ layer_output = self.feed_forward_chunk(attention_output)
477
+
478
+ if output_attentions:
479
+ return layer_output, attention_weights
480
+ return layer_output
481
+
482
+ def feed_forward_chunk(self, attention_output):
483
+ attention_output_ln = self.LayerNorm(attention_output)
484
+ intermediate_output = self.intermediate(attention_output_ln)
485
+ layer_output = self.output(intermediate_output, attention_output)
486
+ return layer_output
487
+
488
+
489
+ class EsmEncoder(nn.Module):
490
+ def __init__(self, config):
491
+ super().__init__()
492
+ self.config = config
493
+ self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
494
+ self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
495
+ self.gradient_checkpointing = False
496
+
497
+ def forward(
498
+ self,
499
+ hidden_states: torch.Tensor,
500
+ attention_mask: Optional[torch.Tensor] = None,
501
+ flex_block_mask: Optional[object] = None,
502
+ output_hidden_states: Optional[bool] = False,
503
+ output_attentions: Optional[bool] = False,
504
+ ) -> BaseModelOutputWithPastAndCrossAttentions:
505
+ """Forward pass for transformer encoder.
506
+
507
+ Args:
508
+ hidden_states: Input tensor
509
+ attention_mask: Optional attention mask
510
+ output_hidden_states: Whether to return all hidden states
511
+ output_attentions: Whether to return attention weights
512
+
513
+ Returns:
514
+ BaseModelOutputWithPastAndCrossAttentions containing model outputs
515
+ """
516
+ all_hidden_states = () if output_hidden_states else None
517
+ all_attentions = () if output_attentions else None
518
+
519
+ for layer_module in self.layer:
520
+ if output_hidden_states:
521
+ all_hidden_states = all_hidden_states + (hidden_states,)
522
+
523
+ if self.gradient_checkpointing and self.training:
524
+ layer_outputs = self._gradient_checkpointing_func(
525
+ layer_module.__call__,
526
+ hidden_states,
527
+ attention_mask,
528
+ flex_block_mask,
529
+ output_attentions,
530
+ )
531
+ else:
532
+ layer_outputs = layer_module(
533
+ hidden_states,
534
+ attention_mask,
535
+ flex_block_mask,
536
+ output_attentions,
537
+ )
538
+
539
+ if output_attentions:
540
+ hidden_states, attention_weights = layer_outputs
541
+ all_attentions = all_attentions + (attention_weights,)
542
+ else:
543
+ hidden_states = layer_outputs
544
+
545
+ if self.emb_layer_norm_after:
546
+ hidden_states = self.emb_layer_norm_after(hidden_states)
547
+
548
+ if output_hidden_states:
549
+ all_hidden_states = all_hidden_states + (hidden_states,)
550
+
551
+ return BaseModelOutputWithPastAndCrossAttentions(
552
+ last_hidden_state=hidden_states,
553
+ hidden_states=all_hidden_states,
554
+ attentions=all_attentions,
555
+ )
556
+
557
+
558
+ class FastEsmPreTrainedModel(PreTrainedModel):
559
+ """
560
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
561
+ models.
562
+ """
563
+ config_class = FastEsmConfig
564
+ base_model_prefix = "fastesm"
565
+ supports_gradient_checkpointing = True
566
+ tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
567
+ all_tied_weights_keys = {}
568
+
569
+ def _init_weights(self, module):
570
+ """Initialize the weights"""
571
+ if isinstance(module, nn.Linear):
572
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
573
+ if module.bias is not None:
574
+ module.bias.data.zero_()
575
+ elif isinstance(module, nn.Embedding):
576
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
577
+ if module.padding_idx is not None:
578
+ module.weight.data[module.padding_idx].zero_()
579
+ elif isinstance(module, nn.LayerNorm):
580
+ if module.bias is not None:
581
+ module.bias.data.zero_()
582
+ module.weight.data.fill_(1.0)
583
+
584
+
585
+
586
+ class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
587
+ def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
588
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
589
+ self.config = config
590
+ self.embeddings = EsmEmbeddings(config)
591
+ self.encoder = EsmEncoder(config)
592
+ self.contact_head = EsmContactPredictionHead(
593
+ in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
594
+ )
595
+ # Initialize weights and apply final processing
596
+ self.post_init()
597
+
598
+ def get_input_embeddings(self):
599
+ return self.embeddings.word_embeddings
600
+
601
+ def set_input_embeddings(self, value):
602
+ self.embeddings.word_embeddings = value
603
+
604
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
605
+ token_embedding_output = self.embeddings(input_ids, attention_mask=attention_mask)
606
+ batch_size, seq_length = input_ids.shape
607
+ flex_block_mask = None
608
+ if attention_mask is not None:
609
+ token_attention_mask = attention_mask.bool()
610
+ if self.config.attn_backend == "flex":
611
+ assert create_block_mask is not None, (
612
+ "Flex attention backend requested but torch.create_block_mask is unavailable."
613
+ )
614
+ flex_block_mask = _create_pad_block_mask(token_attention_mask)
615
+ extended_attention_mask = None
616
+ else:
617
+ extended_attention_mask = token_attention_mask[:, None, None, :].expand(
618
+ batch_size, 1, seq_length, seq_length
619
+ )
620
+ else:
621
+ extended_attention_mask = None
622
+ encoder_outputs = self.encoder(
623
+ token_embedding_output,
624
+ attention_mask=extended_attention_mask,
625
+ flex_block_mask=flex_block_mask,
626
+ output_hidden_states=False,
627
+ output_attentions=False,
628
+ )
629
+ return encoder_outputs.last_hidden_state
630
+
631
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
632
+ attns = self(input_ids, attention_mask=attention_mask, output_attentions=True).attentions
633
+ attns = torch.stack(attns, dim=1)
634
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
635
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
636
+ return self.contact_head(input_ids, attns)
637
+
638
+ def forward(
639
+ self,
640
+ input_ids: Optional[torch.Tensor] = None,
641
+ attention_mask: Optional[torch.Tensor] = None,
642
+ position_ids: Optional[torch.Tensor] = None,
643
+ inputs_embeds: Optional[torch.Tensor] = None,
644
+ output_attentions: Optional[bool] = None,
645
+ output_hidden_states: Optional[bool] = None,
646
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
647
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
648
+ """Forward pass for base model.
649
+
650
+ Args:
651
+ input_ids: Input token IDs
652
+ attention_mask: Optional attention mask
653
+ position_ids: Optional position IDs
654
+ inputs_embeds: Optional input embeddings
655
+ output_hidden_states: Whether to return all hidden states
656
+ output_attentions: Whether to return attention weights
657
+
658
+ Returns:
659
+ Model outputs including hidden states and optionally attention weights
660
+ """
661
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
662
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
663
+
664
+ if input_ids is not None and inputs_embeds is not None:
665
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
666
+ elif input_ids is not None:
667
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
668
+ input_shape = input_ids.size()
669
+ elif inputs_embeds is not None:
670
+ input_shape = inputs_embeds.size()[:-1]
671
+ else:
672
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
673
+
674
+ batch_size, seq_length = input_shape
675
+ token_embedding_output = self.embeddings(
676
+ input_ids=input_ids,
677
+ position_ids=position_ids,
678
+ attention_mask=attention_mask,
679
+ inputs_embeds=inputs_embeds,
680
+ )
681
+
682
+ flex_block_mask = None
683
+ if attention_mask is not None:
684
+ token_attention_mask = attention_mask.bool()
685
+ if (
686
+ self.config.attn_backend == "flex"
687
+ and not output_attentions
688
+ ):
689
+ assert create_block_mask is not None, (
690
+ "Flex attention backend requested but torch.create_block_mask is unavailable."
691
+ )
692
+ flex_block_mask = _create_pad_block_mask(token_attention_mask)
693
+ extended_attention_mask = None
694
+ else:
695
+ extended_attention_mask = token_attention_mask[:, None, None, :].expand(
696
+ batch_size, 1, seq_length, seq_length
697
+ )
698
+ else:
699
+ extended_attention_mask = None
700
+
701
+ encoder_outputs = self.encoder(
702
+ token_embedding_output,
703
+ attention_mask=extended_attention_mask,
704
+ flex_block_mask=flex_block_mask,
705
+ output_hidden_states=output_hidden_states,
706
+ output_attentions=output_attentions,
707
+ )
708
+ sequence_output = encoder_outputs.last_hidden_state
709
+
710
+ return BaseModelOutputWithPoolingAndCrossAttentions(
711
+ last_hidden_state=sequence_output,
712
+ hidden_states=encoder_outputs.hidden_states,
713
+ attentions=encoder_outputs.attentions,
714
+ )
715
+
716
+
717
+ class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
718
+ def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
719
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
720
+ self.config = config
721
+ self.esm = FAST_ESM_ENCODER(config)
722
+ self.pooler = EsmPooler(config) if add_pooling_layer else None
723
+ self.post_init()
724
+
725
+ def get_input_embeddings(self):
726
+ return self.esm.embeddings.word_embeddings
727
+
728
+ def set_input_embeddings(self, value):
729
+ self.esm.embeddings.word_embeddings = value
730
+
731
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
732
+ return self.esm._embed(input_ids, attention_mask)
733
+
734
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
735
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
736
+
737
+ def forward(
738
+ self,
739
+ input_ids: Optional[torch.Tensor] = None,
740
+ attention_mask: Optional[torch.Tensor] = None,
741
+ position_ids: Optional[torch.Tensor] = None,
742
+ inputs_embeds: Optional[torch.Tensor] = None,
743
+ output_attentions: Optional[bool] = None,
744
+ output_hidden_states: Optional[bool] = None,
745
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
746
+ **kwargs,
747
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
748
+ """Forward pass for base model.
749
+
750
+ Args:
751
+ input_ids: Input token IDs
752
+ attention_mask: Optional attention mask
753
+ position_ids: Optional position IDs
754
+ inputs_embeds: Optional input embeddings
755
+ output_hidden_states: Whether to return all hidden states
756
+ output_attentions: Whether to return attention weights
757
+
758
+ Returns:
759
+ Model outputs including hidden states and optionally attention weights
760
+ """
761
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
762
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
763
+
764
+ if input_ids is not None and inputs_embeds is not None:
765
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
766
+ elif input_ids is not None:
767
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
768
+ input_shape = input_ids.size()
769
+ elif inputs_embeds is not None:
770
+ input_shape = inputs_embeds.size()[:-1]
771
+ else:
772
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
773
+
774
+ outputs = self.esm(
775
+ input_ids,
776
+ attention_mask=attention_mask,
777
+ position_ids=position_ids,
778
+ inputs_embeds=inputs_embeds,
779
+ output_hidden_states=output_hidden_states,
780
+ output_attentions=output_attentions,
781
+ )
782
+ sequence_output = outputs.last_hidden_state
783
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
784
+
785
+ return BaseModelOutputWithPoolingAndCrossAttentions(
786
+ last_hidden_state=sequence_output,
787
+ pooler_output=pooled_output,
788
+ hidden_states=outputs.hidden_states,
789
+ attentions=outputs.attentions,
790
+ )
791
+
792
+
793
+ class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
794
+ def __init__(self, config, **kwargs):
795
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
796
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
797
+ self.lm_head = EsmLMHead(config)
798
+ self.loss_fct = nn.CrossEntropyLoss()
799
+ self.init_weights()
800
+
801
+ def get_input_embeddings(self):
802
+ return self.esm.embeddings.word_embeddings
803
+
804
+ def get_output_embeddings(self):
805
+ return self.lm_head.decoder
806
+
807
+ def set_output_embeddings(self, new_embeddings):
808
+ self.lm_head.decoder = new_embeddings
809
+
810
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
811
+ return self.esm._embed(input_ids, attention_mask)
812
+
813
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
814
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
815
+
816
+ def forward(
817
+ self,
818
+ input_ids: Optional[torch.Tensor] = None,
819
+ attention_mask: Optional[torch.Tensor] = None,
820
+ position_ids: Optional[torch.Tensor] = None,
821
+ inputs_embeds: Optional[torch.Tensor] = None,
822
+ labels: Optional[torch.Tensor] = None,
823
+ output_attentions: Optional[bool] = None,
824
+ output_hidden_states: Optional[bool] = None,
825
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
826
+ **kwargs,
827
+ ) -> Union[Tuple, EsmMaskedLMOutput]:
828
+ outputs = self.esm(
829
+ input_ids,
830
+ attention_mask=attention_mask,
831
+ position_ids=position_ids,
832
+ inputs_embeds=inputs_embeds,
833
+ output_hidden_states=output_hidden_states,
834
+ output_attentions=output_attentions,
835
+ )
836
+ sequence_output = outputs.last_hidden_state
837
+ prediction_scores = self.lm_head(sequence_output)
838
+
839
+ loss = None
840
+ if labels is not None:
841
+ labels = labels.to(prediction_scores.device)
842
+ loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
843
+
844
+ return EsmMaskedLMOutput(
845
+ loss=loss,
846
+ logits=prediction_scores,
847
+ last_hidden_state=sequence_output,
848
+ hidden_states=outputs.hidden_states,
849
+ attentions=outputs.attentions,
850
+ )
851
+
852
+
853
+ class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
854
+ def __init__(self, config, **kwargs):
855
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
856
+ self.num_labels = config.num_labels
857
+ self.config = config
858
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
859
+ self.classifier = EsmClassificationHead(config)
860
+ self.mse = nn.MSELoss()
861
+ self.ce = nn.CrossEntropyLoss()
862
+ self.bce = nn.BCEWithLogitsLoss()
863
+ self.init_weights()
864
+
865
+ def get_input_embeddings(self):
866
+ return self.esm.embeddings.word_embeddings
867
+
868
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
869
+ return self.esm._embed(input_ids, attention_mask)
870
+
871
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
872
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
873
+
874
+ def forward(
875
+ self,
876
+ input_ids: Optional[torch.Tensor] = None,
877
+ attention_mask: Optional[torch.Tensor] = None,
878
+ position_ids: Optional[torch.Tensor] = None,
879
+ inputs_embeds: Optional[torch.Tensor] = None,
880
+ labels: Optional[torch.Tensor] = None,
881
+ output_attentions: Optional[bool] = None,
882
+ output_hidden_states: Optional[bool] = None,
883
+ return_dict: Optional[bool] = None,
884
+ **kwargs,
885
+ ) -> Union[Tuple, SequenceClassifierOutput]:
886
+ outputs = self.esm(
887
+ input_ids,
888
+ attention_mask=attention_mask,
889
+ position_ids=position_ids,
890
+ inputs_embeds=inputs_embeds,
891
+ output_attentions=output_attentions,
892
+ output_hidden_states=output_hidden_states,
893
+ )
894
+ sequence_output = outputs.last_hidden_state
895
+ logits = self.classifier(sequence_output)
896
+
897
+ loss = None
898
+ if labels is not None:
899
+ labels = labels.to(logits.device)
900
+ if self.config.problem_type is None:
901
+ if self.num_labels == 1:
902
+ self.config.problem_type = "regression"
903
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
904
+ self.config.problem_type = "single_label_classification"
905
+ else:
906
+ self.config.problem_type = "multi_label_classification"
907
+
908
+ if self.config.problem_type == "regression":
909
+ if self.num_labels == 1:
910
+ loss = self.mse(logits.squeeze(), labels.squeeze())
911
+ else:
912
+ loss = self.mse(logits, labels)
913
+ elif self.config.problem_type == "single_label_classification":
914
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
915
+ elif self.config.problem_type == "multi_label_classification":
916
+ loss = self.bce(logits, labels)
917
+
918
+ return SequenceClassifierOutput(
919
+ loss=loss,
920
+ logits=logits,
921
+ hidden_states=outputs.hidden_states,
922
+ attentions=outputs.attentions,
923
+ )
924
+
925
+
926
+ class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin):
927
+ def __init__(self, config, **kwargs):
928
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
929
+ self.num_labels = config.num_labels
930
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
931
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
932
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
933
+ self.loss_fct = nn.CrossEntropyLoss()
934
+ self.init_weights()
935
+
936
+ def get_input_embeddings(self):
937
+ return self.esm.embeddings.word_embeddings
938
+
939
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
940
+ return self.esm._embed(input_ids, attention_mask)
941
+
942
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
943
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
944
+
945
+ def forward(
946
+ self,
947
+ input_ids: Optional[torch.Tensor] = None,
948
+ attention_mask: Optional[torch.Tensor] = None,
949
+ position_ids: Optional[torch.Tensor] = None,
950
+ inputs_embeds: Optional[torch.Tensor] = None,
951
+ labels: Optional[torch.Tensor] = None,
952
+ output_attentions: Optional[bool] = None,
953
+ output_hidden_states: Optional[bool] = None,
954
+ return_dict: Optional[bool] = None,
955
+ **kwargs,
956
+ ) -> Union[Tuple, TokenClassifierOutput]:
957
+ outputs = self.esm(
958
+ input_ids,
959
+ attention_mask=attention_mask,
960
+ position_ids=position_ids,
961
+ inputs_embeds=inputs_embeds,
962
+ output_attentions=output_attentions,
963
+ output_hidden_states=output_hidden_states,
964
+ )
965
+ sequence_output = outputs.last_hidden_state
966
+ sequence_output = self.dropout(sequence_output)
967
+ logits = self.classifier(sequence_output)
968
+
969
+ loss = None
970
+ if labels is not None:
971
+ labels = labels.to(logits.device)
972
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
973
+
974
+ return TokenClassifierOutput(
975
+ loss=loss,
976
+ logits=logits,
977
+ hidden_states=outputs.hidden_states,
978
+ attentions=outputs.attentions,
979
+ )
980
+
981
+