Transformers
Safetensors
dplm2
custom_code
lhallee commited on
Commit
88c7873
·
verified ·
1 Parent(s): b04f443

Upload modeling_dplm2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dplm2.py +47 -36
modeling_dplm2.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  import torch._inductor.config as inductor_config
3
  import torch._dynamo as dynamo
@@ -429,7 +431,8 @@ Contains: AttentionBackend enum, backend resolution, mask creation,
429
  flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
430
  """
431
  from enum import Enum
432
- from typing import Optional
 
433
 
434
  import torch
435
  import torch.nn as nn
@@ -447,7 +450,12 @@ _compiled_flex_attention = None
447
 
448
 
449
  def _get_flex_attention_fn():
450
- """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set."""
 
 
 
 
 
451
  global _compiled_flex_attention
452
  if flex_attention is None:
453
  return None
@@ -455,12 +463,15 @@ def _get_flex_attention_fn():
455
  if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
456
  return flex_attention
457
  if _compiled_flex_attention is None:
458
- _compiled_flex_attention = torch.compile(flex_attention)
 
 
 
459
  return _compiled_flex_attention
460
 
461
 
462
  ### Kernels Flash Attention Detection
463
- def _infer_kernels_flash_variant(kernel) -> str | None:
464
  if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
465
  return "flash_attn2"
466
  if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
@@ -576,7 +587,7 @@ class IndexFirstAxis(torch.autograd.Function):
576
  ).reshape(-1, *other_shape)
577
 
578
  @staticmethod
579
- def backward(ctx, grad_output) -> tuple[torch.Tensor, None]:
580
  (indices,) = ctx.saved_tensors
581
  assert grad_output.ndim >= 2
582
  other_shape = grad_output.shape[1:]
@@ -599,7 +610,7 @@ class IndexPutFirstAxis(torch.autograd.Function):
599
  return output
600
 
601
  @staticmethod
602
- def backward(ctx, grad_output) -> tuple[torch.Tensor, None, None]:
603
  (indices,) = ctx.saved_tensors
604
  return grad_output[indices], None, None
605
 
@@ -618,7 +629,7 @@ def _unpad_input(
618
  key_layer: torch.Tensor,
619
  value_layer: torch.Tensor,
620
  attention_mask_2d: torch.Tensor,
621
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[int, int]]:
622
  batch_size, seq_len, num_heads, head_dim = query_layer.shape
623
  seqlens = attention_mask_2d.sum(dim=1).int()
624
  cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
@@ -634,7 +645,7 @@ def kernels_flash_attention_func(
634
  query_states: torch.Tensor,
635
  key_states: torch.Tensor,
636
  value_states: torch.Tensor,
637
- attention_mask_2d: torch.Tensor | None = None,
638
  causal: bool = False,
639
  ) -> torch.Tensor:
640
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
@@ -707,7 +718,7 @@ def get_attention_mask(
707
  seq_len: int,
708
  device: torch.device,
709
  attention_mask: Optional[torch.Tensor] = None,
710
- ) -> tuple[torch.Tensor | None, torch.Tensor | None, "BlockMask | None"]:
711
  """Build padding masks once for all encoder layers.
712
 
713
  Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
@@ -838,7 +849,7 @@ class DPLM2MaskedLMOutput(ModelOutput):
838
  last_hidden_state: Optional[torch.Tensor] = None
839
  hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
840
  attentions: Optional[Tuple[torch.Tensor, ...]] = None
841
- s_max: Optional[Tuple[list[torch.Tensor], ...]] = None
842
 
843
 
844
  @dataclass
@@ -846,7 +857,7 @@ class DPLM2EncoderOutput(ModelOutput):
846
  last_hidden_state: Optional[torch.Tensor] = None
847
  hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
848
  attentions: Optional[Tuple[torch.Tensor, ...]] = None
849
- s_max: Optional[Tuple[list[torch.Tensor], ...]] = None
850
 
851
 
852
  class DPLM2Config(EsmConfig):
