| |
| |
| |
| |
| @@ -152,7 +152,10 @@ def load_and_register_kernel(attn_implementation: str) -> None: |
| if not is_kernel(attn_implementation): |
| return |
| if not _kernels_available: |
| - raise ImportError("`kernels` is not installed. Please install it with `pip install kernels`.") |
| + raise ImportError( |
| + "`kernels` is either not installed or uses an incompatible version. " |
| + "Please install the latest version with `pip install -U kernels`." |
| + ) |
| |
| # Need to be imported here as otherwise we have a circular import in `modeling_utils` |
| from ..masking_utils import ALL_MASK_ATTENTION_FUNCTIONS |
| |
| |
| |
| |
| @@ -2675,35 +2675,39 @@ def _check_and_adjust_attn_implementation( |
| None to sdpa (to potentially eager). |
| """ |
| applicable_attn_implementation = attn_implementation |
| + |
| # If FA not installed, do not fail but use kernels instead |
| if ( |
| - applicable_attn_implementation == "flash_attention_2" |
| + attn_implementation is not None |
| + and attn_implementation.startswith("flash_attention") |
| and self._supports_flash_attn |
| - and not is_flash_attn_2_available() |
| + and not (is_flash_attn_2_available() or is_flash_attn_3_available()) |
| and is_kernels_available() |
| ): |
| - applicable_attn_implementation = "kernels-community/flash-attn" |
| + if attn_implementation.endswith("2"): |
| + applicable_attn_implementation = "kernels-community/flash-attn" |
| + else: |
| + applicable_attn_implementation = "kernels-community/vllm-flash-attn3" |
| + |
| if is_kernel(applicable_attn_implementation): |
| try: |
| load_and_register_kernel(applicable_attn_implementation) |
| # log that we used kernel fallback if successful |
| - if attn_implementation == "flash_attention_2": |
| + if attn_implementation.startswith("flash_attention"): |
| logger.warning_once( |
| - "You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` " |
| - "library instead!" |
| + f"You do not have `flash_attn` installed, using `{applicable_attn_implementation}` " |
| + "from the `kernels` library instead!" |
| ) |
| except Exception as e: |
| - if attn_implementation == "flash_attention_2": |
| - self._flash_attn_2_can_dispatch() # will fail as fa2 is not available but raise the proper exception |
| - logger.warning_once( |
| - f"Could not find a kernel matching `{applicable_attn_implementation}` compatible with your device in the " |
| - f"hub:\n{e}.\nUsing default attention implementation instead (sdpa if available, eager otherwise)." |
| - ) |
| - try: |
| - self._sdpa_can_dispatch(is_init_check) |
| - applicable_attn_implementation = "sdpa" |
| - except (ValueError, ImportError): |
| - applicable_attn_implementation = "eager" |
| + # raise the proper exception for requested flash attention |
| + if attn_implementation.startswith("flash_attention"): |
| + if attn_implementation.endswith("2"): |
| + self._flash_attn_2_can_dispatch() |
| + else: |
| + self._flash_attn_3_can_dispatch() |
| + |
| + # error properly out if a kernel was specifically requested |
| + raise e |
| else: |
| applicable_attn_implementation = self.get_correct_attn_implementation( |
| applicable_attn_implementation, is_init_check |
| |
| |
| |
| |
| @@ -88,6 +88,7 @@ |
| from transformers.utils.import_utils import ( |
| is_flash_attn_2_available, |
| is_flash_attn_3_available, |
| + is_kernels_available, |
| is_torch_npu_available, |
| ) |
| |
| @@ -2849,6 +2850,9 @@ def test_not_available_flash(self): |
| reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case." |
| ) |
| |
| + if is_kernels_available(): |
| + self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_flash`") |
| + |
| with self.assertRaises(ImportError) as cm: |
| _ = AutoModel.from_pretrained( |
| "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2" |
| @@ -2864,6 +2868,9 @@ def test_not_available_flash_with_config(self): |
| reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case." |
| ) |
| |
| + if is_kernels_available(): |
| + self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_flash_with_config`") |
| + |
| config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel") |
| |
| with self.assertRaises(ImportError) as cm: |
| @@ -2875,6 +2882,41 @@ def test_not_available_flash_with_config(self): |
| |
| self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception)) |
| |
| + def test_kernels_fallback(self): |
| + if not is_kernels_available(): |
| + self.skipTest(reason="Please install `kernels` package to run `test_kernels_fallback`") |
| + |
| + if is_flash_attn_2_available(): |
| + self.skipTest(reason="Please uninstall flash-attn package to run test_kernels_fallback") |
| + |
| + if is_torch_npu_available(): |
| + self.skipTest( |
| + reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case." |
| + ) |
| + |
| + logger = logging.get_logger("transformers.modeling_utils") |
| + with LoggingLevel(logging.WARNING): |
| + with CaptureLogger(logger) as cl: |
| + _ = AutoModel.from_pretrained( |
| + "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2" |
| + ) |
| + |
| + self.assertTrue( |
| + "You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` library instead!" |
| + in cl.out |
| + ) |
| + |
| + def test_not_available_kernels(self): |
| + if is_kernels_available(): |
| + self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_kernels`") |
| + |
| + with self.assertRaises(ImportError) as cm: |
| + _ = AutoModel.from_pretrained( |
| + "hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="kernels-community/flash-attn" |
| + ) |
| + |
| + self.assertTrue("`kernels` is either not installed or uses an incompatible version." in str(cm.exception)) |
| + |
| |
| @require_torch |
| class TestTensorSharing(TestCasePlus): |
|
|