Upload modeling_fastesm.py with huggingface_hub
Browse files- modeling_fastesm.py +48 -37
modeling_fastesm.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch._inductor.config as inductor_config
|
| 3 |
import torch._dynamo as dynamo
|
|
@@ -429,7 +431,8 @@ Contains: AttentionBackend enum, backend resolution, mask creation,
|
|
| 429 |
flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
|
| 430 |
"""
|
| 431 |
from enum import Enum
|
| 432 |
-
from
|
|
|
|
| 433 |
|
| 434 |
import torch
|
| 435 |
import torch.nn as nn
|
|
@@ -447,7 +450,12 @@ _compiled_flex_attention = None
|
|
| 447 |
|
| 448 |
|
| 449 |
def _get_flex_attention_fn():
|
| 450 |
-
"""Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
global _compiled_flex_attention
|
| 452 |
if flex_attention is None:
|
| 453 |
return None
|
|
@@ -455,12 +463,15 @@ def _get_flex_attention_fn():
|
|
| 455 |
if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
|
| 456 |
return flex_attention
|
| 457 |
if _compiled_flex_attention is None:
|
| 458 |
-
_compiled_flex_attention = torch.compile(
|
|
|
|
|
|
|
|
|
|
| 459 |
return _compiled_flex_attention
|
| 460 |
|
| 461 |
|
| 462 |
### Kernels Flash Attention Detection
|
| 463 |
-
def _infer_kernels_flash_variant(kernel) -> str
|
| 464 |
if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
|
| 465 |
return "flash_attn2"
|
| 466 |
if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
|
|
@@ -576,7 +587,7 @@ class IndexFirstAxis(torch.autograd.Function):
|
|
| 576 |
).reshape(-1, *other_shape)
|
| 577 |
|
| 578 |
@staticmethod
|
| 579 |
-
def backward(ctx, grad_output) ->
|
| 580 |
(indices,) = ctx.saved_tensors
|
| 581 |
assert grad_output.ndim >= 2
|
| 582 |
other_shape = grad_output.shape[1:]
|
|
@@ -599,7 +610,7 @@ class IndexPutFirstAxis(torch.autograd.Function):
|
|
| 599 |
return output
|
| 600 |
|
| 601 |
@staticmethod
|
| 602 |
-
def backward(ctx, grad_output) ->
|
| 603 |
(indices,) = ctx.saved_tensors
|
| 604 |
return grad_output[indices], None, None
|
| 605 |
|
|
@@ -618,7 +629,7 @@ def _unpad_input(
|
|
| 618 |
key_layer: torch.Tensor,
|
| 619 |
value_layer: torch.Tensor,
|
| 620 |
attention_mask_2d: torch.Tensor,
|
| 621 |
-
) ->
|
| 622 |
batch_size, seq_len, num_heads, head_dim = query_layer.shape
|
| 623 |
seqlens = attention_mask_2d.sum(dim=1).int()
|
| 624 |
cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
|
|
@@ -634,7 +645,7 @@ def kernels_flash_attention_func(
|
|
| 634 |
query_states: torch.Tensor,
|
| 635 |
key_states: torch.Tensor,
|
| 636 |
value_states: torch.Tensor,
|
| 637 |
-
attention_mask_2d: torch.Tensor
|
| 638 |
causal: bool = False,
|
| 639 |
) -> torch.Tensor:
|
| 640 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
|
@@ -707,7 +718,7 @@ def get_attention_mask(
|
|
| 707 |
seq_len: int,
|
| 708 |
device: torch.device,
|
| 709 |
attention_mask: Optional[torch.Tensor] = None,
|
| 710 |
-
) ->
|
| 711 |
"""Build padding masks once for all encoder layers.
|
| 712 |
|
| 713 |
Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
|
|
@@ -738,7 +749,7 @@ def get_attention_mask(
|
|
| 738 |
import torch
|
| 739 |
import torch.nn as nn
|
| 740 |
from torch.nn import functional as F
|
| 741 |
-
from typing import
|
| 742 |
from einops import rearrange
|
| 743 |
from dataclasses import dataclass
|
| 744 |
from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
|
|
@@ -762,7 +773,7 @@ class FastEsmEncoderOutput(ModelOutput):
|
|
| 762 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 763 |
hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
|
| 764 |
attentions: Optional[Tuple[torch.Tensor, ...]] = None
|
| 765 |
-
s_max: Optional[Tuple[
|
| 766 |
|
| 767 |
|
| 768 |
@dataclass
|
|
@@ -772,7 +783,7 @@ class EsmMaskedLMOutput(ModelOutput):
|
|
| 772 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 773 |
hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
|
| 774 |
attentions: Optional[Tuple[torch.Tensor, ...]] = None
|
| 775 |
-
s_max: Optional[Tuple[
|
| 776 |
|
| 777 |
|
| 778 |
class FastEsmConfig(PretrainedConfig):
|
|
@@ -858,12 +869,12 @@ class EsmSelfAttention(nn.Module):
|
|
| 858 |
def forward(
|
| 859 |
self,
|
| 860 |
hidden_states: torch.Tensor,
|
| 861 |
-
attention_mask_2d: torch.Tensor
|
| 862 |
-
attention_mask_4d: torch.Tensor
|
| 863 |
-
flex_block_mask:
|
| 864 |
output_attentions: bool = False,
|
| 865 |
output_s_max: bool = False,
|
| 866 |
-
) ->
|
| 867 |
batch_size, seq_length = hidden_states.shape[:-1]
|
| 868 |
hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
|
| 869 |
query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
@@ -890,12 +901,12 @@ class EsmSelfAttention(nn.Module):
|
|
| 890 |
query_BHLD: torch.Tensor,
|
| 891 |
key_BHLD: torch.Tensor,
|
| 892 |
value_BHLD: torch.Tensor,
|
| 893 |
-
attention_mask_2d: torch.Tensor
|
| 894 |
-
attention_mask_4d: torch.Tensor
|
| 895 |
-
flex_block_mask:
|
| 896 |
output_attentions: bool = False,
|
| 897 |
output_s_max: bool = False,
|
| 898 |
-
) ->
|
| 899 |
if output_attentions:
|
| 900 |
return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max)
|
| 901 |
|
|
@@ -912,7 +923,7 @@ class EsmSelfAttention(nn.Module):
|
|
| 912 |
return attn_output, attn_weights, s_max
|
| 913 |
|
| 914 |
@torch.no_grad()
|
| 915 |
-
def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) ->
|
| 916 |
q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1)
|
| 917 |
k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1)
|
| 918 |
s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values
|
|
@@ -923,9 +934,9 @@ class EsmSelfAttention(nn.Module):
|
|
| 923 |
query_BHLD: torch.Tensor,
|
| 924 |
key_BHLD: torch.Tensor,
|
| 925 |
value_BHLD: torch.Tensor,
|
| 926 |
-
attention_mask_4d: torch.Tensor
|
| 927 |
output_s_max: bool = False,
|
| 928 |
-
) ->
|
| 929 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2))
|
| 930 |
if attention_mask_4d is not None:
|
| 931 |
attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
|
|
@@ -942,8 +953,8 @@ class EsmSelfAttention(nn.Module):
|
|
| 942 |
query_BHLD: torch.Tensor,
|
| 943 |
key_BHLD: torch.Tensor,
|
| 944 |
value_BHLD: torch.Tensor,
|
| 945 |
-
attention_mask_2d: torch.Tensor
|
| 946 |
-
) ->
|
| 947 |
query_BLHD = query_BHLD.transpose(1, 2).contiguous()
|
| 948 |
key_BLHD = key_BHLD.transpose(1, 2).contiguous()
|
| 949 |
value_BLHD = value_BHLD.transpose(1, 2).contiguous()
|
|
@@ -958,8 +969,8 @@ class EsmSelfAttention(nn.Module):
|
|
| 958 |
query_BHLD: torch.Tensor,
|
| 959 |
key_BHLD: torch.Tensor,
|
| 960 |
value_BHLD: torch.Tensor,
|
| 961 |
-
flex_block_mask:
|
| 962 |
-
) ->
|
| 963 |
assert flex_attention is not None, "Flex attention is not available in this environment."
|
| 964 |
fn = _get_flex_attention_fn()
|
| 965 |
context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
|
|
@@ -970,8 +981,8 @@ class EsmSelfAttention(nn.Module):
|
|
| 970 |
query_BHLD: torch.Tensor,
|
| 971 |
key_BHLD: torch.Tensor,
|
| 972 |
value_BHLD: torch.Tensor,
|
| 973 |
-
attention_mask_4d: torch.Tensor
|
| 974 |
-
) ->
|
| 975 |
context_BHLD = F.scaled_dot_product_attention(
|
| 976 |
query_BHLD, key_BHLD, value_BHLD,
|
| 977 |
attn_mask=attention_mask_4d,
|
|
@@ -991,12 +1002,12 @@ class EsmAttention(nn.Module):
|
|
| 991 |
def forward(
|
| 992 |
self,
|
| 993 |
hidden_states: torch.Tensor,
|
| 994 |
-
attention_mask_2d: torch.Tensor
|
| 995 |
-
attention_mask_4d: torch.Tensor
|
| 996 |
-
flex_block_mask:
|
| 997 |
output_attentions: bool = False,
|
| 998 |
output_s_max: bool = False,
|
| 999 |
-
) ->
|
| 1000 |
hidden_states_ln = self.LayerNorm(hidden_states)
|
| 1001 |
attn_output, attn_weights, s_max = self.self(
|
| 1002 |
hidden_states_ln,
|
|
@@ -1023,12 +1034,12 @@ class EsmLayer(nn.Module):
|
|
| 1023 |
def forward(
|
| 1024 |
self,
|
| 1025 |
hidden_states: torch.Tensor,
|
| 1026 |
-
attention_mask_2d: torch.Tensor
|
| 1027 |
-
attention_mask_4d: torch.Tensor
|
| 1028 |
-
flex_block_mask:
|
| 1029 |
output_attentions: bool = False,
|
| 1030 |
output_s_max: bool = False,
|
| 1031 |
-
) ->
|
| 1032 |
attention_output, attn_weights, s_max = self.attention(
|
| 1033 |
hidden_states,
|
| 1034 |
attention_mask_2d=attention_mask_2d,
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
import torch._inductor.config as inductor_config
|
| 5 |
import torch._dynamo as dynamo
|
|
|
|
| 431 |
flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
|
| 432 |
"""
|
| 433 |
from enum import Enum
|
| 434 |
+
from functools import partial
|
| 435 |
+
from typing import Dict, List, Optional, Tuple
|
| 436 |
|
| 437 |
import torch
|
| 438 |
import torch.nn as nn
|
|
|
|
| 450 |
|
| 451 |
|
| 452 |
def _get_flex_attention_fn():
|
| 453 |
+
"""Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.
|
| 454 |
+
|
| 455 |
+
Uses kernel_options={"BACKEND": "FLASH"} to prefer Flash Attention 4 (FA4)
|
| 456 |
+
on Hopper/Blackwell GPUs (PyTorch 2.11+). Automatically falls back to Triton
|
| 457 |
+
on older hardware.
|
| 458 |
+
"""
|
| 459 |
global _compiled_flex_attention
|
| 460 |
if flex_attention is None:
|
| 461 |
return None
|
|
|
|
| 463 |
if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
|
| 464 |
return flex_attention
|
| 465 |
if _compiled_flex_attention is None:
|
| 466 |
+
_compiled_flex_attention = torch.compile(
|
| 467 |
+
partial(flex_attention, kernel_options={"BACKEND": "FLASH"}),
|
| 468 |
+
dynamic=False,
|
| 469 |
+
)
|
| 470 |
return _compiled_flex_attention
|
| 471 |
|
| 472 |
|
| 473 |
### Kernels Flash Attention Detection
|
| 474 |
+
def _infer_kernels_flash_variant(kernel) -> Optional[str]:
|
| 475 |
if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
|
| 476 |
return "flash_attn2"
|
| 477 |
if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
|
|
|
|
| 587 |
).reshape(-1, *other_shape)
|
| 588 |
|
| 589 |
@staticmethod
|
| 590 |
+
def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]:
|
| 591 |
(indices,) = ctx.saved_tensors
|
| 592 |
assert grad_output.ndim >= 2
|
| 593 |
other_shape = grad_output.shape[1:]
|
|
|
|
| 610 |
return output
|
| 611 |
|
| 612 |
@staticmethod
|
| 613 |
+
def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]:
|
| 614 |
(indices,) = ctx.saved_tensors
|
| 615 |
return grad_output[indices], None, None
|
| 616 |
|
|
|
|
| 629 |
key_layer: torch.Tensor,
|
| 630 |
value_layer: torch.Tensor,
|
| 631 |
attention_mask_2d: torch.Tensor,
|
| 632 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
|
| 633 |
batch_size, seq_len, num_heads, head_dim = query_layer.shape
|
| 634 |
seqlens = attention_mask_2d.sum(dim=1).int()
|
| 635 |
cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
|
|
|
|
| 645 |
query_states: torch.Tensor,
|
| 646 |
key_states: torch.Tensor,
|
| 647 |
value_states: torch.Tensor,
|
| 648 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 649 |
causal: bool = False,
|
| 650 |
) -> torch.Tensor:
|
| 651 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
|
|
|
| 718 |
seq_len: int,
|
| 719 |
device: torch.device,
|
| 720 |
attention_mask: Optional[torch.Tensor] = None,
|
| 721 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]:
|
| 722 |
"""Build padding masks once for all encoder layers.
|
| 723 |
|
| 724 |
Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
|
|
|
|
| 749 |
import torch
|
| 750 |
import torch.nn as nn
|
| 751 |
from torch.nn import functional as F
|
| 752 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 753 |
from einops import rearrange
|
| 754 |
from dataclasses import dataclass
|
| 755 |
from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
|
|
|
|
| 773 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 774 |
hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
|
| 775 |
attentions: Optional[Tuple[torch.Tensor, ...]] = None
|
| 776 |
+
s_max: Optional[Tuple[List[torch.Tensor], ...]] = None
|
| 777 |
|
| 778 |
|
| 779 |
@dataclass
|
|
|
|
| 783 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 784 |
hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
|
| 785 |
attentions: Optional[Tuple[torch.Tensor, ...]] = None
|
| 786 |
+
s_max: Optional[Tuple[List[torch.Tensor], ...]] = None
|
| 787 |
|
| 788 |
|
| 789 |
class FastEsmConfig(PretrainedConfig):
|
|
|
|
| 869 |
def forward(
|
| 870 |
self,
|
| 871 |
hidden_states: torch.Tensor,
|
| 872 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 873 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 874 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 875 |
output_attentions: bool = False,
|
| 876 |
output_s_max: bool = False,
|
| 877 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
|
| 878 |
batch_size, seq_length = hidden_states.shape[:-1]
|
| 879 |
hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
|
| 880 |
query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
|
|
| 901 |
query_BHLD: torch.Tensor,
|
| 902 |
key_BHLD: torch.Tensor,
|
| 903 |
value_BHLD: torch.Tensor,
|
| 904 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 905 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 906 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 907 |
output_attentions: bool = False,
|
| 908 |
output_s_max: bool = False,
|
| 909 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
|
| 910 |
if output_attentions:
|
| 911 |
return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max)
|
| 912 |
|
|
|
|
| 923 |
return attn_output, attn_weights, s_max
|
| 924 |
|
| 925 |
@torch.no_grad()
|
| 926 |
+
def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) -> List[torch.Tensor]:
|
| 927 |
q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1)
|
| 928 |
k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1)
|
| 929 |
s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values
|
|
|
|
| 934 |
query_BHLD: torch.Tensor,
|
| 935 |
key_BHLD: torch.Tensor,
|
| 936 |
value_BHLD: torch.Tensor,
|
| 937 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 938 |
output_s_max: bool = False,
|
| 939 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[List[torch.Tensor]]]:
|
| 940 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2))
|
| 941 |
if attention_mask_4d is not None:
|
| 942 |
attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
|
|
|
|
| 953 |
query_BHLD: torch.Tensor,
|
| 954 |
key_BHLD: torch.Tensor,
|
| 955 |
value_BHLD: torch.Tensor,
|
| 956 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 957 |
+
) -> Tuple[torch.Tensor, None]:
|
| 958 |
query_BLHD = query_BHLD.transpose(1, 2).contiguous()
|
| 959 |
key_BLHD = key_BHLD.transpose(1, 2).contiguous()
|
| 960 |
value_BLHD = value_BHLD.transpose(1, 2).contiguous()
|
|
|
|
| 969 |
query_BHLD: torch.Tensor,
|
| 970 |
key_BHLD: torch.Tensor,
|
| 971 |
value_BHLD: torch.Tensor,
|
| 972 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 973 |
+
) -> Tuple[torch.Tensor, None]:
|
| 974 |
assert flex_attention is not None, "Flex attention is not available in this environment."
|
| 975 |
fn = _get_flex_attention_fn()
|
| 976 |
context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
|
|
|
|
| 981 |
query_BHLD: torch.Tensor,
|
| 982 |
key_BHLD: torch.Tensor,
|
| 983 |
value_BHLD: torch.Tensor,
|
| 984 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 985 |
+
) -> Tuple[torch.Tensor, None]:
|
| 986 |
context_BHLD = F.scaled_dot_product_attention(
|
| 987 |
query_BHLD, key_BHLD, value_BHLD,
|
| 988 |
attn_mask=attention_mask_4d,
|
|
|
|
| 1002 |
def forward(
|
| 1003 |
self,
|
| 1004 |
hidden_states: torch.Tensor,
|
| 1005 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 1006 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 1007 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 1008 |
output_attentions: bool = False,
|
| 1009 |
output_s_max: bool = False,
|
| 1010 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
|
| 1011 |
hidden_states_ln = self.LayerNorm(hidden_states)
|
| 1012 |
attn_output, attn_weights, s_max = self.self(
|
| 1013 |
hidden_states_ln,
|
|
|
|
| 1034 |
def forward(
|
| 1035 |
self,
|
| 1036 |
hidden_states: torch.Tensor,
|
| 1037 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 1038 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 1039 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 1040 |
output_attentions: bool = False,
|
| 1041 |
output_s_max: bool = False,
|
| 1042 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
|
| 1043 |
attention_output, attn_weights, s_max = self.attention(
|
| 1044 |
hidden_states,
|
| 1045 |
attention_mask_2d=attention_mask_2d,
|