harness / diffs /38023.patch
ArthurZ's picture
ArthurZ HF Staff
Initial harness: 100 perf tasks + Gradio browser
dfefe0b verified
diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py
index 28c6d249c723..d9e101ec1039 100755
--- a/src/transformers/models/esm/modeling_esm.py
+++ b/src/transformers/models/esm/modeling_esm.py
@@ -1,5 +1,6 @@
# coding=utf-8
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,10 +31,14 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
-from ...utils import auto_docstring, logging
+from ...utils import auto_docstring, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
from .configuration_esm import EsmConfig
+if is_flash_attn_2_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+
logger = logging.get_logger(__name__)
@@ -111,8 +116,8 @@ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
return (
- apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
- apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype),
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype),
)
@@ -244,6 +249,8 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
class EsmSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
+ self.config = config
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
@@ -392,10 +399,128 @@ def forward(self, hidden_states, input_tensor):
return hidden_states
+class EsmFlashAttention2(EsmSelfAttention):
+ """
+ ESM flash attention module. This module inherits from `EsmSelfAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, config, position_embedding_type=None):
+ super().__init__(config, position_embedding_type=position_embedding_type)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self.dropout_prob = config.attention_probs_dropout_prob
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ # Flash attention doesn't support output_attentions or cross attention
+ if output_attentions or head_mask is not None or encoder_hidden_states is not None:
+ logger.warning_once(
+ "EsmFlashAttention2 does not support output_attentions, head_mask, or cross_attention. "
+ "Falling back to the manual attention implementation. This warning can be removed using "
+ 'the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ if past_key_value is not None:
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32.
+ input_dtype = query_layer.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.query.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_layer = query_layer.to(target_dtype)
+ key_layer = key_layer.to(target_dtype)
+ value_layer = value_layer.to(target_dtype)
+
+ # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
+ # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
+ # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
+ # ESM code and fix rotary embeddings.
+ query_layer = query_layer * self.attention_head_size**-0.5
+
+ if self.position_embedding_type == "rotary":
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
+ elif self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ raise ValueError(f"ESM flash attention does not support {self.position_embedding_type} embeddings")
+
+ # It would likely be faster to change self.transpose_for_scores to output the correct
+ # dimensions for flash_attention_2, but that would also mean changing the rotary embedding
+ # functions. Here we just permute the dimensions to match the expected input.
+ attn_output = _flash_attention_forward(
+ query_layer.permute(0, 2, 1, 3),
+ key_layer.permute(0, 2, 1, 3),
+ value_layer.permute(0, 2, 1, 3),
+ attention_mask,
+ query_length=q_len,
+ is_causal=self.is_decoder,
+ softmax_scale=1.0,
+ dropout=self.dropout_prob if self.training else 0.0,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+
+ outputs = (attn_output, None)
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+
+ return outputs
+
+
+ESM_ATTENTION_CLASSES = {
+ "eager": EsmSelfAttention,
+ "flash_attention_2": EsmFlashAttention2,
+}
+
+
class EsmAttention(nn.Module):
def __init__(self, config):
super().__init__()
- self.self = EsmSelfAttention(config)
+ self.self = ESM_ATTENTION_CLASSES[config._attn_implementation](config)
self.output = EsmSelfOutput(config)
self.pruned_heads = set()
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -672,6 +797,7 @@ class EsmPreTrainedModel(PreTrainedModel):
base_model_prefix = "esm"
supports_gradient_checkpointing = True
_no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"]
+ _supports_flash_attn_2 = True
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->EsmLMHead
def _init_weights(self, module):
@@ -805,9 +931,13 @@ def forward(
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
- # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
- # ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
+ if self.config._attn_implementation == "flash_attention_2":
+ extended_attention_mask = attention_mask
+
+ else:
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
diff --git a/src/transformers/models/esm/modeling_esmfold.py b/src/transformers/models/esm/modeling_esmfold.py
index c47f87b408a7..203aa9a69a39 100644
--- a/src/transformers/models/esm/modeling_esmfold.py
+++ b/src/transformers/models/esm/modeling_esmfold.py
@@ -1980,6 +1980,7 @@ def distogram(coords, min_bin, max_bin, num_bins):
)
class EsmForProteinFolding(EsmPreTrainedModel):
_no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"]
+ _supports_flash_attn_2 = False
def __init__(self, config):
super().__init__(config)
@@ -2050,6 +2051,7 @@ def forward(
position_ids: Optional[torch.Tensor] = None,
masking_pattern: Optional[torch.Tensor] = None,
num_recycles: Optional[int] = None,
+ output_hidden_states: Optional[bool] = False,
) -> EsmForProteinFoldingOutput:
r"""
masking_pattern (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
diff --git a/tests/models/esm/test_modeling_esm.py b/tests/models/esm/test_modeling_esm.py
index 74f4c277d092..18887bb5927c 100644
--- a/tests/models/esm/test_modeling_esm.py
+++ b/tests/models/esm/test_modeling_esm.py
@@ -13,10 +13,22 @@
# limitations under the License.
"""Testing suite for the PyTorch ESM model."""
+import tempfile
import unittest
+import pytest
+
from transformers import EsmConfig, is_torch_available
-from transformers.testing_utils import TestCasePlus, require_bitsandbytes, require_torch, slow, torch_device
+from transformers.testing_utils import (
+ TestCasePlus,
+ is_flaky,
+ require_bitsandbytes,
+ require_flash_attn,
+ require_torch,
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
@@ -59,6 +71,7 @@ def __init__(
num_labels=3,
num_choices=4,
scope=None,
+ position_embedding_type="rotary",
):
self.parent = parent
self.batch_size = batch_size
@@ -82,6 +95,7 @@ def __init__(
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
+ self.position_embedding_type = position_embedding_type
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@@ -116,6 +130,7 @@ def get_config(self):
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
+ position_embedding_type=self.position_embedding_type,
)
def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
@@ -296,6 +311,39 @@ def test_resize_embeddings_untied(self):
def test_resize_tokens_embeddings(self):
pass
+ @require_flash_attn
+ @require_torch_gpu
+ @pytest.mark.flash_attn_test
+ @is_flaky()
+ @slow
+ def test_flash_attn_2_equivalence(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_2:
+ self.skipTest(reason="Model does not support Flash Attention 2")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
+ model.to(torch_device)
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ dummy_input = dummy_input.to(torch_device)
+ outputs = model(dummy_input, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, output_hidden_states=True)
+
+ logits = outputs.hidden_states[-1]
+ logits_fa = outputs_fa.hidden_states[-1]
+
+ torch.testing.assert_close(logits_fa, logits, atol=1e-2, rtol=1e-3)
+
@slow
@require_torch