lhallee commited on
Commit
a99728d
·
verified ·
1 Parent(s): f4a8cd1

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +46 -238
modeling_fastesm.py CHANGED
@@ -1,4 +1,3 @@
1
- import entrypoint_setup
2
  import torch
3
  import torch.nn as nn
4
  from torch.nn import functional as F
@@ -6,6 +5,7 @@ from typing import Optional, Tuple, Union, Dict, Any
6
  from einops import rearrange
7
  from dataclasses import dataclass
8
  from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
 
9
  from transformers.modeling_outputs import (
10
  ModelOutput,
11
  BaseModelOutputWithPastAndCrossAttentions,
@@ -20,6 +20,9 @@ from transformers.models.esm.modeling_esm import (
20
  EsmLMHead,
21
  EsmSelfOutput,
22
  EsmClassificationHead,
 
 
 
23
  )
24
  try:
25
  from torch.nn.attention.flex_attention import create_block_mask
@@ -28,13 +31,7 @@ except ImportError:
28
  create_block_mask = None
29
  flex_attention = None
30
 
31
- try:
32
- from .embedding_mixin import EmbeddingMixin
33
- except ImportError:
34
- try:
35
- from ..embedding_mixin import EmbeddingMixin
36
- except ImportError:
37
- from embedding_mixin import EmbeddingMixin
38
 
39
 
40
  def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
@@ -80,7 +77,7 @@ class FastEsmConfig(PretrainedConfig):
80
  max_position_embeddings: int = 1026,
81
  initializer_range: float = 0.02,
82
  layer_norm_eps: float = 1e-12,
83
- position_embedding_type: str = "absolute",
84
  emb_layer_norm_before: bool = None,
85
  token_dropout: bool = True,
86
  attn_backend: str = "sdpa",
@@ -119,182 +116,6 @@ class FastEsmConfig(PretrainedConfig):
119
  return output
120
 
121
 
122
- def rotate_half(x: torch.Tensor) -> torch.Tensor:
123
- x1, x2 = x.chunk(2, dim=-1)
124
- return torch.cat((-x2, x1), dim=-1)
125
-
126
-
127
- def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
128
- cos = cos[:, :, : x.shape[-2], :]
129
- sin = sin[:, :, : x.shape[-2], :]
130
-
131
- return (x * cos) + (rotate_half(x) * sin)
132
-
133
-
134
- def symmetrize(x: torch.Tensor) -> torch.Tensor:
135
- "Make layer symmetric in final two dimensions, used for contact prediction."
136
- return x + x.transpose(-1, -2)
137
-
138
-
139
- def average_product_correct(x: torch.Tensor) -> torch.Tensor:
140
- "Perform average product correct, used for contact prediction."
141
- a1 = x.sum(-1, keepdims=True)
142
- a2 = x.sum(-2, keepdims=True)
143
- a12 = x.sum((-1, -2), keepdims=True)
144
-
145
- avg = a1 * a2
146
- avg.div_(a12) # in-place to reduce memory
147
- normalized = x - avg
148
- return normalized
149
-
150
-
151
- class EsmContactPredictionHead(nn.Module):
152
- """Performs symmetrization, apc, and computes a logistic regression on the output features"""
153
-
154
- def __init__(
155
- self,
156
- in_features: int,
157
- bias: bool = True,
158
- eos_idx: int = 2,
159
- ):
160
- super().__init__()
161
- self.in_features = in_features
162
- self.eos_idx = eos_idx
163
- self.regression = nn.Linear(in_features, 1, bias=bias)
164
- self.activation = nn.Sigmoid()
165
-
166
- def forward(self, input_ids: torch.Tensor, attentions: torch.Tensor) -> torch.Tensor:
167
- # remove eos token attentions
168
- eos_mask = input_ids.ne(self.eos_idx).to(attentions)
169
- eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
170
- attentions = attentions * eos_mask[:, None, None, :, :]
171
- attentions = attentions[..., :-1, :-1]
172
- # remove cls token attentions
173
- attentions = attentions[..., 1:, 1:]
174
- batch_size, layers, heads, seqlen, _ = attentions.size()
175
- attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
176
-
177
- # features: batch x channels x tokens x tokens (symmetric)
178
- attentions = attentions.to(
179
- self.regression.weight.device
180
- ) # attentions always float32, may need to convert to float16
181
- attentions = average_product_correct(symmetrize(attentions))
182
- attentions = attentions.permute(0, 2, 3, 1)
183
- return self.activation(self.regression(attentions).squeeze(3))
184
-
185
-
186
- class RotaryEmbedding(torch.nn.Module):
187
- """
188
- Rotary position embeddings based on those in
189
- [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
190
- matrices which depend on their relative positions.
191
- """
192
-
193
- def __init__(self, dim: int):
194
- super().__init__()
195
- # Generate and save the inverse frequency buffer (non trainable)
196
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
197
- inv_freq = inv_freq
198
- self.register_buffer("inv_freq", inv_freq)
199
-
200
- self._seq_len_cached = None
201
- self._cos_cached = None
202
- self._sin_cached = None
203
-
204
- def _update_cos_sin_tables(self, x: torch.Tensor, seq_dimension: int = 2) -> Tuple[torch.Tensor, torch.Tensor]:
205
- seq_len = x.shape[seq_dimension]
206
-
207
- # Reset the tables if the sequence length has changed,
208
- # or if we're on a new device (possibly due to tracing for instance)
209
- if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
210
- self._seq_len_cached = seq_len
211
- t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
212
- freqs = torch.outer(t, self.inv_freq)
213
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
214
-
215
- self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
216
- self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
217
-
218
- return self._cos_cached, self._sin_cached
219
-
220
- def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
221
- self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
222
-
223
- return (
224
- apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
225
- apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
226
- )
227
-
228
-
229
- class EsmEmbeddings(nn.Module):
230
- """
231
- Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
232
- """
233
-
234
- def __init__(self, config):
235
- super().__init__()
236
- self.padding_idx = config.pad_token_id
237
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
238
- if config.emb_layer_norm_before:
239
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
240
- else:
241
- self.layer_norm = None
242
- self.position_embedding_type = config.position_embedding_type
243
- self.register_buffer(
244
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
245
- )
246
- self.token_dropout = config.token_dropout
247
- self.mask_token_id = config.mask_token_id
248
-
249
- def forward(
250
- self,
251
- input_ids: Optional[torch.Tensor] = None,
252
- attention_mask: Optional[torch.Tensor] = None,
253
- position_ids: Optional[torch.Tensor] = None,
254
- inputs_embeds: Optional[torch.Tensor] = None,
255
- past_key_values_length: Optional[int] = 0,
256
- ):
257
- if inputs_embeds is None:
258
- inputs_embeds = self.word_embeddings(input_ids)
259
-
260
- embeddings = inputs_embeds
261
-
262
- if attention_mask is None:
263
- attention_mask = torch.ones_like(input_ids)
264
-
265
- if self.token_dropout:
266
- embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0)
267
- mask_ratio_train = 0.15 * 0.8
268
- src_lengths = attention_mask.sum(-1)
269
- mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
270
- embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
271
- embeddings.dtype
272
- )
273
-
274
- if self.layer_norm is not None:
275
- embeddings = self.layer_norm(embeddings)
276
- if attention_mask is not None:
277
- embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
278
- return embeddings
279
-
280
- def create_position_ids_from_inputs_embeds(self, inputs_embeds):
281
- """
282
- We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
283
-
284
- Args:
285
- inputs_embeds: torch.Tensor
286
-
287
- Returns: torch.Tensor
288
- """
289
- input_shape = inputs_embeds.size()[:-1]
290
- sequence_length = input_shape[1]
291
-
292
- position_ids = torch.arange(
293
- self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
294
- )
295
- return position_ids.unsqueeze(0).expand(input_shape)
296
-
297
-
298
  class EsmSelfAttention(nn.Module):