@@ -986,13 +997,13 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
986
  def forward(
987
  self,
988
  hidden_states: torch.Tensor,
989
- attention_mask_2d: torch.Tensor | None = None,
990
- attention_mask_4d: torch.Tensor | None = None,
991
- flex_block_mask: "BlockMask | None" = None,
992
  output_attentions: bool = False,
993
  output_s_max: bool = False,
994
  type_ids: Optional[torch.Tensor] = None,
995
- ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]:
996
  batch_size, seq_length = hidden_states.shape[:-1]
997
  hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
998
  query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
@@ -1019,12 +1030,12 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
1019
  query_BHLD: torch.Tensor,
1020
  key_BHLD: torch.Tensor,
1021
  value_BHLD: torch.Tensor,
1022
- attention_mask_2d: torch.Tensor | None = None,
1023
- attention_mask_4d: torch.Tensor | None = None,
1024
- flex_block_mask: "BlockMask | None" = None,
1025
  output_attentions: bool = False,
1026
  output_s_max: bool = False,
1027
- ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]:
1028
  if output_attentions:
1029
  return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max)
1030
 
@@ -1041,7 +1052,7 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
1041
  return attn_output, attn_weights, s_max
1042
 
1043
  @torch.no_grad()
1044
- def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) -> list[torch.Tensor]:
1045
  q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1)
1046
  k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1)
1047
  s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values
@@ -1052,9 +1063,9 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
1052
  query_BHLD: torch.Tensor,
1053
  key_BHLD: torch.Tensor,
1054
  value_BHLD: torch.Tensor,
1055
- attention_mask_4d: torch.Tensor | None = None,
1056
  output_s_max: bool = False,
1057
- ) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]:
1058
  attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2))
1059
  if attention_mask_4d is not None:
1060
  attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
@@ -1071,8 +1082,8 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
1071
  query_BHLD: torch.Tensor,
1072
  key_BHLD: torch.Tensor,
1073
  value_BHLD: torch.Tensor,
1074
- attention_mask_2d: torch.Tensor | None = None,
1075
- ) -> tuple[torch.Tensor, None]:
1076
  query_BLHD = query_BHLD.transpose(1, 2).contiguous()
1077
  key_BLHD = key_BHLD.transpose(1, 2).contiguous()
1078
  value_BLHD = value_BHLD.transpose(1, 2).contiguous()
@@ -1087,8 +1098,8 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
1087
  query_BHLD: torch.Tensor,
1088
  key_BHLD: torch.Tensor,
1089
  value_BHLD: torch.Tensor,
1090
- flex_block_mask: "BlockMask | None" = None,
1091
- ) -> tuple[torch.Tensor, None]:
1092
  assert flex_attention is not None, "Flex attention is not available in this environment."
1093
  fn = _get_flex_attention_fn()
1094
  context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
@@ -1099,8 +1110,8 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
1099
  query_BHLD: torch.Tensor,
1100
  key_BHLD: torch.Tensor,
1101
  value_BHLD: torch.Tensor,
