| |
| |
|
|
| import torch |
| from torch.nn import init |
|
|
| from flash_attn.ops.layer_norm import ( |
| DropoutAddLayerNormFn, |
| DropoutAddLayerNormParallelResidualFn, |
| DropoutAddLayerNormSubsetFn, |
| ) |
|
|
|
|
| def rms_norm(x, weight, epsilon): |
| return DropoutAddLayerNormFn.apply( |
| x, None, weight, None, None, None, 0.0, epsilon, False, False, True |
| ) |
|
|
|
|
| def dropout_add_rms_norm( |
| x0, |
| residual, |
| weight, |
| bias, |
| dropout_p, |
| epsilon, |
| rowscale=None, |
| layerscale=None, |
| prenorm=False, |
| residual_in_fp32=False, |
| return_dropout_mask=False, |
| ): |
| """residual_in_fp32 only has an effect if residual is None. |
| Otherwise residual dtype is residual.dtype. |
| """ |
| return DropoutAddLayerNormFn.apply( |
| x0, |
| residual, |
| weight, |
| bias, |
| rowscale, |
| layerscale, |
| dropout_p, |
| epsilon, |
| residual_in_fp32, |
| prenorm, |
| True, |
| return_dropout_mask, |
| ) |
|
|
|
|
| def dropout_add_rms_norm_subset( |
| x0, |
| residual, |
| weight, |
| bias, |
| dropout_p, |
| epsilon, |
| layerscale=None, |
| x0_subset=None, |
| out_subset=None, |
| rowscale_const=1.0, |
| out_numrows=0, |
| prenorm=False, |
| residual_in_fp32=False, |
| return_dropout_mask=False, |
| ): |
| """residual_in_fp32 only has an effect if residual is None. |
| Otherwise residual dtype is residual.dtype. |
| """ |
| return DropoutAddLayerNormSubsetFn.apply( |
| x0, |
| residual, |
| weight, |
| bias, |
| layerscale, |
| x0_subset, |
| out_subset, |
| dropout_p, |
| epsilon, |
| rowscale_const, |
| out_numrows, |
| residual_in_fp32, |
| prenorm, |
| True, |
| return_dropout_mask, |
| ) |
|
|
|
|
| def dropout_add_rms_norm_parallel_residual( |
| x0, |
| x1, |
| residual, |
| weight0, |
| bias0, |
| weight1, |
| bias1, |
| dropout_p, |
| epsilon, |
| prenorm=False, |
| residual_in_fp32=False, |
| return_dropout_mask=False, |
| ): |
| """residual_in_fp32 only has an effect if residual is None. |
| Otherwise residual dtype is residual.dtype. |
| """ |
| return DropoutAddLayerNormParallelResidualFn.apply( |
| x0, |
| x1, |
| residual, |
| weight0, |
| bias0, |
| weight1, |
| bias1, |
| dropout_p, |
| epsilon, |
| residual_in_fp32, |
| prenorm, |
| True, |
| return_dropout_mask, |
| ) |
|
|
|
|
| class RMSNorm(torch.nn.Module): |
| def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): |
| factory_kwargs = {"device": device, "dtype": dtype} |
| super().__init__() |
| self.eps = eps |
| self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) |
| self.register_parameter("bias", None) |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| init.ones_(self.weight) |
|
|
| def forward(self, x): |
| return rms_norm(x, self.weight, self.eps) |
|
|
|
|
| class DropoutAddRMSNorm(torch.nn.Module): |
| def __init__( |
| self, |
| hidden_size, |
| prenorm=False, |
| p=0.0, |
| eps=1e-5, |
| residual_in_fp32=False, |
| device=None, |
| dtype=None, |
| ): |
| factory_kwargs = {"device": device, "dtype": dtype} |
| super().__init__() |
| self.prenorm = prenorm |
| self.p = p |
| self.eps = eps |
| self.residual_in_fp32 = residual_in_fp32 |
| self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) |
| self.register_parameter("bias", None) |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| init.ones_(self.weight) |
|
|
| def forward(self, x0, residual=None): |
| return dropout_add_rms_norm( |
| x0, |
| residual, |
| self.weight, |
| None, |
| self.p if self.training else 0.0, |
| self.eps, |
| prenorm=self.prenorm, |
| residual_in_fp32=self.residual_in_fp32, |
| ) |
|
|