299
  def __init__(self, config, position_embedding_type: Optional[str] = None):
300
  super().__init__()
@@ -322,9 +143,6 @@ class EsmSelfAttention(nn.Module):
322
  if self.position_embedding_type == "rotary":
323
  self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
324
 
325
- def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
326
- return rearrange(x, 'b s (h d) -> b h s d', h=self.num_attention_heads)
327
-
328
  def forward(
329
  self,
330
  hidden_states: torch.Tensor,
@@ -342,9 +160,13 @@ class EsmSelfAttention(nn.Module):
342
  Returns:
343
  Output tensor and optionally attention weights
344
  """
345
- query_layer = self.transpose_for_scores(self.query(hidden_states)) * self.scale
346
- key_layer = self.transpose_for_scores(self.key(hidden_states))
347
- value_layer = self.transpose_for_scores(self.value(hidden_states))
 
 
 
 
348
 
349
  if self.position_embedding_type == "rotary":
350
  query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
@@ -363,32 +185,23 @@ class EsmSelfAttention(nn.Module):
363
  else:
364
  if self.attn_backend == "flex":
365
  assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
366
- assert query_layer.dtype in (torch.float16, torch.bfloat16), (
367
- f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}."
368
- )
369
- if attention_mask is not None:
370
- assert flex_block_mask is not None, (
371
- "Flex attention backend requires a block mask when attention_mask is provided."
372
- )
373
  context_layer = flex_attention(
374
  query_layer,
375
  key_layer,
376
  value_layer,
377
  block_mask=flex_block_mask,
378
- scale=1.0,
379
  )
380
  else:
381
- sdpa_mask = None
382
- if attention_mask is not None:
383
- sdpa_mask = torch.zeros_like(attention_mask, dtype=query_layer.dtype)
384
- sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
385
  context_layer = F.scaled_dot_product_attention(
386
  query_layer,
387
  key_layer,
388
  value_layer,
389
- attn_mask=sdpa_mask,
390
  dropout_p=self.dropout_prob if self.training else 0.0,
391
- scale=1.0
392
  )
393
  context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
394
  return context_layer
@@ -565,22 +378,23 @@ class FastEsmPreTrainedModel(PreTrainedModel):
565
  supports_gradient_checkpointing = True
566
  tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
567
  all_tied_weights_keys = {}
568
-
 
569
  def _init_weights(self, module):
570
  """Initialize the weights"""
571
- if isinstance(module, nn.Linear):
572
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
573
- if module.bias is not None:
574
- module.bias.data.zero_()
575
- elif isinstance(module, nn.Embedding):
576
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
577
- if module.padding_idx is not None:
578
- module.weight.data[module.padding_idx].zero_()
579
- elif isinstance(module, nn.LayerNorm):
580
- if module.bias is not None:
581
- module.bias.data.zero_()
582
- module.weight.data.fill_(1.0)
583
 
 
 
 
 
584
 
585
 
586
  class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
@@ -678,25 +492,19 @@ class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
678
  attention_mask=attention_mask,
679
  inputs_embeds=inputs_embeds,
680
  )
681
-
682
- flex_block_mask = None
683
- if attention_mask is not None:
684
- token_attention_mask = attention_mask.bool()
685
- if (
686
- self.config.attn_backend == "flex"
687
- and not output_attentions
688
- ):
689
- assert create_block_mask is not None, (
690
- "Flex attention backend requested but torch.create_block_mask is unavailable."
691
- )
692
- flex_block_mask = _create_pad_block_mask(token_attention_mask)
693
- extended_attention_mask = None
694
- else:
695
- extended_attention_mask = token_attention_mask[:, None, None, :].expand(
696
- batch_size, 1, seq_length, seq_length
697
- )
698
  else:
 
 
 
 
 
699
  extended_attention_mask = None
 
 
 
700
 
701
  encoder_outputs = self.encoder(
702
  token_embedding_output,
@@ -796,7 +604,7 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
796
  self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
797
  self.lm_head = EsmLMHead(config)
798
  self.loss_fct = nn.CrossEntropyLoss()
799
- self.init_weights()
800
 
801
  def get_input_embeddings(self):
802
  return self.esm.embeddings.word_embeddings
@@ -860,7 +668,7 @@ class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
860
  self.mse = nn.MSELoss()
861
  self.ce = nn.CrossEntropyLoss()
862
  self.bce = nn.BCEWithLogitsLoss()
863
- self.init_weights()
864
 
865
  def get_input_embeddings(self):
866
  return self.esm.embeddings.word_embeddings
@@ -931,7 +739,7 @@ class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin):
931
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
932
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
933
  self.loss_fct = nn.CrossEntropyLoss()
934
- self.init_weights()
935
 
936
  def get_input_embeddings(self):
937
  return self.esm.embeddings.word_embeddings
 
 
1
  import torch
2
  import torch.nn as nn
3
  from torch.nn import functional as F
 
5
  from einops import rearrange
6
  from dataclasses import dataclass
7
  from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
8
+ from transformers import initialization as init
9
  from transformers.modeling_outputs import (
10
  ModelOutput,
11
  BaseModelOutputWithPastAndCrossAttentions,
 
20
  EsmLMHead,
21
  EsmSelfOutput,
22
  EsmClassificationHead,
23
+ EsmContactPredictionHead,
24
+ EsmEmbeddings,
25
+ RotaryEmbedding,
26
  )
27
  try:
28
  from torch.nn.attention.flex_attention import create_block_mask
 
31
  create_block_mask = None
32
  flex_attention = None
33
 
34
+ from embedding_mixin import EmbeddingMixin
 
 
 
 
 
 
35
 
36
 
37
  def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
 
77
  max_position_embeddings: int = 1026,
78
  initializer_range: float = 0.02,
79
  layer_norm_eps: float = 1e-12,
80
+ position_embedding_type: str = "rotary",
81
  emb_layer_norm_before: bool = None,
82
  token_dropout: bool = True,
83
  attn_backend: str = "sdpa",
 
116
  return output
117
 
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  class EsmSelfAttention(nn.Module):
120
  def __init__(self, config, position_embedding_type: Optional[str] = None):
121
  super().__init__()
 
143
  if self.position_embedding_type == "rotary":
144
  self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
145
 
 
 
 
146
  def forward(
147
  self,
148
  hidden_states: torch.Tensor,
 
160
  Returns:
161
  Output tensor and optionally attention weights
162
  """
163
+ batch_size, seq_length = hidden_states.shape[:-1]
164
+ hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
165
+ query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
166
+ key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
167
+ value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
168
+
169
+ query_layer = query_layer * self.scale
170
 
171
  if self.position_embedding_type == "rotary":
172
  query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
 
185
  else:
186
  if self.attn_backend == "flex":
187
  assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
188
+ assert query_layer.dtype in (torch.float16, torch.bfloat16), f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}."
189
+ assert flex_block_mask is not None, "Flex attention backend requires a block mask"
 
 
 
 
 