1102
- attention_mask_4d: torch.Tensor | None = None,
1103
- ) -> tuple[torch.Tensor, None]:
1104
  context_BHLD = F.scaled_dot_product_attention(
1105
  query_BHLD, key_BHLD, value_BHLD,
1106
  attn_mask=attention_mask_4d,
@@ -1120,13 +1131,13 @@ class ModifiedEsmAttention(EsmAttention):
1120
  def forward(
1121
  self,
1122
  hidden_states: torch.Tensor,
1123
- attention_mask_2d: torch.Tensor | None = None,
1124
- attention_mask_4d: torch.Tensor | None = None,
1125
- flex_block_mask: "BlockMask | None" = None,
1126
  output_attentions: bool = False,
1127
  output_s_max: bool = False,
1128
  type_ids: Optional[torch.Tensor] = None,
1129
- ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]:
1130
  hidden_states_ln = self.LayerNorm(hidden_states)
1131
  attn_output, attn_weights, s_max = self.self(
1132
  hidden_states_ln,
@@ -1154,13 +1165,13 @@ class ModifiedEsmLayer(EsmLayer):
1154
  def forward(
1155
  self,
1156
  hidden_states: torch.Tensor,
1157
- attention_mask_2d: torch.Tensor | None = None,
1158
- attention_mask_4d: torch.Tensor | None = None,
1159
- flex_block_mask: "BlockMask | None" = None,
1160
  output_attentions: bool = False,
1161
  output_s_max: bool = False,
1162
  type_ids: Optional[torch.Tensor] = None,
1163
- ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]:
1164
  attention_output, attn_weights, s_max = self.attention(
1165
  hidden_states,
1166
  attention_mask_2d=attention_mask_2d,
 
1
+ from __future__ import annotations
2
+
3
  import torch
4
  import torch._inductor.config as inductor_config
5
  import torch._dynamo as dynamo
 
431
  flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
432
  """
433
  from enum import Enum
434
+ from functools import partial
435
+ from typing import Dict, List, Optional, Tuple
436
 
437
  import torch
438
  import torch.nn as nn
 
450
 
451
 
452
  def _get_flex_attention_fn():
453
+ """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.
454
+
455
+ Uses kernel_options={"BACKEND": "FLASH"} to prefer Flash Attention 4 (FA4)
456
+ on Hopper/Blackwell GPUs (PyTorch 2.11+). Automatically falls back to Triton
457
+ on older hardware.
458
+ """
459
  global _compiled_flex_attention
460
  if flex_attention is None:
461
  return None
 
463
  if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
464
  return flex_attention
465
  if _compiled_flex_attention is None:
466
+ _compiled_flex_attention = torch.compile(
467
+ partial(flex_attention, kernel_options={"BACKEND": "FLASH"}),
468
+ dynamic=False,
469
+ )
470
  return _compiled_flex_attention
471
 
472
 
473
  ### Kernels Flash Attention Detection
474
+ def _infer_kernels_flash_variant(kernel) -> Optional[str]:
475
  if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
476
  return "flash_attn2"
477
  if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
 
587
  ).reshape(-1, *other_shape)
588
 
589
  @staticmethod
590
+ def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]:
591
  (indices,) = ctx.saved_tensors
592
  assert grad_output.ndim >= 2
593
  other_shape = grad_output.shape[1:]
 
610
  return output
611
 
612
  @staticmethod
613
+ def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]:
614
  (indices,) = ctx.saved_tensors
615
  return grad_output[indices], None, None
616
 
 
629
  key_layer: torch.Tensor,
630
  value_layer: torch.Tensor,
631
  attention_mask_2d: torch.Tensor,
632
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
633
  batch_size, seq_len, num_heads, head_dim = query_layer.shape
634
  seqlens = attention_mask_2d.sum(dim=1).int()
635
  cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
 
645
  query_states: torch.Tensor,
646
  key_states: torch.Tensor,
647
  value_states: torch.Tensor,
648
+ attention_mask_2d: Optional[torch.Tensor] = None,
649
  causal: bool = False,
650
  ) -> torch.Tensor:
651
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
 
718
  seq_len: int,
719
  device: torch.device,
720
  attention_mask: Optional[torch.Tensor] = None,
721
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]:
722
  """Build padding masks once for all encoder layers.
723
 
724
  Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
 
849
  last_hidden_state: Optional[torch.Tensor] = None
850
  hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
851
  attentions: Optional[Tuple[torch.Tensor, ...]] = None
852
+ s_max: Optional[Tuple[List[torch.Tensor], ...]] = None
853
 
854
 
855
  @dataclass
 
857
  last_hidden_state: Optional[torch.Tensor] = None
858
  hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
859
  attentions: Optional[Tuple[torch.Tensor, ...]] = None
860
+ s_max: Optional[Tuple[List[torch.Tensor], ...]] = None
861
 
862
 
863
  class DPLM2Config(EsmConfig):
 
997
  def forward(
998
  self,
999
  hidden_states: torch.Tensor,
1000
+ attention_mask_2d: Optional[torch.Tensor] = None,
1001
+ attention_mask_4d: Optional[torch.Tensor] = None,
1002
+ flex_block_mask: Optional[BlockMask] = None,
1003
  output_attentions: bool = False,
1004
  output_s_max: bool = False,
1005
  type_ids: Optional[torch.Tensor] = None,
1006
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
1007
  batch_size, seq_length = hidden_states.shape[:-1]
1008
  hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
1009
  query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
 
1030
  query_BHLD: torch.Tensor,
1031
  key_BHLD: torch.Tensor,
1032
  value_BHLD: torch.Tensor,
1033
+ attention_mask_2d: Optional[torch.Tensor] = None,
1034
+ attention_mask_4d: Optional[torch.Tensor] = None,
1035
+ flex_block_mask: Optional[BlockMask] = None,
1036
  output_attentions: bool = False,
1037
  output_s_max: bool = False,
1038
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
1039
  if output_attentions:
1040
  return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max)
1041
 
 
1052
  return attn_output, attn_weights, s_max
1053
 
1054
  @torch.no_grad()
1055
+ def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) -> List[torch.Tensor]:
1056
  q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1)
1057
  k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1)
