diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 5be21e2f9a51..71adbb9188e7 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -152,7 +152,10 @@ def load_and_register_kernel(attn_implementation: str) -> None: if not is_kernel(attn_implementation): return if not _kernels_available: - raise ImportError("`kernels` is not installed. Please install it with `pip install kernels`.") + raise ImportError( + "`kernels` is either not installed or uses an incompatible version. " + "Please install the latest version with `pip install -U kernels`." + ) # Need to be imported here as otherwise we have a circular import in `modeling_utils` from ..masking_utils import ALL_MASK_ATTENTION_FUNCTIONS diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a64085c4e931..e2a623c6bbe4 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2675,35 +2675,39 @@ def _check_and_adjust_attn_implementation( None to sdpa (to potentially eager). """ applicable_attn_implementation = attn_implementation + # If FA not installed, do not fail but use kernels instead if ( - applicable_attn_implementation == "flash_attention_2" + attn_implementation is not None + and attn_implementation.startswith("flash_attention") and self._supports_flash_attn - and not is_flash_attn_2_available() + and not (is_flash_attn_2_available() or is_flash_attn_3_available()) and is_kernels_available() ): - applicable_attn_implementation = "kernels-community/flash-attn" + if attn_implementation.endswith("2"): + applicable_attn_implementation = "kernels-community/flash-attn" + else: + applicable_attn_implementation = "kernels-community/vllm-flash-attn3" + if is_kernel(applicable_attn_implementation): try: load_and_register_kernel(applicable_attn_implementation) # log that we used kernel fallback if successful - if attn_implementation == "flash_attention_2": + if attn_implementation.startswith("flash_attention"): logger.warning_once( - "You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` " - "library instead!" + f"You do not have `flash_attn` installed, using `{applicable_attn_implementation}` " + "from the `kernels` library instead!" ) except Exception as e: - if attn_implementation == "flash_attention_2": - self._flash_attn_2_can_dispatch() # will fail as fa2 is not available but raise the proper exception - logger.warning_once( - f"Could not find a kernel matching `{applicable_attn_implementation}` compatible with your device in the " - f"hub:\n{e}.\nUsing default attention implementation instead (sdpa if available, eager otherwise)." - ) - try: - self._sdpa_can_dispatch(is_init_check) - applicable_attn_implementation = "sdpa" - except (ValueError, ImportError): - applicable_attn_implementation = "eager" + # raise the proper exception for requested flash attention + if attn_implementation.startswith("flash_attention"): + if attn_implementation.endswith("2"): + self._flash_attn_2_can_dispatch() + else: + self._flash_attn_3_can_dispatch() + + # error properly out if a kernel was specifically requested + raise e else: applicable_attn_implementation = self.get_correct_attn_implementation( applicable_attn_implementation, is_init_check diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index fc2bbb60c452..1674e9dabfc2 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -88,6 +88,7 @@ from transformers.utils.import_utils import ( is_flash_attn_2_available, is_flash_attn_3_available, + is_kernels_available, is_torch_npu_available, ) @@ -2849,6 +2850,9 @@ def test_not_available_flash(self): reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case." ) + if is_kernels_available(): + self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_flash`") + with self.assertRaises(ImportError) as cm: _ = AutoModel.from_pretrained( "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2" @@ -2864,6 +2868,9 @@ def test_not_available_flash_with_config(self): reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case." ) + if is_kernels_available(): + self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_flash_with_config`") + config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel") with self.assertRaises(ImportError) as cm: @@ -2875,6 +2882,41 @@ def test_not_available_flash_with_config(self): self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception)) + def test_kernels_fallback(self): + if not is_kernels_available(): + self.skipTest(reason="Please install `kernels` package to run `test_kernels_fallback`") + + if is_flash_attn_2_available(): + self.skipTest(reason="Please uninstall flash-attn package to run test_kernels_fallback") + + if is_torch_npu_available(): + self.skipTest( + reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case." + ) + + logger = logging.get_logger("transformers.modeling_utils") + with LoggingLevel(logging.WARNING): + with CaptureLogger(logger) as cl: + _ = AutoModel.from_pretrained( + "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2" + ) + + self.assertTrue( + "You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` library instead!" + in cl.out + ) + + def test_not_available_kernels(self): + if is_kernels_available(): + self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_kernels`") + + with self.assertRaises(ImportError) as cm: + _ = AutoModel.from_pretrained( + "hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="kernels-community/flash-attn" + ) + + self.assertTrue("`kernels` is either not installed or uses an incompatible version." in str(cm.exception)) + @require_torch class TestTensorSharing(TestCasePlus):