190
  context_layer = flex_attention(
191
  query_layer,
192
  key_layer,
193
  value_layer,
194
  block_mask=flex_block_mask,
195
+ scale=1.0, # applied before rotary
196
  )
197
  else:
 
 
 
 
198
  context_layer = F.scaled_dot_product_attention(
199
  query_layer,
200
  key_layer,
201
  value_layer,
202
+ attn_mask=attention_mask,
203
  dropout_p=self.dropout_prob if self.training else 0.0,
204
+ scale=1.0 # applied before rotary
205
  )
206
  context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
207
  return context_layer
 
378
  supports_gradient_checkpointing = True
379
  tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
380
  all_tied_weights_keys = {}
381
+
382
+ @torch.no_grad()
383
  def _init_weights(self, module):
384
  """Initialize the weights"""
385
+ super()._init_weights(module)
386
+ if isinstance(module, EsmLMHead):
387
+ init.zeros_(module.bias)
388
+ elif isinstance(module, EsmEmbeddings):
389
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
390
+ elif isinstance(module, RotaryEmbedding):
391
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim))
392
+ init.copy_(module.inv_freq, inv_freq)
 
 
 
 
393
 
394
+ def get_output_embeddings(self):
395
+ # NOTE: get_output_embeddings() must return None to prevent accidental weight tying.
396
+ # See e.g. https://github.com/huggingface/transformers/pull/39339#discussion_r2219126400
397
+ return None
398
 
