Spaces:
Sleeping
Sleeping
| # Copyright (c) 2022, Tri Dao. | |
| # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py | |
| import dropout_layer_norm | |
| import torch | |
| from torch.nn import init | |
| def maybe_align(x, alignment_in_bytes=16): | |
| """Assume that x already has last dim divisible by alignment_in_bytes""" | |
| # TD [2023-07-04] I'm not 100% sure that clone will align the memory | |
| # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440 | |
| return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() | |
| def _dropout_add_layer_norm_forward( | |
| x0, | |
| residual, | |
| gamma, | |
| beta, | |
| rowscale, | |
| colscale, | |
| dropout_p, | |
| epsilon, | |
| residual_in_fp32=False, | |
| is_rms_norm=False, | |
| ): | |
| """Assume that arguments are contiguous and aligned to 16 bytes""" | |
| hidden_size = gamma.numel() | |
| x0mat = x0.view((-1, hidden_size)) | |
| residualmat = residual.view((-1, hidden_size)) if residual is not None else None | |
| rowscale = rowscale.view(-1) if rowscale is not None else None | |
| zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( | |
| x0mat, | |
| residualmat, | |
| gamma, | |
| beta, | |
| rowscale, | |
| colscale, | |
| None, | |
| None, | |
| dropout_p, | |
| epsilon, | |
| 1.0, | |
| 0, | |
| None, | |
| residual_in_fp32, | |
| is_rms_norm, | |
| ) | |
| # dmask is None if dropout_p == 0.0 | |
| # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype | |
| return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma | |
| def _dropout_add_layer_norm_backward( | |
| dz, | |
| dx, | |
| x, | |
| x0, | |
| dmask, | |
| mu, | |
| rsigma, | |
| gamma, | |
| rowscale, | |
| colscale, | |
| dropout_p, | |
| has_residual, | |
| is_rms_norm=False, | |
| ): | |
| """Assume that arguments are contiguous and aligned to 16 bytes | |
| dx == None means that it was a post-norm architecture | |
| (x = drop(x0) + residual was not returned in the fwd). | |
| x0 must not be None if we have colscale. | |
| """ | |
| hidden_size = gamma.numel() | |
| xmat = x.view((-1, hidden_size)) | |
| dzmat = dz.view(xmat.shape) | |
| dxmat = dx.view(xmat.shape) if dx is not None else None | |
| x0mat = x0.view((-1, hidden_size)) if x0 is not None else None | |
| rowscale = rowscale.view(-1) if rowscale is not None else None | |
| if colscale is not None: | |
| assert x0 is not None, "x0 is required to compute the gradient of colscale" | |
| dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( | |
| dzmat, | |
| dxmat, | |
| xmat, | |
| x0mat, | |
| dmask, | |
| mu, | |
| rsigma, | |
| gamma, | |
| rowscale, | |
| colscale, | |
| None, | |
| None, | |
| dropout_p, | |
| 1.0, | |
| 0, | |
| has_residual, | |
| is_rms_norm, | |
| ) | |
| # dresidualmat is None if not has_residual | |
| if colscale is None: | |
| return dx0mat, dresidualmat, dgamma, dbeta | |
| else: | |
| dcolscale = rest[0] | |
| return dx0mat, dresidualmat, dgamma, dbeta, dcolscale | |
| def _dropout_add_layer_norm_subset_forward( | |
| x0, | |
| residual, | |
| gamma, | |
| beta, | |
| colscale, | |
| x0_subset, | |
| out_subset, | |
| dropout_p, | |
| epsilon, | |
| rowscale_const, | |
| out_numrows, | |
| residual_in_fp32=False, | |
| is_rms_norm=False, | |
| ): | |
| """Assume that arguments are contiguous and aligned to 16 bytes""" | |
| hidden_size = gamma.numel() | |
| x0mat = x0.view((-1, hidden_size)) | |
| residualmat = residual.view((-1, hidden_size)) if residual is not None else None | |
| x0_subset = x0_subset.view(-1) if x0_subset is not None else None | |
| out_subset = out_subset.view(-1) if out_subset is not None else None | |
| zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( | |
| x0mat, | |
| residualmat, | |
| gamma, | |
| beta, | |
| None, | |
| colscale, | |
| x0_subset, | |
| out_subset, | |
| dropout_p, | |
| epsilon, | |
| rowscale_const, | |
| out_numrows, | |
| None, | |
| residual_in_fp32, | |
| is_rms_norm, | |
| ) | |
| # dmask is None if dropout_p == 0.0 | |
| # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype | |
| return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma | |
| def _dropout_add_layer_norm_subset_backward( | |
| dz, | |
| dx, | |
| x, | |
| x0, | |
| dmask, | |
| mu, | |
| rsigma, | |
| gamma, | |
| colscale, | |
| x0_subset, | |
| out_subset, | |
| dropout_p, | |
| rowscale_const, | |
| x0_numrows, | |
| has_residual, | |
| is_rms_norm=False, | |
| ): | |
| """Assume that arguments are contiguous and aligned to 16 bytes | |
| dx == None means that it was a post-norm architecture | |
| (x = drop(x0) + residual was not returned in the fwd). | |
| x0 must not be None if we have colscale. | |
| """ | |
| hidden_size = gamma.numel() | |
| xmat = x.view((-1, hidden_size)) | |
| dzmat = dz.view(-1, hidden_size) | |
| dxmat = dx.view(xmat.shape) if dx is not None else None | |
| x0mat = x0.view((-1, hidden_size)) if x0 is not None else None | |
| x0_subset = x0_subset.view(-1) if x0_subset is not None else None | |
| out_subset = out_subset.view(-1) if out_subset is not None else None | |
| if colscale is not None: | |
| assert x0 is not None, "x0 is required to compute the gradient of colscale" | |
| dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( | |
| dzmat, | |
| dxmat, | |
| xmat, | |
| x0mat, | |
| dmask, | |
| mu, | |
| rsigma, | |
| gamma, | |
| None, | |
| colscale, | |
| x0_subset, | |
| out_subset, | |
| dropout_p, | |
| rowscale_const, | |
| x0_numrows, | |
| has_residual, | |
| is_rms_norm, | |
| ) | |
| # dresidualmat is None if not has_residual | |
| if colscale is None: | |
| return dx0mat, dresidualmat, dgamma, dbeta | |
| else: | |
| dcolscale = rest[0] | |
| return dx0mat, dresidualmat, dgamma, dbeta, dcolscale | |
| def _dropout_add_layer_norm_parallel_residual_forward( | |
| x0, | |
| x1, | |
| residual, | |
| gamma0, | |
| beta0, | |
| gamma1, | |
| beta1, | |
| dropout_p, | |
| epsilon, | |
| residual_in_fp32=False, | |
| is_rms_norm=False, | |
| ): | |
| """Assume that arguments are contiguous and aligned to 16 bytes""" | |
| hidden_size = gamma0.numel() | |
| x0mat = x0.view((-1, hidden_size)) | |
| x1mat = x1.view((-1, hidden_size)) if x1 is not None else None | |
| residualmat = residual.view((-1, hidden_size)) if residual is not None else None | |
| ( | |
| z0mat, | |
| z1mat, | |
| xmat, | |
| dmask0, | |
| dmask1, | |
| mu, | |
| rsigma, | |
| ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd( | |
| x0mat, | |
| x1mat, | |
| residualmat, | |
| gamma0, | |
| beta0, | |
| gamma1, | |
| beta1, | |
| dropout_p, | |
| epsilon, | |
| None, | |
| residual_in_fp32, | |
| is_rms_norm, | |
| ) | |
| # dmask0 and dmask1 are None if dropout_p == 0.0 | |
| # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype | |
| return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma | |
| def _dropout_add_layer_norm_parallel_residual_backward( | |
| dz0, | |
| dz1, | |
| dx, | |
| x, | |
| dmask0, | |
| dmask1, | |
| mu, | |
| rsigma, | |
| gamma0, | |
| gamma1, | |
| dropout_p, | |
| has_x1, | |
| has_residual, | |
| is_rms_norm=False, | |
| ): | |
| """Assume that arguments are contiguous and aligned to 16 bytes | |
| dx == None means that it was a post-norm architecture | |
| (x = drop(x0) + residual was not returned in the fwd). | |
| """ | |
| hidden_size = gamma0.numel() | |
| xmat = x.view((-1, hidden_size)) | |
| dz0mat = dz0.view(xmat.shape) | |
| dz1mat = dz1.view(xmat.shape) if dz1 is not None else None | |
| dxmat = dx.view(xmat.shape) if dx is not None else None | |
| ( | |
| dx0mat, | |
| dx1mat, | |
| dresidualmat, | |
| dgamma0, | |
| dbeta0, | |
| dgamma1, | |
| dbeta1, | |
| *rest, | |
| ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd( | |
| dz0mat, | |
| dz1mat, | |
| dxmat, | |
| xmat, | |
| dmask0, | |
| dmask1, | |
| mu, | |
| rsigma, | |
| gamma0, | |
| gamma1, | |
| dropout_p, | |
| has_x1, | |
| has_residual, | |
| is_rms_norm, | |
| ) | |
| # dresidualmat is None if not has_residual | |
| return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 | |
| class DropoutAddLayerNormFn(torch.autograd.Function): | |
| def forward( | |
| ctx, | |
| x0, | |
| residual, | |
| gamma, | |
| beta, | |
| rowscale, | |
| colscale, | |
| dropout_p, | |
| epsilon, | |
| residual_in_fp32=False, | |
| prenorm=False, | |
| is_rms_norm=False, | |
| return_dmask=False, | |
| ): | |
| x0 = maybe_align(x0.contiguous(), 16) | |
| residual = maybe_align(residual.contiguous(), 16) if residual is not None else None | |
| gamma = maybe_align(gamma.contiguous(), 16) | |
| beta = maybe_align(beta.contiguous(), 16) if beta is not None else None | |
| rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None | |
| colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None | |
| zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( | |
| x0, | |
| residual, | |
| gamma, | |
| beta, | |
| rowscale, | |
| colscale, | |
| dropout_p, | |
| epsilon, | |
| residual_in_fp32, | |
| is_rms_norm, | |
| ) | |
| # Only need to save x0 if we need to compute gradient wrt colscale | |
| x0_saved = x0 if colscale is not None else None | |
| ctx.save_for_backward( | |
| xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale | |
| ) | |
| ctx.prenorm = prenorm | |
| ctx.dropout_p = dropout_p | |
| ctx.has_residual = residual is not None | |
| ctx.is_rms_norm = is_rms_norm | |
| ctx.has_beta = beta is not None | |
| if not return_dmask: | |
| return ( | |
| zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape)) | |
| ) | |
| else: | |
| dmask = ( | |
| dmask.view(x0.shape) | |
| if dropout_p > 0.0 | |
| else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) | |
| ) | |
| ctx.mark_non_differentiable(dmask) | |
| return ( | |
| (zmat.view(x0.shape), dmask) | |
| if not prenorm | |
| else (zmat.view(x0.shape), xmat.view(x0.shape), dmask) | |
| ) | |
| def backward(ctx, dz, *args): | |
| # assert dz.is_contiguous() | |
| dz = maybe_align(dz.contiguous(), 16) # this happens! | |
| dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None | |
| x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors | |
| # x0 is None if colscale is None | |
| dropout_p = ctx.dropout_p | |
| has_residual = ctx.has_residual | |
| dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( | |
| dz, | |
| dx, | |
| x, | |
| x0, | |
| dmask, | |
| mu, | |
| rsigma, | |
| gamma, | |
| rowscale, | |
| colscale, | |
| dropout_p, | |
| has_residual, | |
| ctx.is_rms_norm, | |
| ) | |
| dx0 = dx0mat.view(x.shape) | |
| dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None | |
| dcolscale = rest[0] if colscale is not None else None | |
| return ( | |
| dx0, | |
| dresidual, | |
| dgamma, | |
| dbeta if ctx.has_beta else None, | |
| None, | |
| dcolscale, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| ) | |
| class DropoutAddLayerNormSubsetFn(torch.autograd.Function): | |
| def forward( | |
| ctx, | |
| x0, | |
| residual, | |
| gamma, | |
| beta, | |
| colscale, | |
| x0_subset, | |
| out_subset, | |
| dropout_p, | |
| epsilon, | |
| rowscale_const, | |
| out_numrows, | |
| residual_in_fp32=False, | |
| prenorm=False, | |
| is_rms_norm=False, | |
| return_dmask=False, | |
| ): | |
| x0 = maybe_align(x0.contiguous(), 16) | |
| residual = maybe_align(residual.contiguous(), 16) if residual is not None else None | |
| gamma = maybe_align(gamma.contiguous(), 16) | |
| beta = maybe_align(beta.contiguous(), 16) if beta is not None else None | |
| colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None | |
| zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( | |
| x0, | |
| residual, | |
| gamma, | |
| beta, | |
| colscale, | |
| x0_subset, | |
| out_subset, | |
| dropout_p, | |
| epsilon, | |
| rowscale_const, | |
| out_numrows, | |
| residual_in_fp32, | |
| is_rms_norm, | |
| ) | |
| # Only need to save x0 if we need to compute gradient wrt colscale | |
| x0_saved = x0 if colscale is not None else None | |
| x_shape = (-1, *x0.shape[1:]) | |
| ctx.save_for_backward( | |
| xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset | |
| ) | |
| ctx.prenorm = prenorm | |
| ctx.dropout_p = dropout_p | |
| ctx.rowscale_const = rowscale_const | |
| ctx.x0_numrows = x0.shape[:-1].numel() | |
| ctx.has_residual = residual is not None | |
| ctx.is_rms_norm = is_rms_norm | |
| ctx.has_beta = beta is not None | |
| z_shape = (-1, *x0.shape[1:]) | |
| if not return_dmask: | |
| return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape)) | |
| else: | |
| z = zmat.view(z_shape) | |
| dmask = ( | |
| dmask.view(x0.shape) | |
| if dropout_p > 0.0 | |
| else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) | |
| ) | |
| ctx.mark_non_differentiable(dmask) | |
| return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask) | |
| def backward(ctx, dz, *args): | |
| # assert dz.is_contiguous() | |
| dz = maybe_align(dz.contiguous(), 16) # this happens! | |
| dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None | |
| x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors | |
| # x0 is None if colscale is None | |
| dropout_p = ctx.dropout_p | |
| has_residual = ctx.has_residual | |
| dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( | |
| dz, | |
| dx, | |
| x, | |
| x0, | |
| dmask, | |
| mu, | |
| rsigma, | |
| gamma, | |
| colscale, | |
| x0_subset, | |
| out_subset, | |
| dropout_p, | |
| ctx.rowscale_const, | |
| ctx.x0_numrows, | |
| has_residual, | |
| ctx.is_rms_norm, | |
| ) | |
| dx0 = dx0mat.view(-1, *x.shape[1:]) | |
| dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None | |
| dcolscale = rest[0] if colscale is not None else None | |
| return ( | |
| dx0, | |
| dresidual, | |
| dgamma, | |
| dbeta if ctx.has_beta else None, | |
| dcolscale, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| ) | |
| class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): | |
| def forward( | |
| ctx, | |
| x0, | |
| x1, | |
| residual, | |
| gamma0, | |
| beta0, | |
| gamma1, | |
| beta1, | |
| dropout_p, | |
| epsilon, | |
| residual_in_fp32=False, | |
| prenorm=False, | |
| is_rms_norm=False, | |
| return_dmask=False, | |
| ): | |
| x0 = maybe_align(x0.contiguous(), 16) | |
| x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None | |
| residual = maybe_align(residual.contiguous(), 16) if residual is not None else None | |
| gamma0 = maybe_align(gamma0.contiguous(), 16) | |
| beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None | |
| gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None | |
| beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None | |
| ( | |
| z0mat, | |
| z1mat, | |
| xmat, | |
| dmask0, | |
| dmask1, | |
| mu, | |
| rsigma, | |
| ) = _dropout_add_layer_norm_parallel_residual_forward( | |
| x0, | |
| x1, | |
| residual, | |
| gamma0, | |
| beta0, | |
| gamma1, | |
| beta1, | |
| dropout_p, | |
| epsilon, | |
| residual_in_fp32, | |
| is_rms_norm, | |
| ) | |
| ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma) | |
| ctx.prenorm = prenorm | |
| ctx.dropout_p = dropout_p | |
| ctx.has_x1 = x1 is not None | |
| ctx.has_residual = residual is not None | |
| ctx.is_rms_norm = is_rms_norm | |
| ctx.has_beta = beta0 is not None | |
| z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None) | |
| if not return_dmask: | |
| return z if not prenorm else (*z, xmat.view(x0.shape)) | |
| else: | |
| dmask0 = ( | |
| dmask0.view(x0.shape) | |
| if dropout_p > 0.0 | |
| else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) | |
| ) | |
| dmask1 = ( | |
| dmask1.view(x0.shape) | |
| if dropout_p > 0.0 and x1 is not None | |
| else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) | |
| ) | |
| ctx.mark_non_differentiable(dmask0) | |
| ctx.mark_non_differentiable(dmask1) | |
| return ( | |
| (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1) | |
| ) | |
| def backward(ctx, dz0, dz1, *args): | |
| dz0 = maybe_align(dz0.contiguous(), 16) # this happens! | |
| dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None | |
| dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None | |
| x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors | |
| dropout_p = ctx.dropout_p | |
| has_x1 = ctx.has_x1 | |
| has_residual = ctx.has_residual | |
| ( | |
| dx0mat, | |
| dx1mat, | |
| dresidualmat, | |
| dgamma0, | |
| dbeta0, | |
| dgamma1, | |
| dbeta1, | |
| ) = _dropout_add_layer_norm_parallel_residual_backward( | |
| dz0, | |
| dz1, | |
| dx, | |
| x, | |
| dmask0, | |
| dmask1, | |
| mu, | |
| rsigma, | |
| gamma0, | |
| gamma1, | |
| dropout_p, | |
| has_x1, | |
| has_residual, | |
| ctx.is_rms_norm, | |
| ) | |
| dx0 = dx0mat.view(x.shape) | |
| dx1 = dx1mat.view(x.shape) if dx1mat is not None else None | |
| dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None | |
| return ( | |
| dx0, | |
| dx1, | |
| dresidual, | |
| dgamma0, | |
| dbeta0 if ctx.has_beta else None, | |
| dgamma1, | |
| dbeta1 if ctx.has_beta else None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| ) | |
| def layer_norm(x, weight, bias, epsilon): | |
| return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) | |
| def dropout_add_layer_norm( | |
| x0, | |
| residual, | |
| weight, | |
| bias, | |
| dropout_p, | |
| epsilon, | |
| rowscale=None, | |
| layerscale=None, | |
| prenorm=False, | |
| residual_in_fp32=False, | |
| return_dropout_mask=False, | |
| ): | |
| """residual_in_fp32 only has an effect if residual is None. | |
| Otherwise residual dtype is residual.dtype. | |
| """ | |
| return DropoutAddLayerNormFn.apply( | |
| x0, | |
| residual, | |
| weight, | |
| bias, | |
| rowscale, | |
| layerscale, | |
| dropout_p, | |
| epsilon, | |
| residual_in_fp32, | |
| prenorm, | |
| False, | |
| return_dropout_mask, | |
| ) | |
| def dropout_add_layer_norm_subset( | |
| x0, | |
| residual, | |
| weight, | |
| bias, | |
| dropout_p, | |
| epsilon, | |
| layerscale=None, | |
| x0_subset=None, | |
| out_subset=None, | |
| rowscale_const=1.0, | |
| out_numrows=0, | |
| prenorm=False, | |
| residual_in_fp32=False, | |
| return_dropout_mask=False, | |
| ): | |
| """residual_in_fp32 only has an effect if residual is None. | |
| Otherwise residual dtype is residual.dtype. | |
| """ | |
| return DropoutAddLayerNormSubsetFn.apply( | |
| x0, | |
| residual, | |
| weight, | |
| bias, | |
| layerscale, | |
| x0_subset, | |
| out_subset, | |
| dropout_p, | |
| epsilon, | |
| rowscale_const, | |
| out_numrows, | |
| residual_in_fp32, | |
| prenorm, | |
| False, | |
| return_dropout_mask, | |
| ) | |
| def dropout_add_layer_norm_parallel_residual( | |
| x0, | |
| x1, | |
| residual, | |
| weight0, | |
| bias0, | |
| weight1, | |
| bias1, | |
| dropout_p, | |
| epsilon, | |
| prenorm=False, | |
| residual_in_fp32=False, | |
| return_dropout_mask=False, | |
| ): | |
| """residual_in_fp32 only has an effect if residual is None. | |
| Otherwise residual dtype is residual.dtype. | |
| """ | |
| return DropoutAddLayerNormParallelResidualFn.apply( | |
| x0, | |
| x1, | |
| residual, | |
| weight0, | |
| bias0, | |
| weight1, | |
| bias1, | |
| dropout_p, | |
| epsilon, | |
| residual_in_fp32, | |
| prenorm, | |
| False, | |
| return_dropout_mask, | |
| ) | |
| class DropoutAddLayerNorm(torch.nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size, | |
| prenorm=False, | |
| p=0.0, | |
| eps=1e-5, | |
| residual_in_fp32=False, | |
| device=None, | |
| dtype=None, | |
| ): | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| super().__init__() | |
| self.prenorm = prenorm | |
| self.p = p | |
| self.eps = eps | |
| self.residual_in_fp32 = residual_in_fp32 | |
| self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) | |
| self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| init.ones_(self.weight) | |
| init.zeros_(self.bias) | |
| def forward(self, x0, residual=None): | |
| return dropout_add_layer_norm( | |
| x0, | |
| residual, | |
| self.weight, | |
| self.bias, | |
| self.p if self.training else 0.0, | |
| self.eps, | |
| prenorm=self.prenorm, | |
| residual_in_fp32=self.residual_in_fp32, | |
| ) | |