lhallee commited on
Commit
d2fcb7b
·
verified ·
1 Parent(s): 4549a6a

Upload modeling_dplm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dplm.py +87 -75
modeling_dplm.py CHANGED
@@ -412,22 +412,38 @@ class BaseSequenceTokenizer:
412
  raise NotImplementedError
413
 
414
 
415
- def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
416
- assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
417
- token_valid = attention_mask_2d.bool()
418
- batch_size, seq_len = token_valid.shape
419
-
420
- def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
421
- return token_valid[batch_idx, q_idx] & token_valid[batch_idx, kv_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
 
423
- return create_block_mask(
424
- mask_mod,
425
- batch_size,
426
- 1,
427
- seq_len,
428
- seq_len,
429
- device=attention_mask_2d.device,
430
- )
431
 
432
 
433
  @dataclass
@@ -459,11 +475,20 @@ class DPLMPreTrainedModel(EsmPreTrainedModel):
459
  tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
460
  all_tied_weights_keys = {}
461
 
 
 
 
 
 
 
 
 
 
462
 
463
  class ModifiedEsmSelfAttention(EsmSelfAttention):
464
  def __init__(self, config, position_embedding_type=None):
465
  super().__init__(config, position_embedding_type)
466
- self.attn_backend = config.attn_backend
467
 
468
  def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
469
  new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
@@ -473,7 +498,7 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
473
  def forward(
474
  self,
475
  hidden_states: torch.Tensor,
476
- attention_mask: Optional[torch.FloatTensor] = None,
477
  head_mask: Optional[torch.FloatTensor] = None,
478
  encoder_hidden_states: Optional[torch.FloatTensor] = None,
479
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
@@ -522,24 +547,21 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
522
  value_layer = value_layer.contiguous()
523
 
524
  if output_attentions:
 
525
  attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
526
- if attention_mask is not None:
527
- attention_scores = attention_scores + attention_mask
528
  attention_probs = F.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
529
  context_layer = torch.matmul(attention_probs, value_layer)
530
  else:
531
  attention_probs = None
532
- if self.attn_backend == "flex":
533
  assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
534
  assert query_layer.dtype in (torch.float16, torch.bfloat16), (
535
  f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}."
536
  )
537
  assert is_cross_attention is False, "Flex attention backend currently does not support cross-attention."
538
  assert past_key_value is None, "Flex attention backend currently does not support KV caching."
539
- if attention_mask is not None:
540
- assert flex_block_mask is not None, (
541
- "Flex attention backend requires a block mask when attention_mask is provided."
542
- )
543
  context_layer = flex_attention(
544
  query_layer,
545
  key_layer,
@@ -579,14 +601,14 @@ class ModifiedEsmAttention(EsmAttention):
579
 
580
  def forward(
581
  self,
582
- hidden_states,
583
- attention_mask=None,
584
- head_mask=None,
585
- encoder_hidden_states=None,
586
- encoder_attention_mask=None,
587
- past_key_value=None,
588
- output_attentions=False,
589
- flex_block_mask=None,
590
  ):
591
  hidden_states_ln = self.LayerNorm(hidden_states)
592
  self_outputs = self.self(
@@ -622,14 +644,14 @@ class ModifiedEsmLayer(EsmLayer):
622
 
623
  def forward(
624
  self,
625
- hidden_states,
626
- attention_mask=None,
627
- head_mask=None,
628
- encoder_hidden_states=None,
629
- encoder_attention_mask=None,
630
- past_key_value=None,
631
- output_attentions=False,
632
- flex_block_mask=None,
633
  ):
634
  self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
635
  self_attention_outputs = self.attention(
@@ -688,17 +710,17 @@ class ModifiedEsmEncoder(EsmEncoder):
688
 
689
  def forward(
690
  self,
691
- hidden_states,
692
- attention_mask=None,
693
- head_mask=None,
694
- encoder_hidden_states=None,
695
- encoder_attention_mask=None,
696
- past_key_values=None,
697
- use_cache=None,
698
- output_attentions=False,
699
- output_hidden_states=False,
700
- return_dict=True,
701
- flex_block_mask=None,
702
  ):
703
  all_hidden_states = () if output_hidden_states else None
704
  all_self_attentions = () if output_attentions else None
@@ -873,22 +895,12 @@ class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin):
873
  past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
874
 
875
  if attention_mask is None:
876
- attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
877
-
878
- token_attention_mask = None
879
- if attention_mask.dim() == 2:
880
  token_attention_mask = attention_mask.bool()
881
- if self.config.attn_backend == "flex" and output_attentions is False:
882
- extended_attention_mask = None
883
- else:
884
- extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
885
  elif attention_mask.dim() == 4:
886
- if self.config.attn_backend == "flex" and output_attentions is False:
887
- extended_attention_mask = None
888
- else:
889
- extended_attention_mask = attention_mask
890
- if input_ids is not None:
891
- token_attention_mask = input_ids.ne(self.config.pad_token_id)
892
  else:
893
  raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
894
 
@@ -907,16 +919,16 @@ class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin):
907
  if embedding_attention_mask is None and input_ids is not None:
908
  embedding_attention_mask = input_ids.ne(self.config.pad_token_id)
909
 
910
- flex_block_mask = None
911
- if (
912
- self.config.attn_backend == "flex"
913
- and token_attention_mask is not None
914
- and output_attentions is False
915
- ):
916
- assert create_block_mask is not None, (
917
- "Flex attention backend requested but torch.create_block_mask is unavailable."
918
- )
919
- flex_block_mask = _create_pad_block_mask(token_attention_mask)
920
 
921
  embedding_output = self.embeddings(
922
  input_ids=input_ids,
 
412
  raise NotImplementedError
413
 
414
 
415
+ def get_attention_mask(
416
+ attn_backend: str,
417
+ batch_size: int,
418
+ seq_len: int,
419
+ device: torch.device,
420
+ attention_mask: Optional[torch.Tensor] = None,
421
+ ) -> Tuple[Optional[torch.Tensor], Optional[object]]:
422
+ if attention_mask is None:
423
+ token_attention_mask = torch.ones((batch_size, seq_len), device=device).bool()
424
+ else:
425
+ token_attention_mask = attention_mask.bool()
426
+
427
+ if attn_backend == "flex":
428
+ assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
429
+
430
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
431
+ return token_attention_mask[batch_idx, q_idx] & token_attention_mask[batch_idx, kv_idx]
432
+
433
+ flex_block_mask = create_block_mask(
434
+ mask_mod,
435
+ batch_size,
436
+ 1,
437
+ seq_len,
438
+ seq_len,
439
+ device=device,
440
+ )
441
+ extended_attention_mask = None
442
+ else:
443
+ flex_block_mask = None
444
+ extended_attention_mask = token_attention_mask[:, None, :, None] & token_attention_mask[:, None, None, :]
445
 
446
+ return extended_attention_mask, flex_block_mask
 
 
 
 
 
 
 
447
 
448
 
449
  @dataclass
 
475
  tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
476
  all_tied_weights_keys = {}
477
 
478
+ @property
479
+ def attn_backend(self) -> str:
480
+ return self.config.attn_backend
481
+
482
+ @attn_backend.setter
483
+ def attn_backend(self, backend: str) -> None:
484
+ assert backend in ("sdpa", "flex"), f"Unsupported attn_backend: {backend}"
485
+ self.config.attn_backend = backend
486
+
487
 
488
  class ModifiedEsmSelfAttention(EsmSelfAttention):
489
  def __init__(self, config, position_embedding_type=None):
490
  super().__init__(config, position_embedding_type)
491
+ self.config = config
492
 
493
  def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
494
  new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
 
498
  def forward(
499
  self,
500
  hidden_states: torch.Tensor,
501
+ attention_mask: Optional[torch.Tensor],
502
  head_mask: Optional[torch.FloatTensor] = None,
503
  encoder_hidden_states: Optional[torch.FloatTensor] = None,
504
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
 
547
  value_layer = value_layer.contiguous()
548
 
549
  if output_attentions:
550
+ assert attention_mask is not None, "output_attentions=True requires a concrete attention mask."
551
  attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
552
+ attention_scores = attention_scores.masked_fill(attention_mask.logical_not(), float("-inf"))
 
553
  attention_probs = F.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
554
  context_layer = torch.matmul(attention_probs, value_layer)
555
  else:
556
  attention_probs = None
557
+ if self.config.attn_backend == "flex":
558
  assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
559
  assert query_layer.dtype in (torch.float16, torch.bfloat16), (
560
  f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}."
561
  )
562
  assert is_cross_attention is False, "Flex attention backend currently does not support cross-attention."
563
  assert past_key_value is None, "Flex attention backend currently does not support KV caching."
564
+ assert flex_block_mask is not None, "Flex attention backend requires a block mask."
 
 
 
565
  context_layer = flex_attention(
566
  query_layer,
567
  key_layer,
 
601
 
602
  def forward(
603
  self,
604
+ hidden_states: torch.Tensor,
605
+ attention_mask: Optional[torch.Tensor],
606
+ head_mask: Optional[torch.Tensor] = None,
607
+ encoder_hidden_states: Optional[torch.Tensor] = None,
608
+ encoder_attention_mask: Optional[torch.Tensor] = None,
609
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
610
+ output_attentions: bool = False,
611
+ flex_block_mask: Optional[object] = None,
612
  ):
613
  hidden_states_ln = self.LayerNorm(hidden_states)
614
  self_outputs = self.self(
 
644
 
645
  def forward(
646
  self,
647
+ hidden_states: torch.Tensor,
648
+ attention_mask: Optional[torch.Tensor],
649
+ head_mask: Optional[torch.Tensor] = None,
650
+ encoder_hidden_states: Optional[torch.Tensor] = None,
651
+ encoder_attention_mask: Optional[torch.Tensor] = None,
652
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
653
+ output_attentions: bool = False,
654
+ flex_block_mask: Optional[object] = None,
655
  ):
656
  self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
657
  self_attention_outputs = self.attention(
 
710
 
711
  def forward(
712
  self,
713
+ hidden_states: torch.Tensor,
714
+ attention_mask: Optional[torch.Tensor],
715
+ head_mask: Optional[torch.Tensor] = None,
716
+ encoder_hidden_states: Optional[torch.Tensor] = None,
717
+ encoder_attention_mask: Optional[torch.Tensor] = None,
718
+ past_key_values: Optional[List[Tuple[Tuple[torch.FloatTensor]]]] = None,
719
+ use_cache: Optional[bool] = None,
720
+ output_attentions: bool = False,
721
+ output_hidden_states: bool = False,
722
+ return_dict: bool = True,
723
+ flex_block_mask: Optional[object] = None,
724
  ):
725
  all_hidden_states = () if output_hidden_states else None
726
  all_self_attentions = () if output_attentions else None
 
895
  past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
896
 
897
  if attention_mask is None:
898
+ token_attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device).bool()
899
+ elif attention_mask.dim() == 2:
 
 
900
  token_attention_mask = attention_mask.bool()
 
 
 
 
901
  elif attention_mask.dim() == 4:
902
+ assert input_ids is not None, "4D attention_mask requires input_ids to infer token-level mask."
903
+ token_attention_mask = input_ids.ne(self.config.pad_token_id)
 
 
 
 
904
  else:
905
  raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
906
 
 
919
  if embedding_attention_mask is None and input_ids is not None:
920
  embedding_attention_mask = input_ids.ne(self.config.pad_token_id)
921
 
922
+ if self.config.attn_backend == "flex" and output_attentions:
923
+ raise AssertionError("output_attentions=True is not supported with attn_backend='flex'.")
924
+
925
+ extended_attention_mask, flex_block_mask = get_attention_mask(
926
+ attn_backend=self.config.attn_backend,
927
+ batch_size=batch_size,
928
+ seq_len=seq_length,
929
+ device=device,
930
+ attention_mask=token_attention_mask,
931
+ )
932
 
933
  embedding_output = self.embeddings(
934
  input_ids=input_ids,