1058
  s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values
 
1063
  query_BHLD: torch.Tensor,
1064
  key_BHLD: torch.Tensor,
1065
  value_BHLD: torch.Tensor,
1066
+ attention_mask_4d: Optional[torch.Tensor] = None,
1067
  output_s_max: bool = False,
1068
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[List[torch.Tensor]]]:
1069
  attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2))
1070
  if attention_mask_4d is not None:
1071
  attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
 
1082
  query_BHLD: torch.Tensor,
1083
  key_BHLD: torch.Tensor,
1084
  value_BHLD: torch.Tensor,
1085
+ attention_mask_2d: Optional[torch.Tensor] = None,
1086
+ ) -> Tuple[torch.Tensor, None]:
1087
  query_BLHD = query_BHLD.transpose(1, 2).contiguous()
1088
  key_BLHD = key_BHLD.transpose(1, 2).contiguous()
1089
  value_BLHD = value_BHLD.transpose(1, 2).contiguous()
 
1098
  query_BHLD: torch.Tensor,
1099
  key_BHLD: torch.Tensor,
1100
  value_BHLD: torch.Tensor,
1101
+ flex_block_mask: Optional[BlockMask] = None,
1102
+ ) -> Tuple[torch.Tensor, None]:
1103
  assert flex_attention is not None, "Flex attention is not available in this environment."
1104
  fn = _get_flex_attention_fn()
1105
  context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
 
1110
  query_BHLD: torch.Tensor,
1111
  key_BHLD: torch.Tensor,
1112
  value_BHLD: torch.Tensor,
1113
+ attention_mask_4d: Optional[torch.Tensor] = None,
1114
+ ) -> Tuple[torch.Tensor, None]:
1115
  context_BHLD = F.scaled_dot_product_attention(
1116
  query_BHLD, key_BHLD, value_BHLD,
1117
  attn_mask=attention_mask_4d,
 
1131
  def forward(
1132
  self,
1133
  hidden_states: torch.Tensor,
1134
+ attention_mask_2d: Optional[torch.Tensor] = None,
1135
+ attention_mask_4d: Optional[torch.Tensor] = None,
1136
+ flex_block_mask: Optional[BlockMask] = None,
1137
  output_attentions: bool = False,
1138
  output_s_max: bool = False,
1139
  type_ids: Optional[torch.Tensor] = None,
1140
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
1141
  hidden_states_ln = self.LayerNorm(hidden_states)
1142
  attn_output, attn_weights, s_max = self.self(
1143
  hidden_states_ln,
 
1165
  def forward(
1166
  self,
1167
  hidden_states: torch.Tensor,
1168
+ attention_mask_2d: Optional[torch.Tensor] = None,
1169
+ attention_mask_4d: Optional[torch.Tensor] = None,
1170
+ flex_block_mask: Optional[BlockMask] = None,
1171
  output_attentions: bool = False,
1172
  output_s_max: bool = False,
1173
  type_ids: Optional[torch.Tensor] = None,
1174
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
1175
  attention_output, attn_weights, s_max = self.attention(
1176
  hidden_states,
1177
  attention_mask_2d=attention_mask_2d,