399
 
400
  class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
 
492
  attention_mask=attention_mask,
493
  inputs_embeds=inputs_embeds,
494
  )
495
+
496
+ if attention_mask is None:
497
+ token_attention_mask = torch.ones((batch_size, seq_length), device=input_ids.device).bool()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
  else:
499
+ token_attention_mask = attention_mask.bool()
500
+
501
+ if self.config.attn_backend == "flex" and not output_attentions:
502
+ assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
503
+ flex_block_mask = _create_pad_block_mask(token_attention_mask)
504
  extended_attention_mask = None
505
+ else:
506
+ flex_block_mask = None
507
+ extended_attention_mask = token_attention_mask[:, None, None, :].expand(batch_size, 1, seq_length, seq_length)
508
 
509
  encoder_outputs = self.encoder(
510
  token_embedding_output,
 
604
  self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
605
  self.lm_head = EsmLMHead(config)
606
  self.loss_fct = nn.CrossEntropyLoss()
607
+ self.post_init()
608
 
609
  def get_input_embeddings(self):
610
  return self.esm.embeddings.word_embeddings
 
668
  self.mse = nn.MSELoss()
669
  self.ce = nn.CrossEntropyLoss()
670
  self.bce = nn.BCEWithLogitsLoss()
671
+ self.post_init()
672
 
673
  def get_input_embeddings(self):
674
  return self.esm.embeddings.word_embeddings
 
739
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
740
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
741
  self.loss_fct = nn.CrossEntropyLoss()
742
+ self.post_init()
743
 
744
  def get_input_embeddings(self):
745
  return self.esm.embeddings.word_embeddings