File size: 8,512 Bytes
dfefe0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 50a200ae76bc..17df24664edb 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -1978,8 +1978,6 @@ def _from_config(cls, config, **kwargs):
         if isinstance(torch_dtype, str):
             torch_dtype = getattr(torch, torch_dtype)
 
-        use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
-
         # override default dtype if needed
         dtype_orig = None
         if torch_dtype is not None:
@@ -1998,7 +1996,6 @@ def _from_config(cls, config, **kwargs):
         if not getattr(config, "_attn_implementation_autoset", False):
             config = cls._autoset_attn_implementation(
                 config,
-                use_flash_attention_2=use_flash_attention_2,
                 check_device_map=False,
                 torch_dtype=torch_dtype,
             )
@@ -2024,7 +2021,6 @@ def _from_config(cls, config, **kwargs):
     def _autoset_attn_implementation(
         cls,
         config,
-        use_flash_attention_2: bool = False,
         torch_dtype: Optional[torch.dtype] = None,
         device_map: Optional[Union[str, Dict[str, int]]] = None,
         check_device_map: bool = True,
@@ -2032,21 +2028,14 @@ def _autoset_attn_implementation(
         """
         Automatically checks and dispatches to a default attention implementation. In order of priority:
             1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained).
-            2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example)
-            3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example)
-            4. The default model's implementation otherwise (`LlamaAttention` for example) .
+            2. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example)
+            3. The default model's implementation otherwise (`LlamaAttention` for example) .
         """
         # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitly set by the user.
         # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
         # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
         requested_attn_implementation = None
         if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
-            if config._attn_implementation != "flash_attention_2" and use_flash_attention_2:
-                raise ValueError(
-                    f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible.'
-                    ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
-                )
-
             if isinstance(config._attn_implementation, str) and re.match(
                 r"^[^/:]+/[^/:]+:[^/:]+$", config._attn_implementation
             ):
@@ -2111,12 +2100,6 @@ def _autoset_attn_implementation(
             if sub_config is not None:
                 sub_config._attn_implementation_internal = curr_attn_implementation
 
-        if use_flash_attention_2:
-            logger.warning_once(
-                'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.'
-            )
-            config._attn_implementation = "flash_attention_2"
-
         if config._attn_implementation == "flash_attention_2":
             cls._check_and_enable_flash_attn_2(
                 config,
@@ -2128,10 +2111,10 @@ def _autoset_attn_implementation(
         elif requested_attn_implementation == "flex_attention":
             config = cls._check_and_enable_flex_attn(config, hard_check_only=True)
         elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
-            # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
+            # flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
             config = cls._check_and_enable_sdpa(
                 config,
-                hard_check_only=False if requested_attn_implementation is None else True,
+                hard_check_only=requested_attn_implementation is not None,
             )
 
             if (
@@ -4022,7 +4005,6 @@ def from_pretrained(
         variant = kwargs.pop("variant", None)
         adapter_kwargs = kwargs.pop("adapter_kwargs", {})
         adapter_name = kwargs.pop("adapter_name", "default")
-        use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
         generation_config = kwargs.pop("generation_config", None)
         gguf_file = kwargs.pop("gguf_file", None)
         tp_plan = kwargs.pop("tp_plan", None)
@@ -4403,7 +4385,6 @@ def from_pretrained(
         if not getattr(config, "_attn_implementation_autoset", False):
             config = cls._autoset_attn_implementation(
                 config,
-                use_flash_attention_2=use_flash_attention_2,
                 torch_dtype=torch_dtype,
                 device_map=device_map,
             )
diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py
index accc07f6c35b..5c5ecb669630 100644
--- a/src/transformers/models/modernbert/modeling_modernbert.py
+++ b/src/transformers/models/modernbert/modeling_modernbert.py
@@ -643,7 +643,6 @@ def init_weight(module: nn.Module, std: float):
     def _autoset_attn_implementation(
         cls,
         config,
-        use_flash_attention_2: bool = False,
         torch_dtype: Optional[torch.dtype] = None,
         device_map: Optional[Union[str, Dict[str, int]]] = None,
         check_device_map: bool = True,
@@ -666,8 +665,7 @@ def _autoset_attn_implementation(
                 config._attn_implementation_internal = None
         return super()._autoset_attn_implementation(
             config,
-            use_flash_attention_2=use_flash_attention_2,
-            torch_dtype=torch.float16,
+            torch_dtype=torch_dtype,
             device_map=device_map,
             check_device_map=check_device_map,
         )
diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py
index f687324c5af5..7ef052a93cf2 100644
--- a/src/transformers/models/modernbert/modular_modernbert.py
+++ b/src/transformers/models/modernbert/modular_modernbert.py
@@ -845,7 +845,6 @@ def init_weight(module: nn.Module, std: float):
     def _autoset_attn_implementation(
         cls,
         config,
-        use_flash_attention_2: bool = False,
         torch_dtype: Optional[torch.dtype] = None,
         device_map: Optional[Union[str, Dict[str, int]]] = None,
         check_device_map: bool = True,
@@ -868,8 +867,7 @@ def _autoset_attn_implementation(
                 config._attn_implementation_internal = None
         return super()._autoset_attn_implementation(
             config,
-            use_flash_attention_2=use_flash_attention_2,
-            torch_dtype=torch.float16,
+            torch_dtype=torch_dtype,
             device_map=device_map,
             check_device_map=check_device_map,
         )
diff --git a/tests/models/diffllama/test_modeling_diffllama.py b/tests/models/diffllama/test_modeling_diffllama.py
index c738fbf76d1a..b03231396019 100644
--- a/tests/models/diffllama/test_modeling_diffllama.py
+++ b/tests/models/diffllama/test_modeling_diffllama.py
@@ -488,7 +488,7 @@ def test_use_flash_attention_2_true(self):
                 model.save_pretrained(tmp_dir)
 
                 new_model = DiffLlamaForCausalLM.from_pretrained(
-                    tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
+                    tmp_dir, attn_implementation="flash_attention_2", torch_dtype=torch.float16
                 ).to("cuda")
 
                 self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")