harness / diffs /41010.patch
ArthurZ's picture
ArthurZ HF Staff
Initial harness: 100 perf tasks + Gradio browser
dfefe0b verified
diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py
index 5be21e2f9a51..71adbb9188e7 100644
--- a/src/transformers/integrations/hub_kernels.py
+++ b/src/transformers/integrations/hub_kernels.py
@@ -152,7 +152,10 @@ def load_and_register_kernel(attn_implementation: str) -> None:
if not is_kernel(attn_implementation):
return
if not _kernels_available:
- raise ImportError("`kernels` is not installed. Please install it with `pip install kernels`.")
+ raise ImportError(
+ "`kernels` is either not installed or uses an incompatible version. "
+ "Please install the latest version with `pip install -U kernels`."
+ )
# Need to be imported here as otherwise we have a circular import in `modeling_utils`
from ..masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index a64085c4e931..e2a623c6bbe4 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -2675,35 +2675,39 @@ def _check_and_adjust_attn_implementation(
None to sdpa (to potentially eager).
"""
applicable_attn_implementation = attn_implementation
+
# If FA not installed, do not fail but use kernels instead
if (
- applicable_attn_implementation == "flash_attention_2"
+ attn_implementation is not None
+ and attn_implementation.startswith("flash_attention")
and self._supports_flash_attn
- and not is_flash_attn_2_available()
+ and not (is_flash_attn_2_available() or is_flash_attn_3_available())
and is_kernels_available()
):
- applicable_attn_implementation = "kernels-community/flash-attn"
+ if attn_implementation.endswith("2"):
+ applicable_attn_implementation = "kernels-community/flash-attn"
+ else:
+ applicable_attn_implementation = "kernels-community/vllm-flash-attn3"
+
if is_kernel(applicable_attn_implementation):
try:
load_and_register_kernel(applicable_attn_implementation)
# log that we used kernel fallback if successful
- if attn_implementation == "flash_attention_2":
+ if attn_implementation.startswith("flash_attention"):
logger.warning_once(
- "You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` "
- "library instead!"
+ f"You do not have `flash_attn` installed, using `{applicable_attn_implementation}` "
+ "from the `kernels` library instead!"
)
except Exception as e:
- if attn_implementation == "flash_attention_2":
- self._flash_attn_2_can_dispatch() # will fail as fa2 is not available but raise the proper exception
- logger.warning_once(
- f"Could not find a kernel matching `{applicable_attn_implementation}` compatible with your device in the "
- f"hub:\n{e}.\nUsing default attention implementation instead (sdpa if available, eager otherwise)."
- )
- try:
- self._sdpa_can_dispatch(is_init_check)
- applicable_attn_implementation = "sdpa"
- except (ValueError, ImportError):
- applicable_attn_implementation = "eager"
+ # raise the proper exception for requested flash attention
+ if attn_implementation.startswith("flash_attention"):
+ if attn_implementation.endswith("2"):
+ self._flash_attn_2_can_dispatch()
+ else:
+ self._flash_attn_3_can_dispatch()
+
+ # error properly out if a kernel was specifically requested
+ raise e
else:
applicable_attn_implementation = self.get_correct_attn_implementation(
applicable_attn_implementation, is_init_check
diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py
index fc2bbb60c452..1674e9dabfc2 100644
--- a/tests/utils/test_modeling_utils.py
+++ b/tests/utils/test_modeling_utils.py
@@ -88,6 +88,7 @@
from transformers.utils.import_utils import (
is_flash_attn_2_available,
is_flash_attn_3_available,
+ is_kernels_available,
is_torch_npu_available,
)
@@ -2849,6 +2850,9 @@ def test_not_available_flash(self):
reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
)
+ if is_kernels_available():
+ self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_flash`")
+
with self.assertRaises(ImportError) as cm:
_ = AutoModel.from_pretrained(
"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
@@ -2864,6 +2868,9 @@ def test_not_available_flash_with_config(self):
reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
)
+ if is_kernels_available():
+ self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_flash_with_config`")
+
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel")
with self.assertRaises(ImportError) as cm:
@@ -2875,6 +2882,41 @@ def test_not_available_flash_with_config(self):
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
+ def test_kernels_fallback(self):
+ if not is_kernels_available():
+ self.skipTest(reason="Please install `kernels` package to run `test_kernels_fallback`")
+
+ if is_flash_attn_2_available():
+ self.skipTest(reason="Please uninstall flash-attn package to run test_kernels_fallback")
+
+ if is_torch_npu_available():
+ self.skipTest(
+ reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
+ )
+
+ logger = logging.get_logger("transformers.modeling_utils")
+ with LoggingLevel(logging.WARNING):
+ with CaptureLogger(logger) as cl:
+ _ = AutoModel.from_pretrained(
+ "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
+ )
+
+ self.assertTrue(
+ "You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` library instead!"
+ in cl.out
+ )
+
+ def test_not_available_kernels(self):
+ if is_kernels_available():
+ self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_kernels`")
+
+ with self.assertRaises(ImportError) as cm:
+ _ = AutoModel.from_pretrained(
+ "hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="kernels-community/flash-attn"
+ )
+
+ self.assertTrue("`kernels` is either not installed or uses an incompatible version." in str(cm.exception))
+
@require_torch
class TestTensorSharing(TestCasePlus):