| | |
| | |
| |
|
| | import dropout_layer_norm |
| | import torch |
| | from torch.nn import init |
| |
|
| |
|
| | def maybe_align(x, alignment_in_bytes=16): |
| | """Assume that x already has last dim divisible by alignment_in_bytes""" |
| | |
| | |
| | return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() |
| |
|
| |
|
| | def _dropout_add_layer_norm_forward( |
| | x0, |
| | residual, |
| | gamma, |
| | beta, |
| | rowscale, |
| | colscale, |
| | dropout_p, |
| | epsilon, |
| | residual_in_fp32=False, |
| | is_rms_norm=False, |
| | ): |
| | """Assume that arguments are contiguous and aligned to 16 bytes""" |
| | hidden_size = gamma.numel() |
| | x0mat = x0.view((-1, hidden_size)) |
| | residualmat = residual.view((-1, hidden_size)) if residual is not None else None |
| | rowscale = rowscale.view(-1) if rowscale is not None else None |
| | zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( |
| | x0mat, |
| | residualmat, |
| | gamma, |
| | beta, |
| | rowscale, |
| | colscale, |
| | None, |
| | None, |
| | dropout_p, |
| | epsilon, |
| | 1.0, |
| | 0, |
| | None, |
| | residual_in_fp32, |
| | is_rms_norm, |
| | ) |
| | |
| | |
| | return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma |
| |
|
| |
|
| | def _dropout_add_layer_norm_backward( |
| | dz, |
| | dx, |
| | x, |
| | x0, |
| | dmask, |
| | mu, |
| | rsigma, |
| | gamma, |
| | rowscale, |
| | colscale, |
| | dropout_p, |
| | has_residual, |
| | is_rms_norm=False, |
| | ): |
| | """Assume that arguments are contiguous and aligned to 16 bytes |
| | dx == None means that it was a post-norm architecture |
| | (x = drop(x0) + residual was not returned in the fwd). |
| | x0 must not be None if we have colscale. |
| | """ |
| | hidden_size = gamma.numel() |
| | xmat = x.view((-1, hidden_size)) |
| | dzmat = dz.view(xmat.shape) |
| | dxmat = dx.view(xmat.shape) if dx is not None else None |
| | x0mat = x0.view((-1, hidden_size)) if x0 is not None else None |
| | rowscale = rowscale.view(-1) if rowscale is not None else None |
| | if colscale is not None: |
| | assert x0 is not None, "x0 is required to compute the gradient of colscale" |
| | dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( |
| | dzmat, |
| | dxmat, |
| | xmat, |
| | x0mat, |
| | dmask, |
| | mu, |
| | rsigma, |
| | gamma, |
| | rowscale, |
| | colscale, |
| | None, |
| | None, |
| | dropout_p, |
| | 1.0, |
| | 0, |
| | has_residual, |
| | is_rms_norm, |
| | ) |
| | |
| | if colscale is None: |
| | return dx0mat, dresidualmat, dgamma, dbeta |
| | else: |
| | dcolscale = rest[0] |
| | return dx0mat, dresidualmat, dgamma, dbeta, dcolscale |
| |
|
| |
|
| | def _dropout_add_layer_norm_subset_forward( |
| | x0, |
| | residual, |
| | gamma, |
| | beta, |
| | colscale, |
| | x0_subset, |
| | out_subset, |
| | dropout_p, |
| | epsilon, |
| | rowscale_const, |
| | out_numrows, |
| | residual_in_fp32=False, |
| | is_rms_norm=False, |
| | ): |
| | """Assume that arguments are contiguous and aligned to 16 bytes""" |
| | hidden_size = gamma.numel() |
| | x0mat = x0.view((-1, hidden_size)) |
| | residualmat = residual.view((-1, hidden_size)) if residual is not None else None |
| | x0_subset = x0_subset.view(-1) if x0_subset is not None else None |
| | out_subset = out_subset.view(-1) if out_subset is not None else None |
| | zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( |
| | x0mat, |
| | residualmat, |
| | gamma, |
| | beta, |
| | None, |
| | colscale, |
| | x0_subset, |
| | out_subset, |
| | dropout_p, |
| | epsilon, |
| | rowscale_const, |
| | out_numrows, |
| | None, |
| | residual_in_fp32, |
| | is_rms_norm, |
| | ) |
| | |
| | |
| | return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma |
| |
|
| |
|
| | def _dropout_add_layer_norm_subset_backward( |
| | dz, |
| | dx, |
| | x, |
| | x0, |
| | dmask, |
| | mu, |
| | rsigma, |
| | gamma, |
| | colscale, |
| | x0_subset, |
| | out_subset, |
| | dropout_p, |
| | rowscale_const, |
| | x0_numrows, |
| | has_residual, |
| | is_rms_norm=False, |
| | ): |
| | """Assume that arguments are contiguous and aligned to 16 bytes |
| | dx == None means that it was a post-norm architecture |
| | (x = drop(x0) + residual was not returned in the fwd). |
| | x0 must not be None if we have colscale. |
| | """ |
| | hidden_size = gamma.numel() |
| | xmat = x.view((-1, hidden_size)) |
| | dzmat = dz.view(-1, hidden_size) |
| | dxmat = dx.view(xmat.shape) if dx is not None else None |
| | x0mat = x0.view((-1, hidden_size)) if x0 is not None else None |
| | x0_subset = x0_subset.view(-1) if x0_subset is not None else None |
| | out_subset = out_subset.view(-1) if out_subset is not None else None |
| | if colscale is not None: |
| | assert x0 is not None, "x0 is required to compute the gradient of colscale" |
| | dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( |
| | dzmat, |
| | dxmat, |
| | xmat, |
| | x0mat, |
| | dmask, |
| | mu, |
| | rsigma, |
| | gamma, |
| | None, |
| | colscale, |
| | x0_subset, |
| | out_subset, |
| | dropout_p, |
| | rowscale_const, |
| | x0_numrows, |
| | has_residual, |
| | is_rms_norm, |
| | ) |
| | |
| | if colscale is None: |
| | return dx0mat, dresidualmat, dgamma, dbeta |
| | else: |
| | dcolscale = rest[0] |
| | return dx0mat, dresidualmat, dgamma, dbeta, dcolscale |
| |
|
| |
|
| | def _dropout_add_layer_norm_parallel_residual_forward( |
| | x0, |
| | x1, |
| | residual, |
| | gamma0, |
| | beta0, |
| | gamma1, |
| | beta1, |
| | dropout_p, |
| | epsilon, |
| | residual_in_fp32=False, |
| | is_rms_norm=False, |
| | ): |
| | """Assume that arguments are contiguous and aligned to 16 bytes""" |
| | hidden_size = gamma0.numel() |
| | x0mat = x0.view((-1, hidden_size)) |
| | x1mat = x1.view((-1, hidden_size)) if x1 is not None else None |
| | residualmat = residual.view((-1, hidden_size)) if residual is not None else None |
| | ( |
| | z0mat, |
| | z1mat, |
| | xmat, |
| | dmask0, |
| | dmask1, |
| | mu, |
| | rsigma, |
| | ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd( |
| | x0mat, |
| | x1mat, |
| | residualmat, |
| | gamma0, |
| | beta0, |
| | gamma1, |
| | beta1, |
| | dropout_p, |
| | epsilon, |
| | None, |
| | residual_in_fp32, |
| | is_rms_norm, |
| | ) |
| | |
| | |
| | return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma |
| |
|
| |
|
| | def _dropout_add_layer_norm_parallel_residual_backward( |
| | dz0, |
| | dz1, |
| | dx, |
| | x, |
| | dmask0, |
| | dmask1, |
| | mu, |
| | rsigma, |
| | gamma0, |
| | gamma1, |
| | dropout_p, |
| | has_x1, |
| | has_residual, |
| | is_rms_norm=False, |
| | ): |
| | """Assume that arguments are contiguous and aligned to 16 bytes |
| | dx == None means that it was a post-norm architecture |
| | (x = drop(x0) + residual was not returned in the fwd). |
| | """ |
| | hidden_size = gamma0.numel() |
| | xmat = x.view((-1, hidden_size)) |
| | dz0mat = dz0.view(xmat.shape) |
| | dz1mat = dz1.view(xmat.shape) if dz1 is not None else None |
| | dxmat = dx.view(xmat.shape) if dx is not None else None |
| | ( |
| | dx0mat, |
| | dx1mat, |
| | dresidualmat, |
| | dgamma0, |
| | dbeta0, |
| | dgamma1, |
| | dbeta1, |
| | *rest, |
| | ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd( |
| | dz0mat, |
| | dz1mat, |
| | dxmat, |
| | xmat, |
| | dmask0, |
| | dmask1, |
| | mu, |
| | rsigma, |
| | gamma0, |
| | gamma1, |
| | dropout_p, |
| | has_x1, |
| | has_residual, |
| | is_rms_norm, |
| | ) |
| | |
| | return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 |
| |
|
| |
|
| | class DropoutAddLayerNormFn(torch.autograd.Function): |
| | @staticmethod |
| | def forward( |
| | ctx, |
| | x0, |
| | residual, |
| | gamma, |
| | beta, |
| | rowscale, |
| | colscale, |
| | dropout_p, |
| | epsilon, |
| | residual_in_fp32=False, |
| | prenorm=False, |
| | is_rms_norm=False, |
| | return_dmask=False, |
| | ): |
| | x0 = maybe_align(x0.contiguous(), 16) |
| | residual = maybe_align(residual.contiguous(), 16) if residual is not None else None |
| | gamma = maybe_align(gamma.contiguous(), 16) |
| | beta = maybe_align(beta.contiguous(), 16) if beta is not None else None |
| | rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None |
| | colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None |
| | zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( |
| | x0, |
| | residual, |
| | gamma, |
| | beta, |
| | rowscale, |
| | colscale, |
| | dropout_p, |
| | epsilon, |
| | residual_in_fp32, |
| | is_rms_norm, |
| | ) |
| | |
| | x0_saved = x0 if colscale is not None else None |
| | ctx.save_for_backward( |
| | xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale |
| | ) |
| | ctx.prenorm = prenorm |
| | ctx.dropout_p = dropout_p |
| | ctx.has_residual = residual is not None |
| | ctx.is_rms_norm = is_rms_norm |
| | ctx.has_beta = beta is not None |
| | if not return_dmask: |
| | return ( |
| | zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape)) |
| | ) |
| | else: |
| | dmask = ( |
| | dmask.view(x0.shape) |
| | if dropout_p > 0.0 |
| | else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) |
| | ) |
| | ctx.mark_non_differentiable(dmask) |
| | return ( |
| | (zmat.view(x0.shape), dmask) |
| | if not prenorm |
| | else (zmat.view(x0.shape), xmat.view(x0.shape), dmask) |
| | ) |
| |
|
| | @staticmethod |
| | def backward(ctx, dz, *args): |
| | |
| | dz = maybe_align(dz.contiguous(), 16) |
| | dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None |
| | x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors |
| | |
| | dropout_p = ctx.dropout_p |
| | has_residual = ctx.has_residual |
| | dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( |
| | dz, |
| | dx, |
| | x, |
| | x0, |
| | dmask, |
| | mu, |
| | rsigma, |
| | gamma, |
| | rowscale, |
| | colscale, |
| | dropout_p, |
| | has_residual, |
| | ctx.is_rms_norm, |
| | ) |
| | dx0 = dx0mat.view(x.shape) |
| | dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None |
| | dcolscale = rest[0] if colscale is not None else None |
| | return ( |
| | dx0, |
| | dresidual, |
| | dgamma, |
| | dbeta if ctx.has_beta else None, |
| | None, |
| | dcolscale, |
| | None, |
| | None, |
| | None, |
| | None, |
| | None, |
| | None, |
| | ) |
| |
|
| |
|
| | class DropoutAddLayerNormSubsetFn(torch.autograd.Function): |
| | @staticmethod |
| | def forward( |
| | ctx, |
| | x0, |
| | residual, |
| | gamma, |
| | beta, |
| | colscale, |
| | x0_subset, |
| | out_subset, |
| | dropout_p, |
| | epsilon, |
| | rowscale_const, |
| | out_numrows, |
| | residual_in_fp32=False, |
| | prenorm=False, |
| | is_rms_norm=False, |
| | return_dmask=False, |
| | ): |
| | x0 = maybe_align(x0.contiguous(), 16) |
| | residual = maybe_align(residual.contiguous(), 16) if residual is not None else None |
| | gamma = maybe_align(gamma.contiguous(), 16) |
| | beta = maybe_align(beta.contiguous(), 16) if beta is not None else None |
| | colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None |
| | zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( |
| | x0, |
| | residual, |
| | gamma, |
| | beta, |
| | colscale, |
| | x0_subset, |
| | out_subset, |
| | dropout_p, |
| | epsilon, |
| | rowscale_const, |
| | out_numrows, |
| | residual_in_fp32, |
| | is_rms_norm, |
| | ) |
| | |
| | x0_saved = x0 if colscale is not None else None |
| | x_shape = (-1, *x0.shape[1:]) |
| | ctx.save_for_backward( |
| | xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset |
| | ) |
| | ctx.prenorm = prenorm |
| | ctx.dropout_p = dropout_p |
| | ctx.rowscale_const = rowscale_const |
| | ctx.x0_numrows = x0.shape[:-1].numel() |
| | ctx.has_residual = residual is not None |
| | ctx.is_rms_norm = is_rms_norm |
| | ctx.has_beta = beta is not None |
| | z_shape = (-1, *x0.shape[1:]) |
| | if not return_dmask: |
| | return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape)) |
| | else: |
| | z = zmat.view(z_shape) |
| | dmask = ( |
| | dmask.view(x0.shape) |
| | if dropout_p > 0.0 |
| | else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) |
| | ) |
| | ctx.mark_non_differentiable(dmask) |
| | return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask) |
| |
|
| | @staticmethod |
| | def backward(ctx, dz, *args): |
| | |
| | dz = maybe_align(dz.contiguous(), 16) |
| | dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None |
| | x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors |
| | |
| | dropout_p = ctx.dropout_p |
| | has_residual = ctx.has_residual |
| | dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( |
| | dz, |
| | dx, |
| | x, |
| | x0, |
| | dmask, |
| | mu, |
| | rsigma, |
| | gamma, |
| | colscale, |
| | x0_subset, |
| | out_subset, |
| | dropout_p, |
| | ctx.rowscale_const, |
| | ctx.x0_numrows, |
| | has_residual, |
| | ctx.is_rms_norm, |
| | ) |
| | dx0 = dx0mat.view(-1, *x.shape[1:]) |
| | dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None |
| | dcolscale = rest[0] if colscale is not None else None |
| | return ( |
| | dx0, |
| | dresidual, |
| | dgamma, |
| | dbeta if ctx.has_beta else None, |
| | dcolscale, |
| | None, |
| | None, |
| | None, |
| | None, |
| | None, |
| | None, |
| | None, |
| | None, |
| | None, |
| | None, |
| | ) |
| |
|
| |
|
| | class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): |
| | @staticmethod |
| | def forward( |
| | ctx, |
| | x0, |
| | x1, |
| | residual, |
| | gamma0, |
| | beta0, |
| | gamma1, |
| | beta1, |
| | dropout_p, |
| | epsilon, |
| | residual_in_fp32=False, |
| | prenorm=False, |
| | is_rms_norm=False, |
| | return_dmask=False, |
| | ): |
| | x0 = maybe_align(x0.contiguous(), 16) |
| | x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None |
| | residual = maybe_align(residual.contiguous(), 16) if residual is not None else None |
| | gamma0 = maybe_align(gamma0.contiguous(), 16) |
| | beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None |
| | gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None |
| | beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None |
| | ( |
| | z0mat, |
| | z1mat, |
| | xmat, |
| | dmask0, |
| | dmask1, |
| | mu, |
| | rsigma, |
| | ) = _dropout_add_layer_norm_parallel_residual_forward( |
| | x0, |
| | x1, |
| | residual, |
| | gamma0, |
| | beta0, |
| | gamma1, |
| | beta1, |
| | dropout_p, |
| | epsilon, |
| | residual_in_fp32, |
| | is_rms_norm, |
| | ) |
| | ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma) |
| | ctx.prenorm = prenorm |
| | ctx.dropout_p = dropout_p |
| | ctx.has_x1 = x1 is not None |
| | ctx.has_residual = residual is not None |
| | ctx.is_rms_norm = is_rms_norm |
| | ctx.has_beta = beta0 is not None |
| | z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None) |
| | if not return_dmask: |
| | return z if not prenorm else (*z, xmat.view(x0.shape)) |
| | else: |
| | dmask0 = ( |
| | dmask0.view(x0.shape) |
| | if dropout_p > 0.0 |
| | else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) |
| | ) |
| | dmask1 = ( |
| | dmask1.view(x0.shape) |
| | if dropout_p > 0.0 and x1 is not None |
| | else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) |
| | ) |
| | ctx.mark_non_differentiable(dmask0) |
| | ctx.mark_non_differentiable(dmask1) |
| | return ( |
| | (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1) |
| | ) |
| |
|
| | @staticmethod |
| | def backward(ctx, dz0, dz1, *args): |
| | dz0 = maybe_align(dz0.contiguous(), 16) |
| | dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None |
| | dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None |
| | x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors |
| | dropout_p = ctx.dropout_p |
| | has_x1 = ctx.has_x1 |
| | has_residual = ctx.has_residual |
| | ( |
| | dx0mat, |
| | dx1mat, |
| | dresidualmat, |
| | dgamma0, |
| | dbeta0, |
| | dgamma1, |
| | dbeta1, |
| | ) = _dropout_add_layer_norm_parallel_residual_backward( |
| | dz0, |
| | dz1, |
| | dx, |
| | x, |
| | dmask0, |
| | dmask1, |
| | mu, |
| | rsigma, |
| | gamma0, |
| | gamma1, |
| | dropout_p, |
| | has_x1, |
| | has_residual, |
| | ctx.is_rms_norm, |
| | ) |
| | dx0 = dx0mat.view(x.shape) |
| | dx1 = dx1mat.view(x.shape) if dx1mat is not None else None |
| | dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None |
| | return ( |
| | dx0, |
| | dx1, |
| | dresidual, |
| | dgamma0, |
| | dbeta0 if ctx.has_beta else None, |
| | dgamma1, |
| | dbeta1 if ctx.has_beta else None, |
| | None, |
| | None, |
| | None, |
| | None, |
| | None, |
| | None, |
| | ) |
| |
|
| |
|
| | def layer_norm(x, weight, bias, epsilon): |
| | return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) |
| |
|
| |
|
| | def dropout_add_layer_norm( |
| | x0, |
| | residual, |
| | weight, |
| | bias, |
| | dropout_p, |
| | epsilon, |
| | rowscale=None, |
| | layerscale=None, |
| | prenorm=False, |
| | residual_in_fp32=False, |
| | return_dropout_mask=False, |
| | ): |
| | """residual_in_fp32 only has an effect if residual is None. |
| | Otherwise residual dtype is residual.dtype. |
| | """ |
| | return DropoutAddLayerNormFn.apply( |
| | x0, |
| | residual, |
| | weight, |
| | bias, |
| | rowscale, |
| | layerscale, |
| | dropout_p, |
| | epsilon, |
| | residual_in_fp32, |
| | prenorm, |
| | False, |
| | return_dropout_mask, |
| | ) |
| |
|
| |
|
| | def dropout_add_layer_norm_subset( |
| | x0, |
| | residual, |
| | weight, |
| | bias, |
| | dropout_p, |
| | epsilon, |
| | layerscale=None, |
| | x0_subset=None, |
| | out_subset=None, |
| | rowscale_const=1.0, |
| | out_numrows=0, |
| | prenorm=False, |
| | residual_in_fp32=False, |
| | return_dropout_mask=False, |
| | ): |
| | """residual_in_fp32 only has an effect if residual is None. |
| | Otherwise residual dtype is residual.dtype. |
| | """ |
| | return DropoutAddLayerNormSubsetFn.apply( |
| | x0, |
| | residual, |
| | weight, |
| | bias, |
| | layerscale, |
| | x0_subset, |
| | out_subset, |
| | dropout_p, |
| | epsilon, |
| | rowscale_const, |
| | out_numrows, |
| | residual_in_fp32, |
| | prenorm, |
| | False, |
| | return_dropout_mask, |
| | ) |
| |
|
| |
|
| | def dropout_add_layer_norm_parallel_residual( |
| | x0, |
| | x1, |
| | residual, |
| | weight0, |
| | bias0, |
| | weight1, |
| | bias1, |
| | dropout_p, |
| | epsilon, |
| | prenorm=False, |
| | residual_in_fp32=False, |
| | return_dropout_mask=False, |
| | ): |
| | """residual_in_fp32 only has an effect if residual is None. |
| | Otherwise residual dtype is residual.dtype. |
| | """ |
| | return DropoutAddLayerNormParallelResidualFn.apply( |
| | x0, |
| | x1, |
| | residual, |
| | weight0, |
| | bias0, |
| | weight1, |
| | bias1, |
| | dropout_p, |
| | epsilon, |
| | residual_in_fp32, |
| | prenorm, |
| | False, |
| | return_dropout_mask, |
| | ) |
| |
|
| |
|
| | class DropoutAddLayerNorm(torch.nn.Module): |
| | def __init__( |
| | self, |
| | hidden_size, |
| | prenorm=False, |
| | p=0.0, |
| | eps=1e-5, |
| | residual_in_fp32=False, |
| | device=None, |
| | dtype=None, |
| | ): |
| | factory_kwargs = {"device": device, "dtype": dtype} |
| | super().__init__() |
| | self.prenorm = prenorm |
| | self.p = p |
| | self.eps = eps |
| | self.residual_in_fp32 = residual_in_fp32 |
| | self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) |
| | self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) |
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self): |
| | init.ones_(self.weight) |
| | init.zeros_(self.bias) |
| |
|
| | def forward(self, x0, residual=None): |
| | return dropout_add_layer_norm( |
| | x0, |
| | residual, |
| | self.weight, |
| | self.bias, |
| | self.p if self.training else 0.0, |
| | self.eps, |
| | prenorm=self.prenorm, |
| | residual_in_fp32=self.residual_in_fp32, |
| | ) |
| |
|