lhallee commited on
Commit
2f2bca8
·
verified ·
1 Parent(s): b3ee1c0

Upload modeling_dplm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dplm.py +23 -16
modeling_dplm.py CHANGED
@@ -420,9 +420,9 @@ def get_attention_mask(
420
  attention_mask: Optional[torch.Tensor] = None,
421
  ) -> Tuple[Optional[torch.Tensor], Optional[object]]:
422
  if attention_mask is None:
423
- token_attention_mask = torch.ones((batch_size, seq_len), device=device).bool()
424
  else:
425
- token_attention_mask = attention_mask.bool()
426
 
427
  if attn_backend == "flex":
428
  assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
@@ -430,8 +430,10 @@ def get_attention_mask(
430
  if attention_mask is None:
431
  flex_block_mask = None
432
  else:
 
 
433
  def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
434
- return (token_attention_mask[batch_idx, q_idx] == token_attention_mask[batch_idx, kv_idx]) & (token_attention_mask[batch_idx, q_idx] != 0)
435
 
436
  flex_block_mask = create_block_mask(
437
  mask_mod,
@@ -441,12 +443,12 @@ def get_attention_mask(
441
  seq_len,
442
  device=device,
443
  )
444
- extended_attention_mask = None
445
  else:
446
  flex_block_mask = None
447
- extended_attention_mask = token_attention_mask[:, None, :, None] & token_attention_mask[:, None, None, :]
448
 
449
- return extended_attention_mask, flex_block_mask
450
 
451
 
452
  @dataclass
@@ -478,6 +480,11 @@ class DPLMPreTrainedModel(EsmPreTrainedModel):
478
  tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
479
  all_tied_weights_keys = {}
480
 
 
 
 
 
 
481
  @property
482
  def attn_backend(self) -> str:
483
  return self.config.attn_backend
@@ -899,12 +906,12 @@ class FAST_DPLM_ENCODER(DPLMPreTrainedModel, EmbeddingMixin):
899
  past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
900
 
901
  if attention_mask is None:
902
- token_attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device).bool()
903
  elif attention_mask.dim() == 2:
904
- token_attention_mask = attention_mask.bool()
905
  elif attention_mask.dim() == 4:
906
  assert input_ids is not None, "4D attention_mask requires input_ids to infer token-level mask."
907
- token_attention_mask = input_ids.ne(self.config.pad_token_id)
908
  else:
909
  raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
910
 
@@ -919,19 +926,19 @@ class FAST_DPLM_ENCODER(DPLMPreTrainedModel, EmbeddingMixin):
919
 
920
  head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
921
 
922
- embedding_attention_mask = token_attention_mask
923
  if embedding_attention_mask is None and input_ids is not None:
924
  embedding_attention_mask = input_ids.ne(self.config.pad_token_id)
925
 
926
  if self.config.attn_backend == "flex" and output_attentions:
927
  raise AssertionError("output_attentions=True is not supported with attn_backend='flex'.")
928
 
929
- extended_attention_mask, flex_block_mask = get_attention_mask(
930
  attn_backend=self.config.attn_backend,
931
  batch_size=batch_size,
932
  seq_len=seq_length,
933
  device=device,
934
- attention_mask=token_attention_mask,
935
  )
936
 
937
  embedding_output = self.embeddings(
@@ -942,7 +949,7 @@ class FAST_DPLM_ENCODER(DPLMPreTrainedModel, EmbeddingMixin):
942
  )
943
  encoder_outputs = self.encoder(
944
  embedding_output,
945
- attention_mask=extended_attention_mask,
946
  head_mask=head_mask,
947
  encoder_hidden_states=encoder_hidden_states,
948
  encoder_attention_mask=encoder_extended_attention_mask,
@@ -1041,7 +1048,7 @@ class DPLMForMaskedLM(DPLMPreTrainedModel, EmbeddingMixin):
1041
  def __init__(self, config, dropout: float = 0.1):
1042
  config.hidden_dropout_prob = dropout
1043
  DPLMPreTrainedModel.__init__(self, config)
1044
- self.esm = DPLMModel(config, add_pooling_layer=False)
1045
  self.lm_head = EsmLMHead(config)
1046
  self.loss_fct = nn.CrossEntropyLoss()
1047
  self.post_init()
@@ -1136,7 +1143,7 @@ class DPLMForSequenceClassification(DPLMPreTrainedModel, EmbeddingMixin):
1136
  def __init__(self, config):
1137
  DPLMPreTrainedModel.__init__(self, config)
1138
  self.num_labels = config.num_labels
1139
- self.esm = DPLMModel(config, add_pooling_layer=False)
1140
  self.classifier = EsmClassificationHead(config)
1141
  self.mse = nn.MSELoss()
1142
  self.ce = nn.CrossEntropyLoss()
@@ -1206,7 +1213,7 @@ class DPLMForTokenClassification(DPLMPreTrainedModel, EmbeddingMixin):
1206
  def __init__(self, config):
1207
  DPLMPreTrainedModel.__init__(self, config)
1208
  self.num_labels = config.num_labels
1209
- self.esm = DPLMModel(config, add_pooling_layer=False)
1210
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
1211
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1212
  self.loss_fct = nn.CrossEntropyLoss()
 
420
  attention_mask: Optional[torch.Tensor] = None,
421
  ) -> Tuple[Optional[torch.Tensor], Optional[object]]:
422
  if attention_mask is None:
423
+ attention_mask_2d = torch.ones((batch_size, seq_len), device=device).bool()
424
  else:
425
+ attention_mask_2d = attention_mask.bool()
426
 
427
  if attn_backend == "flex":
428
  assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
 
430
  if attention_mask is None:
431
  flex_block_mask = None
432
  else:
433
+ valid_lens = attention_mask_2d.sum(dim=-1)
434
+
435
  def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
436
+ return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx])
437
 
