Spaces:
Sleeping
Sleeping
| # Copyright (c) 2024, Tri Dao. | |
| from functools import partial | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from torchvision.ops import StochasticDepth | |
| from flash_attn.modules.mha import MHA | |
| from flash_attn.modules.mlp import Mlp | |
| try: | |
| from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm | |
| except ImportError: | |
| layer_norm_fn, RMSNorm = None, None | |
| class Block(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| mixer_cls=None, | |
| mlp_cls=None, | |
| norm_cls=nn.LayerNorm, | |
| dropout_cls=nn.Dropout, | |
| prenorm=True, | |
| resid_dropout1=0.0, | |
| resid_dropout2=0.0, | |
| drop_path1=0.0, | |
| drop_path2=0.0, | |
| fused_dropout_add_ln=False, | |
| return_residual=False, | |
| residual_in_fp32=False, | |
| sequence_parallel=False, | |
| mark_shared_params=False, | |
| ): | |
| """ | |
| For prenorm=True, this Block has a slightly different structure compared to a regular | |
| prenorm Transformer block. | |
| The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add. | |
| [Ref: https://arxiv.org/abs/2002.04745] | |
| Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both | |
| the hidden_states (output of the MLP) and the residual. | |
| This is for performance reasons, as we can fuse the dropout, add and LayerNorm. | |
| The residual needs to be provided (except for the very first block). | |
| For prenorm=False, this Block has the same structure as a regular postnorm Transformer | |
| block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN. | |
| return_residual: whether each of the sub-layers (mixer and mlp) will return the residual. | |
| This is for performance reason: for post-norm architecture, returning the input allows us | |
| to fuse the backward of nn.Linear with the residual connection. | |
| """ | |
| super().__init__() | |
| self.prenorm = prenorm | |
| self.fused_dropout_add_ln = fused_dropout_add_ln | |
| self.return_residual = return_residual | |
| self.residual_in_fp32 = residual_in_fp32 | |
| if self.residual_in_fp32: | |
| assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True" | |
| if mixer_cls is None: | |
| mixer_cls = partial(MHA, num_heads=dim // 64) | |
| if mlp_cls is None: | |
| mlp_cls = partial(Mlp, hidden_features=4 * dim) | |
| self.mixer = mixer_cls(dim) | |
| self.dropout1 = dropout_cls(resid_dropout1) | |
| self.drop_path1 = StochasticDepth(drop_path1, mode="row") | |
| self.norm1 = norm_cls(dim) | |
| self.mlp = mlp_cls(dim) | |
| if not isinstance(self.mlp, nn.Identity): | |
| self.dropout2 = dropout_cls(resid_dropout2) | |
| self.drop_path2 = StochasticDepth(drop_path2, mode="row") | |
| self.norm2 = norm_cls(dim) | |
| if self.fused_dropout_add_ln: | |
| assert layer_norm_fn is not None, "Triton is not installed" | |
| assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance( | |
| self.dropout1, nn.Dropout | |
| ) | |
| # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, | |
| # then the input to each worker in the tensor parallel group will be different. | |
| # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers. | |
| # For now this is not an issue because we always use sequence_parallel=True during training | |
| # and only use sequence_parallel=False during inference. | |
| # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads. | |
| if sequence_parallel: | |
| for p in self.norm1.parameters(): | |
| p._sequence_parallel = True | |
| if hasattr(self, "norm2"): | |
| for p in self.norm2.parameters(): | |
| p._sequence_parallel = True | |
| # Mark the norm parameters as "shared_params" so that we sync their values at init. | |
| if mark_shared_params: | |
| for p in self.norm1.parameters(): | |
| p._shared_params = True | |
| if hasattr(self, "norm2"): | |
| for p in self.norm2.parameters(): | |
| p._shared_params = True | |
| def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): | |
| return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) | |
| def forward( | |
| self, | |
| hidden_states: Tensor, | |
| residual: Optional[Tensor] = None, | |
| mixer_subset=None, | |
| mixer_kwargs=None, | |
| ): | |
| r"""Pass the input through the encoder layer. | |
| Args: | |
| hidden_states: the sequence to the encoder layer (required). | |
| residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual)) | |
| mixer_subset: for cross-attention only. If not None, will take a subset of x | |
| before applying the query projection. Useful for e.g., ViT where we only care | |
| about the CLS token in the last layer. | |
| """ | |
| if self.prenorm: | |
| if not self.fused_dropout_add_ln: | |
| dropped = self.drop_path1(self.dropout1(hidden_states)) | |
| residual = (dropped + residual) if residual is not None else dropped | |
| hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) | |
| if self.residual_in_fp32: | |
| residual = residual.to(torch.float32) | |
| else: | |
| if self.drop_path1.p == 0 or not self.training: | |
| rowscale1 = None | |
| else: | |
| rowscale1 = self.drop_path1( | |
| torch.ones( | |
| hidden_states.shape[:-1], | |
| device=hidden_states.device, | |
| dtype=hidden_states.dtype, | |
| ) | |
| ) | |
| hidden_states, residual = layer_norm_fn( | |
| hidden_states, | |
| self.norm1.weight, | |
| self.norm1.bias, | |
| residual=residual, | |
| eps=self.norm1.eps, | |
| dropout_p=self.dropout1.p if self.training else 0.0, | |
| rowscale=rowscale1, | |
| prenorm=True, | |
| residual_in_fp32=self.residual_in_fp32, | |
| is_rms_norm=isinstance(self.norm1, RMSNorm) | |
| ) | |
| if mixer_kwargs is None: | |
| mixer_kwargs = {} | |
| if mixer_subset is not None: | |
| mixer_kwargs["mixer_subset"] = mixer_subset | |
| hidden_states = self.mixer(hidden_states, **mixer_kwargs) | |
| if mixer_subset is not None: | |
| residual = residual[:, mixer_subset] | |
| if not isinstance(self.mlp, nn.Identity): | |
| if not self.fused_dropout_add_ln: | |
| dropped = self.drop_path2(self.dropout2(hidden_states)) | |
| residual = (dropped + residual) if residual is not None else dropped | |
| hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) | |
| if self.residual_in_fp32: | |
| residual = residual.to(torch.float32) | |
| else: | |
| if self.drop_path2.p == 0 or not self.training: | |
| rowscale2 = None | |
| else: | |
| rowscale2 = self.drop_path2( | |
| torch.ones( | |
| hidden_states.shape[:-1], | |
| device=hidden_states.device, | |
| dtype=hidden_states.dtype, | |
| ) | |
| ) | |
| hidden_states, residual = layer_norm_fn( | |
| hidden_states, | |
| self.norm2.weight, | |
| self.norm2.bias, | |
| residual=residual, | |
| eps=self.norm2.eps, | |
| dropout_p=self.dropout2.p if self.training else 0.0, | |
| rowscale=rowscale2, | |
| prenorm=True, | |
| residual_in_fp32=self.residual_in_fp32, | |
| is_rms_norm=isinstance(self.norm2, RMSNorm) | |
| ) | |
| hidden_states = self.mlp(hidden_states) | |
| return hidden_states, residual | |
| else: | |
| assert residual is None | |
| mixer_out = self.mixer( | |
| hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {}) | |
| ) | |
| if self.return_residual: # mixer out is actually a pair here | |
| mixer_out, hidden_states = mixer_out | |
| if not self.fused_dropout_add_ln: | |
| hidden_states = self.norm1( | |
| (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to( | |
| dtype=self.norm1.weight.dtype | |
| ) | |
| ) | |
| else: | |
| if self.drop_path1.p == 0 or not self.training: | |
| rowscale1 = None | |
| else: | |
| rowscale1 = self.drop_path1( | |
| torch.ones( | |
| mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype | |
| ) | |
| ) | |
| hidden_states = layer_norm_fn( | |
| mixer_out, | |
| self.norm1.weight, | |
| self.norm1.bias, | |
| residual=hidden_states, | |
| eps=self.norm1.eps, | |
| dropout_p=self.dropout1.p if self.training else 0.0, | |
| rowscale=rowscale1, | |
| prenorm=False, | |
| is_rms_norm=isinstance(self.norm1, RMSNorm) | |
| ) | |
| if not isinstance(self.mlp, nn.Identity): | |
| mlp_out = self.mlp(hidden_states) | |
| if self.return_residual: # mlp out is actually a pair here | |
| mlp_out, hidden_states = mlp_out | |
| if not self.fused_dropout_add_ln: | |
| hidden_states = self.norm2( | |
| (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to( | |
| dtype=self.norm2.weight.dtype | |
| ) | |
| ) | |
| else: | |
| if self.drop_path2.p == 0 or not self.training: | |
| rowscale2 = None | |
| else: | |
| rowscale2 = self.drop_path2( | |
| torch.ones( | |
| mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype | |
| ) | |
| ) | |
| hidden_states = layer_norm_fn( | |
| mlp_out, | |
| self.norm2.weight, | |
| self.norm2.bias, | |
| residual=hidden_states, | |
| eps=self.norm2.eps, | |
| dropout_p=self.dropout2.p if self.training else 0.0, | |
| rowscale=rowscale2, | |
| prenorm=False, | |
| is_rms_norm=isinstance(self.norm2, RMSNorm) | |
| ) | |
| return hidden_states | |
| class ParallelBlock(nn.Module): | |
| """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX, | |
| and PaLM. | |
| """ | |
| def __init__( | |
| self, | |
| dim, | |
| mixer_cls=None, | |
| mlp_cls=None, | |
| norm_cls=nn.LayerNorm, | |
| dropout_cls=nn.Dropout, | |
| resid_dropout1=0.0, | |
| resid_dropout2=0.0, | |
| tied_norm=False, | |
| fused_dropout_add_ln=False, | |
| residual_in_fp32=False, | |
| sequence_parallel=False, | |
| mark_shared_params=False, | |
| ): | |
| """ | |
| This Block has a slightly different structure compared to a regular | |
| prenorm Transformer block. | |
| The standard block is: LN -> MHA / MLP -> Dropout -> Add. | |
| [Ref: https://arxiv.org/abs/2002.04745] | |
| Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both | |
| the hidden_states (output1 of the MHA / MLP) and the residual. | |
| This is for performance reasons, as we can fuse the dropout, add and LayerNorm. | |
| The residual needs to be provided (except for the very first block). | |
| """ | |
| super().__init__() | |
| self.tied_norm = tied_norm | |
| self.fused_dropout_add_ln = fused_dropout_add_ln | |
| self.residual_in_fp32 = residual_in_fp32 | |
| if mixer_cls is None: | |
| mixer_cls = partial(MHA, num_heads=dim // 64) | |
| if mlp_cls is None: | |
| mlp_cls = partial(Mlp, hidden_features=4 * dim) | |
| self.mixer = mixer_cls(dim) | |
| self.dropout1 = dropout_cls(resid_dropout1) | |
| self.norm1 = norm_cls(dim) | |
| self.mlp = mlp_cls(dim) | |
| self.dropout2 = dropout_cls(resid_dropout2) | |
| if not self.tied_norm: | |
| self.norm2 = norm_cls(dim) | |
| if self.fused_dropout_add_ln: | |
| assert layer_norm_fn is not None, "Triton is not installed" | |
| assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance( | |
| self.dropout1, nn.Dropout | |
| ) | |
| # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, | |
| # then the input to each worker in the tensor parallel group will be different. | |
| # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers. | |
| # For now this is not an issue because we always use sequence_parallel=True during training | |
| # and only use sequence_parallel=False during inference. | |
| # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads. | |
| if sequence_parallel: | |
| for p in self.norm1.parameters(): | |
| p._sequence_parallel = True | |
| if hasattr(self, "norm2"): | |
| for p in self.norm2.parameters(): | |
| p._sequence_parallel = True | |
| # Mark the norm parameters as "shared_params" so that we sync their values at init. | |
| if mark_shared_params: | |
| for p in self.norm1.parameters(): | |
| p._shared_params = True | |
| if hasattr(self, "norm2"): | |
| for p in self.norm2.parameters(): | |
| p._shared_params = True | |
| def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): | |
| return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) | |
| def forward( | |
| self, | |
| hidden_states1: Tensor, | |
| hidden_states2: Optional[Tensor] = None, | |
| residual: Optional[Tensor] = None, | |
| mixer_kwargs=None, | |
| ): | |
| r"""Pass the input through the encoder layer. | |
| Args: | |
| hidden_states1: the output of the previous attention (mixer) or embedding layer. | |
| hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1). | |
| residual. | |
| """ | |
| # TODO: Ideally we should only do the allgather / allreduce once for | |
| # the Linear to MLP & Attention | |
| if not self.fused_dropout_add_ln: | |
| dropped1 = self.dropout1(hidden_states1) | |
| # For the very 1st block, we only want 1 dropout, not two different dropouts | |
| if hidden_states2 is not None: | |
| dropped2 = self.dropout2(hidden_states2) | |
| residual = ( | |
| (residual + dropped1 + dropped2) | |
| if residual is not None | |
| else dropped1 + dropped2 | |
| ) | |
| else: | |
| residual = (residual + dropped1) if residual is not None else dropped1 | |
| hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) | |
| hidden_states2 = ( | |
| self.norm2(residual.to(dtype=self.norm2.weight.dtype)) | |
| if not self.tied_norm | |
| else hidden_states1 | |
| ) | |
| if self.residual_in_fp32: | |
| residual = residual.to(torch.float32) | |
| else: | |
| weight2, bias2 = ( | |
| (self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None) | |
| ) | |
| hidden_states1, *rest, residual = layer_norm_fn( | |
| hidden_states1, | |
| self.norm1.weight, | |
| self.norm1.bias, | |
| residual=residual, | |
| x1=hidden_states2, | |
| weight1=weight2, | |
| bias1=bias2, | |
| eps=self.norm1.eps, | |
| dropout_p=self.dropout1.p if self.training else 0.0, | |
| prenorm=True, | |
| residual_in_fp32=self.residual_in_fp32, | |
| is_rms_norm=isinstance(self.norm1, RMSNorm) | |
| ) | |
| if self.tied_norm: | |
| hidden_states2 = hidden_states1 | |
| else: | |
| hidden_states2, = rest | |
| if mixer_kwargs is None: | |
| mixer_kwargs = {} | |
| hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs) | |
| hidden_states2 = self.mlp(hidden_states2) | |
| return hidden_states1, hidden_states2, residual | |