| |
| |
| |
| |
| @@ -1978,8 +1978,6 @@ def _from_config(cls, config, **kwargs): |
| if isinstance(torch_dtype, str): |
| torch_dtype = getattr(torch, torch_dtype) |
| |
| - use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) |
| - |
| # override default dtype if needed |
| dtype_orig = None |
| if torch_dtype is not None: |
| @@ -1998,7 +1996,6 @@ def _from_config(cls, config, **kwargs): |
| if not getattr(config, "_attn_implementation_autoset", False): |
| config = cls._autoset_attn_implementation( |
| config, |
| - use_flash_attention_2=use_flash_attention_2, |
| check_device_map=False, |
| torch_dtype=torch_dtype, |
| ) |
| @@ -2024,7 +2021,6 @@ def _from_config(cls, config, **kwargs): |
| def _autoset_attn_implementation( |
| cls, |
| config, |
| - use_flash_attention_2: bool = False, |
| torch_dtype: Optional[torch.dtype] = None, |
| device_map: Optional[Union[str, Dict[str, int]]] = None, |
| check_device_map: bool = True, |
| @@ -2032,21 +2028,14 @@ def _autoset_attn_implementation( |
| """ |
| Automatically checks and dispatches to a default attention implementation. In order of priority: |
| 1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained). |
| - 2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example) |
| - 3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example) |
| - 4. The default model's implementation otherwise (`LlamaAttention` for example) . |
| + 2. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example) |
| + 3. The default model's implementation otherwise (`LlamaAttention` for example) . |
| """ |
| # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitly set by the user. |
| # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager"). |
| # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model) |
| requested_attn_implementation = None |
| if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None: |
| - if config._attn_implementation != "flash_attention_2" and use_flash_attention_2: |
| - raise ValueError( |
| - f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible.' |
| - ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' |
| - ) |
| - |
| if isinstance(config._attn_implementation, str) and re.match( |
| r"^[^/:]+/[^/:]+:[^/:]+$", config._attn_implementation |
| ): |
| @@ -2111,12 +2100,6 @@ def _autoset_attn_implementation( |
| if sub_config is not None: |
| sub_config._attn_implementation_internal = curr_attn_implementation |
| |
| - if use_flash_attention_2: |
| - logger.warning_once( |
| - 'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.' |
| - ) |
| - config._attn_implementation = "flash_attention_2" |
| - |
| if config._attn_implementation == "flash_attention_2": |
| cls._check_and_enable_flash_attn_2( |
| config, |
| @@ -2128,10 +2111,10 @@ def _autoset_attn_implementation( |
| elif requested_attn_implementation == "flex_attention": |
| config = cls._check_and_enable_flex_attn(config, hard_check_only=True) |
| elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available(): |
| - # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. |
| + # flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. |
| config = cls._check_and_enable_sdpa( |
| config, |
| - hard_check_only=False if requested_attn_implementation is None else True, |
| + hard_check_only=requested_attn_implementation is not None, |
| ) |
| |
| if ( |
| @@ -4022,7 +4005,6 @@ def from_pretrained( |
| variant = kwargs.pop("variant", None) |
| adapter_kwargs = kwargs.pop("adapter_kwargs", {}) |
| adapter_name = kwargs.pop("adapter_name", "default") |
| - use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) |
| generation_config = kwargs.pop("generation_config", None) |
| gguf_file = kwargs.pop("gguf_file", None) |
| tp_plan = kwargs.pop("tp_plan", None) |
| @@ -4403,7 +4385,6 @@ def from_pretrained( |
| if not getattr(config, "_attn_implementation_autoset", False): |
| config = cls._autoset_attn_implementation( |
| config, |
| - use_flash_attention_2=use_flash_attention_2, |
| torch_dtype=torch_dtype, |
| device_map=device_map, |
| ) |
| |
| |
| |
| |
| @@ -643,7 +643,6 @@ def init_weight(module: nn.Module, std: float): |
| def _autoset_attn_implementation( |
| cls, |
| config, |
| - use_flash_attention_2: bool = False, |
| torch_dtype: Optional[torch.dtype] = None, |
| device_map: Optional[Union[str, Dict[str, int]]] = None, |
| check_device_map: bool = True, |
| @@ -666,8 +665,7 @@ def _autoset_attn_implementation( |
| config._attn_implementation_internal = None |
| return super()._autoset_attn_implementation( |
| config, |
| - use_flash_attention_2=use_flash_attention_2, |
| - torch_dtype=torch.float16, |
| + torch_dtype=torch_dtype, |
| device_map=device_map, |
| check_device_map=check_device_map, |
| ) |
| |
| |
| |
| |
| @@ -845,7 +845,6 @@ def init_weight(module: nn.Module, std: float): |
| def _autoset_attn_implementation( |
| cls, |
| config, |
| - use_flash_attention_2: bool = False, |
| torch_dtype: Optional[torch.dtype] = None, |
| device_map: Optional[Union[str, Dict[str, int]]] = None, |
| check_device_map: bool = True, |
| @@ -868,8 +867,7 @@ def _autoset_attn_implementation( |
| config._attn_implementation_internal = None |
| return super()._autoset_attn_implementation( |
| config, |
| - use_flash_attention_2=use_flash_attention_2, |
| - torch_dtype=torch.float16, |
| + torch_dtype=torch_dtype, |
| device_map=device_map, |
| check_device_map=check_device_map, |
| ) |
| |
| |
| |
| |
| @@ -488,7 +488,7 @@ def test_use_flash_attention_2_true(self): |
| model.save_pretrained(tmp_dir) |
| |
| new_model = DiffLlamaForCausalLM.from_pretrained( |
| - tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16 |
| + tmp_dir, attn_implementation="flash_attention_2", torch_dtype=torch.float16 |
| ).to("cuda") |
| |
| self.assertTrue(new_model.config._attn_implementation == "flash_attention_2") |
|
|