lhallee commited on
Commit
34601d3
·
verified ·
1 Parent(s): 0b91ce1

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +48 -37
modeling_fastesm.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  import torch._inductor.config as inductor_config
3
  import torch._dynamo as dynamo
@@ -429,7 +431,8 @@ Contains: AttentionBackend enum, backend resolution, mask creation,
429
  flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
430
  """
431
  from enum import Enum
432
- from typing import Optional
 
433
 
434
  import torch
435
  import torch.nn as nn
@@ -447,7 +450,12 @@ _compiled_flex_attention = None
447
 
448
 
449
  def _get_flex_attention_fn():
450
- """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set."""
 
 
 
 
 
451
  global _compiled_flex_attention
452
  if flex_attention is None:
453
  return None
@@ -455,12 +463,15 @@ def _get_flex_attention_fn():
455
  if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
456
  return flex_attention
457
  if _compiled_flex_attention is None:
458
- _compiled_flex_attention = torch.compile(flex_attention)
 
 
 
459
  return _compiled_flex_attention
460
 
461
 
462
  ### Kernels Flash Attention Detection
463
- def _infer_kernels_flash_variant(kernel) -> str | None:
464
  if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
465
  return "flash_attn2"
466
  if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
@@ -576,7 +587,7 @@ class IndexFirstAxis(torch.autograd.Function):
576
  ).reshape(-1, *other_shape)
577
 
578
  @staticmethod
579
- def backward(ctx, grad_output) -> tuple[torch.Tensor, None]:
580
  (indices,) = ctx.saved_tensors
581
  assert grad_output.ndim >= 2
582
  other_shape = grad_output.shape[1:]
@@ -599,7 +610,7 @@ class IndexPutFirstAxis(torch.autograd.Function):
599
  return output
600
 
601
  @staticmethod
602
- def backward(ctx, grad_output) -> tuple[torch.Tensor, None, None]:
603
  (indices,) = ctx.saved_tensors
604
  return grad_output[indices], None, None
605
 
@@ -618,7 +629,7 @@ def _unpad_input(
618
  key_layer: torch.Tensor,
619
  value_layer: torch.Tensor,
620
  attention_mask_2d: torch.Tensor,
621
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[int, int]]:
622
  batch_size, seq_len, num_heads, head_dim = query_layer.shape
623
  seqlens = attention_mask_2d.sum(dim=1).int()
624
  cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
@@ -634,7 +645,7 @@ def kernels_flash_attention_func(
634
  query_states: torch.Tensor,
635
  key_states: torch.Tensor,
636
  value_states: torch.Tensor,
637
- attention_mask_2d: torch.Tensor | None = None,
638
  causal: bool = False,
639
  ) -> torch.Tensor:
640
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
@@ -707,7 +718,7 @@ def get_attention_mask(
707
  seq_len: int,
708
  device: torch.device,
709
  attention_mask: Optional[torch.Tensor] = None,
710
- ) -> tuple[torch.Tensor | None, torch.Tensor | None, "BlockMask | None"]:
711
  """Build padding masks once for all encoder layers.
712
 
713
  Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
