diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 2b0d506eb895..cf5379fc619c 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -40,7 +40,8 @@ def generate_simple( attn_impl = { "sdpa_paged": "sdpa", "eager_paged": "eager", - "flash_paged": "flash_attention_2", + "paged_attention": "eager", # TODO: this does not work on AMD docker + "flash_paged": "flash_attention_2", # TODO: this does not work on AMD docker }[attn_impl] model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.bfloat16, attn_implementation=attn_impl) diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py index 2f11f452c1bb..1d1db72a7605 100644 --- a/src/transformers/integrations/flash_paged.py +++ b/src/transformers/integrations/flash_paged.py @@ -6,11 +6,21 @@ from ..utils import is_flash_attn_2_available +# For some reason, if we dont assign the function to a variable here, it will be garbage collected try: if is_flash_attn_2_available(): from flash_attn import flash_attn_varlen_func # noqa: F401 -except Exception: - pass + + FLASH_ATTN_VARLEN_FUNC = flash_attn_varlen_func + else: + raise RuntimeError( + "Flash Attention 2 is not installed. Please refer to https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install it" + ) +except Exception as e: + msg = repr(e) + + def FLASH_ATTN_VARLEN_FUNC(*args, **kwargs): + raise Exception(f"flash_attn_varlen_func is not available: {msg}") def paged_attention_forward( @@ -63,6 +73,8 @@ def paged_attention_forward( if implementation is not None and hasattr(implementation, "flash_attn_varlen_func"): flash_attn_varlen_func = implementation.flash_attn_varlen_func + else: + flash_attn_varlen_func = FLASH_ATTN_VARLEN_FUNC custom_kwargs = {"s_aux": kwargs.get("s_aux")} if "s_aux" in kwargs else {}