438
  flex_block_mask = create_block_mask(
439
  mask_mod,
 
443
  seq_len,
444
  device=device,
445
  )
446
+ attention_mask_4d = None
447
  else:
448
  flex_block_mask = None
449
+ attention_mask_4d = attention_mask_2d[:, None, :, None] & attention_mask_2d[:, None, None, :]
450
 
451
+ return attention_mask_4d, flex_block_mask
452
 
453
 
454
  @dataclass
 
480
  tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
481
  all_tied_weights_keys = {}
482
 
483
+ @classmethod
484
+ def is_remote_code(cls) -> bool:
485
+ # Prevent post-load reinitialization of tensors already loaded from checkpoints.
486
+ return True
487
+
488
  @property
489
  def attn_backend(self) -> str:
490
  return self.config.attn_backend
 
906
  past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
907
 
908
  if attention_mask is None:
909
+ attention_mask_2d = torch.ones((batch_size, seq_length + past_key_values_length), device=device).bool()
910
  elif attention_mask.dim() == 2:
911
+ attention_mask_2d = attention_mask.bool()
912
  elif attention_mask.dim() == 4:
913
  assert input_ids is not None, "4D attention_mask requires input_ids to infer token-level mask."
914
+ attention_mask_2d = input_ids.ne(self.config.pad_token_id)
915
  else:
916
  raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
917
 
 
926
 
927
  head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
928
 
929
+ embedding_attention_mask = attention_mask_2d
930
  if embedding_attention_mask is None and input_ids is not None:
931
  embedding_attention_mask = input_ids.ne(self.config.pad_token_id)
932
 
933
  if self.config.attn_backend == "flex" and output_attentions:
934
  raise AssertionError("output_attentions=True is not supported with attn_backend='flex'.")
935
 
936
+ attention_mask_4d, flex_block_mask = get_attention_mask(
937
  attn_backend=self.config.attn_backend,
938
  batch_size=batch_size,
939
  seq_len=seq_length,
940
  device=device,
941
+ attention_mask=attention_mask_2d,
942
  )
943
 
944
  embedding_output = self.embeddings(
 
949
  )
950
  encoder_outputs = self.encoder(
951
  embedding_output,
952
+ attention_mask=attention_mask_4d,
953
  head_mask=head_mask,
954
  encoder_hidden_states=encoder_hidden_states,
955
  encoder_attention_mask=encoder_extended_attention_mask,
 
1048
  def __init__(self, config, dropout: float = 0.1):
1049
  config.hidden_dropout_prob = dropout
1050
  DPLMPreTrainedModel.__init__(self, config)
1051
+ self.esm = FAST_DPLM_ENCODER(config)
1052
  self.lm_head = EsmLMHead(config)
1053
  self.loss_fct = nn.CrossEntropyLoss()
1054
  self.post_init()
 
1143
  def __init__(self, config):
1144
  DPLMPreTrainedModel.__init__(self, config)
1145
  self.num_labels = config.num_labels
1146
+ self.esm = FAST_DPLM_ENCODER(config)
1147
  self.classifier = EsmClassificationHead(config)
1148
  self.mse = nn.MSELoss()
1149
  self.ce = nn.CrossEntropyLoss()
 
1213
  def __init__(self, config):
1214
  DPLMPreTrainedModel.__init__(self, config)
1215
  self.num_labels = config.num_labels
1216
+ self.esm = FAST_DPLM_ENCODER(config)
1217
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
1218
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1219
  self.loss_fct = nn.CrossEntropyLoss()