@@ -738,7 +749,7 @@ def get_attention_mask(
738
  import torch
739
  import torch.nn as nn
740
  from torch.nn import functional as F
741
- from typing import Optional, Tuple, Dict, Any
742
  from einops import rearrange
743
  from dataclasses import dataclass
744
  from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
@@ -762,7 +773,7 @@ class FastEsmEncoderOutput(ModelOutput):
762
  last_hidden_state: Optional[torch.Tensor] = None
763
  hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
764
  attentions: Optional[Tuple[torch.Tensor, ...]] = None
765
- s_max: Optional[Tuple[list[torch.Tensor], ...]] = None
766
 
767
 
768
  @dataclass
@@ -772,7 +783,7 @@ class EsmMaskedLMOutput(ModelOutput):
772
  last_hidden_state: Optional[torch.Tensor] = None
773
  hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
774
  attentions: Optional[Tuple[torch.Tensor, ...]] = None
775
- s_max: Optional[Tuple[list[torch.Tensor], ...]] = None
776
 
777
 
778
  class FastEsmConfig(PretrainedConfig):
@@ -858,12 +869,12 @@ class EsmSelfAttention(nn.Module):
858
  def forward(
859
  self,
860
  hidden_states: torch.Tensor,
861
- attention_mask_2d: torch.Tensor | None = None,
862
- attention_mask_4d: torch.Tensor | None = None,
863
- flex_block_mask: "BlockMask | None" = None,
864
  output_attentions: bool = False,
865
  output_s_max: bool = False,
866
- ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]:
867
  batch_size, seq_length = hidden_states.shape[:-1]
868
  hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
869
  query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
@@ -890,12 +901,12 @@ class EsmSelfAttention(nn.Module):
890
  query_BHLD: torch.Tensor,
891
  key_BHLD: torch.Tensor,
892
  value_BHLD: torch.Tensor,
893
- attention_mask_2d: torch.Tensor | None = None,
894
- attention_mask_4d: torch.Tensor | None = None,
895
- flex_block_mask: "BlockMask | None" = None,
896
  output_attentions: bool = False,
897
  output_s_max: bool = False,
898
- ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]:
899
  if output_attentions:
900
  return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max)
901
 
@@ -912,7 +923,7 @@ class EsmSelfAttention(nn.Module):
912
  return attn_output, attn_weights, s_max
913
 
914
  @torch.no_grad()
915
- def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) -> list[torch.Tensor]:
916
  q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1)
917
  k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1)
918
  s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values
@@ -923,9 +934,9 @@ class EsmSelfAttention(nn.Module):
923
  query_BHLD: torch.Tensor,
924
  key_BHLD: torch.Tensor,
925
  value_BHLD: torch.Tensor,
926
- attention_mask_4d: torch.Tensor | None = None,
927
  output_s_max: bool = False,
928
- ) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]:
929
  attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2))
930
  if attention_mask_4d is not None:
931
  attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
@@ -942,8 +953,8 @@ class EsmSelfAttention(nn.Module):
942
  query_BHLD: torch.Tensor,
943
  key_BHLD: torch.Tensor,
944
  value_BHLD: torch.Tensor,
945
- attention_mask_2d: torch.Tensor | None = None,
946
- ) -> tuple[torch.Tensor, None]:
947
  query_BLHD = query_BHLD.transpose(1, 2).contiguous()
948
  key_BLHD = key_BHLD.transpose(1, 2).contiguous()
949
  value_BLHD = value_BHLD.transpose(1, 2).contiguous()
@@ -958,8 +969,8 @@ class EsmSelfAttention(nn.Module):
958
  query_BHLD: torch.Tensor,
959
  key_BHLD: torch.Tensor,
960
  value_BHLD: torch.Tensor,
961
- flex_block_mask: "BlockMask | None" = None,
962
- ) -> tuple[torch.Tensor, None]:
963
  assert flex_attention is not None, "Flex attention is not available in this environment."
964
  fn = _get_flex_attention_fn()
965
  context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
@@ -970,8 +981,8 @@ class EsmSelfAttention(nn.Module):
970
  query_BHLD: torch.Tensor,
971
  key_BHLD: torch.Tensor,
972
  value_BHLD: torch.Tensor,
