Spaces:
Sleeping
Sleeping
| # Copyright (c) 2022, Tri Dao. | |
| # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py | |
| import torch | |
| from torch.nn import init | |
| from flash_attn.ops.layer_norm import ( | |
| DropoutAddLayerNormFn, | |
| DropoutAddLayerNormParallelResidualFn, | |
| DropoutAddLayerNormSubsetFn, | |
| ) | |
| def rms_norm(x, weight, epsilon): | |
| return DropoutAddLayerNormFn.apply( | |
| x, None, weight, None, None, None, 0.0, epsilon, False, False, True | |
| ) | |
| def dropout_add_rms_norm( | |
| x0, | |
| residual, | |
| weight, | |
| bias, | |
| dropout_p, | |
| epsilon, | |
| rowscale=None, | |
| layerscale=None, | |
| prenorm=False, | |
| residual_in_fp32=False, | |
| return_dropout_mask=False, | |
| ): | |
| """residual_in_fp32 only has an effect if residual is None. | |
| Otherwise residual dtype is residual.dtype. | |
| """ | |
| return DropoutAddLayerNormFn.apply( | |
| x0, | |
| residual, | |
| weight, | |
| bias, | |
| rowscale, | |
| layerscale, | |
| dropout_p, | |
| epsilon, | |
| residual_in_fp32, | |
| prenorm, | |
| True, | |
| return_dropout_mask, | |
| ) | |
| def dropout_add_rms_norm_subset( | |
| x0, | |
| residual, | |
| weight, | |
| bias, | |
| dropout_p, | |
| epsilon, | |
| layerscale=None, | |
| x0_subset=None, | |
| out_subset=None, | |
| rowscale_const=1.0, | |
| out_numrows=0, | |
| prenorm=False, | |
| residual_in_fp32=False, | |
| return_dropout_mask=False, | |
| ): | |
| """residual_in_fp32 only has an effect if residual is None. | |
| Otherwise residual dtype is residual.dtype. | |
| """ | |
| return DropoutAddLayerNormSubsetFn.apply( | |
| x0, | |
| residual, | |
| weight, | |
| bias, | |
| layerscale, | |
| x0_subset, | |
| out_subset, | |
| dropout_p, | |
| epsilon, | |
| rowscale_const, | |
| out_numrows, | |
| residual_in_fp32, | |
| prenorm, | |
| True, | |
| return_dropout_mask, | |
| ) | |
| def dropout_add_rms_norm_parallel_residual( | |
| x0, | |
| x1, | |
| residual, | |
| weight0, | |
| bias0, | |
| weight1, | |
| bias1, | |
| dropout_p, | |
| epsilon, | |
| prenorm=False, | |
| residual_in_fp32=False, | |
| return_dropout_mask=False, | |
| ): | |
| """residual_in_fp32 only has an effect if residual is None. | |
| Otherwise residual dtype is residual.dtype. | |
| """ | |
| return DropoutAddLayerNormParallelResidualFn.apply( | |
| x0, | |
| x1, | |
| residual, | |
| weight0, | |
| bias0, | |
| weight1, | |
| bias1, | |
| dropout_p, | |
| epsilon, | |
| residual_in_fp32, | |
| prenorm, | |
| True, | |
| return_dropout_mask, | |
| ) | |
| class RMSNorm(torch.nn.Module): | |
| def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) | |
| self.register_parameter("bias", None) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| init.ones_(self.weight) | |
| def forward(self, x): | |
| return rms_norm(x, self.weight, self.eps) | |
| class DropoutAddRMSNorm(torch.nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size, | |
| prenorm=False, | |
| p=0.0, | |
| eps=1e-5, | |
| residual_in_fp32=False, | |
| device=None, | |
| dtype=None, | |
| ): | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| super().__init__() | |
| self.prenorm = prenorm | |
| self.p = p | |
| self.eps = eps | |
| self.residual_in_fp32 = residual_in_fp32 | |
| self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) | |
| self.register_parameter("bias", None) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| init.ones_(self.weight) | |
| def forward(self, x0, residual=None): | |
| return dropout_add_rms_norm( | |
| x0, | |
| residual, | |
| self.weight, | |
| None, | |
| self.p if self.training else 0.0, | |
| self.eps, | |
| prenorm=self.prenorm, | |
| residual_in_fp32=self.residual_in_fp32, | |
| ) | |