File size: 2,087 Bytes
dfefe0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py
index 2b0d506eb895..cf5379fc619c 100644
--- a/examples/pytorch/continuous_batching.py
+++ b/examples/pytorch/continuous_batching.py
@@ -40,7 +40,8 @@ def generate_simple(
     attn_impl = {
         "sdpa_paged": "sdpa",
         "eager_paged": "eager",
-        "flash_paged": "flash_attention_2",
+        "paged_attention": "eager",  # TODO: this does not work on AMD docker
+        "flash_paged": "flash_attention_2",  # TODO: this does not work on AMD docker
     }[attn_impl]
 
     model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.bfloat16, attn_implementation=attn_impl)
diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py
index 2f11f452c1bb..1d1db72a7605 100644
--- a/src/transformers/integrations/flash_paged.py
+++ b/src/transformers/integrations/flash_paged.py
@@ -6,11 +6,21 @@
 from ..utils import is_flash_attn_2_available
 
 
+# For some reason, if we dont assign the function to a variable here, it will be garbage collected
 try:
     if is_flash_attn_2_available():
         from flash_attn import flash_attn_varlen_func  # noqa: F401
-except Exception:
-    pass
+
+        FLASH_ATTN_VARLEN_FUNC = flash_attn_varlen_func
+    else:
+        raise RuntimeError(
+            "Flash Attention 2 is not installed. Please refer to https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install it"
+        )
+except Exception as e:
+    msg = repr(e)
+
+    def FLASH_ATTN_VARLEN_FUNC(*args, **kwargs):
+        raise Exception(f"flash_attn_varlen_func is not available: {msg}")
 
 
 def paged_attention_forward(
@@ -63,6 +73,8 @@ def paged_attention_forward(
 
     if implementation is not None and hasattr(implementation, "flash_attn_varlen_func"):
         flash_attn_varlen_func = implementation.flash_attn_varlen_func
+    else:
+        flash_attn_varlen_func = FLASH_ATTN_VARLEN_FUNC
 
     custom_kwargs = {"s_aux": kwargs.get("s_aux")} if "s_aux" in kwargs else {}