| | import torch |
| | import logging |
| | import torch.nn as nn |
| | import transformers |
| | from flash_attn.bert_padding import unpad_input, pad_input |
| | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func |
| | from einops import rearrange |
| | from typing import List, Optional, Tuple, Union |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | head_mask: Optional[torch.FloatTensor] = None, |
| | encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| | encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| | past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| | output_attentions: Optional[bool] = False, |
| | ) -> Tuple[torch.Tensor]: |
| | mixed_query_layer = self.query(hidden_states) |
| | assert encoder_hidden_states is None, "Cross-attention is not supported for ESM" |
| | assert past_key_value is None, "Past key value is not supported for ESM" |
| | assert self.is_decoder is False, "Decoder is not supported for ESM" |
| | assert self.position_embedding_type == "rotary", "Rotary embeddings are required for ESM" |
| | assert head_mask is None, "Head mask is not supported for ESM" |
| | assert output_attentions is False, "Output attentions is not supported for ESM" |
| | key_layer = self.transpose_for_scores(self.key(hidden_states)) |
| | value_layer = self.transpose_for_scores(self.value(hidden_states)) |
| | query_layer = self.transpose_for_scores(mixed_query_layer) |
| |
|
| | query_layer = query_layer * self.attention_head_size**-0.5 |
| |
|
| | if self.position_embedding_type == "rotary": |
| | query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) |
| | |
| | |
| | qkv = torch.stack([query_layer, key_layer, value_layer], dim=2) |
| | qkv = qkv.transpose(1,3) |
| | assert attention_mask is not None |
| | key_padding_mask = attention_mask |
| | bsz, q_len, _ = hidden_states.size() |
| | nheads = qkv.shape[-2] |
| | x = rearrange(qkv, "b s three h d -> b s (three h d)") |
| | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) |
| | x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads) |
| | x_unpad = x_unpad.to(torch.bfloat16) |
| | output_unpad = flash_attn_varlen_qkvpacked_func(x_unpad, cu_q_lens, max_s, self.dropout.p if self.training else 0.0, softmax_scale=1, causal=False) |
| | if False: |
| | outputs = rearrange(pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len), "b s (h d) -> b s h d", h=nheads) |
| | outputs = rearrange(outputs, "b s h d -> b s (h d)") |
| | else: |
| | outputs = pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len) |
| | return (outputs,) |
| |
|
| |
|
| | def get_extended_attention_mask( |
| | self, attention_mask: torch.Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None |
| | ) -> torch.Tensor: |
| | return attention_mask |
| |
|
| |
|
| | def forward_original( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | head_mask: Optional[torch.FloatTensor] = None, |
| | encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| | encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| | past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| | output_attentions: Optional[bool] = False, |
| | ) -> Tuple[torch.Tensor]: |
| | mixed_query_layer = self.query(hidden_states) |
| |
|
| | |
| | |
| | |
| | is_cross_attention = encoder_hidden_states is not None |
| |
|
| | if is_cross_attention and past_key_value is not None: |
| | |
| | key_layer = past_key_value[0] |
| | value_layer = past_key_value[1] |
| | attention_mask = encoder_attention_mask |
| | elif is_cross_attention: |
| | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) |
| | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) |
| | attention_mask = encoder_attention_mask |
| | elif past_key_value is not None: |
| | key_layer = self.transpose_for_scores(self.key(hidden_states)) |
| | value_layer = self.transpose_for_scores(self.value(hidden_states)) |
| | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) |
| | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) |
| | else: |
| | key_layer = self.transpose_for_scores(self.key(hidden_states)) |
| | value_layer = self.transpose_for_scores(self.value(hidden_states)) |
| |
|
| | query_layer = self.transpose_for_scores(mixed_query_layer) |
| |
|
| | |
| | |
| | |
| | |
| | query_layer = query_layer * self.attention_head_size**-0.5 |
| |
|
| | if self.is_decoder: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | past_key_value = (key_layer, value_layer) |
| |
|
| | if self.position_embedding_type == "rotary": |
| | query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) |
| |
|
| | |
| | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
| |
|
| | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": |
| | seq_length = hidden_states.size()[1] |
| | position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) |
| | position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) |
| | distance = position_ids_l - position_ids_r |
| | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) |
| | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) |
| |
|
| | if self.position_embedding_type == "relative_key": |
| | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
| | attention_scores = attention_scores + relative_position_scores |
| | elif self.position_embedding_type == "relative_key_query": |
| | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
| | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) |
| | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key |
| |
|
| | if attention_mask is not None: |
| | |
| | attention_scores = attention_scores + attention_mask |
| |
|
| | |
| | attention_probs = nn.functional.softmax(attention_scores, dim=-1) |
| |
|
| | |
| | |
| | attention_probs = self.dropout(attention_probs) |
| |
|
| | |
| | if head_mask is not None: |
| | attention_probs = attention_probs * head_mask |
| |
|
| | context_layer = torch.matmul(attention_probs, value_layer) |
| |
|
| | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| | context_layer = context_layer.view(new_context_layer_shape) |
| |
|
| | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) |
| |
|
| | if self.is_decoder: |
| | outputs = outputs + (past_key_value,) |
| | return outputs |
| |
|
| |
|
| | def get_extended_attention_mask_original( |
| | self, attention_mask: torch.Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None |
| | ) -> torch.Tensor: |
| | """ |
| | Makes broadcastable attention and causal masks so that future and masked tokens are ignored. |
| | |
| | Arguments: |
| | attention_mask (`torch.Tensor`): |
| | Mask with ones indicating tokens to attend to, zeros for tokens to ignore. |
| | input_shape (`Tuple[int]`): |
| | The shape of the input to the model. |
| | |
| | Returns: |
| | `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. |
| | """ |
| | if dtype is None: |
| | dtype = self.dtype |
| |
|
| | if not (attention_mask.dim() == 2 and self.config.is_decoder): |
| | |
| | if device is not None: |
| | print( |
| | "The `device` argument is deprecated and will be removed in v5 of Transformers." |
| | ) |
| | |
| | |
| | if attention_mask.dim() == 3: |
| | extended_attention_mask = attention_mask[:, None, :, :] |
| | elif attention_mask.dim() == 2: |
| | |
| | |
| | |
| | if self.config.is_decoder: |
| | extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( |
| | input_shape, attention_mask, device |
| | ) |
| | else: |
| | extended_attention_mask = attention_mask[:, None, None, :] |
| | else: |
| | raise ValueError( |
| | f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | extended_attention_mask = extended_attention_mask.to(dtype=dtype) |
| | extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min |
| | return extended_attention_mask |
| |
|
| |
|
| | def replace_esm_attn_with_flash_attn(): |
| | cuda_major, cuda_minor = torch.cuda.get_device_capability() |
| | if cuda_major < 8: |
| | logging.warning( |
| | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." |
| | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" |
| | ) |
| | |
| | transformers.models.esm.modeling_esm.EsmModel.get_extended_attention_mask = get_extended_attention_mask |
| | transformers.models.esm.modeling_esm.EsmSelfAttention.forward = forward |
| |
|
| |
|
| | def replace_flash_attn_with_esm_attn(): |
| | cuda_major, cuda_minor = torch.cuda.get_device_capability() |
| | if cuda_major < 8: |
| | logging.warning( |
| | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." |
| | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" |
| | ) |
| | transformers.models.esm.modeling_esm.EsmModel.get_extended_attention_mask = get_extended_attention_mask_original |
| | transformers.models.esm.modeling_esm.EsmSelfAttention.forward = forward_original |
| |
|
| | if __name__ == '__main__': |
| | pass |