diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 28c6d249c723..d9e101ec1039 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -1,5 +1,6 @@ # coding=utf-8 # Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,10 +31,14 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import auto_docstring, logging +from ...utils import auto_docstring, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging from .configuration_esm import EsmConfig +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + logger = logging.get_logger(__name__) @@ -111,8 +116,8 @@ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) return ( - apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), - apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), + apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype), + apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype), ) @@ -244,6 +249,8 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): class EsmSelfAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " @@ -392,10 +399,128 @@ def forward(self, hidden_states, input_tensor): return hidden_states +class EsmFlashAttention2(EsmSelfAttention): + """ + ESM flash attention module. This module inherits from `EsmSelfAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, config, position_embedding_type=None): + super().__init__(config, position_embedding_type=position_embedding_type) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # Flash attention doesn't support output_attentions or cross attention + if output_attentions or head_mask is not None or encoder_hidden_states is not None: + logger.warning_once( + "EsmFlashAttention2 does not support output_attentions, head_mask, or cross_attention. " + "Falling back to the manual attention implementation. This warning can be removed using " + 'the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + bsz, q_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + if past_key_value is not None: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + input_dtype = query_layer.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.query.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_layer = query_layer.to(target_dtype) + key_layer = key_layer.to(target_dtype) + value_layer = value_layer.to(target_dtype) + + # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim). + # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent, + # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original + # ESM code and fix rotary embeddings. + query_layer = query_layer * self.attention_head_size**-0.5 + + if self.position_embedding_type == "rotary": + query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) + elif self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + raise ValueError(f"ESM flash attention does not support {self.position_embedding_type} embeddings") + + # It would likely be faster to change self.transpose_for_scores to output the correct + # dimensions for flash_attention_2, but that would also mean changing the rotary embedding + # functions. Here we just permute the dimensions to match the expected input. + attn_output = _flash_attention_forward( + query_layer.permute(0, 2, 1, 3), + key_layer.permute(0, 2, 1, 3), + value_layer.permute(0, 2, 1, 3), + attention_mask, + query_length=q_len, + is_causal=self.is_decoder, + softmax_scale=1.0, + dropout=self.dropout_prob if self.training else 0.0, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + + outputs = (attn_output, None) + if self.is_decoder: + outputs = outputs + (past_key_value,) + + return outputs + + +ESM_ATTENTION_CLASSES = { + "eager": EsmSelfAttention, + "flash_attention_2": EsmFlashAttention2, +} + + class EsmAttention(nn.Module): def __init__(self, config): super().__init__() - self.self = EsmSelfAttention(config) + self.self = ESM_ATTENTION_CLASSES[config._attn_implementation](config) self.output = EsmSelfOutput(config) self.pruned_heads = set() self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -672,6 +797,7 @@ class EsmPreTrainedModel(PreTrainedModel): base_model_prefix = "esm" supports_gradient_checkpointing = True _no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"] + _supports_flash_attn_2 = True # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->EsmLMHead def _init_weights(self, module): @@ -805,9 +931,13 @@ def forward( if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config._attn_implementation == "flash_attention_2": + extended_attention_mask = attention_mask + + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] diff --git a/src/transformers/models/esm/modeling_esmfold.py b/src/transformers/models/esm/modeling_esmfold.py index c47f87b408a7..203aa9a69a39 100644 --- a/src/transformers/models/esm/modeling_esmfold.py +++ b/src/transformers/models/esm/modeling_esmfold.py @@ -1980,6 +1980,7 @@ def distogram(coords, min_bin, max_bin, num_bins): ) class EsmForProteinFolding(EsmPreTrainedModel): _no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"] + _supports_flash_attn_2 = False def __init__(self, config): super().__init__(config) @@ -2050,6 +2051,7 @@ def forward( position_ids: Optional[torch.Tensor] = None, masking_pattern: Optional[torch.Tensor] = None, num_recycles: Optional[int] = None, + output_hidden_states: Optional[bool] = False, ) -> EsmForProteinFoldingOutput: r""" masking_pattern (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): diff --git a/tests/models/esm/test_modeling_esm.py b/tests/models/esm/test_modeling_esm.py index 74f4c277d092..18887bb5927c 100644 --- a/tests/models/esm/test_modeling_esm.py +++ b/tests/models/esm/test_modeling_esm.py @@ -13,10 +13,22 @@ # limitations under the License. """Testing suite for the PyTorch ESM model.""" +import tempfile import unittest +import pytest + from transformers import EsmConfig, is_torch_available -from transformers.testing_utils import TestCasePlus, require_bitsandbytes, require_torch, slow, torch_device +from transformers.testing_utils import ( + TestCasePlus, + is_flaky, + require_bitsandbytes, + require_flash_attn, + require_torch, + require_torch_gpu, + slow, + torch_device, +) from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask @@ -59,6 +71,7 @@ def __init__( num_labels=3, num_choices=4, scope=None, + position_embedding_type="rotary", ): self.parent = parent self.batch_size = batch_size @@ -82,6 +95,7 @@ def __init__( self.num_labels = num_labels self.num_choices = num_choices self.scope = scope + self.position_embedding_type = position_embedding_type def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -116,6 +130,7 @@ def get_config(self): max_position_embeddings=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, initializer_range=self.initializer_range, + position_embedding_type=self.position_embedding_type, ) def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): @@ -296,6 +311,39 @@ def test_resize_embeddings_untied(self): def test_resize_tokens_embeddings(self): pass + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @is_flaky() + @slow + def test_flash_attn_2_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(reason="Model does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager") + model.to(torch_device) + + dummy_input = inputs_dict[model_class.main_input_name] + dummy_input = dummy_input.to(torch_device) + outputs = model(dummy_input, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, output_hidden_states=True) + + logits = outputs.hidden_states[-1] + logits_fa = outputs_fa.hidden_states[-1] + + torch.testing.assert_close(logits_fa, logits, atol=1e-2, rtol=1e-3) + @slow @require_torch