lhallee commited on
Commit
dbac9d3
·
verified ·
1 Parent(s): e24a972

Upload modeling_dplm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dplm.py +38 -27
modeling_dplm.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  import torch._inductor.config as inductor_config
3
  import torch._dynamo as dynamo
@@ -429,7 +431,8 @@ Contains: AttentionBackend enum, backend resolution, mask creation,
429
  flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
430
  """
431
  from enum import Enum
432
- from typing import Optional
 
433
 
434
  import torch
435
  import torch.nn as nn
@@ -447,7 +450,12 @@ _compiled_flex_attention = None
447
 
448
 
449
  def _get_flex_attention_fn():
450
- """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set."""
 
 
 
 
 
451
  global _compiled_flex_attention
452
  if flex_attention is None:
453
  return None
@@ -455,12 +463,15 @@ def _get_flex_attention_fn():
455
  if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
456
  return flex_attention
457
  if _compiled_flex_attention is None:
458
- _compiled_flex_attention = torch.compile(flex_attention)
 
 
 
459
  return _compiled_flex_attention
460
 
461
 
462
  ### Kernels Flash Attention Detection
463
- def _infer_kernels_flash_variant(kernel) -> str | None:
464
  if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
465
  return "flash_attn2"
466
  if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
@@ -576,7 +587,7 @@ class IndexFirstAxis(torch.autograd.Function):
576
  ).reshape(-1, *other_shape)
577
 
578
  @staticmethod
579
- def backward(ctx, grad_output) -> tuple[torch.Tensor, None]:
580
  (indices,) = ctx.saved_tensors
581
  assert grad_output.ndim >= 2
582
  other_shape = grad_output.shape[1:]
@@ -599,7 +610,7 @@ class IndexPutFirstAxis(torch.autograd.Function):
599
  return output
600
 
601
  @staticmethod
602
- def backward(ctx, grad_output) -> tuple[torch.Tensor, None, None]:
603
  (indices,) = ctx.saved_tensors
604
  return grad_output[indices], None, None
605
 
@@ -618,7 +629,7 @@ def _unpad_input(
618
  key_layer: torch.Tensor,
619
  value_layer: torch.Tensor,
620
  attention_mask_2d: torch.Tensor,
621
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[int, int]]:
622
  batch_size, seq_len, num_heads, head_dim = query_layer.shape
623
  seqlens = attention_mask_2d.sum(dim=1).int()
624
  cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
@@ -634,7 +645,7 @@ def kernels_flash_attention_func(
634
  query_states: torch.Tensor,
635
  key_states: torch.Tensor,
636
  value_states: torch.Tensor,
637
- attention_mask_2d: torch.Tensor | None = None,
638
  causal: bool = False,
639
  ) -> torch.Tensor:
640
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
@@ -707,7 +718,7 @@ def get_attention_mask(
707
  seq_len: int,
708
  device: torch.device,
709
  attention_mask: Optional[torch.Tensor] = None,
710
- ) -> tuple[torch.Tensor | None, torch.Tensor | None, "BlockMask | None"]:
711
  """Build padding masks once for all encoder layers.
712
 
713
  Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
@@ -782,7 +793,7 @@ class DPLMMaskedLMOutput(ModelOutput):
782
  last_hidden_state: Optional[torch.Tensor] = None
783
  hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
784
  attentions: Optional[Tuple[torch.Tensor, ...]] = None
785
- s_max: Optional[Tuple[list[torch.Tensor], ...]] = None
786
 
787
 
788
  @dataclass
@@ -790,7 +801,7 @@ class DPLMEncoderOutput(ModelOutput):
790
  last_hidden_state: Optional[torch.Tensor] = None
791
  hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
792
  attentions: Optional[Tuple[torch.Tensor, ...]] = None
793
- s_max: Optional[Tuple[list[torch.Tensor], ...]] = None
794
 
795
 
796
  class DPLMConfig(EsmConfig):
@@ -859,7 +870,7 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
859
  output_attentions: Optional[bool] = False,
860
  output_s_max: Optional[bool] = False,
861
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
862
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[list[torch.Tensor]]]:
863
  if past_key_values is not None:
864
  past_key_value = past_key_values
865
 
@@ -930,12 +941,12 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
930
  query_BHLD: torch.Tensor,
931
  key_BHLD: torch.Tensor,
932
  value_BHLD: torch.Tensor,
933
- attention_mask_2d: torch.Tensor | None = None,
934
- attention_mask_4d: torch.Tensor | None = None,
935
- flex_block_mask: "BlockMask | None" = None,
936
  output_attentions: bool = False,
937
  output_s_max: bool = False,
938
- ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]:
939
  if output_attentions:
940
  return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max)
