diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..0cd58331b2a989b68be4ec5676383437fca8687b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.so filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..64510a07ab488778f7867d2fd0e8e47830a8fcf3 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/build/temp* diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5914c9964973d4ccb0ebcf3785ee2e52825caf2b --- /dev/null +++ b/README.md @@ -0,0 +1,20 @@ +This CUDA extension implements fused dropout + residual + LayerNorm, building on +Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). +Major changes: +- Add dropout and residual. +- Make it work for both pre-norm and post-norm architecture. +- Support more hidden dimensions (all dimensions divisible by 8, up to 8192). +- Implement RMSNorm as an option. +- Support layer norm with parallel residual (e.g., GPT-J, GPT-NeoX, PaLM). + +If you want to use it for dimensions larger than 8k, please file an issue. + +This extension has only been tested on A100s. + +```sh +cd csrc/layer_norm && pip install . +``` + +As of 2024-01-05, this extension is no longer used in the FlashAttention repo. +We've instead switched to a Triton-based +[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py). diff --git a/api.py b/api.py new file mode 100644 index 0000000000000000000000000000000000000000..4b6cd798fd02844ef9cd3897f8ab95e490e638bf --- /dev/null +++ b/api.py @@ -0,0 +1,800 @@ +# Copyright (c) 2022, Tri Dao. +# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py + +import dropout_layer_norm +import torch +from torch.nn import init + + +def maybe_align(x, alignment_in_bytes=16): + """Assume that x already has last dim divisible by alignment_in_bytes""" + # TD [2023-07-04] I'm not 100% sure that clone will align the memory + # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440 + return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() + + +def _dropout_add_layer_norm_forward( + x0, + residual, + gamma, + beta, + rowscale, + colscale, + dropout_p, + epsilon, + residual_in_fp32=False, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes""" + hidden_size = gamma.numel() + x0mat = x0.view((-1, hidden_size)) + residualmat = residual.view((-1, hidden_size)) if residual is not None else None + rowscale = rowscale.view(-1) if rowscale is not None else None + zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( + x0mat, + residualmat, + gamma, + beta, + rowscale, + colscale, + None, + None, + dropout_p, + epsilon, + 1.0, + 0, + None, + residual_in_fp32, + is_rms_norm, + ) + # dmask is None if dropout_p == 0.0 + # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype + return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma + + +def _dropout_add_layer_norm_backward( + dz, + dx, + x, + x0, + dmask, + mu, + rsigma, + gamma, + rowscale, + colscale, + dropout_p, + has_residual, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes + dx == None means that it was a post-norm architecture + (x = drop(x0) + residual was not returned in the fwd). + x0 must not be None if we have colscale. + """ + hidden_size = gamma.numel() + xmat = x.view((-1, hidden_size)) + dzmat = dz.view(xmat.shape) + dxmat = dx.view(xmat.shape) if dx is not None else None + x0mat = x0.view((-1, hidden_size)) if x0 is not None else None + rowscale = rowscale.view(-1) if rowscale is not None else None + if colscale is not None: + assert x0 is not None, "x0 is required to compute the gradient of colscale" + dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( + dzmat, + dxmat, + xmat, + x0mat, + dmask, + mu, + rsigma, + gamma, + rowscale, + colscale, + None, + None, + dropout_p, + 1.0, + 0, + has_residual, + is_rms_norm, + ) + # dresidualmat is None if not has_residual + if colscale is None: + return dx0mat, dresidualmat, dgamma, dbeta + else: + dcolscale = rest[0] + return dx0mat, dresidualmat, dgamma, dbeta, dcolscale + + +def _dropout_add_layer_norm_subset_forward( + x0, + residual, + gamma, + beta, + colscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + residual_in_fp32=False, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes""" + hidden_size = gamma.numel() + x0mat = x0.view((-1, hidden_size)) + residualmat = residual.view((-1, hidden_size)) if residual is not None else None + x0_subset = x0_subset.view(-1) if x0_subset is not None else None + out_subset = out_subset.view(-1) if out_subset is not None else None + zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( + x0mat, + residualmat, + gamma, + beta, + None, + colscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + None, + residual_in_fp32, + is_rms_norm, + ) + # dmask is None if dropout_p == 0.0 + # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype + return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma + + +def _dropout_add_layer_norm_subset_backward( + dz, + dx, + x, + x0, + dmask, + mu, + rsigma, + gamma, + colscale, + x0_subset, + out_subset, + dropout_p, + rowscale_const, + x0_numrows, + has_residual, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes + dx == None means that it was a post-norm architecture + (x = drop(x0) + residual was not returned in the fwd). + x0 must not be None if we have colscale. + """ + hidden_size = gamma.numel() + xmat = x.view((-1, hidden_size)) + dzmat = dz.view(-1, hidden_size) + dxmat = dx.view(xmat.shape) if dx is not None else None + x0mat = x0.view((-1, hidden_size)) if x0 is not None else None + x0_subset = x0_subset.view(-1) if x0_subset is not None else None + out_subset = out_subset.view(-1) if out_subset is not None else None + if colscale is not None: + assert x0 is not None, "x0 is required to compute the gradient of colscale" + dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( + dzmat, + dxmat, + xmat, + x0mat, + dmask, + mu, + rsigma, + gamma, + None, + colscale, + x0_subset, + out_subset, + dropout_p, + rowscale_const, + x0_numrows, + has_residual, + is_rms_norm, + ) + # dresidualmat is None if not has_residual + if colscale is None: + return dx0mat, dresidualmat, dgamma, dbeta + else: + dcolscale = rest[0] + return dx0mat, dresidualmat, dgamma, dbeta, dcolscale + + +def _dropout_add_layer_norm_parallel_residual_forward( + x0, + x1, + residual, + gamma0, + beta0, + gamma1, + beta1, + dropout_p, + epsilon, + residual_in_fp32=False, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes""" + hidden_size = gamma0.numel() + x0mat = x0.view((-1, hidden_size)) + x1mat = x1.view((-1, hidden_size)) if x1 is not None else None + residualmat = residual.view((-1, hidden_size)) if residual is not None else None + ( + z0mat, + z1mat, + xmat, + dmask0, + dmask1, + mu, + rsigma, + ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd( + x0mat, + x1mat, + residualmat, + gamma0, + beta0, + gamma1, + beta1, + dropout_p, + epsilon, + None, + residual_in_fp32, + is_rms_norm, + ) + # dmask0 and dmask1 are None if dropout_p == 0.0 + # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype + return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma + + +def _dropout_add_layer_norm_parallel_residual_backward( + dz0, + dz1, + dx, + x, + dmask0, + dmask1, + mu, + rsigma, + gamma0, + gamma1, + dropout_p, + has_x1, + has_residual, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes + dx == None means that it was a post-norm architecture + (x = drop(x0) + residual was not returned in the fwd). + """ + hidden_size = gamma0.numel() + xmat = x.view((-1, hidden_size)) + dz0mat = dz0.view(xmat.shape) + dz1mat = dz1.view(xmat.shape) if dz1 is not None else None + dxmat = dx.view(xmat.shape) if dx is not None else None + ( + dx0mat, + dx1mat, + dresidualmat, + dgamma0, + dbeta0, + dgamma1, + dbeta1, + *rest, + ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd( + dz0mat, + dz1mat, + dxmat, + xmat, + dmask0, + dmask1, + mu, + rsigma, + gamma0, + gamma1, + dropout_p, + has_x1, + has_residual, + is_rms_norm, + ) + # dresidualmat is None if not has_residual + return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 + + +class DropoutAddLayerNormFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x0, + residual, + gamma, + beta, + rowscale, + colscale, + dropout_p, + epsilon, + residual_in_fp32=False, + prenorm=False, + is_rms_norm=False, + return_dmask=False, + ): + x0 = maybe_align(x0.contiguous(), 16) + residual = maybe_align(residual.contiguous(), 16) if residual is not None else None + gamma = maybe_align(gamma.contiguous(), 16) + beta = maybe_align(beta.contiguous(), 16) if beta is not None else None + rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None + colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None + zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( + x0, + residual, + gamma, + beta, + rowscale, + colscale, + dropout_p, + epsilon, + residual_in_fp32, + is_rms_norm, + ) + # Only need to save x0 if we need to compute gradient wrt colscale + x0_saved = x0 if colscale is not None else None + ctx.save_for_backward( + xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale + ) + ctx.prenorm = prenorm + ctx.dropout_p = dropout_p + ctx.has_residual = residual is not None + ctx.is_rms_norm = is_rms_norm + ctx.has_beta = beta is not None + if not return_dmask: + return ( + zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape)) + ) + else: + dmask = ( + dmask.view(x0.shape) + if dropout_p > 0.0 + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) + ) + ctx.mark_non_differentiable(dmask) + return ( + (zmat.view(x0.shape), dmask) + if not prenorm + else (zmat.view(x0.shape), xmat.view(x0.shape), dmask) + ) + + @staticmethod + def backward(ctx, dz, *args): + # assert dz.is_contiguous() + dz = maybe_align(dz.contiguous(), 16) # this happens! + dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None + x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors + # x0 is None if colscale is None + dropout_p = ctx.dropout_p + has_residual = ctx.has_residual + dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( + dz, + dx, + x, + x0, + dmask, + mu, + rsigma, + gamma, + rowscale, + colscale, + dropout_p, + has_residual, + ctx.is_rms_norm, + ) + dx0 = dx0mat.view(x.shape) + dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None + dcolscale = rest[0] if colscale is not None else None + return ( + dx0, + dresidual, + dgamma, + dbeta if ctx.has_beta else None, + None, + dcolscale, + None, + None, + None, + None, + None, + None, + ) + + +class DropoutAddLayerNormSubsetFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x0, + residual, + gamma, + beta, + colscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + residual_in_fp32=False, + prenorm=False, + is_rms_norm=False, + return_dmask=False, + ): + x0 = maybe_align(x0.contiguous(), 16) + residual = maybe_align(residual.contiguous(), 16) if residual is not None else None + gamma = maybe_align(gamma.contiguous(), 16) + beta = maybe_align(beta.contiguous(), 16) if beta is not None else None + colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None + zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( + x0, + residual, + gamma, + beta, + colscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + residual_in_fp32, + is_rms_norm, + ) + # Only need to save x0 if we need to compute gradient wrt colscale + x0_saved = x0 if colscale is not None else None + x_shape = (-1, *x0.shape[1:]) + ctx.save_for_backward( + xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset + ) + ctx.prenorm = prenorm + ctx.dropout_p = dropout_p + ctx.rowscale_const = rowscale_const + ctx.x0_numrows = x0.shape[:-1].numel() + ctx.has_residual = residual is not None + ctx.is_rms_norm = is_rms_norm + ctx.has_beta = beta is not None + z_shape = (-1, *x0.shape[1:]) + if not return_dmask: + return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape)) + else: + z = zmat.view(z_shape) + dmask = ( + dmask.view(x0.shape) + if dropout_p > 0.0 + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) + ) + ctx.mark_non_differentiable(dmask) + return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask) + + @staticmethod + def backward(ctx, dz, *args): + # assert dz.is_contiguous() + dz = maybe_align(dz.contiguous(), 16) # this happens! + dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None + x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors + # x0 is None if colscale is None + dropout_p = ctx.dropout_p + has_residual = ctx.has_residual + dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( + dz, + dx, + x, + x0, + dmask, + mu, + rsigma, + gamma, + colscale, + x0_subset, + out_subset, + dropout_p, + ctx.rowscale_const, + ctx.x0_numrows, + has_residual, + ctx.is_rms_norm, + ) + dx0 = dx0mat.view(-1, *x.shape[1:]) + dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None + dcolscale = rest[0] if colscale is not None else None + return ( + dx0, + dresidual, + dgamma, + dbeta if ctx.has_beta else None, + dcolscale, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x0, + x1, + residual, + gamma0, + beta0, + gamma1, + beta1, + dropout_p, + epsilon, + residual_in_fp32=False, + prenorm=False, + is_rms_norm=False, + return_dmask=False, + ): + x0 = maybe_align(x0.contiguous(), 16) + x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None + residual = maybe_align(residual.contiguous(), 16) if residual is not None else None + gamma0 = maybe_align(gamma0.contiguous(), 16) + beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None + gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None + beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None + ( + z0mat, + z1mat, + xmat, + dmask0, + dmask1, + mu, + rsigma, + ) = _dropout_add_layer_norm_parallel_residual_forward( + x0, + x1, + residual, + gamma0, + beta0, + gamma1, + beta1, + dropout_p, + epsilon, + residual_in_fp32, + is_rms_norm, + ) + ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma) + ctx.prenorm = prenorm + ctx.dropout_p = dropout_p + ctx.has_x1 = x1 is not None + ctx.has_residual = residual is not None + ctx.is_rms_norm = is_rms_norm + ctx.has_beta = beta0 is not None + z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None) + if not return_dmask: + return z if not prenorm else (*z, xmat.view(x0.shape)) + else: + dmask0 = ( + dmask0.view(x0.shape) + if dropout_p > 0.0 + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) + ) + dmask1 = ( + dmask1.view(x0.shape) + if dropout_p > 0.0 and x1 is not None + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) + ) + ctx.mark_non_differentiable(dmask0) + ctx.mark_non_differentiable(dmask1) + return ( + (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1) + ) + + @staticmethod + def backward(ctx, dz0, dz1, *args): + dz0 = maybe_align(dz0.contiguous(), 16) # this happens! + dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None + dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None + x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors + dropout_p = ctx.dropout_p + has_x1 = ctx.has_x1 + has_residual = ctx.has_residual + ( + dx0mat, + dx1mat, + dresidualmat, + dgamma0, + dbeta0, + dgamma1, + dbeta1, + ) = _dropout_add_layer_norm_parallel_residual_backward( + dz0, + dz1, + dx, + x, + dmask0, + dmask1, + mu, + rsigma, + gamma0, + gamma1, + dropout_p, + has_x1, + has_residual, + ctx.is_rms_norm, + ) + dx0 = dx0mat.view(x.shape) + dx1 = dx1mat.view(x.shape) if dx1mat is not None else None + dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None + return ( + dx0, + dx1, + dresidual, + dgamma0, + dbeta0 if ctx.has_beta else None, + dgamma1, + dbeta1 if ctx.has_beta else None, + None, + None, + None, + None, + None, + None, + ) + + +def layer_norm(x, weight, bias, epsilon): + return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) + + +def dropout_add_layer_norm( + x0, + residual, + weight, + bias, + dropout_p, + epsilon, + rowscale=None, + layerscale=None, + prenorm=False, + residual_in_fp32=False, + return_dropout_mask=False, +): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormFn.apply( + x0, + residual, + weight, + bias, + rowscale, + layerscale, + dropout_p, + epsilon, + residual_in_fp32, + prenorm, + False, + return_dropout_mask, + ) + + +def dropout_add_layer_norm_subset( + x0, + residual, + weight, + bias, + dropout_p, + epsilon, + layerscale=None, + x0_subset=None, + out_subset=None, + rowscale_const=1.0, + out_numrows=0, + prenorm=False, + residual_in_fp32=False, + return_dropout_mask=False, +): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormSubsetFn.apply( + x0, + residual, + weight, + bias, + layerscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + residual_in_fp32, + prenorm, + False, + return_dropout_mask, + ) + + +def dropout_add_layer_norm_parallel_residual( + x0, + x1, + residual, + weight0, + bias0, + weight1, + bias1, + dropout_p, + epsilon, + prenorm=False, + residual_in_fp32=False, + return_dropout_mask=False, +): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormParallelResidualFn.apply( + x0, + x1, + residual, + weight0, + bias0, + weight1, + bias1, + dropout_p, + epsilon, + residual_in_fp32, + prenorm, + False, + return_dropout_mask, + ) + + +class DropoutAddLayerNorm(torch.nn.Module): + def __init__( + self, + hidden_size, + prenorm=False, + p=0.0, + eps=1e-5, + residual_in_fp32=False, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.prenorm = prenorm + self.p = p + self.eps = eps + self.residual_in_fp32 = residual_in_fp32 + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.reset_parameters() + + def reset_parameters(self): + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, x0, residual=None): + return dropout_add_layer_norm( + x0, + residual, + self.weight, + self.bias, + self.p if self.training else 0.0, + self.eps, + prenorm=self.prenorm, + residual_in_fp32=self.residual_in_fp32, + ) diff --git a/build/lib.linux-x86_64-3.10/dropout_layer_norm.cpython-310-x86_64-linux-gnu.so b/build/lib.linux-x86_64-3.10/dropout_layer_norm.cpython-310-x86_64-linux-gnu.so new file mode 100755 index 0000000000000000000000000000000000000000..c4b2ea8084f001a9fa23cfa4c56e9838fe1926d3 --- /dev/null +++ b/build/lib.linux-x86_64-3.10/dropout_layer_norm.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aba674c175147bfdff6acb354745749070519df35f66522d71b2743aedc3b5a9 +size 26705096 diff --git a/ln.h b/ln.h new file mode 100644 index 0000000000000000000000000000000000000000..9830c092d0aca9f3466154a18d1d3c32d651716e --- /dev/null +++ b/ln.h @@ -0,0 +1,281 @@ +#pragma once + +#include +#include +#include + +#ifdef OLD_GENERATOR_PATH +#include +#else +#include +#endif + +namespace layer_norm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct LaunchParams{ + + size_t elts_per_thread; + size_t workspace_bytes; + size_t barrier_size; + + cudaDeviceProp * props; + + cudaStream_t stream; + + Params params; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ParamsBase { + ParamsBase() + : ctas_per_col(0) + , rows(0) + , cols(0) + , x(nullptr) + , mu(nullptr) + , rs(nullptr) + , gamma(nullptr) + , gamma1(nullptr) + , rowscale(nullptr) + , colscale(nullptr) + , dropout_keep_p(1.f) + , dropout_scale(1.f) + , is_rms_norm(false) + , workspace(nullptr) + , barrier(nullptr) + { + } + + // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. + int ctas_per_col; + + // Input is interpreted as matrix. We normalize across columns. + int rows; + int cols; + + // Common data pointers. + void *x0; + void *x1; + void *residual; + void *x; + void *dmask; + void *dmask1; + void *mu; + void *rs; + void *gamma; + void *gamma1; + void *rowscale; + void *colscale; + void *x0_subset; + void *z_subset; + + float inverse_cols; + + float dropout_keep_p; + float dropout_scale; + float rowscale_const; + + bool is_rms_norm; + + // Multi-CTA workspace in gmem. + void *workspace; + + // Multi-CTA sync barriers in gmem. + int *barrier; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct FwdParams : public ParamsBase { + FwdParams() + : ParamsBase() + , z(nullptr) + , z1(nullptr) + , beta(nullptr) + , beta1(nullptr) + , epsilon(0.f) + { + } + + // Output of LN FWD. + void *z; + void *z1; + void *beta; + void *beta1; + float epsilon; + + // Random state. + at::PhiloxCudaState philox_args; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct BwdParams : public ParamsBase { + BwdParams() + : ParamsBase() + , dz(nullptr) + , dz1(nullptr) + , dx(nullptr) + , dbeta_part(nullptr) + , dgamma_part(nullptr) + , dbeta1_part(nullptr) + , dgamma1_part(nullptr) + , dcolscale_part(nullptr) + , dx0(nullptr) + , dx1(nullptr) + , dresidual(nullptr) + , dbeta(nullptr) + , dgamma(nullptr) + , dbeta1(nullptr) + , dgamma1(nullptr) + , dcolscale(nullptr) + { + } + + // Input: gradient wrt. LN FWD output. + void *dz; + void *dz1; + // Input: gradient wrt residual. + void *dx; + + // Workspace for Wgrad pre-reduction. + void *dbeta_part; + void *dgamma_part; + void *dbeta1_part; + void *dgamma1_part; + void *dcolscale_part; + + // Output: Dgrad. + void *dx0; + void *dx1; + void *dresidual; + // Output: Wgrad. + void *dbeta; + void *dgamma; + void *dbeta1; + void *dgamma1; + void *dcolscale; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using FwdFunction = std::function&, const bool)>; +using BwdFunction = std::function&, const bool)>; +using FunctionKey = uint64_t; +using FwdRegistry = std::unordered_map; +using BwdRegistry = std::unordered_map; + +extern FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS; +extern BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using fp32 = float; +using fp16 = half; +using bf16 = nv_bfloat16; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TypeId{}; + +template<> +struct TypeId{ + constexpr static uint32_t Value = 0; +}; + +template<> +struct TypeId{ + constexpr static uint32_t Value = 1; +}; + +template<> +struct TypeId{ + constexpr static uint32_t Value = 2; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Type2Key{ + constexpr static uint32_t Value = TypeId::Value << S; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct WeightType2Key : public Type2Key{}; + +template +struct InputType2Key : public Type2Key{}; + +template +struct ResidualType2Key : public Type2Key{}; + +template +struct OutputType2Key : public Type2Key{}; + +template +struct ComputeType2Key : public Type2Key{}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Types2Key{ + constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | ResidualType2Key::Value | OutputType2Key::Value | ComputeType2Key::Value; + constexpr static inline uint64_t get(const uint64_t hidden_size){ + constexpr uint64_t type_key = Value; + return (type_key << 32) | hidden_size; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FwdRegistrar{ + FwdRegistrar(FwdFunction f){ + uint64_t key = Types2Key::get(HIDDEN_SIZE); + FWD_FUNCS.insert({ key, f }); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BwdRegistrar{ + BwdRegistrar(BwdFunction f){ + uint64_t key = Types2Key::get(HIDDEN_SIZE); + BWD_FUNCS.insert({ key, f }); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FwdParallelRegistrar{ + FwdParallelRegistrar(FwdFunction f){ + uint64_t key = Types2Key::get(HIDDEN_SIZE); + PARALLEL_FWD_FUNCS.insert({ key, f }); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BwdParallelRegistrar{ + BwdParallelRegistrar(BwdFunction f){ + uint64_t key = Types2Key::get(HIDDEN_SIZE); + PARALLEL_BWD_FUNCS.insert({ key, f }); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace layer_norm diff --git a/ln_api.cpp b/ln_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3981bbad58e56023c33ff66b89c130f4d1636a36 --- /dev/null +++ b/ln_api.cpp @@ -0,0 +1,850 @@ +#include +#include "ATen/cuda/CUDAContext.h" +#include + +#include "ln.h" + +/* + +Supported Type combinations: + +input residual compute weights output +============================================ +fp32 fp32 fp32 fp32 fp32 +fp16 fp32 fp32 fp32 fp16 +fp16 fp16 fp32 fp32 fp16 +bf16 fp32 fp32 fp32 bf16 +bf16 bf16 fp32 fp32 bf16 +fp16 fp16 fp32 fp16 fp16 +bf16 bf16 fp32 bf16 bf16 + +Remarks: +Output type = Input type +Compute always in FP32 + +*/ + +namespace layer_norm { + +// Create registries and provide runtime versions of config hash functions. + +FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS; +BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +uint32_t get_type_id(torch::Dtype dtype){ + if( dtype == torch::kFloat16 ) { + return TypeId::Value; + } else if( dtype == torch::kBFloat16 ) { + return TypeId::Value; + } else if( dtype == torch::kFloat32 ) { + return TypeId::Value; + } else { + TORCH_CHECK(false, "Type not supported: ", dtype); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) { + using namespace layer_norm; + uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(rtype) << 4) | (get_type_id(otype) << 6) | (get_type_id(ctype) << 8); + uint64_t launcher_key = (type_key << 32) | hidden_size; + return launcher_key; +} + +} // namespace layer_norm + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { + auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size)); + if( iter != layer_norm::FWD_FUNCS.end() ) { + return iter->second; + } else { + TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { + auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size)); + if( iter != layer_norm::BWD_FUNCS.end() ) { + return iter->second; + } else { + TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +layer_norm::FwdFunction & get_parallel_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { + auto iter = layer_norm::PARALLEL_FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size)); + if( iter != layer_norm::PARALLEL_FWD_FUNCS.end() ) { + return iter->second; + } else { + TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +layer_norm::BwdFunction & get_parallel_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { + auto iter = layer_norm::PARALLEL_BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size)); + if( iter != layer_norm::PARALLEL_BWD_FUNCS.end() ) { + return iter->second; + } else { + TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size + c10::optional &residual_, // Residual: BxSxhidden_size + const at::Tensor &gamma, // hidden_size + c10::optional &beta_, // hidden_size + c10::optional &rowscale_, // BxS + c10::optional &colscale_, // hidden_size + c10::optional &x0_subset_, // BxS + c10::optional &z_subset_, // BxS + const float dropout_p, + const float epsilon, + const float rowscale_const, + const int64_t z_numrows, + c10::optional gen_, + bool residual_in_fp32=false, + bool is_rms_norm=false +) { + auto itype = x0.scalar_type(); + auto rtype = residual_.has_value() + ? residual_.value().scalar_type() + : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type()); + auto wtype = gamma.scalar_type(); + auto otype = itype; + auto ctype = torch::kFloat32; + auto mtype = torch::kUInt8; + + TORCH_CHECK(x0.is_cuda()); + TORCH_CHECK(gamma.is_cuda()); + + TORCH_CHECK(x0.is_contiguous()); + // c10::IntArrayRef does not own the storage, so we need to construct a vector. + // Otherwise just constructing IntArrayRef({blah}) will cause uninitialized memory because + // blah is then deallocated. + std::vector sizes_vec {!x0_subset_.has_value() ? x0.size(0) : x0_subset_.value().size(0), x0.size(1)}; + auto sizes = c10::IntArrayRef(sizes_vec); + TORCH_CHECK(x0.dim() == 2); + TORCH_CHECK(sizes.size() == 2); + + const int rows = sizes[0]; + const int cols = sizes[1]; + auto hidden_size = gamma.numel(); + TORCH_CHECK(hidden_size == cols); + + if (beta_.has_value()) { + auto beta = beta_.value(); + TORCH_CHECK(beta.dtype() == wtype); + TORCH_CHECK(beta.is_cuda()); + TORCH_CHECK(beta.is_contiguous()); + TORCH_CHECK(beta.sizes() == gamma.sizes()); + } + + if (residual_.has_value()) { + auto residual = residual_.value(); + TORCH_CHECK(residual.is_cuda()); + TORCH_CHECK(residual.is_contiguous()); + TORCH_CHECK(residual.sizes() == sizes); + } + + if (rowscale_.has_value()) { + auto rowscale = rowscale_.value(); + TORCH_CHECK(rowscale.is_cuda()); + TORCH_CHECK(rowscale.is_contiguous()); + TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows}); + TORCH_CHECK(rowscale.dtype() == itype); + } + + if (colscale_.has_value()) { + auto colscale = colscale_.value(); + TORCH_CHECK(colscale.is_cuda()); + TORCH_CHECK(colscale.is_contiguous()); + TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols}); + TORCH_CHECK(colscale.dtype() == wtype); + } + + if (x0_subset_.has_value()) { + auto x0_subset = x0_subset_.value(); + TORCH_CHECK(x0_subset.is_cuda()); + TORCH_CHECK(x0_subset.is_contiguous()); + TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows}); + TORCH_CHECK(x0_subset.dtype() == torch::kInt32); + + TORCH_CHECK(z_subset_.has_value()); + auto z_subset = z_subset_.value(); + TORCH_CHECK(z_subset.is_cuda()); + TORCH_CHECK(z_subset.is_contiguous()); + TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows}); + TORCH_CHECK(z_subset.dtype() == torch::kInt32); + } + + TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192)); + TORCH_CHECK(epsilon >= 0.f); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x0.get_device()}; + + auto opts = x0.options(); + + bool save_x = residual_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype); + at::Tensor x; + if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); } + at::Tensor dmask; + if (dropout_p > 0.f) { dmask = torch::empty(x0.sizes(), opts.dtype(mtype)); }; + auto z = torch::empty(z_subset_.has_value() ? c10::IntArrayRef{z_numrows, cols} : sizes, opts.dtype(otype)); + + auto mu = torch::empty({ rows }, opts.dtype(ctype)); + auto rsigma = torch::empty({ rows }, opts.dtype(ctype)); + + layer_norm::LaunchParams launch_params; + + launch_params.props = at::cuda::getCurrentDeviceProperties(); + launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); + TORCH_CHECK(dropout_p < 1.f); + launch_params.params.dropout_keep_p = 1.f - dropout_p; + launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr; + launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; + launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr; + launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr; + launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr; + + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024); + // Request the kernel launcher. + auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple)); + + // Set the kernel runtime parameters. + layer_norm::FwdParams ¶ms = launch_params.params; + params.rows = rows; + params.cols = cols; + params.x0 = x0.data_ptr(); + params.x = save_x ? x.data_ptr() : nullptr; + params.dmask = dropout_p > 0.f ? dmask.data_ptr() : nullptr; + params.mu = mu.data_ptr(); + params.rs = rsigma.data_ptr(); + params.gamma = gamma.data_ptr(); + params.beta = beta_.has_value() ? beta_.value().data_ptr() : nullptr; + params.z = z.data_ptr(); + params.epsilon = epsilon; + params.dropout_scale = 1.f / (1.f - dropout_p); + params.inverse_cols = 1.f / float(params.cols); + params.rowscale_const = rowscale_const; + params.is_rms_norm = is_rms_norm; + + // Query the kernel-specific launch parameters. + launcher(launch_params, true); + + at::Tensor workspace, barrier; + + if (dropout_p > 0.f) { + // number of times random will be generated per thread, to offset philox counter in thc random + // state + int64_t counter_offset = launch_params.elts_per_thread; + + // See Note [Acquire lock when using random generators] + { + std::lock_guard lock(gen->mutex_); + params.philox_args = gen->philox_cuda_state(counter_offset); + } + } + + if( launch_params.barrier_size > 0 ) { + auto options = x0.options(); + barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); + workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); + params.workspace = workspace.data_ptr(); + params.barrier = barrier.data_ptr(); + } + + // Launch the kernel. + launcher(launch_params, false); + + return { z, x, dmask, mu, rsigma }; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidden_size + c10::optional &dx_, // BxSxhidden_size + const at::Tensor &x, // BxSxhidden_size + c10::optional &x0_, // BxSxhidden_size + c10::optional &dmask_, // BxSxhidden_size + const at::Tensor &mu, // BxS, FP32! + const at::Tensor &rsigma, // BxS, FP32! + const at::Tensor &gamma, // hidden_size + c10::optional &rowscale_, // BxS + c10::optional &colscale_, // hidden_size + c10::optional &x0_subset_, // BxS + c10::optional &z_subset_, // BxS + const float dropout_p, + const float rowscale_const, + const int64_t x0_numrows, + const bool has_residual, + bool is_rms_norm=false +) { + + auto itype = dz.scalar_type(); + auto rtype = x.scalar_type(); + auto wtype = gamma.scalar_type(); + auto otype = itype; + auto ctype = torch::kFloat32; + auto mtype = torch::kUInt8; + + if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); } + + TORCH_CHECK(dz.dtype() == otype); + TORCH_CHECK(mu.dtype() == ctype); + TORCH_CHECK(rsigma.dtype() == ctype); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(dz.is_cuda()); + TORCH_CHECK(mu.is_cuda()); + TORCH_CHECK(rsigma.is_cuda()); + TORCH_CHECK(gamma.is_cuda()); + + TORCH_CHECK(x.is_contiguous()); + TORCH_CHECK(dz.is_contiguous()); + + auto sizes = x.sizes(); + TORCH_CHECK(sizes.size() == 2); + auto rows = sizes[0]; + auto cols = sizes[1]; + TORCH_CHECK(dz.dim() == 2); + TORCH_CHECK(dz.size(1) == cols); + auto hidden_size = gamma.numel(); + TORCH_CHECK(hidden_size == cols); + + // c10::IntArrayRef does not own the storage, so we need to construct a vector. + // Otherwise just constructing IntArrayRef({blah}) will cause uninitialized memory because + // blah is then deallocated. + std::vector x0_sizes_vec {!x0_subset_.has_value() ? rows : x0_numrows, cols}; + auto x0_sizes = c10::IntArrayRef(x0_sizes_vec); + + if (dx_.has_value()) { + auto dx = dx_.value(); + TORCH_CHECK(dx.dtype() == rtype); + TORCH_CHECK(dx.is_cuda()); + TORCH_CHECK(dx.is_contiguous()); + TORCH_CHECK(dx.sizes() == sizes); + } + + if (dmask_.has_value()) { + auto dmask = dmask_.value(); + TORCH_CHECK(dmask.dtype() == mtype); + TORCH_CHECK(dmask.is_cuda()); + TORCH_CHECK(dmask.is_contiguous()); + TORCH_CHECK(dmask.sizes() == x0_sizes); + } + + if (rowscale_.has_value()) { + auto rowscale = rowscale_.value(); + TORCH_CHECK(rowscale.is_cuda()); + TORCH_CHECK(rowscale.is_contiguous()); + TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows}); + TORCH_CHECK(rowscale.dtype() == itype); + } + + if (colscale_.has_value()) { + auto colscale = colscale_.value(); + TORCH_CHECK(colscale.is_cuda()); + TORCH_CHECK(colscale.is_contiguous()); + TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols}); + TORCH_CHECK(colscale.dtype() == wtype); + + TORCH_CHECK(x0_.has_value()); + auto x0 = x0_.value(); + TORCH_CHECK(x0.is_cuda()); + TORCH_CHECK(x0.is_contiguous()); + TORCH_CHECK(x0.sizes() == x0_sizes); + TORCH_CHECK(x0.dtype() == itype); + } + + if (x0_subset_.has_value()) { + auto x0_subset = x0_subset_.value(); + TORCH_CHECK(x0_subset.is_cuda()); + TORCH_CHECK(x0_subset.is_contiguous()); + TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows}); + TORCH_CHECK(x0_subset.dtype() == torch::kInt32); + + TORCH_CHECK(z_subset_.has_value()); + auto z_subset = z_subset_.value(); + TORCH_CHECK(z_subset.is_cuda()); + TORCH_CHECK(z_subset.is_contiguous()); + TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows}); + TORCH_CHECK(z_subset.dtype() == torch::kInt32); + } + + TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192)); + + TORCH_CHECK(mu.numel() == rows); + TORCH_CHECK(mu.sizes() == rsigma.sizes()); + + TORCH_CHECK(gamma.numel() == cols); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)dz.get_device()}; + + auto opts = x.options(); + + auto dx0 = torch::empty(x0_sizes, opts.dtype(itype)); + at::Tensor dresidual; + if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); } + auto dgamma = torch::empty_like(gamma); + auto dbeta = torch::empty_like(gamma); + at::Tensor dcolscale; + if (colscale_.has_value()) { + dcolscale = torch::empty_like(colscale_.value()); + } + + layer_norm::LaunchParams launch_params; + launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); + launch_params.props = at::cuda::getCurrentDeviceProperties(); + TORCH_CHECK(dropout_p < 1.f); + launch_params.params.dropout_keep_p = 1.f - dropout_p; + launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr; + launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; + launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr; + launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr; + launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr; + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024); + auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple)); + + launcher(launch_params, true); + + auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); + auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); + at::Tensor dcolscale_part; + if (colscale_.has_value()) { + dcolscale_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); + } + at::Tensor workspace, barrier; + + layer_norm::BwdParams ¶ms = launch_params.params; + params.rows = rows; + params.cols = cols; + params.x = x.data_ptr(); + params.x0 = x0_.has_value() ? x0_.value().data_ptr() : nullptr; + params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr; + params.mu = mu.data_ptr(); + params.rs = rsigma.data_ptr(); + params.gamma = gamma.data_ptr(); + params.dz = dz.data_ptr(); + params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr; + params.dx0 = dx0.data_ptr(); + params.dbeta = dbeta.data_ptr(); + params.dgamma = dgamma.data_ptr(); + params.dcolscale = colscale_.has_value() ? dcolscale.data_ptr() : nullptr; + params.dbeta_part = dbeta_part.data_ptr(); + params.dgamma_part = dgamma_part.data_ptr(); + params.dcolscale_part = colscale_.has_value() ? dcolscale_part.data_ptr() : nullptr; + params.dropout_scale = 1.f / (1.f - dropout_p); + params.inverse_cols = 1.f / float(params.cols); + params.rowscale_const = rowscale_const; + params.is_rms_norm = is_rms_norm; + + if( launch_params.barrier_size > 0 ) { + // TODO Any way to avoid this? + barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32)); + workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar)); + params.workspace = workspace.data_ptr(); + params.barrier = barrier.data_ptr(); + } + + launcher(launch_params, false); + + std::vector result = { dx0, dresidual, dgamma, dbeta, dgamma_part, dbeta_part }; + if (colscale_.has_value()) { + result.push_back(dcolscale); + result.push_back(dcolscale_part); + } + return result; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +std::vector dropout_add_ln_parallel_residual_fwd( + const at::Tensor &x0, // Input: BxSxhidden_size + c10::optional &x1_, // Input: BxSxhidden_size + c10::optional &residual_, // Residual: BxSxhidden_size + const at::Tensor &gamma0, // hidden_size + c10::optional &beta0_, // hidden_size + c10::optional &gamma1_, // hidden_size + c10::optional &beta1_, // hidden_size + const float dropout_p, + const float epsilon, + c10::optional gen_, + bool residual_in_fp32=false, + bool is_rms_norm=false +) { + auto itype = x0.scalar_type(); + auto rtype = residual_.has_value() + ? residual_.value().scalar_type() + : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type()); + auto wtype = gamma0.scalar_type(); + auto otype = itype; + auto ctype = torch::kFloat32; + auto mtype = torch::kUInt8; + + TORCH_CHECK(x0.is_cuda()); + TORCH_CHECK(gamma0.is_cuda()); + + TORCH_CHECK(x0.is_contiguous()); + const auto sizes = x0.sizes(); + TORCH_CHECK(x0.dim() == 2); + + const int rows = sizes[0]; + const int cols = sizes[1]; + auto hidden_size = gamma0.numel(); + TORCH_CHECK(hidden_size == cols); + + if (x1_.has_value()) { + auto x1 = x1_.value(); + TORCH_CHECK(x1.is_cuda()); + TORCH_CHECK(x1.is_contiguous()); + TORCH_CHECK(x1.sizes() == sizes); + } + + if (residual_.has_value()) { + auto residual = residual_.value(); + TORCH_CHECK(residual.is_cuda()); + TORCH_CHECK(residual.is_contiguous()); + TORCH_CHECK(residual.sizes() == sizes); + } + + if (beta0_.has_value()) { + auto beta0 = beta0_.value(); + TORCH_CHECK(beta0.dtype() == wtype); + TORCH_CHECK(beta0.is_cuda()); + TORCH_CHECK(beta0.is_contiguous()); + TORCH_CHECK(beta0.sizes() == gamma0.sizes()); + } + + if (gamma1_.has_value()) { + auto gamma1 = gamma1_.value(); + TORCH_CHECK(gamma1.dtype() == wtype); + TORCH_CHECK(gamma1.is_cuda()); + TORCH_CHECK(gamma1.is_contiguous()); + TORCH_CHECK(gamma1.sizes() == gamma0.sizes()); + } + + if (beta1_.has_value()) { + auto beta1 = beta1_.value(); + TORCH_CHECK(beta1.dtype() == wtype); + TORCH_CHECK(beta1.is_cuda()); + TORCH_CHECK(beta1.is_contiguous()); + TORCH_CHECK(beta1.sizes() == gamma0.sizes()); + } + + TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192)); + TORCH_CHECK(epsilon >= 0.f); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x0.get_device()}; + + auto opts = x0.options(); + + bool save_x = residual_.has_value() || x1_.has_value() || (dropout_p > 0.f) || (itype != rtype); + at::Tensor x; + if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); } + at::Tensor dmask0, dmask1; + if (dropout_p > 0.f) { + dmask0 = torch::empty(x0.sizes(), opts.dtype(mtype)); + if (x1_.has_value()) { dmask1 = torch::empty(x0.sizes(), opts.dtype(mtype)); } + }; + auto z0 = torch::empty(sizes, opts.dtype(otype)); + at::Tensor z1; + if (gamma1_.has_value()) { z1 = torch::empty(sizes, opts.dtype(otype)); } + + auto mu = torch::empty({ rows }, opts.dtype(ctype)); + auto rsigma = torch::empty({ rows }, opts.dtype(ctype)); + + layer_norm::LaunchParams launch_params; + + launch_params.props = at::cuda::getCurrentDeviceProperties(); + launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); + TORCH_CHECK(dropout_p < 1.f); + launch_params.params.dropout_keep_p = 1.f - dropout_p; + launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr; + + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024); + // Request the kernel launcher. + auto launcher = get_parallel_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple)); + + // Set the kernel runtime parameters. + layer_norm::FwdParams ¶ms = launch_params.params; + params.rows = rows; + params.cols = cols; + params.x0 = x0.data_ptr(); + params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr; + params.x = save_x ? x.data_ptr() : nullptr; + params.dmask = dropout_p > 0.f ? dmask0.data_ptr() : nullptr; + params.dmask1 = (dropout_p > 0.f && x1_.has_value()) ? dmask1.data_ptr() : nullptr; + params.mu = mu.data_ptr(); + params.rs = rsigma.data_ptr(); + params.gamma = gamma0.data_ptr(); + params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr; + params.beta = beta0_.has_value() ? beta0_.value().data_ptr() : nullptr; + params.beta1 = beta1_.has_value() ? beta1_.value().data_ptr() : nullptr; + params.z = z0.data_ptr(); + params.z1 = gamma1_.has_value() ? z1.data_ptr() : nullptr; + params.epsilon = epsilon; + params.dropout_scale = 1.f / (1.f - dropout_p); + params.inverse_cols = 1.f / float(params.cols); + params.is_rms_norm = is_rms_norm; + + // Query the kernel-specific launch parameters. + launcher(launch_params, true); + + at::Tensor workspace, barrier; + + if (dropout_p > 0.f) { + // number of times random will be generated per thread, to offset philox counter in thc random + // state + int64_t counter_offset = 2 * launch_params.elts_per_thread; + + // See Note [Acquire lock when using random generators] + { + std::lock_guard lock(gen->mutex_); + params.philox_args = gen->philox_cuda_state(counter_offset); + } + } + + if( launch_params.barrier_size > 0 ) { + auto options = x0.options(); + barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); + workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); + params.workspace = workspace.data_ptr(); + params.barrier = barrier.data_ptr(); + } + + // Launch the kernel. + launcher(launch_params, false); + + return { z0, z1, x, dmask0, dmask1, mu, rsigma }; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +std::vector dropout_add_ln_parallel_residual_bwd( + const at::Tensor &dz0, // BxSxhidden_size + c10::optional &dz1_, // BxSxhidden_size + c10::optional &dx_, // BxSxhidden_size + const at::Tensor &x, // BxSxhidden_size + c10::optional &dmask0_, // BxSxhidden_size + c10::optional &dmask1_, // BxSxhidden_size + const at::Tensor &mu, // BxS, FP32! + const at::Tensor &rsigma, // BxS, FP32! + const at::Tensor &gamma0, // hidden_size + c10::optional &gamma1_, // hidden_size + const float dropout_p, + const bool has_x1, + const bool has_residual, + bool is_rms_norm=false +) { + + auto itype = dz0.scalar_type(); + auto rtype = x.scalar_type(); + auto wtype = gamma0.scalar_type(); + auto otype = itype; + auto ctype = torch::kFloat32; + auto mtype = torch::kUInt8; + + if (dropout_p > 0.f) { TORCH_CHECK(dmask0_.has_value()); } + + TORCH_CHECK(dz0.dtype() == otype); + TORCH_CHECK(dz0.dtype() == otype); + TORCH_CHECK(mu.dtype() == ctype); + TORCH_CHECK(rsigma.dtype() == ctype); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(dz0.is_cuda()); + TORCH_CHECK(mu.is_cuda()); + TORCH_CHECK(rsigma.is_cuda()); + TORCH_CHECK(gamma0.is_cuda()); + + TORCH_CHECK(x.is_contiguous()); + TORCH_CHECK(dz0.is_contiguous()); + + auto sizes = x.sizes(); + TORCH_CHECK(sizes.size() == 2); + auto rows = sizes[0]; + auto cols = sizes[1]; + TORCH_CHECK(dz0.dim() == 2); + TORCH_CHECK(dz0.size(1) == cols); + auto hidden_size = gamma0.numel(); + TORCH_CHECK(hidden_size == cols); + + if (dz1_.has_value()) { + auto dz1 = dz1_.value(); + TORCH_CHECK(dz1.dtype() == otype); + TORCH_CHECK(dz1.is_cuda()); + TORCH_CHECK(dz1.is_contiguous()); + TORCH_CHECK(dz1.sizes() == sizes); + + TORCH_CHECK(gamma1_.has_value()); + auto gamma1 = gamma1_.value(); + TORCH_CHECK(gamma1.dtype() == wtype); + TORCH_CHECK(gamma1.is_cuda()); + TORCH_CHECK(gamma1.is_contiguous()); + TORCH_CHECK(gamma1.sizes() == gamma0.sizes()); + } + + if (dx_.has_value()) { + auto dx = dx_.value(); + TORCH_CHECK(dx.dtype() == rtype); + TORCH_CHECK(dx.is_cuda()); + TORCH_CHECK(dx.is_contiguous()); + TORCH_CHECK(dx.sizes() == sizes); + } + + if (dmask0_.has_value()) { + auto dmask0 = dmask0_.value(); + TORCH_CHECK(dmask0.dtype() == mtype); + TORCH_CHECK(dmask0.is_cuda()); + TORCH_CHECK(dmask0.is_contiguous()); + TORCH_CHECK(dmask0.sizes() == sizes); + + if (has_x1) { + TORCH_CHECK(dmask1_.has_value()); + auto dmask1 = dmask1_.value(); + TORCH_CHECK(dmask1.dtype() == mtype); + TORCH_CHECK(dmask1.is_cuda()); + TORCH_CHECK(dmask1.is_contiguous()); + TORCH_CHECK(dmask1.sizes() == sizes); + } + } + + TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192)); + + TORCH_CHECK(mu.numel() == rows); + TORCH_CHECK(mu.sizes() == rsigma.sizes()); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)dz0.get_device()}; + + auto opts = x.options(); + + auto dx0 = torch::empty(sizes, opts.dtype(itype)); + at::Tensor dx1; + if (has_x1) { dx1 = torch::empty(sizes, opts.dtype(itype)); } + at::Tensor dresidual; + if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); } + auto dgamma0 = torch::empty_like(gamma0); + auto dbeta0 = torch::empty_like(gamma0); + at::Tensor dgamma1, dbeta1; + if (gamma1_.has_value()) { + dgamma1 = torch::empty_like(gamma0); + dbeta1 = torch::empty_like(gamma0); + } + + layer_norm::LaunchParams launch_params; + launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); + launch_params.props = at::cuda::getCurrentDeviceProperties(); + TORCH_CHECK(dropout_p < 1.f); + launch_params.params.dropout_keep_p = 1.f - dropout_p; + launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr; + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024); + auto launcher = get_parallel_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple)); + + launcher(launch_params, true); + + auto dgamma0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); + auto dbeta0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); + at::Tensor dgamma1_part, dbeta1_part; + if (gamma1_.has_value()) { + dgamma1_part = torch::zeros_like(dgamma0_part); + dbeta1_part = torch::zeros_like(dbeta0_part); + } + at::Tensor workspace, barrier; + + layer_norm::BwdParams ¶ms = launch_params.params; + params.rows = rows; + params.cols = cols; + params.x = x.data_ptr(); + params.dmask = dropout_p > 0.f ? dmask0_.value().data_ptr() : nullptr; + params.dmask1 = (dropout_p > 0.f && has_x1) ? dmask1_.value().data_ptr() : nullptr; + params.mu = mu.data_ptr(); + params.rs = rsigma.data_ptr(); + params.gamma = gamma0.data_ptr(); + params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr; + params.dz = dz0.data_ptr(); + params.dz1 = dz1_.has_value() ? dz1_.value().data_ptr() : nullptr; + params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr; + params.dx0 = dx0.data_ptr(); + params.dx1 = has_x1 ? dx1.data_ptr() : nullptr; + params.dbeta = dbeta0.data_ptr(); + params.dgamma = dgamma0.data_ptr(); + params.dbeta1 = gamma1_.has_value() ? dbeta1.data_ptr() : nullptr; + params.dgamma1 = gamma1_.has_value() ? dgamma1.data_ptr() : nullptr; + params.dbeta_part = dbeta0_part.data_ptr(); + params.dgamma_part = dgamma0_part.data_ptr(); + params.dbeta1_part = gamma1_.has_value() ? dbeta1_part.data_ptr() : nullptr; + params.dgamma1_part = gamma1_.has_value() ? dgamma1_part.data_ptr() : nullptr; + params.dropout_scale = 1.f / (1.f - dropout_p); + params.inverse_cols = 1.f / float(params.cols); + params.is_rms_norm = is_rms_norm; + + if( launch_params.barrier_size > 0 ) { + // TODO Any way to avoid this? + barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32)); + workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar)); + params.workspace = workspace.data_ptr(); + params.barrier = barrier.data_ptr(); + } + + launcher(launch_params, false); + + std::vector result = { dx0, dx1, dresidual, dgamma0, dbeta0, dgamma1, dbeta1, dgamma0_part, dbeta0_part, dgamma1_part, dbeta1_part }; + return result; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "CUDA DropoutAddLayerNorm"; + m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel", + py::arg("x0"), py::arg("residual"), py::arg("gamma"), py::arg("beta_"), + py::arg("rowscale_"), py::arg("colscale_"), py::arg("x0_subset_"), py::arg("z_subset_"), + py::arg("dropout_p"), py::arg("epsilon"), py::arg("rowscale_const"), py::arg("z_numrows"), + py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false); + m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel", + py::arg("dz"), py::arg("dx_"), py::arg("x"), py::arg("x0_"), py::arg("dmask_"), py::arg("mu"), + py::arg("rsigma"), py::arg("gamma"), py::arg("rowscale_"), py::arg("colscale_"), + py::arg("x0_subset_"), py::arg("z_subset_"), py::arg("dropout_p"), py::arg("rowscale_const"), + py::arg("x0_numrows"), py::arg("has_residual"), py::arg("is_rms_norm")=false); + m.def("dropout_add_ln_parallel_residual_fwd", &dropout_add_ln_parallel_residual_fwd, "Run Dropout + Add + LayerNorm parallel residual forward kernel", + py::arg("x0"), py::arg("x1_"), py::arg("residual"), py::arg("gamma0"), py::arg("beta0_"), + py::arg("gamma1_"), py::arg("beta1_"), py::arg("dropout_p"), py::arg("epsilon"), + py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false); + m.def("dropout_add_ln_parallel_residual_bwd", &dropout_add_ln_parallel_residual_bwd, "Run Dropout + Add + LayerNorm parallel residual backward kernel", + py::arg("dz0"), py::arg("dz1_"), py::arg("dx_"), py::arg("x"), py::arg("dmask0_"), + py::arg("dmask1_"), py::arg("mu"), py::arg("rsigma"), py::arg("gamma0"), py::arg("gamma1_"), + py::arg("dropout_p"), py::arg("has_x1"), py::arg("has_residual"), py::arg("is_rms_norm")=false); +} diff --git a/ln_bwd_1024.cu b/ln_bwd_1024.cu new file mode 100644 index 0000000000000000000000000000000000000000..f7101f6450fcdb8baa4ff4e79379d913048696b6 --- /dev/null +++ b/ln_bwd_1024.cu @@ -0,0 +1,15 @@ +#include "ln_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/ln_bwd_1280.cu b/ln_bwd_1280.cu new file mode 100644 index 0000000000000000000000000000000000000000..a80a5762a178bd1fd1cd2ef4d0fb2010c1eea22e --- /dev/null +++ b/ln_bwd_1280.cu @@ -0,0 +1,15 @@ +#include "ln_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/ln_bwd_1536.cu b/ln_bwd_1536.cu new file mode 100644 index 0000000000000000000000000000000000000000..0c25c088494d52f3b68251235d29c23a46ffc430 --- /dev/null +++ b/ln_bwd_1536.cu @@ -0,0 +1,15 @@ +#include "ln_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); diff --git a/ln_bwd_2048.cu b/ln_bwd_2048.cu new file mode 100644 index 0000000000000000000000000000000000000000..06c0e608a3e48ec7fad2081bc6ff82425ea1c56a --- /dev/null +++ b/ln_bwd_2048.cu @@ -0,0 +1,15 @@ +#include "ln_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); \ No newline at end of file diff --git a/ln_bwd_256.cu b/ln_bwd_256.cu new file mode 100644 index 0000000000000000000000000000000000000000..20945432b8e97be21d80ada73aa0b3e709733a5b --- /dev/null +++ b/ln_bwd_256.cu @@ -0,0 +1,15 @@ +#include "ln_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_BWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/ln_bwd_2560.cu b/ln_bwd_2560.cu new file mode 100644 index 0000000000000000000000000000000000000000..309184c37b93e1f90bc1020a47973dae84f0f0c8 --- /dev/null +++ b/ln_bwd_2560.cu @@ -0,0 +1,15 @@ +#include "ln_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); diff --git a/ln_bwd_3072.cu b/ln_bwd_3072.cu new file mode 100644 index 0000000000000000000000000000000000000000..e156b11cd92f450a6ce8e0c432487bd36d6f9847 --- /dev/null +++ b/ln_bwd_3072.cu @@ -0,0 +1,15 @@ +#include "ln_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); \ No newline at end of file diff --git a/ln_bwd_4096.cu b/ln_bwd_4096.cu new file mode 100644 index 0000000000000000000000000000000000000000..b715b0efe48c4111ae4301365018d19f537c7a81 --- /dev/null +++ b/ln_bwd_4096.cu @@ -0,0 +1,15 @@ +#include "ln_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); \ No newline at end of file diff --git a/ln_bwd_512.cu b/ln_bwd_512.cu new file mode 100644 index 0000000000000000000000000000000000000000..2b472118f0a0025917edc4c706492ca5dc8fa205 --- /dev/null +++ b/ln_bwd_512.cu @@ -0,0 +1,15 @@ +#include "ln_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_BWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/ln_bwd_5120.cu b/ln_bwd_5120.cu new file mode 100644 index 0000000000000000000000000000000000000000..38f3fbd406db8989f4a9806e64075bf52444c529 --- /dev/null +++ b/ln_bwd_5120.cu @@ -0,0 +1,15 @@ +#include "ln_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); \ No newline at end of file diff --git a/ln_bwd_6144.cu b/ln_bwd_6144.cu new file mode 100644 index 0000000000000000000000000000000000000000..469ed4b6c7691c581bbd1db5b8587de860afcb16 --- /dev/null +++ b/ln_bwd_6144.cu @@ -0,0 +1,15 @@ +#include "ln_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); \ No newline at end of file diff --git a/ln_bwd_7168.cu b/ln_bwd_7168.cu new file mode 100644 index 0000000000000000000000000000000000000000..549eab11aa3c770bea97bda727495f3e141ec24b --- /dev/null +++ b/ln_bwd_7168.cu @@ -0,0 +1,15 @@ +#include "ln_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_BWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); +REGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); +REGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); +REGISTER_BWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); +REGISTER_BWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); +REGISTER_BWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); +REGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); +REGISTER_BWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); \ No newline at end of file diff --git a/ln_bwd_768.cu b/ln_bwd_768.cu new file mode 100644 index 0000000000000000000000000000000000000000..5db64d3d7b184f6ffb01ae0e1a26e0acec3bbe3d --- /dev/null +++ b/ln_bwd_768.cu @@ -0,0 +1,15 @@ +#include "ln_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/ln_bwd_8192.cu b/ln_bwd_8192.cu new file mode 100644 index 0000000000000000000000000000000000000000..e6514e613fe9cbf444ad4919a5acf9579b216c9e --- /dev/null +++ b/ln_bwd_8192.cu @@ -0,0 +1,15 @@ +#include "ln_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); \ No newline at end of file diff --git a/ln_bwd_kernels.cuh b/ln_bwd_kernels.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c7261d218442acbcf60b61ce2e8803556193d8cd --- /dev/null +++ b/ln_bwd_kernels.cuh @@ -0,0 +1,534 @@ +#pragma once + +#include "ln.h" +#include "ln_utils.cuh" +#include "ln_kernel_traits.h" +#include "static_switch.h" + +namespace layer_norm { + +template +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) +void ln_bwd_kernel(layer_norm::BwdParams params) { + + enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { WARPS_N = Ktraits::WARPS_N }; + enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; + enum { COLS = Ktraits::COLS }; + enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; + enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; + enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; + + using input_t = typename Ktraits::input_t; + using compute_t = typename Ktraits::compute_t; + using index_t = typename Ktraits::index_t; + using mask_t = typename Ktraits::mask_t; + using Ivec = typename Ktraits::Ivec; + using Rvec = typename Ktraits::Rvec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + using Cvec = typename Ktraits::Cvec; + using Mvec = typename Ktraits::Mvec; + using Reducer = typename Ktraits::Reducer; + using reduce_t = typename Reducer::Type; + + extern __shared__ char smem_[]; + + const bool has_residual = params.dresidual != nullptr; + const bool prenorm = params.dx != nullptr; + + const index_t tidx = threadIdx.x; + const index_t bidn = blockIdx.x % CTAS_PER_ROW; + const index_t bidm = blockIdx.x / CTAS_PER_ROW; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / Ktraits::WARPS_N; + const index_t warp_n = warp % Ktraits::WARPS_N; + const index_t tid_r = warp_n * THREADS_PER_WARP + lane; + + const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; + const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; + + static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); + + const input_t *rowscale = static_cast(params.rowscale); + const index_t *x0_subset = static_cast(params.x0_subset); + const index_t *z_subset = static_cast(params.z_subset); + + Cvec dzy_sum[LDGS]; + Cvec dz_sum[LDGS]; + Cvec dcolscale_sum[LDGS]; + + memset(dzy_sum, 0, sizeof(dzy_sum)); + memset(dz_sum, 0, sizeof(dz_sum)); + if (Has_colscale) { memset(dcolscale_sum, 0, sizeof(dcolscale_sum)); } + + compute_t * smem_wgrad = reinterpret_cast(smem_); + char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; + + Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad); + + Sum sum; + + const index_t num_valid_ldgs = + ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG; + + Wvec gamma[LDGS]; + Wvec colscale[LDGS]; + index_t idx = c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + gamma[it].load_from(params.gamma, idx); + if (Has_colscale) { colscale[it].load_from(params.colscale, idx); } + idx += Ktraits::VEC_COLS_PER_LDG; + } + } + // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the + // last blocks with syncthreads! + // grid stride over rows + #pragma unroll 1 + for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { + const compute_t mu_r = static_cast(params.mu)[row]; + const compute_t rs_r = static_cast(params.rs)[row]; + const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const; + const int row_z = !Has_subset ? row + 1 : z_subset[row]; + const int row_x0 = !Has_subset ? row + 1 : x0_subset[row]; + const bool load_dz = !Has_subset || row_z > 0; + const bool save_dx0 = !Has_subset || row_x0 > 0; + Mvec dmask[LDGS]; + Rvec dx[LDGS]; + compute_t dy[LDGS * NUM_ELTS]; + compute_t y[LDGS * NUM_ELTS]; + compute_t mdy_local = 0.f; + compute_t mdyy_local = 0.f; + // If dz is not loaded, then dy should be 0 and we don't care about the value of y. + if (load_dz) { + index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c; + index_t idx_z = !Has_subset ? idx_x : (load_dz ? (row_z - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); + index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + Rvec x; + Ovec dz; + dz.load_from(params.dz, !Has_subset ? idx_x : idx_z); + if (prenorm) { dx[it].load_from(params.dx, idx_x); } + x.load_from(params.x, idx_x); + if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); } + idx_x += Ktraits::VEC_COLS_PER_LDG; + idx_z += Ktraits::VEC_COLS_PER_LDG; + idx_x0 += Ktraits::VEC_COLS_PER_LDG; + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + compute_t x_tmp = x.data.elt[jt]; + compute_t y_tmp = rs_r * (x_tmp - (!params.is_rms_norm ? mu_r : 0.f)); + compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) * compute_t(dz.data.elt[jt]); + compute_t dz_tmp = dz.data.elt[jt]; + + mdy_local += dy_tmp; + mdyy_local += dy_tmp * y_tmp; + + dy[it * NUM_ELTS + jt] = dy_tmp; + y[it * NUM_ELTS + jt] = y_tmp; + + dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp; + dz_sum[it].data.elt[jt] += dz_tmp; + } + } + } + } else { + index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c; + index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + if (prenorm) { dx[it].load_from(params.dx, idx_x); } + if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); } + idx_x += Ktraits::VEC_COLS_PER_LDG; + idx_x0 += Ktraits::VEC_COLS_PER_LDG; + } + } + } + + reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum); + mdy_local = layer_norm::Get<0>::of(result) * params.inverse_cols; + mdyy_local = layer_norm::Get<1>::of(result) * params.inverse_cols; + + index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c; + index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + Ivec dx0; + Rvec dresidual; + Ivec x0; + if (Has_colscale && save_dx0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); } + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + compute_t dx_tmp_res; + if (load_dz) { + compute_t dy_tmp = dy[it * NUM_ELTS + jt]; + compute_t y_tmp = y[it * NUM_ELTS + jt]; + compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + (!params.is_rms_norm ? mdy_local : 0.f))); + dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp; + } else { + dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f; + } + if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; } + if (save_dx0) { + compute_t dx0_tmp_res = dx_tmp_res * rowscale_val; + if (Is_dropout) { + dx0_tmp_res *= params.dropout_scale; + if (Has_colscale) { + dcolscale_sum[it].data.elt[jt] += dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(x0.data.elt[jt]) : 0.f; + dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(colscale[it].data.elt[jt]) : 0.f; + } else { + dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res : 0.f; + } + } else { + if (Has_colscale) { + dcolscale_sum[it].data.elt[jt] += dx0_tmp_res * compute_t(x0.data.elt[jt]); + dx0.data.elt[jt] = dx0_tmp_res * compute_t(colscale[it].data.elt[jt]); + } else { + dx0.data.elt[jt] = dx0_tmp_res; + } + } + } + } + if (has_residual) { dresidual.store_to(params.dresidual, idx_x); } + if (save_dx0) { dx0.store_to(params.dx0, !Has_subset ? idx_x : idx_x0); } + idx_x += Ktraits::VEC_COLS_PER_LDG; + idx_x0 += Ktraits::VEC_COLS_PER_LDG; + } + } + + } // end: grid stride loop + + if( WARPS_M == 1 ) { + idx = r * params.cols / Ktraits::ELTS_PER_LDG + c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + dz_sum[it].store_to(params.dbeta_part, idx); + dzy_sum[it].store_to(params.dgamma_part, idx); + if (Has_colscale) { dcolscale_sum[it].store_to(params.dcolscale_part, idx); } + idx += Ktraits::VEC_COLS_PER_LDG; + } + } + } else { + static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA."); + // Finalize reduction of part dgamma and dbeta for this CTA + // by reducing over the rows held across the WARPS_M warps + + // Assumption: blockSize divides hidden size. + enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA }; + static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, ""); + + idx = warp_m * Ktraits::VEC_COLS + tid_r; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + dz_sum[it].store_to(smem_wgrad, idx); + idx += THREADS_PER_ROW; + } + __syncthreads(); + compute_t cta_dz_sum[NUM_RES]; + memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES); + for( int it = 0; it < ROWS_PER_CTA; it++ ) { + for( int jt = 0; jt < NUM_RES; jt++ ) { + cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; + } + } + __syncthreads(); + + idx = warp_m * Ktraits::VEC_COLS + tid_r; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + dzy_sum[it].store_to(smem_wgrad, idx); + idx += THREADS_PER_ROW; + } + __syncthreads(); + compute_t cta_dzy_sum[NUM_RES]; + memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES); + for( int it = 0; it < ROWS_PER_CTA; it++ ) { + for( int jt = 0; jt < NUM_RES; jt++ ) { + cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; + } + } + + compute_t cta_dcolscale_sum[NUM_RES]; + if (Has_colscale) { + __syncthreads(); + idx = warp_m * Ktraits::VEC_COLS + tid_r; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + dcolscale_sum[it].store_to(smem_wgrad, idx); + idx += THREADS_PER_ROW; + } + __syncthreads(); + memset(cta_dcolscale_sum, 0, sizeof(compute_t) * NUM_RES); + for( int it = 0; it < ROWS_PER_CTA; it++ ) { + for( int jt = 0; jt < NUM_RES; jt++ ) { + cta_dcolscale_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; + } + } + } + + const index_t num_valid_writes + = (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA; + compute_t *dgamma_part = static_cast(params.dgamma_part) + bidm * params.cols + tidx; + compute_t *dbeta_part = static_cast(params.dbeta_part) + bidm * params.cols + tidx; + compute_t *dcolscale_part = Has_colscale ? static_cast(params.dcolscale_part) + bidm * params.cols + tidx : nullptr; + for( int jt = 0; jt < NUM_RES; jt++ ) { + if (Is_even_cols || (jt < num_valid_writes)) { + *dgamma_part = cta_dzy_sum[jt]; + dgamma_part += Ktraits::THREADS_PER_CTA; + *dbeta_part = cta_dz_sum[jt]; + dbeta_part += Ktraits::THREADS_PER_CTA; + if (Has_colscale) { + *dcolscale_part = cta_dcolscale_sum[jt]; + dcolscale_part += Ktraits::THREADS_PER_CTA; + } + } + } + + } +} + +template +__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) +void ln_bwd_finalize_kernel(BwdParams params) +{ + + using compute_t = typename Kernel_traits::compute_t; + using weight_t = typename Kernel_traits::weight_t; + using index_t = typename Kernel_traits::index_t; + using Reducer = typename Kernel_traits::Reducer; + using reduce_t = typename Reducer::Type; + + Sum sum; + enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG }; + enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP }; + + __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA]; + + constexpr uint32_t bidm = 0; + + const uint32_t bidn = blockIdx.x; + const uint32_t tidx = threadIdx.x; + const uint32_t warp = tidx / THREADS_PER_WARP; + const uint32_t lane = tidx % THREADS_PER_WARP; + + Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_); + + const uint32_t c = bidn * THREADS_PER_WARP + lane; + const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; + constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; + for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) { + // Each thread sums over NUM_ELT columns. + Vec dbeta_local, dgamma_local, dcolscale_local; + memset(&dgamma_local, 0, sizeof(dgamma_local)); + memset(&dbeta_local, 0, sizeof(dbeta_local)); + if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); } + if (Is_even_cols || col < params.cols) { + for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) { + index_t idx = row * params.cols + col; + + Vec dbeta_part, dgamma_part, dcolscale_part; + dbeta_part.load_from(params.dbeta_part, idx); + dgamma_part.load_from(params.dgamma_part, idx); + if (Has_colscale) { dcolscale_part.load_from(params.dcolscale_part, idx); } + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + dgamma_local.data.elt[it] += dgamma_part.data.elt[it]; + dbeta_local.data.elt[it] += dbeta_part.data.elt[it]; + if (Has_colscale) { dcolscale_local.data.elt[it] += dcolscale_part.data.elt[it]; } + } + } + } + void * smem_gamma = smem_; + void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; + void * smem_colscale = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; + + const int write_row = warp; + const int write_col = lane ^ write_row; + const int write_idx = write_row * THREADS_PER_WARP + write_col; + + dgamma_local.store_to(smem_gamma, write_idx); + dbeta_local.store_to(smem_beta, write_idx); + if (Has_colscale) { dcolscale_local.store_to(smem_colscale, write_idx); } + + __syncthreads(); + + // It would be probably safe to reuse the first row of smem_beta and smem_gamma + void * smem_gamma_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE]; + void * smem_beta_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT]; + void * smem_colscale_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT]; + + + // More than one iter iff ROWS_PER_CTA < 32. + for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) { + const int read_row = lane; + const int read_col = w ^ read_row; + const int read_idx = read_row * THREADS_PER_WARP + read_col; + + memset(&dbeta_local, 0, sizeof(dbeta_local)); + memset(&dgamma_local, 0, sizeof(dgamma_local)); + if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); } + + // Load beta and gamma transposed + if(read_row < Kernel_traits::ROWS_PER_CTA){ + dbeta_local.load_from(smem_beta, read_idx); + dgamma_local.load_from(smem_gamma, read_idx); + if (Has_colscale) { dcolscale_local.load_from(smem_colscale, read_idx); } + } + + // Call reducer on the loaded value(s) and convert. + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + compute_t b_i = dbeta_local.data.elt[it]; + compute_t g_i = dgamma_local.data.elt[it]; + b_i = reducer.allreduce(b_i, sum); + g_i = reducer.allreduce(g_i, sum); + + dgamma_local.data.elt[it] = g_i; + dbeta_local.data.elt[it] = b_i; + if (Has_colscale) { + compute_t cs_i = dcolscale_local.data.elt[it]; + cs_i = reducer.allreduce(cs_i, sum); + dcolscale_local.data.elt[it] = cs_i; + } + } + + // Leader stores the result at the current column. + if(lane == 0){ + dgamma_local.store_to(smem_gamma_out, w); + dbeta_local.store_to(smem_beta_out, w); + if (Has_colscale) { dcolscale_local.store_to(smem_colscale_out, w); } + } + + } + + // All writes done. + __syncthreads(); + + // Pack and store: 2-wide stores with half the threads. + if (Is_even_cols || col_out * 2 < params.cols) { + if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) { + + using src_t = typename TypeToVec2::Type; + using dst_t = typename TypeToVec2::Type; + Vec dbeta_vec2, dgamma_vec2, dcolscale_vec2; + Vec dbeta_out2, dgamma_out2, dcolscale_out2; + + dgamma_vec2.load_from(smem_gamma_out, lane); + dbeta_vec2.load_from(smem_beta_out, lane); + if (Has_colscale) { dcolscale_vec2.load_from(smem_colscale_out, lane); } + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + dgamma_out2.data.elt[it] = Converter::convert(dgamma_vec2.data.elt[it]); + dbeta_out2.data.elt[it] = Converter::convert(dbeta_vec2.data.elt[it]); + if (Has_colscale) { dcolscale_out2.data.elt[it] = Converter::convert(dcolscale_vec2.data.elt[it]); } + } + dgamma_out2.store_to(params.dgamma, col_out); + dbeta_out2.store_to(params.dbeta, col_out); + if (Has_colscale) { dcolscale_out2.store_to(params.dcolscale, col_out); } + } + } + } +} +} // namespace layer_norm + +using namespace layer_norm; + +template< + typename weight_t, + typename input_t, + typename residual_t, + typename output_t, + typename compute_t, + typename index_t, + int HIDDEN_SIZE, + int CTAS_PER_ROW, + int WARPS_M, + int WARPS_N, + int BYTES_PER_LDG_MAIN, + int BYTES_PER_LDG_FINAL +> +void launch_(LaunchParams &launch_params, const bool configure_params){ + + using Kernel_traits = Kernel_traits; + bool is_dropout = launch_params.params.dropout_keep_p < 1.f; + bool has_colscale = launch_params.params.colscale != nullptr; + bool has_subset = launch_params.params.x0_subset != nullptr; + bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; + BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { + BOOL_SWITCH(has_colscale, HasColscaleConst, [&] { + BOOL_SWITCH(has_subset, HasSubsetConst, [&] { + BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { + auto kernel = &ln_bwd_kernel; + if( configure_params ) { + int ctas_per_sm; + CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES)); + launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if(Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; + launch_params.workspace_bytes = launch_params.params.ctas_per_col + * Kernel_traits::WARPS_M + * Kernel_traits::CTAS_PER_ROW + * sizeof(typename Kernel_traits::reduce_t) + * 2; + } + return; + } + + if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + + if( Kernel_traits::CTAS_PER_ROW == 1 ) { + kernel<<>>(launch_params.params); + } else { + dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = (void *)&launch_params.params; + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); + } + + using Kernel_traits_f = layer_norm::Kernel_traits_finalize; + + auto kernel_f = &layer_norm::ln_bwd_finalize_kernel; + kernel_f<<>>(launch_params.params); + }); + }); + }); + }); +} diff --git a/ln_fwd_1024.cu b/ln_fwd_1024.cu new file mode 100644 index 0000000000000000000000000000000000000000..824d86e9fd05920d3e557b42356feec86c904f68 --- /dev/null +++ b/ln_fwd_1024.cu @@ -0,0 +1,15 @@ +#include "ln_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/ln_fwd_1280.cu b/ln_fwd_1280.cu new file mode 100644 index 0000000000000000000000000000000000000000..1ff58cbc2889a2c06c51df560d2b35ca4e079201 --- /dev/null +++ b/ln_fwd_1280.cu @@ -0,0 +1,15 @@ +#include "ln_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/ln_fwd_1536.cu b/ln_fwd_1536.cu new file mode 100644 index 0000000000000000000000000000000000000000..a8e19d4dba97d91cd246e62ba80a2936ac05755c --- /dev/null +++ b/ln_fwd_1536.cu @@ -0,0 +1,15 @@ +#include "ln_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/ln_fwd_2048.cu b/ln_fwd_2048.cu new file mode 100644 index 0000000000000000000000000000000000000000..6f9794c1e77f91a333d64cc6e461560622b87e12 --- /dev/null +++ b/ln_fwd_2048.cu @@ -0,0 +1,15 @@ +#include "ln_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/ln_fwd_256.cu b/ln_fwd_256.cu new file mode 100644 index 0000000000000000000000000000000000000000..f3a541c6dbf20cd94bb56607bbb23e6a81059bdc --- /dev/null +++ b/ln_fwd_256.cu @@ -0,0 +1,15 @@ +#include "ln_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_FWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/ln_fwd_2560.cu b/ln_fwd_2560.cu new file mode 100644 index 0000000000000000000000000000000000000000..1650671e059ec358f8109c1d592694458e77d489 --- /dev/null +++ b/ln_fwd_2560.cu @@ -0,0 +1,15 @@ +#include "ln_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/ln_fwd_3072.cu b/ln_fwd_3072.cu new file mode 100644 index 0000000000000000000000000000000000000000..25bb8691dc9f6a95297301efbd91567a5c22d1c2 --- /dev/null +++ b/ln_fwd_3072.cu @@ -0,0 +1,15 @@ +#include "ln_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); diff --git a/ln_fwd_4096.cu b/ln_fwd_4096.cu new file mode 100644 index 0000000000000000000000000000000000000000..b2bffb5831bf1b6eb18cd1e2cd2c4636a06f5736 --- /dev/null +++ b/ln_fwd_4096.cu @@ -0,0 +1,15 @@ +#include "ln_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); diff --git a/ln_fwd_512.cu b/ln_fwd_512.cu new file mode 100644 index 0000000000000000000000000000000000000000..a08fe34c55d61eecdbc74caa41dfbec10b3a8126 --- /dev/null +++ b/ln_fwd_512.cu @@ -0,0 +1,15 @@ +#include "ln_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_FWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/ln_fwd_5120.cu b/ln_fwd_5120.cu new file mode 100644 index 0000000000000000000000000000000000000000..bebbd69f05b38a5e3c0dae5d248de467118ef8c5 --- /dev/null +++ b/ln_fwd_5120.cu @@ -0,0 +1,15 @@ +#include "ln_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); diff --git a/ln_fwd_6144.cu b/ln_fwd_6144.cu new file mode 100644 index 0000000000000000000000000000000000000000..4df01ead2f292e255221e6fb0b48e63941a22cab --- /dev/null +++ b/ln_fwd_6144.cu @@ -0,0 +1,15 @@ +#include "ln_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16); diff --git a/ln_fwd_7168.cu b/ln_fwd_7168.cu new file mode 100644 index 0000000000000000000000000000000000000000..8343666d10c2788cb2c19ba4f448eef2ccf2b956 --- /dev/null +++ b/ln_fwd_7168.cu @@ -0,0 +1,15 @@ +#include "ln_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_FWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); diff --git a/ln_fwd_768.cu b/ln_fwd_768.cu new file mode 100644 index 0000000000000000000000000000000000000000..06d5a3b09cdd4941764885f5107bbbfa6b264eef --- /dev/null +++ b/ln_fwd_768.cu @@ -0,0 +1,15 @@ +#include "ln_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/ln_fwd_8192.cu b/ln_fwd_8192.cu new file mode 100644 index 0000000000000000000000000000000000000000..bf7cb40252baf820c88dff1337c81dffd934087a --- /dev/null +++ b/ln_fwd_8192.cu @@ -0,0 +1,15 @@ +#include "ln_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16); +REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16); diff --git a/ln_fwd_kernels.cuh b/ln_fwd_kernels.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f6bccb8c28a2b3d967dddc3d8b21e1888ed2e29c --- /dev/null +++ b/ln_fwd_kernels.cuh @@ -0,0 +1,272 @@ +#pragma once + +#ifdef OLD_GENERATOR_PATH +#include +#else +#include +#endif + +#include // For at::cuda::philox::unpack +#include + +#include "ln.h" +#include "ln_utils.cuh" +#include "ln_kernel_traits.h" +#include "static_switch.h" + +namespace layer_norm { + +template +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) +void ln_fwd_kernel(FwdParams params) { + + enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; + enum { WARPS_N = Ktraits::WARPS_N }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; + enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; + enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::NUM_ELTS }; + enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; + + using input_t = typename Ktraits::input_t; + using residual_t = typename Ktraits::residual_t; + using output_t = typename Ktraits::output_t; + using index_t = typename Ktraits::index_t; + using compute_t = typename Ktraits::compute_t; + using mask_t = typename Ktraits::mask_t; + using Ivec = typename Ktraits::Ivec; + using Rvec = typename Ktraits::Rvec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + using Cvec = typename Ktraits::Cvec; + using Mvec = typename Ktraits::Mvec; + + using Stats = typename Ktraits::Stats; + using stats_t = typename Stats::stats_t; + + const bool has_residual = params.residual != nullptr; + const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same::value); + + extern __shared__ char smem_[]; + + const index_t tidx = threadIdx.x; + const index_t bidn = blockIdx.x % CTAS_PER_ROW; + const index_t bidm = blockIdx.x / CTAS_PER_ROW; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / WARPS_N; + const index_t warp_n = warp % WARPS_N; + + const index_t r = bidm * ROWS_PER_CTA + warp_m; + const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; + + Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); + + compute_t *mu_ptr = static_cast(params.mu); + compute_t *rs_ptr = static_cast(params.rs); + + const input_t *rowscale = static_cast(params.rowscale); + const index_t *x0_subset = static_cast(params.x0_subset); + const index_t *z_subset = static_cast(params.z_subset); + + // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu + curandStatePhilox4_32_10_t state; + if (Is_dropout) { + auto seeds = at::cuda::philox::unpack(params.philox_args); + const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x; + curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state); + } + + const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG; + + Wvec gamma[LDGS]; + Wvec beta[LDGS]; + Wvec colscale[LDGS]; + index_t idx = c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + gamma[it].load_from(params.gamma, idx); + if (params.beta != nullptr) { + beta[it].load_from(params.beta, idx); + } else { + beta[it].zero_(); + } + if (Has_colscale) { colscale[it].load_from(params.colscale, idx); } + idx += VEC_COLS_PER_LDG; + } + } + + for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { + const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const; + const int row_x0 = !Has_subset ? row + 1 : x0_subset[row]; + const int row_z = !Has_subset ? row + 1 : z_subset[row]; + const bool load_x0 = !Has_subset || row_x0 > 0; + index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c; + index_t idx_x0 = !Has_subset ? idx_x : (load_x0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); + compute_t xf[LDGS * NUM_ELTS]; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + Ivec x0; + Rvec residual; + Rvec x; + Mvec dmask; + if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); } + if (has_residual) { residual.load_from(params.residual, idx_x); } + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use + // the more efficient curand_uniform4. + compute_t x_ij; + if (load_x0) { + mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p; + if (Is_dropout) { dmask.data.elt[jt] = keep; } + compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val; + x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f; + if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); } + x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij; + } else { + x_ij = has_residual ? compute_t(residual.data.elt[jt]) : 0.f; + } + if (save_x) { x.data.elt[jt] = x_ij; } + xf[it * NUM_ELTS + jt] = x_ij; + } + if (save_x) { x.store_to(params.x, idx_x); } + if (Is_dropout && load_x0) { dmask.store_to(params.dmask, !Has_subset ? idx_x : idx_x0); } + idx_x += VEC_COLS_PER_LDG; + idx_x0 += VEC_COLS_PER_LDG; + } + } + + static_assert(CTAS_PER_ROW == 1, "Don't support multiple CTAs per row for now"); + const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG; + const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG; + const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG; + auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int { + // Need to convert to int, otherwise the subtraction will wrap around. + const index_t valid_partial_vecs_in_warp = + std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)), + int(THREADS_PER_WARP)); + return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS; + }; + stats_t s = stats.template compute( + xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS + ); + + compute_t mu = layer_norm::Get<0>::of(s); + compute_t m2 = layer_norm::Get<1>::of(s); + + if( bidn == 0 && warp_n == 0 && lane == 0 ) { + mu_ptr[row] = mu; + } + + compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu)); + + if( bidn == 0 && warp_n == 0 && lane == 0 ) { + rs_ptr[row] = rs; + } + + const bool save_z = !Has_subset || row_z > 0; + if (save_z) { + index_t idx_z = (!Has_subset ? row : (row_z - 1)) * params.cols / Ktraits::ELTS_PER_LDG + c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + Ovec z; + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f))); + compute_t g_ij = gamma[it].data.elt[jt]; + compute_t b_ij = beta[it].data.elt[jt]; + z.data.elt[jt] = output_t(g_ij * y_ij + b_ij); + } + z.store_to(params.z, idx_z); + idx_z += VEC_COLS_PER_LDG; + } + } + } + + } +} + +} // namespace layer_norm + +using namespace layer_norm; + +template< + typename weight_t, + typename input_t, + typename residual_t, + typename output_t, + typename compute_t, + typename index_t, + int HIDDEN_SIZE, + int CTAS_PER_ROW, + int WARPS_M, + int WARPS_N, + int BYTES_PER_LDG +> +void launch_(LaunchParams &launch_params, const bool configure_params){ + + using Kernel_traits = Kernel_traits; + bool has_colscale = launch_params.params.colscale != nullptr; + bool has_subset = launch_params.params.x0_subset != nullptr; + bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; + BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] { + BOOL_SWITCH(has_colscale, HasColscaleConst, [&] { + BOOL_SWITCH(has_subset, HasSubsetConst, [&] { + BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { + auto kernel = &ln_fwd_kernel; + if( configure_params ) { + int ctas_per_sm; + CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD)); + launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; + const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA; + launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS; + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if(Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; + launch_params.workspace_bytes = launch_params.params.ctas_per_col + * Kernel_traits::WARPS_M + * Kernel_traits::CTAS_PER_ROW + * sizeof(typename Kernel_traits::Stats::stats_t) + * 2; + } + return; + } + + if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + + if( Kernel_traits::CTAS_PER_ROW == 1 ) { + kernel<<>>(launch_params.params); + } else { + dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = (void *)&launch_params.params; + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); + } + }); + }); + }); + }); +} diff --git a/ln_kernel_traits.h b/ln_kernel_traits.h new file mode 100644 index 0000000000000000000000000000000000000000..77de6bf9af60c9ae70427097db26cf4ed130b359 --- /dev/null +++ b/ln_kernel_traits.h @@ -0,0 +1,172 @@ +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace layer_norm { +template< + uint32_t HIDDEN_SIZE_, + typename weight_t_, + typename input_t_, + typename residual_t_, + typename output_t_, + typename compute_t_, + typename index_t_, + uint32_t THREADS_PER_CTA_ +> +struct Kernel_traits_base { + + using weight_t = weight_t_; + using input_t = input_t_; + using residual_t = residual_t_; + using output_t = output_t_; + using compute_t = compute_t_; + using index_t = index_t_; + + enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; + enum { THREADS_PER_CTA = THREADS_PER_CTA_ }; + enum { THREADS_PER_WARP = 32 }; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + uint32_t HIDDEN_SIZE_, + typename weight_t_, + typename input_t_, + typename residual_t_, + typename output_t_, + typename compute_t_, + typename index_t_, + bool Has_colscale, + uint32_t THREADS_PER_CTA_, + uint32_t BYTES_PER_LDG_, + typename Base = Kernel_traits_base +> +struct Kernel_traits_finalize : public Base { + enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP }; + static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP); + // Bytes per global load from the input. + enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; + // Number of elements fetched by a global load. + enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) }; + // Bytes per global store of the weights. + enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) }; + static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!"); + static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!"); + // The total number of BYTES_PER_LDG-wide words in a hidden vector. + enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG }; + static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_)); + + // Shared memory size to transpose the CTA result. + enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG }; + // Shared memory size to coalsece the CTA result. + enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG }; + // Shared memory requirement per CTA. + static constexpr int NUM_FACTORS = Has_colscale ? 3 : 2; + enum { SMEM_BYTES_PER_CTA = NUM_FACTORS * SMEM_BYTES_TRANSPOSE + NUM_FACTORS * SMEM_BYTES_OUTPUT }; + + // The type of the reducer. + using Reducer = layer_norm::Reducer; + + // Condition for the whole CTA to participate in syncthreads. + static_assert(COLS % Base::THREADS_PER_WARP == 0); + enum { CTAS = COLS / Base::THREADS_PER_WARP }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +template< + typename weight_t_, + typename input_t_, + typename residual_t_, + typename output_t_, + typename compute_t_, + typename index_t_, + uint32_t HIDDEN_SIZE_, + uint32_t CTAS_PER_ROW_, + uint32_t WARPS_M_, + uint32_t WARPS_N_, + uint32_t BYTES_PER_LDG_ = 16, + typename Base = Kernel_traits_base< + HIDDEN_SIZE_, + weight_t_, + input_t_, + residual_t_, + output_t_, + compute_t_, + index_t_, + WARPS_M_*WARPS_N_*THREADS_PER_WARP + > +> +struct Kernel_traits : public Base { + + using input_t = typename Base::input_t; + using residual_t = typename Base::residual_t; + using weight_t = typename Base::weight_t; + using compute_t = typename Base::compute_t; + using output_t = typename Base::output_t; + using index_t = typename Base::index_t; + // using mask_t = unsigned char; + using mask_t = bool; + + enum { CTAS_PER_ROW = CTAS_PER_ROW_ }; + enum { WARPS_M = WARPS_M_ }; + enum { WARPS_N = WARPS_N_ }; + enum { COLS = HIDDEN_SIZE_ }; + enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; + enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; + enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) }; + + enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP }; + enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW }; + enum { ROWS_PER_CTA = WARPS_M }; + + enum { BYTES_PER_ROW = COLS * sizeof(input_t) }; + enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG }; + // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed + enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) }; + static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1); + + using reduce_t = typename layer_norm::TypeToVec2::Type; + using Reducer = layer_norm::Reducer; + + enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES }; + enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD }; + + using Ivec = layer_norm::Vec; + using Rvec = layer_norm::Vec; + using Ovec = layer_norm::Vec; + using Wvec = layer_norm::Vec; + using Cvec = layer_norm::Vec; + using Mvec = layer_norm::Vec; + enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) }; + + // Assume that each thread can handle the same number of elements in the output and weights as in the input. + static_assert(sizeof(input_t) == sizeof(output_t)); + static_assert(sizeof(input_t) <= sizeof(residual_t)); + // The number of columns fetched per load from input: one per thread. + enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW }; + // The total number of vectorized loads/stores per hidden vector. + enum { VEC_COLS = COLS / ELTS_PER_LDG }; + // The number of loads per thread for the input. + enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG }; + static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS); + //static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, ""); + + using Stats = layer_norm::Stats; + enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES }; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace layer_norm diff --git a/ln_parallel_bwd_1024.cu b/ln_parallel_bwd_1024.cu new file mode 100644 index 0000000000000000000000000000000000000000..6f4e77466c6c6d5a00275d54f4e68da062a5fc1a --- /dev/null +++ b/ln_parallel_bwd_1024.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/ln_parallel_bwd_1280.cu b/ln_parallel_bwd_1280.cu new file mode 100644 index 0000000000000000000000000000000000000000..2dba3bebf26e99b853e7ef4b9b56421cf483e0bd --- /dev/null +++ b/ln_parallel_bwd_1280.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/ln_parallel_bwd_1536.cu b/ln_parallel_bwd_1536.cu new file mode 100644 index 0000000000000000000000000000000000000000..c2ac4b1b0998ca412dea02466f0d8fbe69f48216 --- /dev/null +++ b/ln_parallel_bwd_1536.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); diff --git a/ln_parallel_bwd_2048.cu b/ln_parallel_bwd_2048.cu new file mode 100644 index 0000000000000000000000000000000000000000..f7f959e2fa785a4df3b6a32506f527e1723d83cc --- /dev/null +++ b/ln_parallel_bwd_2048.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); \ No newline at end of file diff --git a/ln_parallel_bwd_256.cu b/ln_parallel_bwd_256.cu new file mode 100644 index 0000000000000000000000000000000000000000..fa613cf45e1045d046cefc4afd55ded754bc20a4 --- /dev/null +++ b/ln_parallel_bwd_256.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/ln_parallel_bwd_2560.cu b/ln_parallel_bwd_2560.cu new file mode 100644 index 0000000000000000000000000000000000000000..5f5707612df09149885d7883728672dc3a2b751f --- /dev/null +++ b/ln_parallel_bwd_2560.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); diff --git a/ln_parallel_bwd_3072.cu b/ln_parallel_bwd_3072.cu new file mode 100644 index 0000000000000000000000000000000000000000..8fdcb8ffb4d0f0e0fcae6aee930808bd0349ede5 --- /dev/null +++ b/ln_parallel_bwd_3072.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); \ No newline at end of file diff --git a/ln_parallel_bwd_4096.cu b/ln_parallel_bwd_4096.cu new file mode 100644 index 0000000000000000000000000000000000000000..8decfb085ac8ace1e3694a491bb66a83209027b8 --- /dev/null +++ b/ln_parallel_bwd_4096.cu @@ -0,0 +1,17 @@ +#include "ln_parallel_residual_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +// Use 8 warps otherwise there's a lot of register spilling + +REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); \ No newline at end of file diff --git a/ln_parallel_bwd_512.cu b/ln_parallel_bwd_512.cu new file mode 100644 index 0000000000000000000000000000000000000000..178453d3045bfefd95018320d357ea8662018782 --- /dev/null +++ b/ln_parallel_bwd_512.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/ln_parallel_bwd_5120.cu b/ln_parallel_bwd_5120.cu new file mode 100644 index 0000000000000000000000000000000000000000..815521973da7266534c7e8b167fa0b8baa47fa2c --- /dev/null +++ b/ln_parallel_bwd_5120.cu @@ -0,0 +1,17 @@ +#include "ln_parallel_residual_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +// Use 8 warps otherwise there's a lot of register spilling + +REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); \ No newline at end of file diff --git a/ln_parallel_bwd_6144.cu b/ln_parallel_bwd_6144.cu new file mode 100644 index 0000000000000000000000000000000000000000..eb8668d8a229d2ec24e5eac57db00f9d650615eb --- /dev/null +++ b/ln_parallel_bwd_6144.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); \ No newline at end of file diff --git a/ln_parallel_bwd_7168.cu b/ln_parallel_bwd_7168.cu new file mode 100644 index 0000000000000000000000000000000000000000..0c12dc476678ce7b24c5fcd0b9408eb686bd6825 --- /dev/null +++ b/ln_parallel_bwd_7168.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); \ No newline at end of file diff --git a/ln_parallel_bwd_768.cu b/ln_parallel_bwd_768.cu new file mode 100644 index 0000000000000000000000000000000000000000..8beece8ab19cea2baefedd118f5d15c90a646526 --- /dev/null +++ b/ln_parallel_bwd_768.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/ln_parallel_bwd_8192.cu b/ln_parallel_bwd_8192.cu new file mode 100644 index 0000000000000000000000000000000000000000..5ad47c94fdff599dde62574d1c535c4bbacae551 --- /dev/null +++ b/ln_parallel_bwd_8192.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_bwd_kernels.cuh" + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_PARALLEL_BWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); \ No newline at end of file diff --git a/ln_parallel_fwd_1024.cu b/ln_parallel_fwd_1024.cu new file mode 100644 index 0000000000000000000000000000000000000000..3c64e169302eea0f94ff65641728c35689d7c4ba --- /dev/null +++ b/ln_parallel_fwd_1024.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/ln_parallel_fwd_1280.cu b/ln_parallel_fwd_1280.cu new file mode 100644 index 0000000000000000000000000000000000000000..9bbfce5bc6c5e0303d70552bb36cf380601dcd38 --- /dev/null +++ b/ln_parallel_fwd_1280.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/ln_parallel_fwd_1536.cu b/ln_parallel_fwd_1536.cu new file mode 100644 index 0000000000000000000000000000000000000000..b57f5edce8eb7b6779475f6eadb8aabba299c802 --- /dev/null +++ b/ln_parallel_fwd_1536.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/ln_parallel_fwd_2048.cu b/ln_parallel_fwd_2048.cu new file mode 100644 index 0000000000000000000000000000000000000000..6fa322d96b4e11aacf5722985672e141f929299b --- /dev/null +++ b/ln_parallel_fwd_2048.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/ln_parallel_fwd_256.cu b/ln_parallel_fwd_256.cu new file mode 100644 index 0000000000000000000000000000000000000000..27445a6bc50c98935c7a5093ee5ffdddf52e2494 --- /dev/null +++ b/ln_parallel_fwd_256.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); \ No newline at end of file diff --git a/ln_parallel_fwd_2560.cu b/ln_parallel_fwd_2560.cu new file mode 100644 index 0000000000000000000000000000000000000000..fdde470c267302adca3d63f2c6b736b67af7ee86 --- /dev/null +++ b/ln_parallel_fwd_2560.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/ln_parallel_fwd_3072.cu b/ln_parallel_fwd_3072.cu new file mode 100644 index 0000000000000000000000000000000000000000..992f71037607066fb4e4d0f1624669f21c2f53b1 --- /dev/null +++ b/ln_parallel_fwd_3072.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); diff --git a/ln_parallel_fwd_4096.cu b/ln_parallel_fwd_4096.cu new file mode 100644 index 0000000000000000000000000000000000000000..381837e60874e44aa5e0efccb8749b2ff41ac3fa --- /dev/null +++ b/ln_parallel_fwd_4096.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); diff --git a/ln_parallel_fwd_512.cu b/ln_parallel_fwd_512.cu new file mode 100644 index 0000000000000000000000000000000000000000..4ba478b01fbdbc2ff5aab0a15fb698eba369f61a --- /dev/null +++ b/ln_parallel_fwd_512.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/ln_parallel_fwd_5120.cu b/ln_parallel_fwd_5120.cu new file mode 100644 index 0000000000000000000000000000000000000000..7ada35228cb603ddd26b06e186989746a86926a8 --- /dev/null +++ b/ln_parallel_fwd_5120.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); diff --git a/ln_parallel_fwd_6144.cu b/ln_parallel_fwd_6144.cu new file mode 100644 index 0000000000000000000000000000000000000000..6f531c881f7f53651c56e3afd1f0f53c580815ec --- /dev/null +++ b/ln_parallel_fwd_6144.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16); diff --git a/ln_parallel_fwd_7168.cu b/ln_parallel_fwd_7168.cu new file mode 100644 index 0000000000000000000000000000000000000000..c99e752cd484a99e97f8bf7a92e433a817c54d64 --- /dev/null +++ b/ln_parallel_fwd_7168.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); diff --git a/ln_parallel_fwd_768.cu b/ln_parallel_fwd_768.cu new file mode 100644 index 0000000000000000000000000000000000000000..f33f519c7fb2934b3b5aabf36a2d9046c4b51ee3 --- /dev/null +++ b/ln_parallel_fwd_768.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/ln_parallel_fwd_8192.cu b/ln_parallel_fwd_8192.cu new file mode 100644 index 0000000000000000000000000000000000000000..360e6d4471062cd40bf245ecff22b579f56d4020 --- /dev/null +++ b/ln_parallel_fwd_8192.cu @@ -0,0 +1,15 @@ +#include "ln_parallel_residual_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16); +REGISTER_PARALLEL_FWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16); diff --git a/ln_parallel_residual_bwd_kernels.cuh b/ln_parallel_residual_bwd_kernels.cuh new file mode 100644 index 0000000000000000000000000000000000000000..521495724400fde6eaecb27e255154a51d8ddbb0 --- /dev/null +++ b/ln_parallel_residual_bwd_kernels.cuh @@ -0,0 +1,540 @@ +#pragma once + +#include "ln.h" +#include "ln_utils.cuh" +#include "ln_kernel_traits.h" +#include "static_switch.h" +#include "ln_bwd_kernels.cuh" + +namespace layer_norm { + +template +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) +void ln_parallel_residual_bwd_kernel(layer_norm::BwdParams params) { + + enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { WARPS_N = Ktraits::WARPS_N }; + enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; + enum { COLS = Ktraits::COLS }; + enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; + enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; + enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; + + using input_t = typename Ktraits::input_t; + using compute_t = typename Ktraits::compute_t; + using index_t = typename Ktraits::index_t; + using mask_t = typename Ktraits::mask_t; + using Ivec = typename Ktraits::Ivec; + using Rvec = typename Ktraits::Rvec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + using Cvec = typename Ktraits::Cvec; + using Mvec = typename Ktraits::Mvec; + using Reducer = typename Ktraits::Reducer; + using reduce_t = typename Reducer::Type; + + extern __shared__ char smem_[]; + + const bool has_residual = params.dresidual != nullptr; + const bool has_x1 = params.dx1 != nullptr; + const bool prenorm = params.dx != nullptr; + + const index_t tidx = threadIdx.x; + const index_t bidn = blockIdx.x % CTAS_PER_ROW; + const index_t bidm = blockIdx.x / CTAS_PER_ROW; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / Ktraits::WARPS_N; + const index_t warp_n = warp % Ktraits::WARPS_N; + const index_t tid_r = warp_n * THREADS_PER_WARP + lane; + + const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; + const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; + + static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); + + Cvec dz0y_sum[LDGS]; + Cvec dz0_sum[LDGS]; + Cvec dz1y_sum[LDGS]; + Cvec dz1_sum[LDGS]; + + memset(dz0y_sum, 0, sizeof(dz0y_sum)); + memset(dz0_sum, 0, sizeof(dz0_sum)); + if (!Tied_norm) { + memset(dz1y_sum, 0, sizeof(dz1y_sum)); + memset(dz1_sum, 0, sizeof(dz1_sum)); + } + + compute_t * smem_wgrad = reinterpret_cast(smem_); + char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; + + Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad); + + Sum sum; + + const index_t num_valid_ldgs = + ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG; + + Wvec gamma0[LDGS]; + Wvec gamma1[LDGS]; + index_t idx = c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + gamma0[it].load_from(params.gamma, idx); + if (!Tied_norm) { gamma1[it].load_from(params.gamma1, idx); } + idx += Ktraits::VEC_COLS_PER_LDG; + } + } + // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the + // last blocks with syncthreads! + // grid stride over rows + #pragma unroll 1 + for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { + const compute_t mu_r = static_cast(params.mu)[row]; + const compute_t rs_r = static_cast(params.rs)[row]; + Mvec dmask0[LDGS], dmask1[LDGS]; + Rvec dx[LDGS]; + compute_t dy[LDGS * NUM_ELTS]; + compute_t y[LDGS * NUM_ELTS]; + compute_t mdy_local = 0.f; + compute_t mdyy_local = 0.f; + index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + Rvec x; + Ovec dz0, dz1; + dz0.load_from(params.dz, idx); + if (!Tied_norm) { dz1.load_from(params.dz1, idx); } + if (prenorm) { dx[it].load_from(params.dx, idx); } + x.load_from(params.x, idx); + if (Is_dropout) { + dmask0[it].load_from(params.dmask, idx); + if (has_x1) { dmask1[it].load_from(params.dmask1, idx); } + } + idx += Ktraits::VEC_COLS_PER_LDG; + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + compute_t x_tmp = x.data.elt[jt]; + compute_t y_tmp = rs_r * (x_tmp - (!params.is_rms_norm ? mu_r : 0.f)); + compute_t dy_tmp = compute_t(gamma0[it].data.elt[jt]) * compute_t(dz0.data.elt[jt]); + if (!Tied_norm) { + dy_tmp += compute_t(gamma1[it].data.elt[jt]) * compute_t(dz1.data.elt[jt]); + } + compute_t dz0_tmp = dz0.data.elt[jt]; + compute_t dz1_tmp; + if (!Tied_norm) { dz1_tmp = dz1.data.elt[jt]; } + + mdy_local += dy_tmp; + mdyy_local += dy_tmp * y_tmp; + + dy[it * NUM_ELTS + jt] = dy_tmp; + y[it * NUM_ELTS + jt] = y_tmp; + + dz0y_sum[it].data.elt[jt] += dz0_tmp * y_tmp; + dz0_sum[it].data.elt[jt] += dz0_tmp; + if (!Tied_norm) { + dz1y_sum[it].data.elt[jt] += dz1_tmp * y_tmp; + dz1_sum[it].data.elt[jt] += dz1_tmp; + } + } + } + } + + reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum); + mdy_local = layer_norm::Get<0>::of(result) * params.inverse_cols; + mdyy_local = layer_norm::Get<1>::of(result) * params.inverse_cols; + + idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + Ivec dx0, dx1; + Rvec dresidual; + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + compute_t dx_tmp_res; + compute_t dy_tmp = dy[it * NUM_ELTS + jt]; + compute_t y_tmp = y[it * NUM_ELTS + jt]; + compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + (!params.is_rms_norm ? mdy_local : 0.f))); + dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp; + if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; } + if (Is_dropout) { + dx0.data.elt[jt] = dmask0[it].data.elt[jt] ? dx_tmp_res * params.dropout_scale : 0.f; + if (has_x1) { dx1.data.elt[jt] = dmask1[it].data.elt[jt] ? dx_tmp_res * params.dropout_scale : 0.f; } + } else { + dx0.data.elt[jt] = dx_tmp_res; + if (has_x1) { dx1.data.elt[jt] = dx_tmp_res; } + } + } + if (has_residual) { dresidual.store_to(params.dresidual, idx); } + dx0.store_to(params.dx0, idx); + if (has_x1) { dx1.store_to(params.dx1, idx); } + idx += Ktraits::VEC_COLS_PER_LDG; + } + } + + } // end: grid stride loop + + if( WARPS_M == 1 ) { + idx = r * params.cols / Ktraits::ELTS_PER_LDG + c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + dz0_sum[it].store_to(params.dbeta_part, idx); + dz0y_sum[it].store_to(params.dgamma_part, idx); + if (!Tied_norm) { + dz1_sum[it].store_to(params.dbeta1_part, idx); + dz1y_sum[it].store_to(params.dgamma1_part, idx); + } + idx += Ktraits::VEC_COLS_PER_LDG; + } + } + } else { + static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA."); + // Finalize reduction of part dgamma and dbeta for this CTA + // by reducing over the rows held across the WARPS_M warps + + // Assumption: blockSize divides hidden size. + enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA }; + static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, ""); + + idx = warp_m * Ktraits::VEC_COLS + tid_r; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + dz0_sum[it].store_to(smem_wgrad, idx); + idx += THREADS_PER_ROW; + } + __syncthreads(); + compute_t cta_dz0_sum[NUM_RES]; + memset(cta_dz0_sum, 0, sizeof(compute_t) * NUM_RES); + for( int it = 0; it < ROWS_PER_CTA; it++ ) { + for( int jt = 0; jt < NUM_RES; jt++ ) { + cta_dz0_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; + } + } + __syncthreads(); + + idx = warp_m * Ktraits::VEC_COLS + tid_r; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + dz0y_sum[it].store_to(smem_wgrad, idx); + idx += THREADS_PER_ROW; + } + __syncthreads(); + compute_t cta_dz0y_sum[NUM_RES]; + memset(cta_dz0y_sum, 0, sizeof(compute_t) * NUM_RES); + for( int it = 0; it < ROWS_PER_CTA; it++ ) { + for( int jt = 0; jt < NUM_RES; jt++ ) { + cta_dz0y_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; + } + } + + compute_t cta_dz1_sum[NUM_RES], cta_dz1y_sum[NUM_RES]; + if (!Tied_norm) { + __syncthreads(); + idx = warp_m * Ktraits::VEC_COLS + tid_r; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + dz1_sum[it].store_to(smem_wgrad, idx); + idx += THREADS_PER_ROW; + } + __syncthreads(); + memset(cta_dz1_sum, 0, sizeof(compute_t) * NUM_RES); + for( int it = 0; it < ROWS_PER_CTA; it++ ) { + for( int jt = 0; jt < NUM_RES; jt++ ) { + cta_dz1_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; + } + } + __syncthreads(); + idx = warp_m * Ktraits::VEC_COLS + tid_r; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + dz1y_sum[it].store_to(smem_wgrad, idx); + idx += THREADS_PER_ROW; + } + __syncthreads(); + memset(cta_dz1y_sum, 0, sizeof(compute_t) * NUM_RES); + for( int it = 0; it < ROWS_PER_CTA; it++ ) { + for( int jt = 0; jt < NUM_RES; jt++ ) { + cta_dz1y_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; + } + } + } + + const index_t num_valid_writes + = (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA; + compute_t *dgamma0_part = static_cast(params.dgamma_part) + bidm * params.cols + tidx; + compute_t *dbeta0_part = static_cast(params.dbeta_part) + bidm * params.cols + tidx; + compute_t *dgamma1_part = !Tied_norm ? static_cast(params.dgamma1_part) + bidm * params.cols + tidx : nullptr; + compute_t *dbeta1_part = !Tied_norm ? static_cast(params.dbeta1_part) + bidm * params.cols + tidx : nullptr; + for( int jt = 0; jt < NUM_RES; jt++ ) { + if (Is_even_cols || (jt < num_valid_writes)) { + *dgamma0_part = cta_dz0y_sum[jt]; + dgamma0_part += Ktraits::THREADS_PER_CTA; + *dbeta0_part = cta_dz0_sum[jt]; + dbeta0_part += Ktraits::THREADS_PER_CTA; + if (!Tied_norm) { + *dgamma1_part = cta_dz1y_sum[jt]; + dgamma1_part += Ktraits::THREADS_PER_CTA; + *dbeta1_part = cta_dz1_sum[jt]; + dbeta1_part += Ktraits::THREADS_PER_CTA; + } + } + } + + } +} + +template +__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) +void ln_parallel_residual_bwd_finalize_kernel(BwdParams params) +{ + + using compute_t = typename Kernel_traits::compute_t; + using weight_t = typename Kernel_traits::weight_t; + using index_t = typename Kernel_traits::index_t; + using Reducer = typename Kernel_traits::Reducer; + using reduce_t = typename Reducer::Type; + + Sum sum; + enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG }; + enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP }; + + // Multiplying by 2 since we have both gamma0 and gamma1 + __shared__ char smem_[2 * Kernel_traits::SMEM_BYTES_PER_CTA]; + + constexpr uint32_t bidm = 0; + + const uint32_t bidn = blockIdx.x; + const uint32_t tidx = threadIdx.x; + const uint32_t warp = tidx / THREADS_PER_WARP; + const uint32_t lane = tidx % THREADS_PER_WARP; + + Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_); + + const uint32_t c = bidn * THREADS_PER_WARP + lane; + const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; + constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; + for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) { + // Each thread sums over NUM_ELT columns. + Vec dbeta0_local, dgamma0_local, dbeta1_local, dgamma1_local; + memset(&dgamma0_local, 0, sizeof(dgamma0_local)); + memset(&dbeta0_local, 0, sizeof(dbeta0_local)); + memset(&dgamma1_local, 0, sizeof(dgamma1_local)); + memset(&dbeta1_local, 0, sizeof(dbeta1_local)); + if (Is_even_cols || col < params.cols) { + for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) { + index_t idx = row * params.cols + col; + + Vec dbeta0_part, dgamma0_part, dbeta1_part, dgamma1_part; + dbeta0_part.load_from(params.dbeta_part, idx); + dgamma0_part.load_from(params.dgamma_part, idx); + dbeta1_part.load_from(params.dbeta1_part, idx); + dgamma1_part.load_from(params.dgamma1_part, idx); + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + dgamma0_local.data.elt[it] += dgamma0_part.data.elt[it]; + dbeta0_local.data.elt[it] += dbeta0_part.data.elt[it]; + dgamma1_local.data.elt[it] += dgamma1_part.data.elt[it]; + dbeta1_local.data.elt[it] += dbeta1_part.data.elt[it]; + } + } + } + void * smem_gamma0 = smem_; + void * smem_beta0 = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; + void * smem_gamma1 = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; + void * smem_beta1 = &smem_[3 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; + + const int write_row = warp; + const int write_col = lane ^ write_row; + const int write_idx = write_row * THREADS_PER_WARP + write_col; + + dgamma0_local.store_to(smem_gamma0, write_idx); + dbeta0_local.store_to(smem_beta0, write_idx); + dgamma1_local.store_to(smem_gamma1, write_idx); + dbeta1_local.store_to(smem_beta1, write_idx); + + __syncthreads(); + + // It would be probably safe to reuse the first row of smem_beta0 and smem_gamma0 + void * smem_gamma0_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; + void * smem_beta0_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT]; + void * smem_gamma1_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT]; + void * smem_beta1_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE + 3 * Kernel_traits::SMEM_BYTES_OUTPUT]; + + // More than one iter iff ROWS_PER_CTA < 32. + for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) { + const int read_row = lane; + const int read_col = w ^ read_row; + const int read_idx = read_row * THREADS_PER_WARP + read_col; + + memset(&dbeta0_local, 0, sizeof(dbeta0_local)); + memset(&dgamma0_local, 0, sizeof(dgamma0_local)); + memset(&dbeta1_local, 0, sizeof(dbeta1_local)); + memset(&dgamma1_local, 0, sizeof(dgamma1_local)); + + // Load beta and gamma transposed + if(read_row < Kernel_traits::ROWS_PER_CTA){ + dbeta0_local.load_from(smem_beta0, read_idx); + dgamma0_local.load_from(smem_gamma0, read_idx); + dbeta1_local.load_from(smem_beta1, read_idx); + dgamma1_local.load_from(smem_gamma1, read_idx); + } + + // Call reducer on the loaded value(s) and convert. + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + compute_t b0_i = dbeta0_local.data.elt[it]; + compute_t g0_i = dgamma0_local.data.elt[it]; + compute_t b1_i = dbeta1_local.data.elt[it]; + compute_t g1_i = dgamma1_local.data.elt[it]; + b0_i = reducer.allreduce(b0_i, sum); + g0_i = reducer.allreduce(g0_i, sum); + b1_i = reducer.allreduce(b1_i, sum); + g1_i = reducer.allreduce(g1_i, sum); + + dgamma0_local.data.elt[it] = g0_i; + dbeta0_local.data.elt[it] = b0_i; + dgamma1_local.data.elt[it] = g1_i; + dbeta1_local.data.elt[it] = b1_i; + } + + // Leader stores the result at the current column. + if(lane == 0){ + dgamma0_local.store_to(smem_gamma0_out, w); + dbeta0_local.store_to(smem_beta0_out, w); + dgamma1_local.store_to(smem_gamma1_out, w); + dbeta1_local.store_to(smem_beta1_out, w); + } + + } + + // All writes done. + __syncthreads(); + + // Pack and store: 2-wide stores with half the threads. + if (Is_even_cols || col_out * 2 < params.cols) { + if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) { + + using src_t = typename TypeToVec2::Type; + using dst_t = typename TypeToVec2::Type; + Vec dbeta0_vec2, dgamma0_vec2, dbeta1_vec2, dgamma1_vec2; + Vec dbeta0_out2, dgamma0_out2, dbeta1_out2, dgamma1_out2; + + dgamma0_vec2.load_from(smem_gamma0_out, lane); + dbeta0_vec2.load_from(smem_beta0_out, lane); + dgamma1_vec2.load_from(smem_gamma1_out, lane); + dbeta1_vec2.load_from(smem_beta1_out, lane); + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + dgamma0_out2.data.elt[it] = Converter::convert(dgamma0_vec2.data.elt[it]); + dbeta0_out2.data.elt[it] = Converter::convert(dbeta0_vec2.data.elt[it]); + dgamma1_out2.data.elt[it] = Converter::convert(dgamma1_vec2.data.elt[it]); + dbeta1_out2.data.elt[it] = Converter::convert(dbeta1_vec2.data.elt[it]); + } + dgamma0_out2.store_to(params.dgamma, col_out); + dbeta0_out2.store_to(params.dbeta, col_out); + dgamma1_out2.store_to(params.dgamma1, col_out); + dbeta1_out2.store_to(params.dbeta1, col_out); + } + } + } +} + +} // namespace layer_norm + +using namespace layer_norm; + +template< + typename weight_t, + typename input_t, + typename residual_t, + typename output_t, + typename compute_t, + typename index_t, + int HIDDEN_SIZE, + int CTAS_PER_ROW, + int WARPS_M, + int WARPS_N, + int BYTES_PER_LDG_MAIN, + int BYTES_PER_LDG_FINAL +> +void launch_parallel_residual_(LaunchParams &launch_params, const bool configure_params){ + + using Kernel_traits = Kernel_traits; + bool is_dropout = launch_params.params.dropout_keep_p < 1.f; + bool tied_norm = launch_params.params.gamma1 == nullptr; + bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; + BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { + BOOL_SWITCH(tied_norm, TiedNormConst, [&] { + BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { + auto kernel = &ln_parallel_residual_bwd_kernel; + if( configure_params ) { + int ctas_per_sm; + CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES)); + launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if(Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; + launch_params.workspace_bytes = launch_params.params.ctas_per_col + * Kernel_traits::WARPS_M + * Kernel_traits::CTAS_PER_ROW + * sizeof(typename Kernel_traits::reduce_t) + * 2; + } + return; + } + + if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + + if( Kernel_traits::CTAS_PER_ROW == 1 ) { + kernel<<>>(launch_params.params); + } else { + dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = (void *)&launch_params.params; + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); + } + + using Kernel_traits_f = layer_norm::Kernel_traits_finalize; + + auto kernel_f = !TiedNormConst + ? &layer_norm::ln_parallel_residual_bwd_finalize_kernel + : &layer_norm::ln_bwd_finalize_kernel; + kernel_f<<>>(launch_params.params); + + }); + }); + }); +} diff --git a/ln_parallel_residual_fwd_kernels.cuh b/ln_parallel_residual_fwd_kernels.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0e55cb4038b4dbe30d9eb47609df3afea4c4f5fb --- /dev/null +++ b/ln_parallel_residual_fwd_kernels.cuh @@ -0,0 +1,281 @@ +#pragma once + +#ifdef OLD_GENERATOR_PATH +#include +#else +#include +#endif + +#include // For at::cuda::philox::unpack +#include + +#include "ln.h" +#include "ln_utils.cuh" +#include "ln_kernel_traits.h" +#include "static_switch.h" + +namespace layer_norm { + +template +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) +void ln_parallel_residual_fwd_kernel(FwdParams params) { + + enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; + enum { WARPS_N = Ktraits::WARPS_N }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; + enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; + enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::NUM_ELTS }; + enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; + + using input_t = typename Ktraits::input_t; + using residual_t = typename Ktraits::residual_t; + using output_t = typename Ktraits::output_t; + using index_t = typename Ktraits::index_t; + using compute_t = typename Ktraits::compute_t; + using mask_t = typename Ktraits::mask_t; + using Ivec = typename Ktraits::Ivec; + using Rvec = typename Ktraits::Rvec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + using Cvec = typename Ktraits::Cvec; + using Mvec = typename Ktraits::Mvec; + + using Stats = typename Ktraits::Stats; + using stats_t = typename Stats::stats_t; + + const bool has_residual = params.residual != nullptr; + const bool has_x1 = params.x1 != nullptr; + const bool save_x = has_residual || has_x1 || Is_dropout || !(std::is_same::value); + + extern __shared__ char smem_[]; + + const index_t tidx = threadIdx.x; + const index_t bidn = blockIdx.x % CTAS_PER_ROW; + const index_t bidm = blockIdx.x / CTAS_PER_ROW; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / WARPS_N; + const index_t warp_n = warp % WARPS_N; + + const index_t r = bidm * ROWS_PER_CTA + warp_m; + const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; + + Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); + + compute_t *mu_ptr = static_cast(params.mu); + compute_t *rs_ptr = static_cast(params.rs); + + // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu + curandStatePhilox4_32_10_t state; + if (Is_dropout) { + auto seeds = at::cuda::philox::unpack(params.philox_args); + const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x; + curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state); + } + + const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG; + + Wvec gamma0[LDGS]; + Wvec beta0[LDGS]; + Wvec gamma1[LDGS]; + Wvec beta1[LDGS]; + index_t idx = c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + gamma0[it].load_from(params.gamma, idx); + if (params.beta != nullptr) { + beta0[it].load_from(params.beta, idx); + } else { + beta0[it].zero_(); + } + if (!Tied_norm) { + gamma1[it].load_from(params.gamma1, idx); + if (params.beta1 != nullptr) { + beta1[it].load_from(params.beta1, idx); + } else { + beta1[it].zero_(); + } + } + idx += VEC_COLS_PER_LDG; + } + } + + for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { + index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; + compute_t xf[LDGS * NUM_ELTS]; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + Ivec x0; + Ivec x1; + Rvec residual; + Rvec x; + Mvec dmask0; + Mvec dmask1; + x0.load_from(params.x0, idx); + if (has_x1) { x1.load_from(params.x1, idx); } + if (has_residual) { residual.load_from(params.residual, idx); } + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use + // the more efficient curand_uniform4. + compute_t x_ij; + mask_t keep0 = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p; + if (Is_dropout) { dmask0.data.elt[jt] = keep0; } + compute_t x0_ij = compute_t(x0.data.elt[jt]); + x0_ij = keep0 ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f; + if (has_x1) { + mask_t keep1 = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p; + if (Is_dropout) { dmask1.data.elt[jt] = keep1; } + compute_t x1_ij = compute_t(x1.data.elt[jt]); + x1_ij = keep1 ? (Is_dropout ? x1_ij * params.dropout_scale : x1_ij) : 0.0f; + x_ij = has_residual ? x0_ij + x1_ij + compute_t(residual.data.elt[jt]) : x0_ij + x1_ij; + } else { + x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij; + } + if (save_x) { x.data.elt[jt] = x_ij; } + xf[it * NUM_ELTS + jt] = x_ij; + } + if (save_x) { x.store_to(params.x, idx); } + if (Is_dropout) { + dmask0.store_to(params.dmask, idx); + if (has_x1) { dmask1.store_to(params.dmask1, idx); } + } + idx += VEC_COLS_PER_LDG; + } + } + + static_assert(CTAS_PER_ROW == 1, "Don't support multiple CTAs per row for now"); + const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG; + const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG; + const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG; + auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int { + // Need to convert to int, otherwise the subtraction will wrap around. + const index_t valid_partial_vecs_in_warp = + std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)), + int(THREADS_PER_WARP)); + return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS; + }; + stats_t s = stats.template compute( + xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS + ); + + compute_t mu = layer_norm::Get<0>::of(s); + compute_t m2 = layer_norm::Get<1>::of(s); + + if( bidn == 0 && warp_n == 0 && lane == 0 ) { + mu_ptr[row] = mu; + } + + compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu)); + + if( bidn == 0 && warp_n == 0 && lane == 0 ) { + rs_ptr[row] = rs; + } + + idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + Ovec z0; + Ovec z1; + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f))); + compute_t g0_ij = gamma0[it].data.elt[jt]; + compute_t b0_ij = beta0[it].data.elt[jt]; + z0.data.elt[jt] = output_t(g0_ij * y_ij + b0_ij); + if (!Tied_norm) { + compute_t g1_ij = gamma1[it].data.elt[jt]; + compute_t b1_ij = beta1[it].data.elt[jt]; + z1.data.elt[jt] = output_t(g1_ij * y_ij + b1_ij); + } + } + z0.store_to(params.z, idx); + if (!Tied_norm) { z1.store_to(params.z1, idx); } + idx += VEC_COLS_PER_LDG; + } + } + + } +} + +} // namespace layer_norm + +using namespace layer_norm; + +template< + typename weight_t, + typename input_t, + typename residual_t, + typename output_t, + typename compute_t, + typename index_t, + int HIDDEN_SIZE, + int CTAS_PER_ROW, + int WARPS_M, + int WARPS_N, + int BYTES_PER_LDG +> +void launch_parallel_residual_(LaunchParams &launch_params, const bool configure_params){ + + using Kernel_traits = Kernel_traits; + bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; + bool tied_norm = launch_params.params.gamma1 == nullptr; + BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] { + BOOL_SWITCH(tied_norm, TiedNormConst, [&] { + BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { + auto kernel = &ln_parallel_residual_fwd_kernel; + if( configure_params ) { + int ctas_per_sm; + CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD)); + launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; + const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA; + launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS; + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if(Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; + launch_params.workspace_bytes = launch_params.params.ctas_per_col + * Kernel_traits::WARPS_M + * Kernel_traits::CTAS_PER_ROW + * sizeof(typename Kernel_traits::Stats::stats_t) + * 2; + } + return; + } + + if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + + if( Kernel_traits::CTAS_PER_ROW == 1 ) { + kernel<<>>(launch_params.params); + } else { + dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = (void *)&launch_params.params; + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); + } + }); + }); + }); +} diff --git a/ln_utils.cuh b/ln_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..178d6fda895b478ac76e2a77a2b1b35115fcc279 --- /dev/null +++ b/ln_utils.cuh @@ -0,0 +1,783 @@ +#pragma once + +#include + +#include +#include + +#include "ln.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +constexpr uint32_t THREADS_PER_WARP = 32; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline void check_cuda_(cudaError_t status, const char *file, int line) { + if( status != cudaSuccess ) { + fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line); + exit(status); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define CHECK_CUDA(ans) \ + { check_cuda_((ans), __FILE__, __LINE__); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define DIVUP(x, y) (((x) + ((y)-1)) / (y)) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ + void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ + const bool configure_params) { \ + launch_( \ + launch_params, configure_params); \ + } \ + static FwdRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ + ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define REGISTER_BWD_LAUNCHER( \ + HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ + void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ + const bool configure_params) { \ + launch_(launch_params, configure_params); \ + } \ + static BwdRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ + ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define REGISTER_PARALLEL_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ + void ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ + const bool configure_params) { \ + launch_parallel_residual_( \ + launch_params, configure_params); \ + } \ + static FwdParallelRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ + ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define REGISTER_PARALLEL_BWD_LAUNCHER( \ + HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ + void ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ + const bool configure_params) { \ + launch_parallel_residual_(launch_params, configure_params); \ + } \ + static BwdParallelRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ + ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 operator+(const float2 & a, const float2 & b){ + return {a.x + b.x, a.y + b.y}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void operator+=(float2 & a, const float2 & b){ + a.x += b.x; + a.y += b.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Sum { + inline __device__ Sum(){} + inline __device__ T operator()(const T &a, const T &b){ + return a + b; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){ + return __shfl_xor_sync(uint32_t(-1), x, idx); +} + +template<> +inline __device__ float2 warp_shuffle_xor(const float2 & x, uint32_t idx){ + return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) }; +} + +template +inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){ + return __shfl_down_sync(uint32_t(-1), x, idx); +} + +template<> +inline __device__ float2 warp_shuffle_down(const float2 & x, uint32_t idx){ + return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) }; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace layer_norm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct uint16 { + uint4 u; + uint4 v; + uint4 s; + uint4 t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct uint8 { + uint4 u; + uint4 v; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BytesToType {}; + +template<> +struct BytesToType<64> { + using Type = uint16; + static_assert(sizeof(Type) == 64); +}; + +template<> +struct BytesToType<32> { + using Type = uint8; + static_assert(sizeof(Type) == 32); +}; + +template<> +struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> +struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> +struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> +struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> +struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TypeToVec2 {}; + +template<> +struct TypeToVec2 { + using Type = float2; +}; + +template<> +struct TypeToVec2 { + using Type = half2; +}; + +template<> +struct TypeToVec2 { + using Type = nv_bfloat162; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Get { + template + static inline __device__ R of(const T &vec); +}; + +template<> +template +inline __device__ R Get<0>::of(const T &vec) { + return vec.x; +} + +template<> +template +inline __device__ R Get<1>::of(const T &vec) { + return vec.y; +} + +template<> +template +inline __device__ R Get<2>::of(const T &vec) { + return vec.z; +} + +template<> +template +inline __device__ R Get<3>::of(const T &vec) { + return vec.w; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Converter{ + static inline __device__ Dst convert(const Src &from) { + return Dst(from); + } +}; + +template<> +struct Converter{ + static inline __device__ half2 convert(const float2 &x) { + return __float22half2_rn(x); + } +}; + +template<> +struct Converter{ + static inline __device__ nv_bfloat162 convert(const float2 &x) { +#if __CUDA_ARCH__ >= 800 + return __float22bfloat162_rn(x); +#else + union { + nv_bfloat162 raw; + nv_bfloat16 x; + nv_bfloat16 y; + } tmp; + tmp.x = __float2bfloat16_rn(x.x); + tmp.y = __float2bfloat16_rn(x.y); + return tmp.raw; +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Zeros{ + static inline __device__ T get() { + return T(0.f); + } +}; + +template<> +struct Zeros{ + static inline __device__ float2 get() { + return make_float2(0.f, 0.f); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Vec { + + enum { BYTES = NUM_ELT * sizeof(Elt_type) }; + + using Vec_type = typename BytesToType::Type; + + using Alias_type = union { + Vec_type vec; + Elt_type elt[NUM_ELT]; + }; + + Alias_type data; + + template + inline __device__ void to(Vec &other) { + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + other.data.elt[it] = S(this->data.elt[it]); + } + } + + template + inline __device__ void assign(const Op &op) { + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + this->data.elt[it] = op(it); + } + } + + inline __device__ void zero_() { + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + this->data.elt[it] = Elt_type(0.f); + } + } + + inline __device__ void load_from(const void *base_ptr, const size_t idx) { + this->data.vec = static_cast(base_ptr)[idx]; + } + + inline __device__ void store_to(void *base_ptr, const size_t idx) { + static_cast(base_ptr)[idx] = this->data.vec; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct InterCTASync { + + template + inline __device__ InterCTASync(Params & params, uint32_t bidm, uint32_t bidn) + : phase_counter_(0) + , b0_(params.barrier + bidm) // The barrier for this group of CTAs. + , b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs. + { + // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0! + } + + inline __device__ void spin_wait_(int *barrier, int step, int expected) { + asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); + for( int found = -1; found != expected; ) { + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); + } + } + + inline __device__ void sync(){ + // ALL THREADS MUST ENTER! + + // We switch barrier every iteration. + int *barrier = phase_counter_ & 0x1 ? b1_ : b0_; + // We decrement every other iteration. + bool dec = phase_counter_ & 0x2; + int step = dec ? -1 : 1; + int expected = dec ? 0 : CTAS_PER_ROW; + // There are only 4 phases: up/down for b0/b1. + phase_counter_ = (phase_counter_ + 1) & 0x3; + + if( threadIdx.x == 0 ) { + spin_wait_(barrier, step, expected); + } + // CTA waits for thread 0 + __syncthreads(); + } + + int phase_counter_; + int * b0_; + int * b1_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Reducer : public Reducer { + + using InterCTASync = InterCTASync; + using Base = Reducer; + using Type = typename Base::Type; + + enum { SMEM_BYTES = Base::SMEM_BYTES }; + + enum { WS_BARRIER_BYTES = 2 * sizeof(int) }; + enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) }; + + // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total) + enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES }; + + template + inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) + , inter_cta_(params, bidm, bidn) + , bidn_(bidn) // CTA id within the group. + , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) + , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) + { + } + + template + inline __device__ T allreduce(T data, Op &op) { + data = Base::reduce(data, op); + // We switch workspace every iteration. + T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; + + // Warp leaders 0 hold the CTA-local results. + if( this->warp_n_ == 0 && this->lane_ == 0 ) { + workspace[bidn_] = data; + } + inter_cta_.sync(); + static_assert(CTAS_PER_ROW <= 32); + T total = Zeros::get(); + if(this->lane_ < CTAS_PER_ROW){ + total = workspace[this->lane_]; + } + total = Reducer::allreduce_(total, op); + + return total; + } + + InterCTASync inter_cta_; + + T *w0_; + T *w1_; + int bidn_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Reducer { + + using Type = T; + enum { SMEM_BYTES = 0 }; + enum { WORKSPACE_BYTES_PER_GROUP = 0 }; + + enum { THREADS_PER_WARP = 32 }; + + template + inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : warp_n_(warp_n) + , lane_(lane) + { + } + + template + static inline __device__ T allreduce_(T data, Op &op) { + #pragma unroll + for( int it = 1; it < THREADS_PER_WARP; it *= 2 ) { + data = op(data, warp_shuffle_xor(data, it)); + } + return data; + } + + template + inline __device__ T allreduce(T data, Op &op) { + return allreduce_(data, op); + } + + template + inline __device__ T reduce(T data, Op &op){ + // only lane 0 holds the result! + #pragma unroll + for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) { + data = op(data, warp_shuffle_down(data, it)); + } + return data; + } + int warp_n_; + int lane_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Reducer : public Reducer { + + using Base = Reducer; + + using Type = T; + + enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 }; + enum { WORKSPACE_BYTES_PER_GROUP = 0 }; + + enum { THREADS_PER_WARP = 32 }; + + template + inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) + , use0_(true) + { + smem0_ = &static_cast(smem)[warp_m * WARPS_N]; + smem1_ = smem0_ + WARPS_M * WARPS_N; + } + + template + inline __device__ T allreduce(T data, Op & op) { + T * smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + data = Base::reduce(data, op); + if( this->lane_ == 0 ) { + smem[this->warp_n_] = data; + } + __syncthreads(); + T out = Zeros::get(); + #pragma unroll + for( int it = 0; it < WARPS_N; it++ ) { + out = op(out, smem[it]); + } + return out; + } + + template + inline __device__ T reduce(T data, Op &op) { + T * smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + // only intra-CTA group leader holds the result! + data = Base::reduce(data, op); + if( this->lane_ == 0 ) { + smem[this->warp_n_] = data; + } + __syncthreads(); + T out = Zeros::get(); + if( this->warp_n_ == 0 && this->lane_ == 0 ) { + #pragma unroll + for( int it = 0; it < WARPS_N; it++ ) { + out = op(out, smem[it]); + } + } + return out; + } + + T * smem0_; + T * smem1_; + bool use0_; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, int_t &n_a, int num_active){ + //Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise) + const int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); + + #pragma unroll + for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) { + // Exchange + int_t n_b = warp_shuffle_down(n_a, step); + T m_b = warp_shuffle_down(m_a, step); + T m2_b = warp_shuffle_down(m2_a, step); + + // Update + const int_t n_ab = n_a + n_b; // We can handle one of them being 0, not both. + const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :( + const T delta = m_a - m_b; + const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; + const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab; + + n_a = n_ab; + m_a = m_ab; + m2_a = m2_ab; + } + // Intra-warp broadcast (only lane 0 has valid stats). + m_a = __shfl_sync(uint32_t(-1), m_a, 0); + m2_a = __shfl_sync(uint32_t(-1), m2_a, 0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Stats { + // This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields. + + using InterCTASync = InterCTASync; + using BlockStats = Stats; + using stats_t = typename BlockStats::stats_t; + + enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; + + template + inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : inter_cta_(params, bidm, bidn) + , block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) + , bidn_(bidn) // CTA id within the group. + , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) + , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) + , warp_n_(warp_n) + , lane_(lane) + { + } + + template + inline __device__ stats_t compute(const T (&elts)[N], const T rn) { + constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP; + // TODO rn is not really needed here.. + constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA); + stats_t block_stats = block_stats_.compute(elts, block_rn); + + stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; + + if( warp_n_ == 0 && lane_ == 0 ) { + workspace[bidn_] = block_stats; + } + + // Wait for all CTAS_PER_ROW CTAS in the group to have written their result. + inter_cta_.sync(); + + T n = Zeros::get(); + T m = Zeros::get(); + T m2 = Zeros::get(); + + // Assume CTA group size in N less than 32, such that we can finalize with a single warp. + static_assert(CTAS_PER_ROW <= 32); + + // Every warp does the final reduction locally. + if( lane_ < CTAS_PER_ROW ) { + stats_t result = workspace[lane_]; + n = ELTS_PER_ROW_PER_CTA; + m = layer_norm::Get<0>::of(result); + m2 = layer_norm::Get<1>::of(result); + } + + warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW); + + return { m, m2 }; + } + + InterCTASync inter_cta_; + BlockStats block_stats_; + + stats_t *w0_; + stats_t *w1_; + int bidn_; + int warp_n_; + int lane_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Stats { + + using WarpStats = Stats; + using stats_t = typename WarpStats::stats_t; + + enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; + + template + inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) + , use0_(true) + { + smem0_ = static_cast(smem) + warp_m * WARPS_N; + smem1_ = smem0_ + WARPS_M * WARPS_N; + } + + template + inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor, + function_t valid_elts_in_warp_fn, const int num_valid_elts = N) { + stats_t * smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + // Compute warp local for all WARPS_N + const auto warp_n = warp_stats_.reducer_.warp_n_; + const T warp_norm_factor = 1.f / T(Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(warp_n)); + stats_t warp_stats = warp_stats_.template compute( + elts, warp_norm_factor, valid_elts_in_warp_fn, num_valid_elts + ); + + //Each warp warp leader stores its stats + const auto lane = warp_stats_.reducer_.lane_; + if( lane == 0 ) { + smem[warp_n] = warp_stats; + } + __syncthreads(); + + int n = 0;; + T m = Zeros::get(); + T m2 = Zeros::get(); + + // Assume that there are less than 32 warps, such that we can finalize with a single warp + static_assert(WARPS_N <= 32); + if(lane < WARPS_N){ + stats_t result = smem[lane]; + n = Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(lane); + m = layer_norm::Get<0>::of(result); + m2 = layer_norm::Get<1>::of(result); + } + + warp_chan_upd_dynamic(m, m2, n, WARPS_N); + + return { m, m2 }; + } + WarpStats warp_stats_; + stats_t * smem0_; + stats_t * smem1_; + bool use0_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Stats { + + using stats_t = typename TypeToVec2::Type; + // The simple Warp reducer. + using Reducer = Reducer; + + enum { SMEM_BYTES = 0 }; + + template + inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) + { + } + + template + inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor, + // const int valid_elts_in_warp_ignored_, const int num_valid_elts = N) { + function_t valid_elts_in_warp_fn, const int num_valid_elts = N) { + + auto sum = Sum(); + + T m = Zeros::get(); + #pragma unroll + for( int it = 0; it < N; it++ ) { + if (Is_even_cols || (it < num_valid_elts)) { + m += elts[it]; + } + } + m = reducer_.allreduce(m, sum) * row_norm_factor; + + T m2 = Zeros::get(); + #pragma unroll + for( int it = 0; it < N; it++ ) { + if (Is_even_cols || (it < num_valid_elts)) { + T diff = (elts[it] - m); + m2 += diff * diff; + } + } + m2 = reducer_.allreduce(m2, sum); + + return {m, m2}; + } + + Reducer reducer_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace layer_norm diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..0ea3784e0a14edec262ee23271e01e7e5b0c6c71 --- /dev/null +++ b/setup.py @@ -0,0 +1,203 @@ +# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py +import os +from packaging.version import parse, Version + +import torch +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME +from setuptools import setup, find_packages +import subprocess + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_cuda_torch_binary_vs_bare_metal(cuda_dir): + raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) + torch_binary_version = parse(torch.version.cuda) + + print("\nCompiling cuda extensions with") + print(raw_output + "from " + cuda_dir + "/bin\n") + + if (bare_metal_version != torch_binary_version): + raise RuntimeError( + "Cuda extensions are being compiled with a version of Cuda that does " + "not match the version used to compile Pytorch binaries. " + "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) + + "In some cases, a minor-version mismatch will not cause later errors: " + "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " + "You can try commenting out this check (at your own risk)." + ) + + +def raise_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + raise RuntimeError( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.2"): + nvcc_threads = os.getenv("NVCC_THREADS") or "4" + return nvcc_extra_args + ["--threads", nvcc_threads] + return nvcc_extra_args + + +if not torch.cuda.is_available(): + # https://github.com/NVIDIA/apex/issues/486 + # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), + # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). + print( + "\nWarning: Torch did not find available GPUs on this system.\n", + "If your intention is to cross-compile, this is not an error.\n" + "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" + "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" + "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" + "If you wish to cross-compile for a single specific architecture,\n" + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', + ) + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.8"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" + elif bare_metal_version >= Version("11.1"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" + elif bare_metal_version == Version("11.0"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" + else: + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" + + +print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) +TORCH_MAJOR = int(torch.__version__.split(".")[0]) +TORCH_MINOR = int(torch.__version__.split(".")[1]) + +cmdclass = {} +ext_modules = [] + +# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h +# See https://github.com/pytorch/pytorch/pull/70650 +generator_flag = [] +torch_dir = torch.__path__[0] +if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + generator_flag = ["-DOLD_GENERATOR_PATH"] + +raise_if_cuda_home_none("--fast_layer_norm") +# Check, if CUDA11 is installed for compute capability 8.0 +cc_flag = [] +_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) +if bare_metal_version < Version("11.0"): + raise RuntimeError("dropout_layer_norm is only supported on CUDA 11 and above") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_70,code=sm_70") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_80,code=sm_80") +if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + +ext_modules.append( + CUDAExtension( + name="dropout_layer_norm", + sources=[ + "ln_api.cpp", + "ln_fwd_256.cu", + "ln_bwd_256.cu", + # "ln_fwd_512.cu", + # "ln_bwd_512.cu", + # "ln_fwd_768.cu", + # "ln_bwd_768.cu", + # "ln_fwd_1024.cu", + # "ln_bwd_1024.cu", + # "ln_fwd_1280.cu", + # "ln_bwd_1280.cu", + # "ln_fwd_1536.cu", + # "ln_bwd_1536.cu", + # "ln_fwd_2048.cu", + # "ln_bwd_2048.cu", + # "ln_fwd_2560.cu", + # "ln_bwd_2560.cu", + # "ln_fwd_3072.cu", + # "ln_bwd_3072.cu", + # "ln_fwd_4096.cu", + # "ln_bwd_4096.cu", + # "ln_fwd_5120.cu", + # "ln_bwd_5120.cu", + # "ln_fwd_6144.cu", + # "ln_bwd_6144.cu", + # "ln_fwd_7168.cu", + # "ln_bwd_7168.cu", + # "ln_fwd_8192.cu", + # "ln_bwd_8192.cu", + # "ln_parallel_fwd_256.cu", + # "ln_parallel_bwd_256.cu", + # "ln_parallel_fwd_512.cu", + # "ln_parallel_bwd_512.cu", + # "ln_parallel_fwd_768.cu", + # "ln_parallel_bwd_768.cu", + # "ln_parallel_fwd_1024.cu", + # "ln_parallel_bwd_1024.cu", + # "ln_parallel_fwd_1280.cu", + # "ln_parallel_bwd_1280.cu", + # "ln_parallel_fwd_1536.cu", + # "ln_parallel_bwd_1536.cu", + # "ln_parallel_fwd_2048.cu", + # "ln_parallel_bwd_2048.cu", + # "ln_parallel_fwd_2560.cu", + # "ln_parallel_bwd_2560.cu", + # "ln_parallel_fwd_3072.cu", + # "ln_parallel_bwd_3072.cu", + # "ln_parallel_fwd_4096.cu", + # "ln_parallel_bwd_4096.cu", + # "ln_parallel_fwd_5120.cu", + # "ln_parallel_bwd_5120.cu", + # "ln_parallel_fwd_6144.cu", + # "ln_parallel_bwd_6144.cu", + # "ln_parallel_fwd_7168.cu", + # "ln_parallel_bwd_7168.cu", + # "ln_parallel_fwd_8192.cu", + # "ln_parallel_bwd_8192.cu", + ], + extra_compile_args={ + "cxx": ["-O3"] + generator_flag, + "nvcc": append_nvcc_threads( + [ + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + ] + + generator_flag + + cc_flag + ), + }, + include_dirs=[this_dir], + ) +) + +setup( + name="dropout_layer_norm", + version="0.1", + description="Fused dropout + add + layer norm", + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension} if ext_modules else {}, +) diff --git a/static_switch.h b/static_switch.h new file mode 100644 index 0000000000000000000000000000000000000000..7920ac045d0a2a1f4c4159ee3eebe51fe1e2c203 --- /dev/null +++ b/static_switch.h @@ -0,0 +1,25 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }()