interactSpeech
/
docs
/transformers
/examples
/modular-transformers
/modeling_from_uppercase_model.py
| # π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨ | |
| # This file was automatically generated from examples/modular-transformers/modular_from_uppercase_model.py. | |
| # Do NOT edit this file manually as any edits will be overwritten by the generation of | |
| # the file from the modular. If any change should be done, please apply the change to the | |
| # modular_from_uppercase_model.py file directly. One of our CI enforces this. | |
| # π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨ | |
| from typing import Optional | |
| import torch | |
| from torch import nn | |
| from ...activations import ACT2FN | |
| from ...pytorch_utils import is_torch_greater_or_equal_than_2_2 | |
| from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging | |
| from .configuration_from_uppercase_model import FromUppercaseModelConfig | |
| if is_flash_attn_2_available(): | |
| from ...modeling_flash_attention_utils import _flash_attention_forward | |
| logger = logging.get_logger(__name__) | |
| class FromUppercaseModelAttention(nn.Module): | |
| """Multi-headed attention from 'Attention Is All You Need' paper""" | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.embed_dim = config.hidden_size | |
| self.num_heads = config.num_attention_heads | |
| self.head_dim = self.embed_dim // self.num_heads | |
| if self.head_dim * self.num_heads != self.embed_dim: | |
| raise ValueError( | |
| f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" | |
| f" {self.num_heads})." | |
| ) | |
| self.scale = self.head_dim**-0.5 | |
| self.dropout = config.attention_dropout | |
| self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) | |
| self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) | |
| self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) | |
| self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) | |
| def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | |
| return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| causal_attention_mask: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = False, | |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """Input shape: Batch x Time x Channel""" | |
| bsz, tgt_len, embed_dim = hidden_states.size() | |
| # get query proj | |
| query_states = self.q_proj(hidden_states) * self.scale | |
| key_states = self._shape(self.k_proj(hidden_states), -1, bsz) | |
| value_states = self._shape(self.v_proj(hidden_states), -1, bsz) | |
| proj_shape = (bsz * self.num_heads, -1, self.head_dim) | |
| query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) | |
| key_states = key_states.view(*proj_shape) | |
| value_states = value_states.view(*proj_shape) | |
| src_len = key_states.size(1) | |
| attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) | |
| if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): | |
| raise ValueError( | |
| f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" | |
| f" {attn_weights.size()}" | |
| ) | |
| # apply the causal_attention_mask first | |
| if causal_attention_mask is not None: | |
| if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): | |
| raise ValueError( | |
| f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" | |
| f" {causal_attention_mask.size()}" | |
| ) | |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask | |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | |
| if attention_mask is not None: | |
| if attention_mask.size() != (bsz, 1, tgt_len, src_len): | |
| raise ValueError( | |
| f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" | |
| ) | |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask | |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1) | |
| if output_attentions: | |
| # this operation is a bit akward, but it's required to | |
| # make sure that attn_weights keeps its gradient. | |
| # In order to do so, attn_weights have to reshaped | |
| # twice and have to be reused in the following | |
| attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) | |
| attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) | |
| else: | |
| attn_weights_reshaped = None | |
| attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) | |
| attn_output = torch.bmm(attn_probs, value_states) | |
| if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): | |
| raise ValueError( | |
| f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" | |
| f" {attn_output.size()}" | |
| ) | |
| attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) | |
| attn_output = attn_output.transpose(1, 2) | |
| attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) | |
| attn_output = self.out_proj(attn_output) | |
| return attn_output, attn_weights_reshaped | |
| class FromUppercaseModelFlashAttention2(FromUppercaseModelAttention): | |
| """ | |
| FromUppercaseModelAttention flash attention module. This module inherits from `FromUppercaseModelAttention` as the weights of the module stays | |
| untouched. The only required change would be on the forward pass where it needs to correctly call the public API of | |
| flash attention and deal with padding tokens in case the input contains any of them. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. | |
| # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. | |
| # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). | |
| self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() | |
| # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| causal_attention_mask: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = False, | |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| output_attentions = False | |
| batch_size, q_len, _ = hidden_states.size() | |
| query_states = self.q_proj(hidden_states) | |
| key_states = self.k_proj(hidden_states) | |
| value_states = self.v_proj(hidden_states) | |
| # Flash attention requires the input to have the shape | |
| # batch_size x seq_length x head_dim x hidden_dim | |
| # therefore we just need to keep the original shape | |
| query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim) | |
| key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim) | |
| value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim) | |
| dropout_rate = self.dropout if self.training else 0.0 | |
| # In PEFT, usually we cast the layer norms in float32 for training stability reasons | |
| # therefore the input hidden states gets silently casted in float32. Hence, we need | |
| # cast them back in the correct dtype just to be sure everything works as expected. | |
| # This might slowdown training & inference so it is recommended to not cast the LayerNorms | |
| # in fp32. | |
| input_dtype = query_states.dtype | |
| if input_dtype == torch.float32: | |
| if torch.is_autocast_enabled(): | |
| target_dtype = torch.get_autocast_gpu_dtype() | |
| # Handle the case where the model is quantized | |
| elif hasattr(self.config, "_pre_quantization_dtype"): | |
| target_dtype = self.config._pre_quantization_dtype | |
| else: | |
| target_dtype = self.q_proj.weight.dtype | |
| logger.warning_once( | |
| f"The input hidden states seems to be silently casted in float32, this might be related to" | |
| f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | |
| f" {target_dtype}." | |
| ) | |
| query_states = query_states.to(target_dtype) | |
| key_states = key_states.to(target_dtype) | |
| value_states = value_states.to(target_dtype) | |
| attn_output = _flash_attention_forward( | |
| query_states, | |
| key_states, | |
| value_states, | |
| attention_mask, | |
| q_len, | |
| dropout=dropout_rate, | |
| is_causal=causal_attention_mask is not None, | |
| use_top_left_mask=self._flash_attn_uses_top_left_mask, | |
| ) | |
| attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() | |
| attn_output = self.out_proj(attn_output) | |
| if not output_attentions: | |
| attn_weights = None | |
| return attn_output, attn_weights | |
| class FromUppercaseModelSdpaAttention(FromUppercaseModelAttention): | |
| """ | |
| SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from | |
| `FromUppercaseModelAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to | |
| SDPA API. | |
| """ | |
| # Adapted from FromUppercaseModelAttention.forward | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| causal_attention_mask: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = False, | |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| if output_attentions: | |
| # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. | |
| logger.warning_once( | |
| "FromUppercaseModelModel is using FromUppercaseModelSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not " | |
| "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying " | |
| "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can " | |
| 'be removed using the argument `attn_implementation="eager"` when loading the model.' | |
| ) | |
| return super().forward( | |
| hidden_states=hidden_states, | |
| attention_mask=attention_mask, | |
| causal_attention_mask=causal_attention_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| # FROM_UPPERCASE_MODEL text model uses both `causal_attention_mask` and `attention_mask` | |
| if attention_mask is not None and causal_attention_mask is not None: | |
| attn_mask = attention_mask + causal_attention_mask | |
| elif causal_attention_mask is not None: | |
| attn_mask = causal_attention_mask | |
| else: | |
| attn_mask = attention_mask | |
| bsz, tgt_len, embed_dim = hidden_states.size() | |
| query_states = self.q_proj(hidden_states) | |
| key_states = self.k_proj(hidden_states) | |
| value_states = self.v_proj(hidden_states) | |
| query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) | |
| key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) | |
| value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) | |
| # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, | |
| # Reference: https://github.com/pytorch/pytorch/issues/112577. | |
| if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None: | |
| query_states = query_states.contiguous() | |
| key_states = key_states.contiguous() | |
| value_states = value_states.contiguous() | |
| # FROM_UPPERCASE_MODEL text model uses both `causal_attention_mask` and `attention_mask` sequentially. | |
| attn_output = torch.nn.functional.scaled_dot_product_attention( | |
| query_states, | |
| key_states, | |
| value_states, | |
| attn_mask=attn_mask, | |
| dropout_p=self.dropout if self.training else 0.0, | |
| scale=self.scale, | |
| ) | |
| attn_output = attn_output.transpose(1, 2) | |
| attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) | |
| attn_output = self.out_proj(attn_output) | |
| return attn_output, None | |
| class FromUppercaseModelMLP(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.activation_fn = ACT2FN[config.hidden_act] | |
| self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) | |
| self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| hidden_states = self.fc1(hidden_states) | |
| hidden_states = self.activation_fn(hidden_states) | |
| hidden_states = self.fc2(hidden_states) | |
| return hidden_states | |
| FROM_UPPERCASE_MODEL_ATTENTION_CLASSES = { | |
| "eager": FromUppercaseModelAttention, | |
| "sdpa": FromUppercaseModelSdpaAttention, | |
| "flash_attention_2": FromUppercaseModelFlashAttention2, | |
| } | |
| class FromUppercaseModelEncoderLayer(nn.Module): | |
| def __init__(self, config: FromUppercaseModelConfig): | |
| super().__init__() | |
| self.embed_dim = config.hidden_size | |
| self.self_attn = FROM_UPPERCASE_MODEL_ATTENTION_CLASSES[config._attn_implementation](config) | |
| self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) | |
| self.mlp = FromUppercaseModelMLP(config) | |
| self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| causal_attention_mask: torch.Tensor, | |
| output_attentions: Optional[bool] = False, | |
| ) -> tuple[torch.FloatTensor]: | |
| """ | |
| Args: | |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | |
| attention_mask (`torch.FloatTensor`): attention mask of size | |
| `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. | |
| `(config.encoder_attention_heads,)`. | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
| returned tensors for more detail. | |
| """ | |
| residual = hidden_states | |
| hidden_states = self.layer_norm1(hidden_states) | |
| hidden_states, attn_weights = self.self_attn( | |
| hidden_states=hidden_states, | |
| attention_mask=attention_mask, | |
| causal_attention_mask=causal_attention_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| hidden_states = residual + hidden_states | |
| residual = hidden_states | |
| hidden_states = self.layer_norm2(hidden_states) | |
| hidden_states = self.mlp(hidden_states) | |
| hidden_states = residual + hidden_states | |
| outputs = (hidden_states,) | |
| if output_attentions: | |
| outputs += (attn_weights,) | |
| return outputs | |