973
- attention_mask_4d: torch.Tensor | None = None,
974
- ) -> tuple[torch.Tensor, None]:
975
  context_BHLD = F.scaled_dot_product_attention(
976
  query_BHLD, key_BHLD, value_BHLD,
977
  attn_mask=attention_mask_4d,
@@ -991,12 +1002,12 @@ class EsmAttention(nn.Module):
991
  def forward(
992
  self,
993
  hidden_states: torch.Tensor,
994
- attention_mask_2d: torch.Tensor | None = None,
995
- attention_mask_4d: torch.Tensor | None = None,
996
- flex_block_mask: "BlockMask | None" = None,
997
  output_attentions: bool = False,
998
  output_s_max: bool = False,
999
- ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]:
1000
  hidden_states_ln = self.LayerNorm(hidden_states)
1001
  attn_output, attn_weights, s_max = self.self(
1002
  hidden_states_ln,
@@ -1023,12 +1034,12 @@ class EsmLayer(nn.Module):
1023
  def forward(
1024
  self,
1025
  hidden_states: torch.Tensor,
1026
- attention_mask_2d: torch.Tensor | None = None,
1027
- attention_mask_4d: torch.Tensor | None = None,
1028
- flex_block_mask: "BlockMask | None" = None,
1029
  output_attentions: bool = False,
1030
  output_s_max: bool = False,
1031
- ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]:
1032
  attention_output, attn_weights, s_max = self.attention(
1033
  hidden_states,
1034
  attention_mask_2d=attention_mask_2d,
 
1
+ from __future__ import annotations
2
+
3
  import torch
4
  import torch._inductor.config as inductor_config
5
  import torch._dynamo as dynamo
 
431
  flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
432
  """
433
  from enum import Enum
434
+ from functools import partial
435
+ from typing import Dict, List, Optional, Tuple
436
 
437
  import torch
438
  import torch.nn as nn
 
450
 
451
 
452
  def _get_flex_attention_fn():
453
+ """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.
454
+
455
+ Uses kernel_options={"BACKEND": "FLASH"} to prefer Flash Attention 4 (FA4)
456
+ on Hopper/Blackwell GPUs (PyTorch 2.11+). Automatically falls back to Triton
457
+ on older hardware.
458
+ """
459
  global _compiled_flex_attention
460
  if flex_attention is None:
461
  return None
 
463
  if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
464
  return flex_attention
465
  if _compiled_flex_attention is None:
466
+ _compiled_flex_attention = torch.compile(
467
+ partial(flex_attention, kernel_options={"BACKEND": "FLASH"}),
468
+ dynamic=False,
469
+ )
470
  return _compiled_flex_attention
471
 
472
 
473
  ### Kernels Flash Attention Detection
474
+ def _infer_kernels_flash_variant(kernel) -> Optional[str]:
475
  if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
476
  return "flash_attn2"
477
  if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
 
587
  ).reshape(-1, *other_shape)
588
 
589
  @staticmethod
590
+ def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]:
591
  (indices,) = ctx.saved_tensors
592
  assert grad_output.ndim >= 2
593
  other_shape = grad_output.shape[1:]
 
610
  return output
611
 
612
  @staticmethod
613
+ def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]:
614
  (indices,) = ctx.saved_tensors
615
  return grad_output[indices], None, None
616
 
 
629
  key_layer: torch.Tensor,
630
  value_layer: torch.Tensor,
631
  attention_mask_2d: torch.Tensor,
632
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
633
  batch_size, seq_len, num_heads, head_dim = query_layer.shape
634
  seqlens = attention_mask_2d.sum(dim=1).int()
635
  cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
 
645
  query_states: torch.Tensor,
646
  key_states: torch.Tensor,
647
  value_states: torch.Tensor,
648
+ attention_mask_2d: Optional[torch.Tensor] = None,
649
  causal: bool = False,
650
  ) -> torch.Tensor:
651
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
 
718
  seq_len: int,
719
  device: torch.device,
720
  attention_mask: Optional[torch.Tensor] = None,
721
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]:
722
  """Build padding masks once for all encoder layers.
723
 
724
  Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
 
749
  import torch
750
  import torch.nn as nn
751
  from torch.nn import functional as F
752
+ from typing import Any, Dict, List, Optional, Tuple
753
  from einops import rearrange
754
  from dataclasses import dataclass
755
  from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
 
773
  last_hidden_state: Optional[torch.Tensor] = None
774
  hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
775
  attentions: Optional[Tuple[torch.Tensor, ...]] = None
776
+ s_max: Optional[Tuple[List[torch.Tensor], ...]] = None
777
 
778
 
779
  @dataclass
 
783
  last_hidden_state: Optional[torch.Tensor] = None
784
  hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
785
  attentions: Optional[Tuple[torch.Tensor, ...]] = None
786
+ s_max: Optional[Tuple[List[torch.Tensor], ...]] = None
787
 
788
 
789
  class FastEsmConfig(PretrainedConfig):
 
869
  def forward(
870
  self,
871
  hidden_states: torch.Tensor,
872
+ attention_mask_2d: Optional[torch.Tensor] = None,
873
+ attention_mask_4d: Optional[torch.Tensor] = None,
874
+ flex_block_mask: Optional[BlockMask] = None,
875
  output_attentions: bool = False,
876
  output_s_max: bool = False,
877
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
878
  batch_size, seq_length = hidden_states.shape[:-1]
879
  hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
880
  query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
 
901
  query_BHLD: torch.Tensor,
902
  key_BHLD: torch.Tensor,
903
  value_BHLD: torch.Tensor,
904
+ attention_mask_2d: Optional[torch.Tensor] = None,
905
+ attention_mask_4d: Optional[torch.Tensor] = None,
906
+ flex_block_mask: Optional[BlockMask] = None,
907
  output_attentions: bool = False,
908
  output_s_max: bool = False,
909
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
910
  if output_attentions:
911
  return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max)
912
 
 
923
  return attn_output, attn_weights, s_max
924
 
925
  @torch.no_grad()
926
+ def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) -> List[torch.Tensor]:
927
  q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1)
928
  k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1)