941
 
@@ -952,7 +963,7 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
952
  return attn_output, attn_weights, s_max
953
 
954
  @torch.no_grad()
955
- def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) -> list[torch.Tensor]:
956
  q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1)
957
  k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1)
958
  s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values
@@ -963,9 +974,9 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
963
  query_BHLD: torch.Tensor,
964
  key_BHLD: torch.Tensor,
965
  value_BHLD: torch.Tensor,
966
- attention_mask_4d: torch.Tensor | None = None,
967
  output_s_max: bool = False,
968
- ) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]:
969
  attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2))
970
  if attention_mask_4d is not None:
971
  attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
@@ -980,8 +991,8 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
980
  query_BHLD: torch.Tensor,
981
  key_BHLD: torch.Tensor,
982
  value_BHLD: torch.Tensor,
983
- attention_mask_2d: torch.Tensor | None = None,
984
- ) -> tuple[torch.Tensor, None]:
985
  query_BLHD = query_BHLD.transpose(1, 2).contiguous()
986
  key_BLHD = key_BHLD.transpose(1, 2).contiguous()
987
  value_BLHD = value_BHLD.transpose(1, 2).contiguous()
@@ -996,8 +1007,8 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
996
  query_BHLD: torch.Tensor,
997
  key_BHLD: torch.Tensor,
998
  value_BHLD: torch.Tensor,
999
- flex_block_mask: "BlockMask | None" = None,
1000
- ) -> tuple[torch.Tensor, None]:
1001
  assert flex_attention is not None, "Flex attention is not available in this environment."
1002
  fn = _get_flex_attention_fn()
1003
  context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
@@ -1008,8 +1019,8 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
1008
  query_BHLD: torch.Tensor,
1009
  key_BHLD: torch.Tensor,
1010
  value_BHLD: torch.Tensor,
1011
- attention_mask_4d: torch.Tensor | None = None,
1012
- ) -> tuple[torch.Tensor, None]:
1013
  context_BHLD = F.scaled_dot_product_attention(
1014
  query_BHLD, key_BHLD, value_BHLD,
1015
  attn_mask=attention_mask_4d,
@@ -1038,7 +1049,7 @@ class ModifiedEsmAttention(EsmAttention):
1038
  past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1039
  output_attentions: bool = False,
1040
  output_s_max: bool = False,
1041
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[list[torch.Tensor]]]:
1042
  hidden_states_ln = self.LayerNorm(hidden_states)
1043
  attn_output, attn_weights, s_max = self.self(
1044
  hidden_states_ln,
@@ -1084,7 +1095,7 @@ class ModifiedEsmLayer(EsmLayer):
1084
  past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1085
  output_attentions: bool = False,
1086
  output_s_max: bool = False,
1087
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[list[torch.Tensor]]]:
1088
  attention_output, attn_weights, s_max = self.attention(
1089
  hidden_states,
1090
  attention_mask_2d=attention_mask_2d,
 
1
+ from __future__ import annotations
2
+
3
  import torch
4
  import torch._inductor.config as inductor_config
5
  import torch._dynamo as dynamo
 
431
  flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
432
  """
433
  from enum import Enum
434
+ from functools import partial
435
+ from typing import Dict, List, Optional, Tuple
436
 
437
  import torch
438
  import torch.nn as nn
 
450
 
451
 
452
  def _get_flex_attention_fn():
453
+ """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.
454
+
455
+ Uses kernel_options={"BACKEND": "FLASH"} to prefer Flash Attention 4 (FA4)
456
+ on Hopper/Blackwell GPUs (PyTorch 2.11+). Automatically falls back to Triton
457
+ on older hardware.
458
+ """
459
  global _compiled_flex_attention
460
  if flex_attention is None:
461
  return None
 
463
  if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
464
  return flex_attention
465
  if _compiled_flex_attention is None:
466
+ _compiled_flex_attention = torch.compile(
467
+ partial(flex_attention, kernel_options={"BACKEND": "FLASH"}),
468
+ dynamic=False,
469
+ )
470
  return _compiled_flex_attention
471
 
472
 
473
  ### Kernels Flash Attention Detection
474
+ def _infer_kernels_flash_variant(kernel) -> Optional[str]:
475
  if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
476
  return "flash_attn2"
477
  if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
 
587
  ).reshape(-1, *other_shape)
588
 
589
  @staticmethod
590
+ def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]:
591
  (indices,) = ctx.saved_tensors
592
  assert grad_output.ndim >= 2
593
  other_shape = grad_output.shape[1:]
 
610
  return output
611
 
612
  @staticmethod
613
+ def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]:
614
  (indices,) = ctx.saved_tensors
615
  return grad_output[indices], None, None
616
 
 
629
  key_layer: torch.Tensor,
630
  value_layer: torch.Tensor,
631
  attention_mask_2d: torch.Tensor,
632
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
633
  batch_size, seq_len, num_heads, head_dim = query_layer.shape
634
  seqlens = attention_mask_2d.sum(dim=1).int()
635
  cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
 
645
  query_states: torch.Tensor,
646
  key_states: torch.Tensor,
647
  value_states: torch.Tensor,
648
+ attention_mask_2d: Optional[torch.Tensor] = None,
649
  causal: bool = False,
650
  ) -> torch.Tensor:
651
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
 
718
  seq_len: int,
719
  device: torch.device,
720
  attention_mask: Optional[torch.Tensor] = None,
721
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]:
722
  """Build padding masks once for all encoder layers.
723
 
724
  Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
 
793
  last_hidden_state: Optional[torch.Tensor] = None
794
  hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
795
  attentions: Optional[Tuple[torch.Tensor, ...]] = None
796
+ s_max: Optional[Tuple[List[torch.Tensor], ...]] = None
797
 
798
 
799
  @dataclass
 
801
  last_hidden_state: Optional[torch.Tensor] = None
802
  hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
803
  attentions: Optional[Tuple[torch.Tensor, ...]] = None
804
+ s_max: Optional[Tuple[List[torch.Tensor], ...]] = None
805
 
806
 
807
  class DPLMConfig(EsmConfig):
 
870
  output_attentions: Optional[bool] = False,
871
  output_s_max: Optional[bool] = False,
872
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
873
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
874
  if past_key_values is not None:
875
  past_key_value = past_key_values
876
 
 
941
  query_BHLD: torch.Tensor,
942
  key_BHLD: torch.Tensor,
943
  value_BHLD: torch.Tensor,
944
+ attention_mask_2d: Optional[torch.Tensor] = None,
945
+ attention_mask_4d: Optional[torch.Tensor] = None,
946
+ flex_block_mask: Optional[BlockMask] = None,
947
  output_attentions: bool = False,
948
  output_s_max: bool = False,
949
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
950
  if output_attentions:
951
  return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max)
