lhallee commited on
Commit
fa6af4e
·
verified ·
1 Parent(s): 3c063b5

Upload test_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_model.py +1126 -0
test_model.py ADDED
@@ -0,0 +1,1126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+ import networkx as nx
5
+ from torch.nn import functional as F
6
+ from torch.utils.data import Dataset as TorchDataset
7
+ from torch.utils.data import DataLoader as DataLoader
8
+ from typing import Optional, Tuple, Union, Callable, List, Dict, Any
9
+ from einops import rearrange
10
+ from dataclasses import dataclass
11
+ from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer, PreTrainedTokenizerBase
12
+ from transformers.modeling_outputs import (
13
+ ModelOutput,
14
+ BaseModelOutputWithPastAndCrossAttentions,
15
+ BaseModelOutputWithPoolingAndCrossAttentions,
16
+ SequenceClassifierOutput,
17
+ TokenClassifierOutput
18
+ )
19
+ from transformers.models.esm.modeling_esm import (
20
+ EsmIntermediate,
21
+ EsmOutput,
22
+ EsmPooler,
23
+ EsmLMHead,
24
+ EsmSelfOutput,
25
+ EsmClassificationHead,
26
+ )
27
+ from tqdm.auto import tqdm
28
+
29
+ from pooler import Pooler
30
+
31
+
32
+ @dataclass
33
+ class EsmMaskedLMOutput(ModelOutput):
34
+ loss: Optional[torch.Tensor] = None
35
+ logits: Optional[torch.Tensor] = None
36
+ last_hidden_state: Optional[torch.Tensor] = None
37
+ hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
38
+ attentions: Optional[Tuple[torch.Tensor, ...]] = None
39
+
40
+
41
+ class FastEsmConfig(PretrainedConfig):
42
+ model_type = "fast_esm"
43
+ def __init__(
44
+ self,
45
+ vocab_size: int = None,
46
+ mask_token_id: int = None,
47
+ pad_token_id: int = None,
48
+ hidden_size: int = 768,
49
+ num_hidden_layers: int = 12,
50
+ num_attention_heads: int = 12,
51
+ intermediate_size: int = 3072,
52
+ hidden_dropout_prob: float = 0.1,
53
+ attention_probs_dropout_prob: float = 0.1,
54
+ max_position_embeddings: int = 1026,
55
+ initializer_range: float = 0.02,
56
+ layer_norm_eps: float = 1e-12,
57
+ position_embedding_type: str = "absolute",
58
+ emb_layer_norm_before: bool = None,
59
+ token_dropout: bool = True,
60
+ **kwargs,
61
+ ):
62
+ super().__init__(
63
+ pad_token_id=pad_token_id,
64
+ mask_token_id=mask_token_id,
65
+ **kwargs,
66
+ )
67
+
68
+ self.vocab_size = vocab_size
69
+ self.hidden_size = hidden_size
70
+ self.num_hidden_layers = num_hidden_layers
71
+ self.num_attention_heads = num_attention_heads
72
+ self.intermediate_size = intermediate_size
73
+ self.hidden_dropout_prob = hidden_dropout_prob
74
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
75
+ self.max_position_embeddings = max_position_embeddings
76
+ self.initializer_range = initializer_range
77
+ self.layer_norm_eps = layer_norm_eps
78
+ self.position_embedding_type = position_embedding_type
79
+ self.emb_layer_norm_before = emb_layer_norm_before
80
+ self.tie_word_embeddings = False
81
+ self.token_dropout = token_dropout
82
+
83
+ def to_dict(self) -> Dict[str, Any]:
84
+ """
85
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
86
+
87
+ Returns:
88
+ `Dict[str, any]`: Dictionar y of all the attributes that make up this configuration instance,
89
+ """
90
+ output = super().to_dict()
91
+ return output
92
+
93
+
94
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
95
+ x1, x2 = x.chunk(2, dim=-1)
96
+ return torch.cat((-x2, x1), dim=-1)
97
+
98
+
99
+ def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
100
+ cos = cos[:, :, : x.shape[-2], :]
101
+ sin = sin[:, :, : x.shape[-2], :]
102
+
103
+ return (x * cos) + (rotate_half(x) * sin)
104
+
105
+
106
+ def symmetrize(x: torch.Tensor) -> torch.Tensor:
107
+ "Make layer symmetric in final two dimensions, used for contact prediction."
108
+ return x + x.transpose(-1, -2)
109
+
110
+
111
+ def average_product_correct(x: torch.Tensor) -> torch.Tensor:
112
+ "Perform average product correct, used for contact prediction."
113
+ a1 = x.sum(-1, keepdims=True)
114
+ a2 = x.sum(-2, keepdims=True)
115
+ a12 = x.sum((-1, -2), keepdims=True)
116
+
117
+ avg = a1 * a2
118
+ avg.div_(a12) # in-place to reduce memory
119
+ normalized = x - avg
120
+ return normalized
121
+
122
+
123
+ class EsmContactPredictionHead(nn.Module):
124
+ """Performs symmetrization, apc, and computes a logistic regression on the output features"""
125
+
126
+ def __init__(
127
+ self,
128
+ in_features: int,
129
+ bias: bool = True,
130
+ eos_idx: int = 2,
131
+ ):
132
+ super().__init__()
133
+ self.in_features = in_features
134
+ self.eos_idx = eos_idx
135
+ self.regression = nn.Linear(in_features, 1, bias=bias)
136
+ self.activation = nn.Sigmoid()
137
+
138
+ def forward(self, input_ids: torch.Tensor, attentions: torch.Tensor) -> torch.Tensor:
139
+ # remove eos token attentions
140
+ eos_mask = input_ids.ne(self.eos_idx).to(attentions)
141
+ eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
142
+ attentions = attentions * eos_mask[:, None, None, :, :]
143
+ attentions = attentions[..., :-1, :-1]
144
+ # remove cls token attentions
145
+ attentions = attentions[..., 1:, 1:]
146
+ batch_size, layers, heads, seqlen, _ = attentions.size()
147
+ attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
148
+
149
+ # features: batch x channels x tokens x tokens (symmetric)
150
+ attentions = attentions.to(
151
+ self.regression.weight.device
152
+ ) # attentions always float32, may need to convert to float16
153
+ attentions = average_product_correct(symmetrize(attentions))
154
+ attentions = attentions.permute(0, 2, 3, 1)
155
+ return self.activation(self.regression(attentions).squeeze(3))
156
+
157
+
158
+ class RotaryEmbedding(torch.nn.Module):
159
+ """
160
+ Rotary position embeddings based on those in
161
+ [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
162
+ matrices which depend on their relative positions.
163
+ """
164
+
165
+ def __init__(self, dim: int):
166
+ super().__init__()
167
+ # Generate and save the inverse frequency buffer (non trainable)
168
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
169
+ inv_freq = inv_freq
170
+ self.register_buffer("inv_freq", inv_freq)
171
+
172
+ self._seq_len_cached = None
173
+ self._cos_cached = None
174
+ self._sin_cached = None
175
+
176
+ def _update_cos_sin_tables(self, x: torch.Tensor, seq_dimension: int = 2) -> Tuple[torch.Tensor, torch.Tensor]:
177
+ seq_len = x.shape[seq_dimension]
178
+
179
+ # Reset the tables if the sequence length has changed,
180
+ # or if we're on a new device (possibly due to tracing for instance)
181
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
182
+ self._seq_len_cached = seq_len
183
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
184
+ freqs = torch.outer(t, self.inv_freq)
185
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
186
+
187
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
188
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
189
+
190
+ return self._cos_cached, self._sin_cached
191
+
192
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
193
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
194
+
195
+ return (
196
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
197
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
198
+ )
199
+
200
+
201
+ class EsmEmbeddings(nn.Module):
202
+ """
203
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
204
+ """
205
+
206
+ def __init__(self, config):
207
+ super().__init__()
208
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
209
+ if config.emb_layer_norm_before:
210
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
211
+ else:
212
+ self.layer_norm = None
213
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
214
+ self.register_buffer(
215
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
216
+ )
217
+ self.token_dropout = config.token_dropout
218
+ self.mask_token_id = config.mask_token_id
219
+
220
+ def forward(
221
+ self,
222
+ input_ids: Optional[torch.Tensor] = None,
223
+ attention_mask: Optional[torch.Tensor] = None,
224
+ position_ids: Optional[torch.Tensor] = None,
225
+ inputs_embeds: Optional[torch.Tensor] = None,
226
+ past_key_values_length: Optional[int] = 0,
227
+ ):
228
+ if inputs_embeds is None:
229
+ inputs_embeds = self.word_embeddings(input_ids)
230
+
231
+ embeddings = inputs_embeds
232
+
233
+ if attention_mask is None:
234
+ attention_mask = torch.ones_like(input_ids)
235
+
236
+ if self.token_dropout:
237
+ embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0)
238
+ mask_ratio_train = 0.15 * 0.8
239
+ src_lengths = attention_mask.sum(-1)
240
+ mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
241
+ embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
242
+ embeddings.dtype
243
+ )
244
+
245
+ if self.layer_norm is not None:
246
+ embeddings = self.layer_norm(embeddings)
247
+ if attention_mask is not None:
248
+ embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
249
+ return embeddings
250
+
251
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
252
+ """
253
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
254
+
255
+ Args:
256
+ inputs_embeds: torch.Tensor
257
+
258
+ Returns: torch.Tensor
259
+ """
260
+ input_shape = inputs_embeds.size()[:-1]
261
+ sequence_length = input_shape[1]
262
+
263
+ position_ids = torch.arange(
264
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
265
+ )
266
+ return position_ids.unsqueeze(0).expand(input_shape)
267
+
268
+
269
+ class EsmSelfAttention(nn.Module):
270
+ def __init__(self, config, position_embedding_type: Optional[str] = None):
271
+ super().__init__()
272
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
273
+ raise ValueError(
274
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
275
+ f"heads ({config.num_attention_heads})"
276
+ )
277
+
278
+ self.num_attention_heads = config.num_attention_heads
279
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
280
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
281
+
282
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
283
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
284
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
285
+ self.scale = self.attention_head_size**-0.5
286
+
287
+ self.dropout_prob = config.attention_probs_dropout_prob
288
+ self.position_embedding_type = position_embedding_type or getattr(
289
+ config, "position_embedding_type", "absolute"
290
+ )
291
+ self.rotary_embeddings = None
292
+ if self.position_embedding_type == "rotary":
293
+ self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
294
+
295
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
296
+ return rearrange(x, 'b s (h d) -> b h s d', h=self.num_attention_heads)
297
+
298
+ def forward(
299
+ self,
300
+ hidden_states: torch.Tensor,
301
+ attention_mask: Optional[torch.Tensor] = None,
302
+ output_attentions: Optional[bool] = False,
303
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
304
+ """Forward pass for self attention.
305
+
306
+ Args:
307
+ hidden_states: Input tensor
308
+ attention_mask: Optional attention mask
309
+ output_attentions: Whether to return attention weights
310
+
311
+ Returns:
312
+ Output tensor and optionally attention weights
313
+ """
314
+ query_layer = self.transpose_for_scores(self.query(hidden_states)) * self.scale
315
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
316
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
317
+
318
+ if self.position_embedding_type == "rotary":
319
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
320
+
321
+ if output_attentions:
322
+ # Manual attention computation - apply scaling here
323
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) * self.scale
324
+ if attention_mask is not None:
325
+ attention_scores = attention_scores + attention_mask
326
+ attention_probs = F.softmax(attention_scores, dim=-1)
327
+ if self.dropout_prob > 0:
328
+ attention_probs = F.dropout(attention_probs, p=self.dropout_prob, training=self.training)
329
+ context_layer = torch.matmul(attention_probs, value_layer)
330
+ context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
331
+ return context_layer, attention_probs
332
+ else:
333
+ context_layer = F.scaled_dot_product_attention(
334
+ query_layer,
335
+ key_layer,
336
+ value_layer,
337
+ attn_mask=attention_mask,
338
+ dropout_p=self.dropout_prob,
339
+ scale=1.0
340
+ )
341
+ context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
342
+ return context_layer
343
+
344
+
345
+ class EsmAttention(nn.Module):
346
+ def __init__(self, config):
347
+ super().__init__()
348
+ self.self = EsmSelfAttention(config)
349
+ self.output = EsmSelfOutput(config)
350
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
351
+
352
+ def forward(
353
+ self,
354
+ hidden_states: torch.Tensor,
355
+ attention_mask: Optional[torch.Tensor] = None,
356
+ output_attentions: Optional[bool] = False,
357
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
358
+ """Forward pass for attention layer.
359
+
360
+ Args:
361
+ hidden_states: Input tensor
362
+ attention_mask: Optional attention mask
363
+ output_attentions: Whether to return attention weights
364
+
365
+ Returns:
366
+ Output tensor and optionally attention weights
367
+ """
368
+ hidden_states_ln = self.LayerNorm(hidden_states)
369
+ self_outputs = self.self(
370
+ hidden_states_ln,
371
+ attention_mask,
372
+ output_attentions,
373
+ )
374
+ if output_attentions:
375
+ attention_output, attention_weights = self_outputs
376
+ attention_output = self.output(attention_output, hidden_states)
377
+ return attention_output, attention_weights
378
+ else:
379
+ attention_output = self_outputs
380
+ return self.output(attention_output, hidden_states)
381
+
382
+
383
+ class EsmLayer(nn.Module):
384
+ def __init__(self, config):
385
+ super().__init__()
386
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
387
+ self.seq_len_dim = 1
388
+ self.attention = EsmAttention(config)
389
+ self.intermediate = EsmIntermediate(config)
390
+ self.output = EsmOutput(config)
391
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
392
+
393
+ def forward(
394
+ self,
395
+ hidden_states: torch.Tensor,
396
+ attention_mask: Optional[torch.Tensor] = None,
397
+ output_attentions: Optional[bool] = False,
398
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
399
+ """Forward pass for transformer layer.
400
+
401
+ Args:
402
+ hidden_states: Input tensor
403
+ attention_mask: Optional attention mask
404
+ output_attentions: Whether to return attention weights
405
+
406
+ Returns:
407
+ Output tensor and optionally attention weights
408
+ """
409
+ attention_outputs = self.attention(
410
+ hidden_states,
411
+ attention_mask,
412
+ output_attentions,
413
+ )
414
+ if output_attentions:
415
+ attention_output, attention_weights = attention_outputs
416
+ else:
417
+ attention_output = attention_outputs
418
+ attention_weights = None
419
+
420
+ layer_output = self.feed_forward_chunk(attention_output)
421
+
422
+ if output_attentions:
423
+ return layer_output, attention_weights
424
+ return layer_output
425
+
426
+ def feed_forward_chunk(self, attention_output):
427
+ attention_output_ln = self.LayerNorm(attention_output)
428
+ intermediate_output = self.intermediate(attention_output_ln)
429
+ layer_output = self.output(intermediate_output, attention_output)
430
+ return layer_output
431
+
432
+
433
+ class EsmEncoder(nn.Module):
434
+ def __init__(self, config):
435
+ super().__init__()
436
+ self.config = config
437
+ self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
438
+ self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
439
+ self.gradient_checkpointing = False
440
+
441
+ def forward(
442
+ self,
443
+ hidden_states: torch.Tensor,
444
+ attention_mask: Optional[torch.Tensor] = None,
445
+ output_hidden_states: Optional[bool] = False,
446
+ output_attentions: Optional[bool] = False,
447
+ ) -> BaseModelOutputWithPastAndCrossAttentions:
448
+ """Forward pass for transformer encoder.
449
+
450
+ Args:
451
+ hidden_states: Input tensor
452
+ attention_mask: Optional attention mask
453
+ output_hidden_states: Whether to return all hidden states
454
+ output_attentions: Whether to return attention weights
455
+
456
+ Returns:
457
+ BaseModelOutputWithPastAndCrossAttentions containing model outputs
458
+ """
459
+ all_hidden_states = () if output_hidden_states else None
460
+ all_attentions = () if output_attentions else None
461
+
462
+ for layer_module in self.layer:
463
+ if output_hidden_states:
464
+ all_hidden_states = all_hidden_states + (hidden_states,)
465
+
466
+ if self.gradient_checkpointing and self.training:
467
+ layer_outputs = self._gradient_checkpointing_func(
468
+ layer_module.__call__,
469
+ hidden_states,
470
+ attention_mask,
471
+ output_attentions,
472
+ )
473
+ else:
474
+ layer_outputs = layer_module(
475
+ hidden_states,
476
+ attention_mask,
477
+ output_attentions,
478
+ )
479
+
480
+ if output_attentions:
481
+ hidden_states, attention_weights = layer_outputs
482
+ all_attentions = all_attentions + (attention_weights,)
483
+ else:
484
+ hidden_states = layer_outputs
485
+
486
+ if self.emb_layer_norm_after:
487
+ hidden_states = self.emb_layer_norm_after(hidden_states)
488
+
489
+ if output_hidden_states:
490
+ all_hidden_states = all_hidden_states + (hidden_states,)
491
+
492
+ return BaseModelOutputWithPastAndCrossAttentions(
493
+ last_hidden_state=hidden_states,
494
+ hidden_states=all_hidden_states,
495
+ attentions=all_attentions,
496
+ )
497
+
498
+
499
+ class ProteinDataset(TorchDataset):
500
+ """Simple dataset for protein sequences."""
501
+ def __init__(self, sequences: list[str]):
502
+ self.sequences = sequences
503
+
504
+ def __len__(self) -> int:
505
+ return len(self.sequences)
506
+
507
+ def __getitem__(self, idx: int) -> str:
508
+ return self.sequences[idx]
509
+
510
+
511
+ def build_collator(tokenizer) -> Callable[[list[str]], tuple[torch.Tensor, torch.Tensor]]:
512
+ def _collate_fn(sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
513
+ """Collate function for batching sequences."""
514
+ return tokenizer(sequences, return_tensors="pt", padding='longest')
515
+ return _collate_fn
516
+
517
+
518
+ class EmbeddingMixin:
519
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
520
+ raise NotImplementedError
521
+
522
+ @property
523
+ def device(self) -> torch.device:
524
+ """Get the device of the model."""
525
+ return next(self.parameters()).device
526
+
527
+ def _read_sequences_from_db(self, db_path: str) -> set[str]:
528
+ """Read sequences from SQLite database."""
529
+ import sqlite3
530
+ sequences = []
531
+ with sqlite3.connect(db_path) as conn:
532
+ c = conn.cursor()
533
+ c.execute("SELECT sequence FROM embeddings")
534
+ while True:
535
+ row = c.fetchone()
536
+ if row is None:
537
+ break
538
+ sequences.append(row[0])
539
+ return set(sequences)
540
+
541
+ def embed_dataset(
542
+ self,
543
+ sequences: List[str],
544
+ tokenizer: PreTrainedTokenizerBase,
545
+ batch_size: int = 2,
546
+ max_len: int = 512,
547
+ truncate: bool = True,
548
+ full_embeddings: bool = False,
549
+ embed_dtype: torch.dtype = torch.float32,
550
+ pooling_types: List[str] = ['mean'],
551
+ num_workers: int = 0,
552
+ sql: bool = False,
553
+ save: bool = True,
554
+ sql_db_path: str = 'embeddings.db',
555
+ save_path: str = 'embeddings.pth',
556
+ ) -> Optional[dict[str, torch.Tensor]]:
557
+ """Embed a dataset of protein sequences.
558
+
559
+ Args:
560
+ sequences: List of protein sequences
561
+ batch_size: Batch size for processing
562
+ max_len: Maximum sequence length
563
+ full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
564
+ pooling_type: Type of pooling ('mean' or 'cls')
565
+ num_workers: Number of workers for data loading, 0 for the main process
566
+ sql: Whether to store embeddings in SQLite database - will be stored in float32
567
+ sql_db_path: Path to SQLite database
568
+
569
+ Returns:
570
+ Dictionary mapping sequences to embeddings, or None if sql=True
571
+
572
+ Note:
573
+ - If sql=True, embeddings can only be stored in float32
574
+ - sql is ideal if you need to stream a very large dataset for training in real-time
575
+ - save=True is ideal if you can store the entire embedding dictionary in RAM
576
+ - sql will be used if it is True and save is True or False
577
+ - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
578
+ - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
579
+
580
+ Example:
581
+ >>> embedder = EmbeddingMixin()
582
+ >>> embedding_dict = embedder.embed_dataset(
583
+ sequences=[
584
+ 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
585
+ ],
586
+ batch_size=2, # adjust for your GPU memory
587
+ max_len=512, # adjust for your needs
588
+ full_embeddings=False, # if True, no pooling is performed
589
+ embed_dtype=torch.float32, # cast to what dtype you want
590
+ pooling_type=['mean', 'cls'], # more than one pooling type will be concatenated together
591
+ num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
592
+ sql=False, # if True, embeddings will be stored in SQLite database
593
+ sql_db_path='embeddings.db',
594
+ save=True, # if True, embeddings will be saved as a .pth file
595
+ save_path='embeddings.pth',
596
+ )
597
+ >>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
598
+ """
599
+ sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
600
+ sequences = sorted(sequences, key=len, reverse=True)
601
+ hidden_size = self.config.hidden_size
602
+ collate_fn = build_collator(tokenizer)
603
+ device = self.device
604
+ pooler = Pooler(pooling_types) if not full_embeddings else None
605
+
606
+ def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
607
+ if full_embeddings or residue_embeddings.ndim == 2: # if already pooled or want residue-wise embeddings
608
+ return residue_embeddings
609
+ else:
610
+ return pooler(residue_embeddings, attention_mask)
611
+
612
+ if sql:
613
+ import sqlite3
614
+ conn = sqlite3.connect(sql_db_path)
615
+ c = conn.cursor()
616
+ c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')
617
+ already_embedded = self._read_sequences_from_db(sql_db_path)
618
+ to_embed = [seq for seq in sequences if seq not in already_embedded]
619
+ print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
620
+ print(f"Embedding {len(to_embed)} new sequences")
621
+ if len(to_embed) > 0:
622
+ dataset = ProteinDataset(to_embed)
623
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False)
624
+ with torch.no_grad():
625
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
626
+ seqs = to_embed[i * batch_size:(i + 1) * batch_size]
627
+ input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
628
+ residue_embeddings = self._embed(input_ids, attention_mask).float() # sql requires float32
629
+ embeddings = get_embeddings(residue_embeddings, attention_mask)
630
+ for seq, emb, mask in zip(seqs, embeddings, attention_mask):
631
+ if full_embeddings:
632
+ emb = emb[mask.bool()].reshape(-1, hidden_size)
633
+ c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
634
+ (seq, emb.cpu().numpy().tobytes()))
635
+
636
+ if (i + 1) % 100 == 0:
637
+ conn.commit()
638
+
639
+ conn.commit()
640
+ conn.close()
641
+ return None
642
+
643
+ embeddings_dict = {}
644
+ if os.path.exists(save_path):
645
+ embeddings_dict = torch.load(save_path, map_location='cpu', weights_only=True)
646
+ to_embed = [seq for seq in sequences if seq not in embeddings_dict]
647
+ print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}")
648
+ print(f"Embedding {len(to_embed)} new sequences")
649
+ else:
650
+ to_embed = sequences
651
+ print(f"Embedding {len(to_embed)} new sequences")
652
+
653
+ if len(to_embed) > 0:
654
+ dataset = ProteinDataset(to_embed)
655
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False)
656
+ with torch.no_grad():
657
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
658
+ seqs = to_embed[i * batch_size:(i + 1) * batch_size]
659
+ input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
660
+ residue_embeddings = self._embed(input_ids, attention_mask)
661
+ embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
662
+ for seq, emb, mask in zip(seqs, embeddings, attention_mask):
663
+ if full_embeddings:
664
+ emb = emb[mask.bool()].reshape(-1, hidden_size)
665
+ embeddings_dict[seq] = emb.cpu()
666
+
667
+ if save:
668
+ torch.save(embeddings_dict, save_path)
669
+
670
+ return embeddings_dict
671
+
672
+
673
+ class FastEsmPreTrainedModel(PreTrainedModel):
674
+ """
675
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
676
+ models.
677
+ """
678
+ config_class = FastEsmConfig
679
+ base_model_prefix = "fastesm"
680
+ supports_gradient_checkpointing = True
681
+ tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
682
+ def _init_weights(self, module):
683
+ """Initialize the weights"""
684
+ if isinstance(module, nn.Linear):
685
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
686
+ if module.bias is not None:
687
+ module.bias.data.zero_()
688
+ elif isinstance(module, nn.Embedding):
689
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
690
+ if module.padding_idx is not None:
691
+ module.weight.data[module.padding_idx].zero_()
692
+ elif isinstance(module, nn.LayerNorm):
693
+ if module.bias is not None:
694
+ module.bias.data.zero_()
695
+ module.weight.data.fill_(1.0)
696
+
697
+ def get_input_embeddings(self) -> nn.Module:
698
+ try:
699
+ return self.embeddings.word_embeddings
700
+ except AttributeError:
701
+ return self.esm.embeddings.word_embeddings
702
+
703
+
704
+ class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
705
+ def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
706
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
707
+ self.config = config
708
+ self.embeddings = EsmEmbeddings(config)
709
+ self.encoder = EsmEncoder(config)
710
+ self.contact_head = EsmContactPredictionHead(
711
+ in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
712
+ )
713
+ # Initialize weights and apply final processing
714
+ self.post_init()
715
+
716
+ def get_input_embeddings(self):
717
+ return self.embeddings.word_embeddings
718
+
719
+ def set_input_embeddings(self, value):
720
+ self.embeddings.word_embeddings = value
721
+
722
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
723
+ token_embedding_output = self.embeddings(input_ids, attention_mask=attention_mask)
724
+ batch_size, seq_length = input_ids.shape
725
+ if attention_mask is not None:
726
+ extended_attention_mask = attention_mask[:, None, None, :].expand(
727
+ batch_size, 1, seq_length, seq_length
728
+ ).bool()
729
+ else:
730
+ extended_attention_mask = None
731
+ encoder_outputs = self.encoder(
732
+ token_embedding_output,
733
+ attention_mask=extended_attention_mask,
734
+ output_hidden_states=False,
735
+ output_attentions=False,
736
+ )
737
+ return encoder_outputs.last_hidden_state
738
+
739
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
740
+ attns = self(input_ids, attention_mask=attention_mask, output_attentions=True).attentions
741
+ attns = torch.stack(attns, dim=1)
742
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
743
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
744
+ return self.contact_head(input_ids, attns)
745
+
746
+ def forward(
747
+ self,
748
+ input_ids: Optional[torch.Tensor] = None,
749
+ attention_mask: Optional[torch.Tensor] = None,
750
+ position_ids: Optional[torch.Tensor] = None,
751
+ inputs_embeds: Optional[torch.Tensor] = None,
752
+ output_attentions: Optional[bool] = None,
753
+ output_hidden_states: Optional[bool] = None,
754
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
755
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
756
+ """Forward pass for base model.
757
+
758
+ Args:
759
+ input_ids: Input token IDs
760
+ attention_mask: Optional attention mask
761
+ position_ids: Optional position IDs
762
+ inputs_embeds: Optional input embeddings
763
+ output_hidden_states: Whether to return all hidden states
764
+ output_attentions: Whether to return attention weights
765
+
766
+ Returns:
767
+ Model outputs including hidden states and optionally attention weights
768
+ """
769
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
770
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
771
+
772
+ if input_ids is not None and inputs_embeds is not None:
773
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
774
+ elif input_ids is not None:
775
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
776
+ input_shape = input_ids.size()
777
+ elif inputs_embeds is not None:
778
+ input_shape = inputs_embeds.size()[:-1]
779
+ else:
780
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
781
+
782
+ batch_size, seq_length = input_shape
783
+ token_embedding_output = self.embeddings(
784
+ input_ids=input_ids,
785
+ position_ids=position_ids,
786
+ attention_mask=attention_mask,
787
+ inputs_embeds=inputs_embeds,
788
+ )
789
+
790
+ if attention_mask is not None:
791
+ extended_attention_mask = attention_mask[:, None, None, :].expand(
792
+ batch_size, 1, seq_length, seq_length
793
+ ).bool()
794
+ else:
795
+ extended_attention_mask = None
796
+
797
+ encoder_outputs = self.encoder(
798
+ token_embedding_output,
799
+ attention_mask=extended_attention_mask,
800
+ output_hidden_states=output_hidden_states,
801
+ output_attentions=output_attentions,
802
+ )
803
+ sequence_output = encoder_outputs.last_hidden_state
804
+
805
+ return BaseModelOutputWithPoolingAndCrossAttentions(
806
+ last_hidden_state=sequence_output,
807
+ hidden_states=encoder_outputs.hidden_states,
808
+ attentions=encoder_outputs.attentions,
809
+ )
810
+
811
+
812
+ class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
813
+ def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
814
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
815
+ self.config = config
816
+ self.esm = FAST_ESM_ENCODER(config)
817
+ self.pooler = EsmPooler(config) if add_pooling_layer else None
818
+ # Initialize weights and apply final processing
819
+ self.post_init()
820
+
821
+ def get_input_embeddings(self):
822
+ return self.embeddings.word_embeddings
823
+
824
+ def set_input_embeddings(self, value):
825
+ self.embeddings.word_embeddings = value
826
+
827
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
828
+ return self.esm._embed(input_ids, attention_mask)
829
+
830
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
831
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
832
+
833
+ def forward(
834
+ self,
835
+ input_ids: Optional[torch.Tensor] = None,
836
+ attention_mask: Optional[torch.Tensor] = None,
837
+ position_ids: Optional[torch.Tensor] = None,
838
+ inputs_embeds: Optional[torch.Tensor] = None,
839
+ output_attentions: Optional[bool] = None,
840
+ output_hidden_states: Optional[bool] = None,
841
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
842
+ **kwargs,
843
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
844
+ """Forward pass for base model.
845
+
846
+ Args:
847
+ input_ids: Input token IDs
848
+ attention_mask: Optional attention mask
849
+ position_ids: Optional position IDs
850
+ inputs_embeds: Optional input embeddings
851
+ output_hidden_states: Whether to return all hidden states
852
+ output_attentions: Whether to return attention weights
853
+
854
+ Returns:
855
+ Model outputs including hidden states and optionally attention weights
856
+ """
857
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
858
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
859
+
860
+ if input_ids is not None and inputs_embeds is not None:
861
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
862
+ elif input_ids is not None:
863
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
864
+ input_shape = input_ids.size()
865
+ elif inputs_embeds is not None:
866
+ input_shape = inputs_embeds.size()[:-1]
867
+ else:
868
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
869
+
870
+ outputs = self.esm(
871
+ input_ids,
872
+ attention_mask=attention_mask,
873
+ position_ids=position_ids,
874
+ inputs_embeds=inputs_embeds,
875
+ output_hidden_states=output_hidden_states,
876
+ output_attentions=output_attentions,
877
+ )
878
+ sequence_output = outputs.last_hidden_state
879
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
880
+
881
+ return BaseModelOutputWithPoolingAndCrossAttentions(
882
+ last_hidden_state=sequence_output,
883
+ pooler_output=pooled_output,
884
+ hidden_states=outputs.hidden_states,
885
+ attentions=outputs.attentions,
886
+ )
887
+
888
+
889
+ class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
890
+ def __init__(self, config, **kwargs):
891
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
892
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
893
+ self.lm_head = EsmLMHead(config)
894
+ self.loss_fct = nn.CrossEntropyLoss()
895
+ self.init_weights()
896
+
897
+ def get_output_embeddings(self):
898
+ return self.lm_head.decoder
899
+
900
+ def set_output_embeddings(self, new_embeddings):
901
+ self.lm_head.decoder = new_embeddings
902
+
903
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
904
+ return self.esm._embed(input_ids, attention_mask)
905
+
906
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
907
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
908
+
909
+ def forward(
910
+ self,
911
+ input_ids: Optional[torch.Tensor] = None,
912
+ attention_mask: Optional[torch.Tensor] = None,
913
+ position_ids: Optional[torch.Tensor] = None,
914
+ inputs_embeds: Optional[torch.Tensor] = None,
915
+ labels: Optional[torch.Tensor] = None,
916
+ output_attentions: Optional[bool] = None,
917
+ output_hidden_states: Optional[bool] = None,
918
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
919
+ **kwargs,
920
+ ) -> Union[Tuple, EsmMaskedLMOutput]:
921
+ outputs = self.esm(
922
+ input_ids,
923
+ attention_mask=attention_mask,
924
+ position_ids=position_ids,
925
+ inputs_embeds=inputs_embeds,
926
+ output_hidden_states=output_hidden_states,
927
+ output_attentions=output_attentions,
928
+ )
929
+ sequence_output = outputs.last_hidden_state
930
+ prediction_scores = self.lm_head(sequence_output)
931
+
932
+ loss = None
933
+ if labels is not None:
934
+ labels = labels.to(prediction_scores.device)
935
+ loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
936
+
937
+ return EsmMaskedLMOutput(
938
+ loss=loss,
939
+ logits=prediction_scores,
940
+ last_hidden_state=sequence_output,
941
+ hidden_states=outputs.hidden_states,
942
+ attentions=outputs.attentions,
943
+ )
944
+
945
+
946
+ class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
947
+ def __init__(self, config, **kwargs):
948
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
949
+ self.num_labels = config.num_labels
950
+ self.config = config
951
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
952
+ self.classifier = EsmClassificationHead(config)
953
+ self.mse = nn.MSELoss()
954
+ self.ce = nn.CrossEntropyLoss()
955
+ self.bce = nn.BCEWithLogitsLoss()
956
+ self.init_weights()
957
+
958
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
959
+ return self.esm._embed(input_ids, attention_mask)
960
+
961
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
962
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
963
+
964
+ def forward(
965
+ self,
966
+ input_ids: Optional[torch.Tensor] = None,
967
+ attention_mask: Optional[torch.Tensor] = None,
968
+ position_ids: Optional[torch.Tensor] = None,
969
+ inputs_embeds: Optional[torch.Tensor] = None,
970
+ labels: Optional[torch.Tensor] = None,
971
+ output_attentions: Optional[bool] = None,
972
+ output_hidden_states: Optional[bool] = None,
973
+ return_dict: Optional[bool] = None,
974
+ **kwargs,
975
+ ) -> Union[Tuple, SequenceClassifierOutput]:
976
+ outputs = self.esm(
977
+ input_ids,
978
+ attention_mask=attention_mask,
979
+ position_ids=position_ids,
980
+ inputs_embeds=inputs_embeds,
981
+ output_attentions=output_attentions,
982
+ output_hidden_states=output_hidden_states,
983
+ )
984
+ sequence_output = outputs.last_hidden_state
985
+ logits = self.classifier(sequence_output)
986
+
987
+ loss = None
988
+ if labels is not None:
989
+ labels = labels.to(logits.device)
990
+ if self.config.problem_type is None:
991
+ if self.num_labels == 1:
992
+ self.config.problem_type = "regression"
993
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
994
+ self.config.problem_type = "single_label_classification"
995
+ else:
996
+ self.config.problem_type = "multi_label_classification"
997
+
998
+ if self.config.problem_type == "regression":
999
+ if self.num_labels == 1:
1000
+ loss = self.mse(logits.squeeze(), labels.squeeze())
1001
+ else:
1002
+ loss = self.mse(logits, labels)
1003
+ elif self.config.problem_type == "single_label_classification":
1004
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
1005
+ elif self.config.problem_type == "multi_label_classification":
1006
+ loss = self.bce(logits, labels)
1007
+
1008
+ return SequenceClassifierOutput(
1009
+ loss=loss,
1010
+ logits=logits,
1011
+ hidden_states=outputs.hidden_states,
1012
+ attentions=outputs.attentions,
1013
+ )
1014
+
1015
+
1016
+ class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin):
1017
+ def __init__(self, config, **kwargs):
1018
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
1019
+ self.num_labels = config.num_labels
1020
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
1021
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1022
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1023
+ self.loss_fct = nn.CrossEntropyLoss()
1024
+ self.init_weights()
1025
+
1026
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1027
+ return self.esm._embed(input_ids, attention_mask)
1028
+
1029
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
1030
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
1031
+
1032
+ def forward(
1033
+ self,
1034
+ input_ids: Optional[torch.Tensor] = None,
1035
+ attention_mask: Optional[torch.Tensor] = None,
1036
+ position_ids: Optional[torch.Tensor] = None,
1037
+ inputs_embeds: Optional[torch.Tensor] = None,
1038
+ labels: Optional[torch.Tensor] = None,
1039
+ output_attentions: Optional[bool] = None,
1040
+ output_hidden_states: Optional[bool] = None,
1041
+ return_dict: Optional[bool] = None,
1042
+ **kwargs,
1043
+ ) -> Union[Tuple, TokenClassifierOutput]:
1044
+ outputs = self.esm(
1045
+ input_ids,
1046
+ attention_mask=attention_mask,
1047
+ position_ids=position_ids,
1048
+ inputs_embeds=inputs_embeds,
1049
+ output_attentions=output_attentions,
1050
+ output_hidden_states=output_hidden_states,
1051
+ )
1052
+ sequence_output = outputs.last_hidden_state
1053
+ sequence_output = self.dropout(sequence_output)
1054
+ logits = self.classifier(sequence_output)
1055
+
1056
+ loss = None
1057
+ if labels is not None:
1058
+ labels = labels.to(logits.device)
1059
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1060
+
1061
+ return TokenClassifierOutput(
1062
+ loss=loss,
1063
+ logits=logits,
1064
+ hidden_states=outputs.hidden_states,
1065
+ attentions=outputs.attentions,
1066
+ )
1067
+
1068
+
1069
+ if __name__ == "__main__":
1070
+ """
1071
+ Test the hidden state differences between the FastEsmModel and the HF EsmModel.
1072
+ In full precision, the differences are very very small, but nonzero due to floating point issues with F.scaled_dot_product_attention.
1073
+ In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
1074
+ """
1075
+ import random
1076
+ from transformers import EsmForMaskedLM as TransformersEsmModel, EsmTokenizer
1077
+
1078
+ model_paths = [
1079
+ "facebook/esm2_t6_8M_UR50D",
1080
+ "facebook/esm2_t12_35M_UR50D",
1081
+ #"facebook/esm2_t30_150M_UR50D",
1082
+ #"facebook/esm2_t33_650M_UR50D",
1083
+ ]
1084
+ canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
1085
+ length = 64
1086
+ seq_count = 100
1087
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1088
+ tolerances = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
1089
+
1090
+ def generate_random_sequence(length: int) -> str:
1091
+ return 'M' + "".join(random.choices(canonical_amino_acids, k=length))
1092
+
1093
+ print("Percentage of hidden states that are within the tolerance:")
1094
+ for model_path in model_paths:
1095
+ print(f"Testing {model_path}...")
1096
+ tokenizer = EsmTokenizer.from_pretrained(model_path)
1097
+ config = FastEsmConfig.from_pretrained(model_path)
1098
+ fast_model = FastEsmForMaskedLM(config).from_pretrained(model_path).to(device)
1099
+ print('fast model')
1100
+ print(fast_model)
1101
+ model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
1102
+ print('transformers model')
1103
+ print(model)
1104
+
1105
+ counts = [0] * len(tolerances)
1106
+ for _ in range(seq_count):
1107
+ example_seq = generate_random_sequence(length)
1108
+ fast_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
1109
+ fast_output = fast_model(fast_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
1110
+
1111
+ model_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
1112
+ model_output = model(model_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
1113
+
1114
+ for i, atol in enumerate(tolerances):
1115
+ if torch.allclose(fast_output, model_output, atol=atol):
1116
+ counts[i] += 1
1117
+
1118
+ print(f"{model_path}:")
1119
+ for i, atol in enumerate(tolerances):
1120
+ print(f" tolerance={atol}: {counts[i] / seq_count * 100}%")
1121
+
1122
+ model.cpu()
1123
+ fast_model.cpu()
1124
+ del model
1125
+ del fast_model
1126
+ torch.cuda.empty_cache()