lhallee commited on
Commit
081d607
·
verified ·
1 Parent(s): 3b1edaf

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +25 -21
modeling_fastesm.py CHANGED
@@ -364,7 +364,6 @@ from typing import Optional, Tuple, Union, Dict, Any
364
  from einops import rearrange
365
  from dataclasses import dataclass
366
  from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
367
- from transformers import initialization as init
368
  from transformers.modeling_outputs import (
369
  ModelOutput,
370
  BaseModelOutputWithPastAndCrossAttentions,
@@ -399,9 +398,9 @@ def get_attention_mask(
399
  attention_mask: Optional[torch.Tensor] = None
400
  ) -> torch.Tensor:
401
  if attention_mask is None:
402
- token_attention_mask = torch.ones((batch_size, seq_len), device=device).bool()
403
  else:
404
- token_attention_mask = attention_mask.bool()
405
 
406
  if attn_backend == "flex":
407
  assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
@@ -409,8 +408,10 @@ def get_attention_mask(
409
  if attention_mask is None:
410
  flex_block_mask = None
411
  else:
 
 
412
  def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
413
- return (token_attention_mask[batch_idx, q_idx] == token_attention_mask[batch_idx, kv_idx]) & (token_attention_mask[batch_idx, q_idx] != 0)
414
 
415
  flex_block_mask = create_block_mask(
416
  mask_mod,
@@ -420,12 +421,12 @@ def get_attention_mask(
420
  seq_len,
421
  device=device,
422
  )
423
- extended_attention_mask = None
424
  else:
425
  flex_block_mask = None
426
- extended_attention_mask = token_attention_mask[:, None, :, None] & token_attention_mask[:, None, None, :]
427
 
428
- return extended_attention_mask, flex_block_mask
429
 
430
 
431
  @dataclass
@@ -763,16 +764,19 @@ class FastEsmPreTrainedModel(PreTrainedModel):
763
  return True
764
 
765
  @torch.no_grad()
766
- def _init_weights(self, module):
767
- """Initialize the weights"""
768
- super()._init_weights(module)
769
- if isinstance(module, EsmLMHead):
770
- init.zeros_(module.bias)
771
- elif isinstance(module, EsmEmbeddings):
772
- init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
773
- elif isinstance(module, RotaryEmbedding):
774
- inv_freq = 1.0 / (10000 ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim))
775
- init.copy_(module.inv_freq, inv_freq)
 
 
 
776
 
777
  def get_output_embeddings(self):
778
  # NOTE: get_output_embeddings() must return None to prevent accidental weight tying.
@@ -809,7 +813,7 @@ class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
809
 
810
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
811
  token_embedding_output = self.embeddings(input_ids, attention_mask=attention_mask)
812
- attention_mask, flex_block_mask = get_attention_mask(
813
  attn_backend=self.config.attn_backend,
814
  batch_size=input_ids.shape[0],
815
  seq_len=input_ids.shape[1],
@@ -818,7 +822,7 @@ class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
818
  )
819
  encoder_outputs = self.encoder(
820
  token_embedding_output,
821
- attention_mask=attention_mask,
822
  flex_block_mask=flex_block_mask,
823
  output_hidden_states=False,
824
  output_attentions=False,
@@ -874,7 +878,7 @@ class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
874
  attention_mask=attention_mask,
875
  inputs_embeds=inputs_embeds,
876
  )
877
- attention_mask, flex_block_mask = get_attention_mask(
878
  attn_backend=self.config.attn_backend,
879
  batch_size=input_ids.shape[0],
880
  seq_len=input_ids.shape[1],
@@ -883,7 +887,7 @@ class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
883
  )
884
  encoder_outputs = self.encoder(
885
  token_embedding_output,
886
- attention_mask=attention_mask,
887
  flex_block_mask=flex_block_mask,
888
  output_hidden_states=output_hidden_states,
889
  output_attentions=output_attentions,
 
364
  from einops import rearrange
365
  from dataclasses import dataclass
366
  from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
 
367
  from transformers.modeling_outputs import (
368
  ModelOutput,
369
  BaseModelOutputWithPastAndCrossAttentions,
 
398
  attention_mask: Optional[torch.Tensor] = None
399
  ) -> torch.Tensor:
400
  if attention_mask is None:
401
+ attention_mask_2d = torch.ones((batch_size, seq_len), device=device).bool()
402
  else:
403
+ attention_mask_2d = attention_mask.bool()
404
 
405
  if attn_backend == "flex":
406
  assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
 
408
  if attention_mask is None:
409
  flex_block_mask = None
410
  else:
411
+ valid_lens = attention_mask_2d.sum(dim=-1)
412
+
413
  def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
414
+ return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx])
415
 
416
  flex_block_mask = create_block_mask(
417
  mask_mod,
 
421
  seq_len,
422
  device=device,
423
  )
424
+ attention_mask_4d = None
425
  else:
426
  flex_block_mask = None
427
+ attention_mask_4d = attention_mask_2d[:, None, :, None] & attention_mask_2d[:, None, None, :]
428
 
429
+ return attention_mask_4d, flex_block_mask
430
 
431
 
432
  @dataclass
 
764
  return True
765
 
766
  @torch.no_grad()
767
+ def _init_weights(self, module: nn.Module) -> None:
768
+ std = self.config.initializer_range
769
+ if isinstance(module, nn.Linear):
770
+ module.weight.data.normal_(mean=0.0, std=std)
771
+ if module.bias is not None:
772
+ module.bias.data.zero_()
773
+ elif isinstance(module, nn.Embedding):
774
+ module.weight.data.normal_(mean=0.0, std=std)
775
+ if module.padding_idx is not None:
776
+ module.weight.data[module.padding_idx].zero_()
777
+
778
+ def post_init(self) -> None:
779
+ super().post_init()
780
 
781
  def get_output_embeddings(self):
782
  # NOTE: get_output_embeddings() must return None to prevent accidental weight tying.
 
813
 
814
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
815
  token_embedding_output = self.embeddings(input_ids, attention_mask=attention_mask)
816
+ attention_mask_4d, flex_block_mask = get_attention_mask(
817
  attn_backend=self.config.attn_backend,
818
  batch_size=input_ids.shape[0],
819
  seq_len=input_ids.shape[1],
 
822
  )
823
  encoder_outputs = self.encoder(
824
  token_embedding_output,
825
+ attention_mask=attention_mask_4d,
826
  flex_block_mask=flex_block_mask,
827
  output_hidden_states=False,
828
  output_attentions=False,
 
878
  attention_mask=attention_mask,
879
  inputs_embeds=inputs_embeds,
880
  )
881
+ attention_mask_4d, flex_block_mask = get_attention_mask(
882
  attn_backend=self.config.attn_backend,
883
  batch_size=input_ids.shape[0],
884
  seq_len=input_ids.shape[1],
 
887
  )
888
  encoder_outputs = self.encoder(
889
  token_embedding_output,
890
+ attention_mask=attention_mask_4d,
891
  flex_block_mask=flex_block_mask,
892
  output_hidden_states=output_hidden_states,
893
  output_attentions=output_attentions,