| |
| |
| |
| |
| @@ -1,5 +1,6 @@ |
| # coding=utf-8 |
| # Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. |
| +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| @@ -30,10 +31,14 @@ |
| TokenClassifierOutput, |
| ) |
| from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer |
| -from ...utils import auto_docstring, logging |
| +from ...utils import auto_docstring, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging |
| from .configuration_esm import EsmConfig |
| |
| |
| +if is_flash_attn_2_available(): |
| + from ...modeling_flash_attention_utils import _flash_attention_forward |
| + |
| + |
| logger = logging.get_logger(__name__) |
| |
| |
| @@ -111,8 +116,8 @@ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch |
| self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) |
| |
| return ( |
| - apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), |
| - apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), |
| + apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype), |
| + apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype), |
| ) |
| |
| |
| @@ -244,6 +249,8 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): |
| class EsmSelfAttention(nn.Module): |
| def __init__(self, config, position_embedding_type=None): |
| super().__init__() |
| + self.config = config |
| + |
| if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): |
| raise ValueError( |
| f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " |
| @@ -392,10 +399,128 @@ def forward(self, hidden_states, input_tensor): |
| return hidden_states |
| |
| |
| +class EsmFlashAttention2(EsmSelfAttention): |
| + """ |
| + ESM flash attention module. This module inherits from `EsmSelfAttention` as the weights of the module stays |
| + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of |
| + flash attention and deal with padding tokens in case the input contains any of them. |
| + """ |
| + |
| + def __init__(self, config, position_embedding_type=None): |
| + super().__init__(config, position_embedding_type=position_embedding_type) |
| + |
| + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. |
| + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. |
| + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). |
| + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() |
| + self.dropout_prob = config.attention_probs_dropout_prob |
| + |
| + def forward( |
| + self, |
| + hidden_states: torch.Tensor, |
| + attention_mask: Optional[torch.FloatTensor] = None, |
| + head_mask: Optional[torch.FloatTensor] = None, |
| + encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| + encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| + output_attentions: Optional[bool] = False, |
| + ) -> Tuple[torch.Tensor]: |
| + # Flash attention doesn't support output_attentions or cross attention |
| + if output_attentions or head_mask is not None or encoder_hidden_states is not None: |
| + logger.warning_once( |
| + "EsmFlashAttention2 does not support output_attentions, head_mask, or cross_attention. " |
| + "Falling back to the manual attention implementation. This warning can be removed using " |
| + 'the argument `attn_implementation="eager"` when loading the model.' |
| + ) |
| + return super().forward( |
| + hidden_states, |
| + attention_mask, |
| + head_mask, |
| + encoder_hidden_states, |
| + encoder_attention_mask, |
| + past_key_value, |
| + output_attentions, |
| + ) |
| + |
| + bsz, q_len, _ = hidden_states.size() |
| + |
| + query_layer = self.transpose_for_scores(self.query(hidden_states)) |
| + key_layer = self.transpose_for_scores(self.key(hidden_states)) |
| + value_layer = self.transpose_for_scores(self.value(hidden_states)) |
| + if past_key_value is not None: |
| + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) |
| + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) |
| + |
| + # In PEFT, usually we cast the layer norms in float32 for training stability reasons |
| + # therefore the input hidden states gets silently casted in float32. Hence, we need |
| + # cast them back in the correct dtype just to be sure everything works as expected. |
| + # This might slowdown training & inference so it is recommended to not cast the LayerNorms |
| + # in fp32. |
| + input_dtype = query_layer.dtype |
| + if input_dtype == torch.float32: |
| + if torch.is_autocast_enabled(): |
| + target_dtype = torch.get_autocast_gpu_dtype() |
| + # Handle the case where the model is quantized |
| + elif hasattr(self.config, "_pre_quantization_dtype"): |
| + target_dtype = self.config._pre_quantization_dtype |
| + else: |
| + target_dtype = self.query.weight.dtype |
| + |
| + logger.warning_once( |
| + f"The input hidden states seems to be silently casted in float32, this might be related to" |
| + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" |
| + f" {target_dtype}." |
| + ) |
| + |
| + query_layer = query_layer.to(target_dtype) |
| + key_layer = key_layer.to(target_dtype) |
| + value_layer = value_layer.to(target_dtype) |
| + |
| + # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim). |
| + # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent, |
| + # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original |
| + # ESM code and fix rotary embeddings. |
| + query_layer = query_layer * self.attention_head_size**-0.5 |
| + |
| + if self.position_embedding_type == "rotary": |
| + query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) |
| + elif self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": |
| + raise ValueError(f"ESM flash attention does not support {self.position_embedding_type} embeddings") |
| + |
| + # It would likely be faster to change self.transpose_for_scores to output the correct |
| + # dimensions for flash_attention_2, but that would also mean changing the rotary embedding |
| + # functions. Here we just permute the dimensions to match the expected input. |
| + attn_output = _flash_attention_forward( |
| + query_layer.permute(0, 2, 1, 3), |
| + key_layer.permute(0, 2, 1, 3), |
| + value_layer.permute(0, 2, 1, 3), |
| + attention_mask, |
| + query_length=q_len, |
| + is_causal=self.is_decoder, |
| + softmax_scale=1.0, |
| + dropout=self.dropout_prob if self.training else 0.0, |
| + use_top_left_mask=self._flash_attn_uses_top_left_mask, |
| + ) |
| + |
| + attn_output = attn_output.reshape(bsz, q_len, -1) |
| + |
| + outputs = (attn_output, None) |
| + if self.is_decoder: |
| + outputs = outputs + (past_key_value,) |
| + |
| + return outputs |
| + |
| + |
| +ESM_ATTENTION_CLASSES = { |
| + "eager": EsmSelfAttention, |
| + "flash_attention_2": EsmFlashAttention2, |
| +} |
| + |
| + |
| class EsmAttention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| - self.self = EsmSelfAttention(config) |
| + self.self = ESM_ATTENTION_CLASSES[config._attn_implementation](config) |
| self.output = EsmSelfOutput(config) |
| self.pruned_heads = set() |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| @@ -672,6 +797,7 @@ class EsmPreTrainedModel(PreTrainedModel): |
| base_model_prefix = "esm" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"] |
| + _supports_flash_attn_2 = True |
| |
| # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->EsmLMHead |
| def _init_weights(self, module): |
| @@ -805,9 +931,13 @@ def forward( |
| if attention_mask is None: |
| attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) |
| |
| - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] |
| - # ourselves in which case we just need to make it broadcastable to all heads. |
| - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) |
| + if self.config._attn_implementation == "flash_attention_2": |
| + extended_attention_mask = attention_mask |
| + |
| + else: |
| + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] |
| + # ourselves in which case we just need to make it broadcastable to all heads. |
| + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) |
| |
| # If a 2D or 3D attention mask is provided for the cross-attention |
| # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] |
| |
| |
| |
| |
| @@ -1980,6 +1980,7 @@ def distogram(coords, min_bin, max_bin, num_bins): |
| ) |
| class EsmForProteinFolding(EsmPreTrainedModel): |
| _no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"] |
| + _supports_flash_attn_2 = False |
| |
| def __init__(self, config): |
| super().__init__(config) |
| @@ -2050,6 +2051,7 @@ def forward( |
| position_ids: Optional[torch.Tensor] = None, |
| masking_pattern: Optional[torch.Tensor] = None, |
| num_recycles: Optional[int] = None, |
| + output_hidden_states: Optional[bool] = False, |
| ) -> EsmForProteinFoldingOutput: |
| r""" |
| masking_pattern (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| |
| |
| |
| |
| @@ -13,10 +13,22 @@ |
| # limitations under the License. |
| """Testing suite for the PyTorch ESM model.""" |
| |
| +import tempfile |
| import unittest |
| |
| +import pytest |
| + |
| from transformers import EsmConfig, is_torch_available |
| -from transformers.testing_utils import TestCasePlus, require_bitsandbytes, require_torch, slow, torch_device |
| +from transformers.testing_utils import ( |
| + TestCasePlus, |
| + is_flaky, |
| + require_bitsandbytes, |
| + require_flash_attn, |
| + require_torch, |
| + require_torch_gpu, |
| + slow, |
| + torch_device, |
| +) |
| |
| from ...test_configuration_common import ConfigTester |
| from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask |
| @@ -59,6 +71,7 @@ def __init__( |
| num_labels=3, |
| num_choices=4, |
| scope=None, |
| + position_embedding_type="rotary", |
| ): |
| self.parent = parent |
| self.batch_size = batch_size |
| @@ -82,6 +95,7 @@ def __init__( |
| self.num_labels = num_labels |
| self.num_choices = num_choices |
| self.scope = scope |
| + self.position_embedding_type = position_embedding_type |
| |
| def prepare_config_and_inputs(self): |
| input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) |
| @@ -116,6 +130,7 @@ def get_config(self): |
| max_position_embeddings=self.max_position_embeddings, |
| type_vocab_size=self.type_vocab_size, |
| initializer_range=self.initializer_range, |
| + position_embedding_type=self.position_embedding_type, |
| ) |
| |
| def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): |
| @@ -296,6 +311,39 @@ def test_resize_embeddings_untied(self): |
| def test_resize_tokens_embeddings(self): |
| pass |
| |
| + @require_flash_attn |
| + @require_torch_gpu |
| + @pytest.mark.flash_attn_test |
| + @is_flaky() |
| + @slow |
| + def test_flash_attn_2_equivalence(self): |
| + for model_class in self.all_model_classes: |
| + if not model_class._supports_flash_attn_2: |
| + self.skipTest(reason="Model does not support Flash Attention 2") |
| + |
| + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| + model = model_class(config) |
| + |
| + with tempfile.TemporaryDirectory() as tmpdirname: |
| + model.save_pretrained(tmpdirname) |
| + model_fa = model_class.from_pretrained( |
| + tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2" |
| + ) |
| + model_fa.to(torch_device) |
| + |
| + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager") |
| + model.to(torch_device) |
| + |
| + dummy_input = inputs_dict[model_class.main_input_name] |
| + dummy_input = dummy_input.to(torch_device) |
| + outputs = model(dummy_input, output_hidden_states=True) |
| + outputs_fa = model_fa(dummy_input, output_hidden_states=True) |
| + |
| + logits = outputs.hidden_states[-1] |
| + logits_fa = outputs_fa.hidden_states[-1] |
| + |
| + torch.testing.assert_close(logits_fa, logits, atol=1e-2, rtol=1e-3) |
| + |
| |
| @slow |
| @require_torch |
|
|