ProtT3_model / model /esm_flash_attention.py
yuccaaa's picture
Add files using upload-large-folder tool
4d12519 verified
import torch
import logging
import torch.nn as nn
import transformers
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
from einops import rearrange
from typing import List, Optional, Tuple, Union
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
assert encoder_hidden_states is None, "Cross-attention is not supported for ESM"
assert past_key_value is None, "Past key value is not supported for ESM"
assert self.is_decoder is False, "Decoder is not supported for ESM"
assert self.position_embedding_type == "rotary", "Rotary embeddings are required for ESM"
assert head_mask is None, "Head mask is not supported for ESM"
assert output_attentions is False, "Output attentions is not supported for ESM"
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
query_layer = query_layer * self.attention_head_size**-0.5
if self.position_embedding_type == "rotary":
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
## query_layer.shape [bsz, nh, t, hd]; key_layer.shape [bsz, nh, t, hd]; value_layer.shape [bsz, nh, t, hd]
qkv = torch.stack([query_layer, key_layer, value_layer], dim=2) # [bsz, nh, 3, t, hd]
qkv = qkv.transpose(1,3) # shape = [bsz, t, 3, nh, hd]
assert attention_mask is not None
key_padding_mask = attention_mask
bsz, q_len, _ = hidden_states.size()
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
x_unpad = x_unpad.to(torch.bfloat16)
output_unpad = flash_attn_varlen_qkvpacked_func(x_unpad, cu_q_lens, max_s, self.dropout.p if self.training else 0.0, softmax_scale=1, causal=False)
if False:
outputs = rearrange(pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len), "b s (h d) -> b s h d", h=nheads)
outputs = rearrange(outputs, "b s h d -> b s (h d)")
else:
outputs = pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len)
return (outputs,)
def get_extended_attention_mask(
self, attention_mask: torch.Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None
) -> torch.Tensor:
return attention_mask
def forward_original(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
# Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
# ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
# but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
# ESM code and fix rotary embeddings.
query_layer = query_layer * self.attention_head_size**-0.5
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
if self.position_embedding_type == "rotary":
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in EsmModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
def get_extended_attention_mask_original(
self, attention_mask: torch.Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None
) -> torch.Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.
Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""
if dtype is None:
dtype = self.dtype
if not (attention_mask.dim() == 2 and self.config.is_decoder):
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
if device is not None:
print(
"The `device` argument is deprecated and will be removed in v5 of Transformers."
)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder:
extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
input_shape, attention_mask, device
)
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
return extended_attention_mask
def replace_esm_attn_with_flash_attn():
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if cuda_major < 8:
logging.warning(
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
)
transformers.models.esm.modeling_esm.EsmModel.get_extended_attention_mask = get_extended_attention_mask
transformers.models.esm.modeling_esm.EsmSelfAttention.forward = forward
def replace_flash_attn_with_esm_attn():
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if cuda_major < 8:
logging.warning(
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
)
transformers.models.esm.modeling_esm.EsmModel.get_extended_attention_mask = get_extended_attention_mask_original
transformers.models.esm.modeling_esm.EsmSelfAttention.forward = forward_original
if __name__ == '__main__':
pass