952
 
 
963
  return attn_output, attn_weights, s_max
964
 
965
  @torch.no_grad()
966
+ def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) -> List[torch.Tensor]:
967
  q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1)
968
  k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1)
969
  s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values
 
974
  query_BHLD: torch.Tensor,
975
  key_BHLD: torch.Tensor,
976
  value_BHLD: torch.Tensor,
977
+ attention_mask_4d: Optional[torch.Tensor] = None,
978
  output_s_max: bool = False,
979
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[List[torch.Tensor]]]:
980
  attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2))
981
  if attention_mask_4d is not None:
982
  attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
 
991
  query_BHLD: torch.Tensor,
992
  key_BHLD: torch.Tensor,
993
  value_BHLD: torch.Tensor,
994
+ attention_mask_2d: Optional[torch.Tensor] = None,
995
+ ) -> Tuple[torch.Tensor, None]:
996
  query_BLHD = query_BHLD.transpose(1, 2).contiguous()
997
  key_BLHD = key_BHLD.transpose(1, 2).contiguous()
998
  value_BLHD = value_BHLD.transpose(1, 2).contiguous()
 
1007
  query_BHLD: torch.Tensor,
1008
  key_BHLD: torch.Tensor,
1009
  value_BHLD: torch.Tensor,
1010
+ flex_block_mask: Optional[BlockMask] = None,
1011
+ ) -> Tuple[torch.Tensor, None]:
1012
  assert flex_attention is not None, "Flex attention is not available in this environment."
1013
  fn = _get_flex_attention_fn()
1014
  context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
 
1019
  query_BHLD: torch.Tensor,
1020
  key_BHLD: torch.Tensor,
1021
  value_BHLD: torch.Tensor,
1022
+ attention_mask_4d: Optional[torch.Tensor] = None,
1023
+ ) -> Tuple[torch.Tensor, None]:
1024
  context_BHLD = F.scaled_dot_product_attention(
1025
  query_BHLD, key_BHLD, value_BHLD,
1026
  attn_mask=attention_mask_4d,
 
1049
  past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1050
  output_attentions: bool = False,
1051
  output_s_max: bool = False,
1052
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
1053
  hidden_states_ln = self.LayerNorm(hidden_states)
1054
  attn_output, attn_weights, s_max = self.self(
1055
  hidden_states_ln,
 
1095
  past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1096
  output_attentions: bool = False,
1097
  output_s_max: bool = False,
1098
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
1099
  attention_output, attn_weights, s_max = self.attention(
1100
  hidden_states,
1101
  attention_mask_2d=attention_mask_2d,