harness / diffs /41078.patch
ArthurZ's picture
ArthurZ HF Staff
Initial harness: 100 perf tasks + Gradio browser
dfefe0b verified
diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py
index 2b0d506eb895..cf5379fc619c 100644
--- a/examples/pytorch/continuous_batching.py
+++ b/examples/pytorch/continuous_batching.py
@@ -40,7 +40,8 @@ def generate_simple(
attn_impl = {
"sdpa_paged": "sdpa",
"eager_paged": "eager",
- "flash_paged": "flash_attention_2",
+ "paged_attention": "eager", # TODO: this does not work on AMD docker
+ "flash_paged": "flash_attention_2", # TODO: this does not work on AMD docker
}[attn_impl]
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.bfloat16, attn_implementation=attn_impl)
diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py
index 2f11f452c1bb..1d1db72a7605 100644
--- a/src/transformers/integrations/flash_paged.py
+++ b/src/transformers/integrations/flash_paged.py
@@ -6,11 +6,21 @@
from ..utils import is_flash_attn_2_available
+# For some reason, if we dont assign the function to a variable here, it will be garbage collected
try:
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func # noqa: F401
-except Exception:
- pass
+
+ FLASH_ATTN_VARLEN_FUNC = flash_attn_varlen_func
+ else:
+ raise RuntimeError(
+ "Flash Attention 2 is not installed. Please refer to https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install it"
+ )
+except Exception as e:
+ msg = repr(e)
+
+ def FLASH_ATTN_VARLEN_FUNC(*args, **kwargs):
+ raise Exception(f"flash_attn_varlen_func is not available: {msg}")
def paged_attention_forward(
@@ -63,6 +73,8 @@ def paged_attention_forward(
if implementation is not None and hasattr(implementation, "flash_attn_varlen_func"):
flash_attn_varlen_func = implementation.flash_attn_varlen_func
+ else:
+ flash_attn_varlen_func = FLASH_ATTN_VARLEN_FUNC
custom_kwargs = {"s_aux": kwargs.get("s_aux")} if "s_aux" in kwargs else {}