File size: 8,512 Bytes
dfefe0b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 50a200ae76bc..17df24664edb 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -1978,8 +1978,6 @@ def _from_config(cls, config, **kwargs):
if isinstance(torch_dtype, str):
torch_dtype = getattr(torch, torch_dtype)
- use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
-
# override default dtype if needed
dtype_orig = None
if torch_dtype is not None:
@@ -1998,7 +1996,6 @@ def _from_config(cls, config, **kwargs):
if not getattr(config, "_attn_implementation_autoset", False):
config = cls._autoset_attn_implementation(
config,
- use_flash_attention_2=use_flash_attention_2,
check_device_map=False,
torch_dtype=torch_dtype,
)
@@ -2024,7 +2021,6 @@ def _from_config(cls, config, **kwargs):
def _autoset_attn_implementation(
cls,
config,
- use_flash_attention_2: bool = False,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
check_device_map: bool = True,
@@ -2032,21 +2028,14 @@ def _autoset_attn_implementation(
"""
Automatically checks and dispatches to a default attention implementation. In order of priority:
1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained).
- 2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example)
- 3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example)
- 4. The default model's implementation otherwise (`LlamaAttention` for example) .
+ 2. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example)
+ 3. The default model's implementation otherwise (`LlamaAttention` for example) .
"""
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitly set by the user.
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
requested_attn_implementation = None
if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
- if config._attn_implementation != "flash_attention_2" and use_flash_attention_2:
- raise ValueError(
- f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible.'
- ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
- )
-
if isinstance(config._attn_implementation, str) and re.match(
r"^[^/:]+/[^/:]+:[^/:]+$", config._attn_implementation
):
@@ -2111,12 +2100,6 @@ def _autoset_attn_implementation(
if sub_config is not None:
sub_config._attn_implementation_internal = curr_attn_implementation
- if use_flash_attention_2:
- logger.warning_once(
- 'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.'
- )
- config._attn_implementation = "flash_attention_2"
-
if config._attn_implementation == "flash_attention_2":
cls._check_and_enable_flash_attn_2(
config,
@@ -2128,10 +2111,10 @@ def _autoset_attn_implementation(
elif requested_attn_implementation == "flex_attention":
config = cls._check_and_enable_flex_attn(config, hard_check_only=True)
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
- # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
+ # flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
config = cls._check_and_enable_sdpa(
config,
- hard_check_only=False if requested_attn_implementation is None else True,
+ hard_check_only=requested_attn_implementation is not None,
)
if (
@@ -4022,7 +4005,6 @@ def from_pretrained(
variant = kwargs.pop("variant", None)
adapter_kwargs = kwargs.pop("adapter_kwargs", {})
adapter_name = kwargs.pop("adapter_name", "default")
- use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
generation_config = kwargs.pop("generation_config", None)
gguf_file = kwargs.pop("gguf_file", None)
tp_plan = kwargs.pop("tp_plan", None)
@@ -4403,7 +4385,6 @@ def from_pretrained(
if not getattr(config, "_attn_implementation_autoset", False):
config = cls._autoset_attn_implementation(
config,
- use_flash_attention_2=use_flash_attention_2,
torch_dtype=torch_dtype,
device_map=device_map,
)
diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py
index accc07f6c35b..5c5ecb669630 100644
--- a/src/transformers/models/modernbert/modeling_modernbert.py
+++ b/src/transformers/models/modernbert/modeling_modernbert.py
@@ -643,7 +643,6 @@ def init_weight(module: nn.Module, std: float):
def _autoset_attn_implementation(
cls,
config,
- use_flash_attention_2: bool = False,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
check_device_map: bool = True,
@@ -666,8 +665,7 @@ def _autoset_attn_implementation(
config._attn_implementation_internal = None
return super()._autoset_attn_implementation(
config,
- use_flash_attention_2=use_flash_attention_2,
- torch_dtype=torch.float16,
+ torch_dtype=torch_dtype,
device_map=device_map,
check_device_map=check_device_map,
)
diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py
index f687324c5af5..7ef052a93cf2 100644
--- a/src/transformers/models/modernbert/modular_modernbert.py
+++ b/src/transformers/models/modernbert/modular_modernbert.py
@@ -845,7 +845,6 @@ def init_weight(module: nn.Module, std: float):
def _autoset_attn_implementation(
cls,
config,
- use_flash_attention_2: bool = False,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
check_device_map: bool = True,
@@ -868,8 +867,7 @@ def _autoset_attn_implementation(
config._attn_implementation_internal = None
return super()._autoset_attn_implementation(
config,
- use_flash_attention_2=use_flash_attention_2,
- torch_dtype=torch.float16,
+ torch_dtype=torch_dtype,
device_map=device_map,
check_device_map=check_device_map,
)
diff --git a/tests/models/diffllama/test_modeling_diffllama.py b/tests/models/diffllama/test_modeling_diffllama.py
index c738fbf76d1a..b03231396019 100644
--- a/tests/models/diffllama/test_modeling_diffllama.py
+++ b/tests/models/diffllama/test_modeling_diffllama.py
@@ -488,7 +488,7 @@ def test_use_flash_attention_2_true(self):
model.save_pretrained(tmp_dir)
new_model = DiffLlamaForCausalLM.from_pretrained(
- tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
+ tmp_dir, attn_implementation="flash_attention_2", torch_dtype=torch.float16
).to("cuda")
self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")
|