929
  s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values
 
934
  query_BHLD: torch.Tensor,
935
  key_BHLD: torch.Tensor,
936
  value_BHLD: torch.Tensor,
937
+ attention_mask_4d: Optional[torch.Tensor] = None,
938
  output_s_max: bool = False,
939
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[List[torch.Tensor]]]:
940
  attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2))
941
  if attention_mask_4d is not None:
942
  attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
 
953
  query_BHLD: torch.Tensor,
954
  key_BHLD: torch.Tensor,
955
  value_BHLD: torch.Tensor,
956
+ attention_mask_2d: Optional[torch.Tensor] = None,
957
+ ) -> Tuple[torch.Tensor, None]:
958
  query_BLHD = query_BHLD.transpose(1, 2).contiguous()
959
  key_BLHD = key_BHLD.transpose(1, 2).contiguous()
960
  value_BLHD = value_BHLD.transpose(1, 2).contiguous()
 
969
  query_BHLD: torch.Tensor,
970
  key_BHLD: torch.Tensor,
971
  value_BHLD: torch.Tensor,
972
+ flex_block_mask: Optional[BlockMask] = None,
973
+ ) -> Tuple[torch.Tensor, None]:
974
  assert flex_attention is not None, "Flex attention is not available in this environment."
975
  fn = _get_flex_attention_fn()
976
  context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
 
981
  query_BHLD: torch.Tensor,
982
  key_BHLD: torch.Tensor,
983
  value_BHLD: torch.Tensor,
984
+ attention_mask_4d: Optional[torch.Tensor] = None,
985
+ ) -> Tuple[torch.Tensor, None]:
986
  context_BHLD = F.scaled_dot_product_attention(
987
  query_BHLD, key_BHLD, value_BHLD,
988
  attn_mask=attention_mask_4d,
 
1002
  def forward(
1003
  self,
1004
  hidden_states: torch.Tensor,
1005
+ attention_mask_2d: Optional[torch.Tensor] = None,
1006
+ attention_mask_4d: Optional[torch.Tensor] = None,
1007
+ flex_block_mask: Optional[BlockMask] = None,
1008
  output_attentions: bool = False,
1009
  output_s_max: bool = False,
1010
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
1011
  hidden_states_ln = self.LayerNorm(hidden_states)
1012
  attn_output, attn_weights, s_max = self.self(
1013
  hidden_states_ln,
 
1034
  def forward(
1035
  self,
1036
  hidden_states: torch.Tensor,
1037
+ attention_mask_2d: Optional[torch.Tensor] = None,
1038
+ attention_mask_4d: Optional[torch.Tensor] = None,
1039
+ flex_block_mask: Optional[BlockMask] = None,
1040
  output_attentions: bool = False,
1041
  output_s_max: bool = False,
1042
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
1043
  attention_output, attn_weights, s_max = self.attention(
1044
  hidden_states,
1045
  attention_mask_2d=attention_mask_2d,