| |
| |
| |
| |
| @@ -40,7 +40,8 @@ def generate_simple( |
| attn_impl = { |
| "sdpa_paged": "sdpa", |
| "eager_paged": "eager", |
| - "flash_paged": "flash_attention_2", |
| + "paged_attention": "eager", # TODO: this does not work on AMD docker |
| + "flash_paged": "flash_attention_2", # TODO: this does not work on AMD docker |
| }[attn_impl] |
| |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.bfloat16, attn_implementation=attn_impl) |
| |
| |
| |
| |
| @@ -6,11 +6,21 @@ |
| from ..utils import is_flash_attn_2_available |
| |
| |
| +# For some reason, if we dont assign the function to a variable here, it will be garbage collected |
| try: |
| if is_flash_attn_2_available(): |
| from flash_attn import flash_attn_varlen_func # noqa: F401 |
| -except Exception: |
| - pass |
| + |
| + FLASH_ATTN_VARLEN_FUNC = flash_attn_varlen_func |
| + else: |
| + raise RuntimeError( |
| + "Flash Attention 2 is not installed. Please refer to https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install it" |
| + ) |
| +except Exception as e: |
| + msg = repr(e) |
| + |
| + def FLASH_ATTN_VARLEN_FUNC(*args, **kwargs): |
| + raise Exception(f"flash_attn_varlen_func is not available: {msg}") |
| |
| |
| def paged_attention_forward( |
| @@ -63,6 +73,8 @@ def paged_attention_forward( |
| |
| if implementation is not None and hasattr(implementation, "flash_attn_varlen_func"): |
| flash_attn_varlen_func = implementation.flash_attn_varlen_func |
| + else: |
| + flash_attn_varlen_func = FLASH_ATTN_VARLEN_FUNC |
| |
| custom_kwargs = {"s_aux": kwargs.get("s_aux")} if "s_aux" in kwargs else {} |
| |
|
|