Workaround for issue
Browse filesWorkaround for issue: get_imports failing to respect conditionals on imports
https://github.com/huggingface/transformers/issues/28459
This should allow this code to work without flash2 module installed -- and allow the code to run on a CPU.
- modelling_walsh.py +14 -7
modelling_walsh.py
CHANGED
|
@@ -27,6 +27,13 @@ from transformers.utils import (
|
|
| 27 |
is_flash_attn_greater_or_equal_2_10,
|
| 28 |
)
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
if is_flash_attn_2_available():
|
| 31 |
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
| 32 |
|
|
@@ -825,7 +832,7 @@ class CausalSelfAttention(nn.Module):
|
|
| 825 |
init.constant_(self.output_linear.bias, 0.)
|
| 826 |
|
| 827 |
# Project QKV input through input matrices, reshape to (batch_size, n_heads, seq_len, d_model), and apply cache.
|
| 828 |
-
def
|
| 829 |
batch_size, seq_len, d_embed = qkv.shape
|
| 830 |
proj = self.in_proj(qkv)
|
| 831 |
query, key, value = proj.chunk(chunks=3, dim=-1)
|
|
@@ -857,15 +864,15 @@ class CausalSelfAttention(nn.Module):
|
|
| 857 |
|
| 858 |
if attn_type == "flash2":
|
| 859 |
if use_cache is None or use_cache == False:
|
| 860 |
-
return self.
|
| 861 |
else:
|
| 862 |
-
return self.
|
| 863 |
|
| 864 |
# qkv: (batch_size, seq_len, d_embed)
|
| 865 |
batch_size, seq_len, d_embed = qkv.shape
|
| 866 |
|
| 867 |
# Feed the inputs through the K, Q, V matrices.
|
| 868 |
-
query, key, value = self.
|
| 869 |
kv_seq_len = key.shape[-2]
|
| 870 |
|
| 871 |
# Default to returning empty attention weights.
|
|
@@ -922,7 +929,7 @@ class CausalSelfAttention(nn.Module):
|
|
| 922 |
)
|
| 923 |
|
| 924 |
# No cache support, but faster
|
| 925 |
-
def
|
| 926 |
self,
|
| 927 |
qkv,
|
| 928 |
):
|
|
@@ -961,7 +968,7 @@ class CausalSelfAttention(nn.Module):
|
|
| 961 |
|
| 962 |
# See https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py
|
| 963 |
#https://huggingface.co/docs/transformers/internal/generation_utils
|
| 964 |
-
def
|
| 965 |
self,
|
| 966 |
qkv,
|
| 967 |
past_key_values,
|
|
@@ -969,7 +976,7 @@ class CausalSelfAttention(nn.Module):
|
|
| 969 |
batch_size, seq_len, d_embed = qkv.shape
|
| 970 |
|
| 971 |
# Feed the inputs through the K, Q, V matrices.
|
| 972 |
-
query, key, value = self.
|
| 973 |
query, key, value = self._downcast_to_float16(query, key, value)
|
| 974 |
|
| 975 |
# Expected inputs to flash2:
|
|
|
|
| 27 |
is_flash_attn_greater_or_equal_2_10,
|
| 28 |
)
|
| 29 |
|
| 30 |
+
# Workaround for https://github.com/huggingface/transformers/issues/28459
|
| 31 |
+
if is_flash_attn_2_available():
|
| 32 |
+
try:
|
| 33 |
+
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
| 34 |
+
except:
|
| 35 |
+
print("Could not import flash2")
|
| 36 |
+
|
| 37 |
if is_flash_attn_2_available():
|
| 38 |
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
| 39 |
|
|
|
|
| 832 |
init.constant_(self.output_linear.bias, 0.)
|
| 833 |
|
| 834 |
# Project QKV input through input matrices, reshape to (batch_size, n_heads, seq_len, d_model), and apply cache.
|
| 835 |
+
def _project_input(self, qkv, past_key_values):
|
| 836 |
batch_size, seq_len, d_embed = qkv.shape
|
| 837 |
proj = self.in_proj(qkv)
|
| 838 |
query, key, value = proj.chunk(chunks=3, dim=-1)
|
|
|
|
| 864 |
|
| 865 |
if attn_type == "flash2":
|
| 866 |
if use_cache is None or use_cache == False:
|
| 867 |
+
return self._flash2_forward(qkv)
|
| 868 |
else:
|
| 869 |
+
return self._flash2_forward_cached(qkv, past_key_values)
|
| 870 |
|
| 871 |
# qkv: (batch_size, seq_len, d_embed)
|
| 872 |
batch_size, seq_len, d_embed = qkv.shape
|
| 873 |
|
| 874 |
# Feed the inputs through the K, Q, V matrices.
|
| 875 |
+
query, key, value = self._project_input(qkv, past_key_values)
|
| 876 |
kv_seq_len = key.shape[-2]
|
| 877 |
|
| 878 |
# Default to returning empty attention weights.
|
|
|
|
| 929 |
)
|
| 930 |
|
| 931 |
# No cache support, but faster
|
| 932 |
+
def _flash2_forward(
|
| 933 |
self,
|
| 934 |
qkv,
|
| 935 |
):
|
|
|
|
| 968 |
|
| 969 |
# See https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py
|
| 970 |
#https://huggingface.co/docs/transformers/internal/generation_utils
|
| 971 |
+
def _flash2_forward_cached(
|
| 972 |
self,
|
| 973 |
qkv,
|
| 974 |
past_key_values,
|
|
|
|
| 976 |
batch_size, seq_len, d_embed = qkv.shape
|
| 977 |
|
| 978 |
# Feed the inputs through the K, Q, V matrices.
|
| 979 |
+
query, key, value = self._project_input(qkv, past_key_values)
|
| 980 |
query, key, value = self._downcast_to_float16(query, key, value)
|
| 981 |
|
| 982 |
# Expected inputs to flash2:
|