diff --git a/.gitignore b/.gitignore index 7a60b85e148f80966a550e5ab6a762a907c69ca6..a2c04a9d11fc7ef97f2a8668004d1d2a3b321e39 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ __pycache__/ +**/__pycache__/ *.pyc diff --git a/CMakeLists.txt b/CMakeLists.txt deleted file mode 100644 index 8f20f8e15f719cb4e62f86f56082612467f5e02d..0000000000000000000000000000000000000000 --- a/CMakeLists.txt +++ /dev/null @@ -1,213 +0,0 @@ -cmake_minimum_required(VERSION 3.26) -project(layer_norm LANGUAGES CXX) - -set(TARGET_DEVICE "cuda" CACHE STRING "Target device backend for kernel") - -install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) - -include(FetchContent) -file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists -message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") - -set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0") - -set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101") - -include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) - -if(DEFINED Python_EXECUTABLE) - # Allow passing through the interpreter (e.g. from setup.py). - find_package(Python COMPONENTS Development Development.SABIModule Interpreter) - if (NOT Python_FOUND) - message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") - endif() -else() - find_package(Python REQUIRED COMPONENTS Development Development.SABIModule Interpreter) -endif() - -append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") - -find_package(Torch REQUIRED) - -if (NOT TARGET_DEVICE STREQUAL "cuda" AND - NOT TARGET_DEVICE STREQUAL "rocm") - return() -endif() - -if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND - CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) - set(CUDA_DEFAULT_KERNEL_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0+PTX") -else() - set(CUDA_DEFAULT_KERNEL_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0+PTX") -endif() - -if (NOT HIP_FOUND AND CUDA_FOUND) - set(GPU_LANG "CUDA") - - - -elseif(HIP_FOUND) - set(GPU_LANG "HIP") - - # Importing torch recognizes and sets up some HIP/ROCm configuration but does - # not let cmake recognize .hip files. In order to get cmake to understand the - # .hip extension automatically, HIP must be enabled explicitly. - enable_language(HIP) -else() - message(FATAL_ERROR "Can't find CUDA or HIP installation.") -endif() - -if(GPU_LANG STREQUAL "CUDA") - clear_cuda_arches(CUDA_ARCH_FLAGS) - extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}") - message(STATUS "CUDA target architectures: ${CUDA_ARCHS}") - # Filter the target architectures by the supported supported archs - # since for some files we will build for all CUDA_ARCHS. - cuda_archs_loose_intersection(CUDA_ARCHS "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}") - message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}") - - if(NVCC_THREADS AND GPU_LANG STREQUAL "CUDA") - list(APPEND GPU_FLAGS "--threads=${NVCC_THREADS}") - endif() - - add_compile_definitions(CUDA_KERNEL) -elseif(GPU_LANG STREQUAL "HIP") - set(ROCM_ARCHS "${HIP_SUPPORTED_ARCHS}") - # TODO: remove this once we can set specific archs per source file set. - override_gpu_arches(GPU_ARCHES - ${GPU_LANG} - "${${GPU_LANG}_SUPPORTED_ARCHS}") - - add_compile_definitions(ROCM_KERNEL) -else() - override_gpu_arches(GPU_ARCHES - ${GPU_LANG} - "${${GPU_LANG}_SUPPORTED_ARCHS}") -endif() - -get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG}) -list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS}) - -set(TORCH_layer_norm_SRC - torch-ext/torch_binding.cpp torch-ext/torch_binding.h -) - - -list(APPEND SRC "${TORCH_layer_norm_SRC}") - - -set(layer_norm_SRC - "layer_norm/ln.h" -"layer_norm/ln_api.cpp" -"layer_norm/ln_bwd_1024.cu" -"layer_norm/ln_bwd_1280.cu" -"layer_norm/ln_bwd_1536.cu" -"layer_norm/ln_bwd_2048.cu" -"layer_norm/ln_bwd_256.cu" -"layer_norm/ln_bwd_2560.cu" -"layer_norm/ln_bwd_3072.cu" -"layer_norm/ln_bwd_4096.cu" -"layer_norm/ln_bwd_512.cu" -"layer_norm/ln_bwd_5120.cu" -"layer_norm/ln_bwd_6144.cu" -"layer_norm/ln_bwd_7168.cu" -"layer_norm/ln_bwd_768.cu" -"layer_norm/ln_bwd_8192.cu" -"layer_norm/ln_bwd_kernels.cuh" -"layer_norm/ln_fwd_1024.cu" -"layer_norm/ln_fwd_1280.cu" -"layer_norm/ln_fwd_1536.cu" -"layer_norm/ln_fwd_2048.cu" -"layer_norm/ln_fwd_256.cu" -"layer_norm/ln_fwd_2560.cu" -"layer_norm/ln_fwd_3072.cu" -"layer_norm/ln_fwd_4096.cu" -"layer_norm/ln_fwd_512.cu" -"layer_norm/ln_fwd_5120.cu" -"layer_norm/ln_fwd_6144.cu" -"layer_norm/ln_fwd_7168.cu" -"layer_norm/ln_fwd_768.cu" -"layer_norm/ln_fwd_8192.cu" -"layer_norm/ln_fwd_kernels.cuh" -"layer_norm/ln_kernel_traits.h" -"layer_norm/ln_parallel_bwd_1024.cu" -"layer_norm/ln_parallel_bwd_1280.cu" -"layer_norm/ln_parallel_bwd_1536.cu" -"layer_norm/ln_parallel_bwd_2048.cu" -"layer_norm/ln_parallel_bwd_256.cu" -"layer_norm/ln_parallel_bwd_2560.cu" -"layer_norm/ln_parallel_bwd_3072.cu" -"layer_norm/ln_parallel_bwd_4096.cu" -"layer_norm/ln_parallel_bwd_512.cu" -"layer_norm/ln_parallel_bwd_5120.cu" -"layer_norm/ln_parallel_bwd_6144.cu" -"layer_norm/ln_parallel_bwd_7168.cu" -"layer_norm/ln_parallel_bwd_768.cu" -"layer_norm/ln_parallel_bwd_8192.cu" -"layer_norm/ln_parallel_fwd_1024.cu" -"layer_norm/ln_parallel_fwd_1280.cu" -"layer_norm/ln_parallel_fwd_1536.cu" -"layer_norm/ln_parallel_fwd_2048.cu" -"layer_norm/ln_parallel_fwd_256.cu" -"layer_norm/ln_parallel_fwd_2560.cu" -"layer_norm/ln_parallel_fwd_3072.cu" -"layer_norm/ln_parallel_fwd_4096.cu" -"layer_norm/ln_parallel_fwd_512.cu" -"layer_norm/ln_parallel_fwd_5120.cu" -"layer_norm/ln_parallel_fwd_6144.cu" -"layer_norm/ln_parallel_fwd_7168.cu" -"layer_norm/ln_parallel_fwd_768.cu" -"layer_norm/ln_parallel_fwd_8192.cu" -"layer_norm/ln_parallel_residual_bwd_kernels.cuh" -"layer_norm/ln_parallel_residual_fwd_kernels.cuh" -"layer_norm/ln_utils.cuh" -"layer_norm/static_switch.h" -) - -# TODO: check if CLion support this: -# https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories -set_source_files_properties( - ${layer_norm_SRC} - PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/.") - -if(GPU_LANG STREQUAL "CUDA") - cuda_archs_loose_intersection(layer_norm_ARCHS "${CUDA_DEFAULT_KERNEL_ARCHS}" "${CUDA_ARCHS}") - message(STATUS "Capabilities for kernel layer_norm: ${layer_norm_ARCHS}") - set_gencode_flags_for_srcs(SRCS "${layer_norm_SRC}" CUDA_ARCHS "${layer_norm_ARCHS}") - - - foreach(_KERNEL_SRC ${layer_norm_SRC}) - if(_KERNEL_SRC MATCHES ".*\\.cu$") - set_property( - SOURCE ${_KERNEL_SRC} - APPEND PROPERTY - COMPILE_OPTIONS "$<$:-O3;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;-U__CUDA_NO_BFLOAT16_OPERATORS__;-U__CUDA_NO_BFLOAT16_CONVERSIONS__;-U__CUDA_NO_BFLOAT162_OPERATORS__;-U__CUDA_NO_BFLOAT162_CONVERSIONS__;--expt-relaxed-constexpr;--expt-extended-lambda;--use_fast_math>" - ) - endif() - endforeach() - - foreach(_KERNEL_SRC ${layer_norm_SRC}) - set_property( - SOURCE ${_KERNEL_SRC} - APPEND PROPERTY - COMPILE_OPTIONS "$<$:-DFLASHATTENTION_DISABLE_PYBIND>" - ) - endforeach() - - list(APPEND SRC "${layer_norm_SRC}") -endif() - - -define_gpu_extension_target( - _layer_norm_711aa42_dirty - DESTINATION _layer_norm_711aa42_dirty - LANGUAGE ${GPU_LANG} - SOURCES ${SRC} - COMPILE_FLAGS ${GPU_FLAGS} - ARCHITECTURES ${GPU_ARCHES} - #INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} - USE_SABI 3 - WITH_SOABI) - -target_link_options(_layer_norm_711aa42_dirty PRIVATE -static-libstdc++) - diff --git a/README.md b/README.md index 780410f8715b7915842faae9b6777975274ad678..bce7386e835bccaaf051975bc84096e7bbfd9395 100644 --- a/README.md +++ b/README.md @@ -2,23 +2,4 @@ tags: - kernel --- -This CUDA extension implements fused dropout + residual + LayerNorm, building on -Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). -Major changes: -- Add dropout and residual. -- Make it work for both pre-norm and post-norm architecture. -- Support more hidden dimensions (all dimensions divisible by 8, up to 8192). -- Implement RMSNorm as an option. -- Support layer norm with parallel residual (e.g., GPT-J, GPT-NeoX, PaLM). - -If you want to use it for dimensions larger than 8k, please file an issue. - -This extension has only been tested on A100s. - -```sh -cd csrc/layer_norm && pip install . -``` - -As of 2024-01-05, this extension is no longer used in the FlashAttention repo. -We've instead switched to a Triton-based -[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py). \ No newline at end of file +This CUDA extension implements fused dropout + residual + LayerNorm from the [flash-attention](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm) repo. \ No newline at end of file diff --git a/api.py b/api.py deleted file mode 100644 index 4b6cd798fd02844ef9cd3897f8ab95e490e638bf..0000000000000000000000000000000000000000 --- a/api.py +++ /dev/null @@ -1,800 +0,0 @@ -# Copyright (c) 2022, Tri Dao. -# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py - -import dropout_layer_norm -import torch -from torch.nn import init - - -def maybe_align(x, alignment_in_bytes=16): - """Assume that x already has last dim divisible by alignment_in_bytes""" - # TD [2023-07-04] I'm not 100% sure that clone will align the memory - # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440 - return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() - - -def _dropout_add_layer_norm_forward( - x0, - residual, - gamma, - beta, - rowscale, - colscale, - dropout_p, - epsilon, - residual_in_fp32=False, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes""" - hidden_size = gamma.numel() - x0mat = x0.view((-1, hidden_size)) - residualmat = residual.view((-1, hidden_size)) if residual is not None else None - rowscale = rowscale.view(-1) if rowscale is not None else None - zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( - x0mat, - residualmat, - gamma, - beta, - rowscale, - colscale, - None, - None, - dropout_p, - epsilon, - 1.0, - 0, - None, - residual_in_fp32, - is_rms_norm, - ) - # dmask is None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype - return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma - - -def _dropout_add_layer_norm_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - rowscale, - colscale, - dropout_p, - has_residual, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes - dx == None means that it was a post-norm architecture - (x = drop(x0) + residual was not returned in the fwd). - x0 must not be None if we have colscale. - """ - hidden_size = gamma.numel() - xmat = x.view((-1, hidden_size)) - dzmat = dz.view(xmat.shape) - dxmat = dx.view(xmat.shape) if dx is not None else None - x0mat = x0.view((-1, hidden_size)) if x0 is not None else None - rowscale = rowscale.view(-1) if rowscale is not None else None - if colscale is not None: - assert x0 is not None, "x0 is required to compute the gradient of colscale" - dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( - dzmat, - dxmat, - xmat, - x0mat, - dmask, - mu, - rsigma, - gamma, - rowscale, - colscale, - None, - None, - dropout_p, - 1.0, - 0, - has_residual, - is_rms_norm, - ) - # dresidualmat is None if not has_residual - if colscale is None: - return dx0mat, dresidualmat, dgamma, dbeta - else: - dcolscale = rest[0] - return dx0mat, dresidualmat, dgamma, dbeta, dcolscale - - -def _dropout_add_layer_norm_subset_forward( - x0, - residual, - gamma, - beta, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32=False, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes""" - hidden_size = gamma.numel() - x0mat = x0.view((-1, hidden_size)) - residualmat = residual.view((-1, hidden_size)) if residual is not None else None - x0_subset = x0_subset.view(-1) if x0_subset is not None else None - out_subset = out_subset.view(-1) if out_subset is not None else None - zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( - x0mat, - residualmat, - gamma, - beta, - None, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - None, - residual_in_fp32, - is_rms_norm, - ) - # dmask is None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype - return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma - - -def _dropout_add_layer_norm_subset_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - colscale, - x0_subset, - out_subset, - dropout_p, - rowscale_const, - x0_numrows, - has_residual, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes - dx == None means that it was a post-norm architecture - (x = drop(x0) + residual was not returned in the fwd). - x0 must not be None if we have colscale. - """ - hidden_size = gamma.numel() - xmat = x.view((-1, hidden_size)) - dzmat = dz.view(-1, hidden_size) - dxmat = dx.view(xmat.shape) if dx is not None else None - x0mat = x0.view((-1, hidden_size)) if x0 is not None else None - x0_subset = x0_subset.view(-1) if x0_subset is not None else None - out_subset = out_subset.view(-1) if out_subset is not None else None - if colscale is not None: - assert x0 is not None, "x0 is required to compute the gradient of colscale" - dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( - dzmat, - dxmat, - xmat, - x0mat, - dmask, - mu, - rsigma, - gamma, - None, - colscale, - x0_subset, - out_subset, - dropout_p, - rowscale_const, - x0_numrows, - has_residual, - is_rms_norm, - ) - # dresidualmat is None if not has_residual - if colscale is None: - return dx0mat, dresidualmat, dgamma, dbeta - else: - dcolscale = rest[0] - return dx0mat, dresidualmat, dgamma, dbeta, dcolscale - - -def _dropout_add_layer_norm_parallel_residual_forward( - x0, - x1, - residual, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - residual_in_fp32=False, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes""" - hidden_size = gamma0.numel() - x0mat = x0.view((-1, hidden_size)) - x1mat = x1.view((-1, hidden_size)) if x1 is not None else None - residualmat = residual.view((-1, hidden_size)) if residual is not None else None - ( - z0mat, - z1mat, - xmat, - dmask0, - dmask1, - mu, - rsigma, - ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd( - x0mat, - x1mat, - residualmat, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - None, - residual_in_fp32, - is_rms_norm, - ) - # dmask0 and dmask1 are None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype - return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma - - -def _dropout_add_layer_norm_parallel_residual_backward( - dz0, - dz1, - dx, - x, - dmask0, - dmask1, - mu, - rsigma, - gamma0, - gamma1, - dropout_p, - has_x1, - has_residual, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes - dx == None means that it was a post-norm architecture - (x = drop(x0) + residual was not returned in the fwd). - """ - hidden_size = gamma0.numel() - xmat = x.view((-1, hidden_size)) - dz0mat = dz0.view(xmat.shape) - dz1mat = dz1.view(xmat.shape) if dz1 is not None else None - dxmat = dx.view(xmat.shape) if dx is not None else None - ( - dx0mat, - dx1mat, - dresidualmat, - dgamma0, - dbeta0, - dgamma1, - dbeta1, - *rest, - ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd( - dz0mat, - dz1mat, - dxmat, - xmat, - dmask0, - dmask1, - mu, - rsigma, - gamma0, - gamma1, - dropout_p, - has_x1, - has_residual, - is_rms_norm, - ) - # dresidualmat is None if not has_residual - return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 - - -class DropoutAddLayerNormFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x0, - residual, - gamma, - beta, - rowscale, - colscale, - dropout_p, - epsilon, - residual_in_fp32=False, - prenorm=False, - is_rms_norm=False, - return_dmask=False, - ): - x0 = maybe_align(x0.contiguous(), 16) - residual = maybe_align(residual.contiguous(), 16) if residual is not None else None - gamma = maybe_align(gamma.contiguous(), 16) - beta = maybe_align(beta.contiguous(), 16) if beta is not None else None - rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None - colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None - zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( - x0, - residual, - gamma, - beta, - rowscale, - colscale, - dropout_p, - epsilon, - residual_in_fp32, - is_rms_norm, - ) - # Only need to save x0 if we need to compute gradient wrt colscale - x0_saved = x0 if colscale is not None else None - ctx.save_for_backward( - xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale - ) - ctx.prenorm = prenorm - ctx.dropout_p = dropout_p - ctx.has_residual = residual is not None - ctx.is_rms_norm = is_rms_norm - ctx.has_beta = beta is not None - if not return_dmask: - return ( - zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape)) - ) - else: - dmask = ( - dmask.view(x0.shape) - if dropout_p > 0.0 - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - ctx.mark_non_differentiable(dmask) - return ( - (zmat.view(x0.shape), dmask) - if not prenorm - else (zmat.view(x0.shape), xmat.view(x0.shape), dmask) - ) - - @staticmethod - def backward(ctx, dz, *args): - # assert dz.is_contiguous() - dz = maybe_align(dz.contiguous(), 16) # this happens! - dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None - x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors - # x0 is None if colscale is None - dropout_p = ctx.dropout_p - has_residual = ctx.has_residual - dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - rowscale, - colscale, - dropout_p, - has_residual, - ctx.is_rms_norm, - ) - dx0 = dx0mat.view(x.shape) - dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None - dcolscale = rest[0] if colscale is not None else None - return ( - dx0, - dresidual, - dgamma, - dbeta if ctx.has_beta else None, - None, - dcolscale, - None, - None, - None, - None, - None, - None, - ) - - -class DropoutAddLayerNormSubsetFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x0, - residual, - gamma, - beta, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32=False, - prenorm=False, - is_rms_norm=False, - return_dmask=False, - ): - x0 = maybe_align(x0.contiguous(), 16) - residual = maybe_align(residual.contiguous(), 16) if residual is not None else None - gamma = maybe_align(gamma.contiguous(), 16) - beta = maybe_align(beta.contiguous(), 16) if beta is not None else None - colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None - zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( - x0, - residual, - gamma, - beta, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32, - is_rms_norm, - ) - # Only need to save x0 if we need to compute gradient wrt colscale - x0_saved = x0 if colscale is not None else None - x_shape = (-1, *x0.shape[1:]) - ctx.save_for_backward( - xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset - ) - ctx.prenorm = prenorm - ctx.dropout_p = dropout_p - ctx.rowscale_const = rowscale_const - ctx.x0_numrows = x0.shape[:-1].numel() - ctx.has_residual = residual is not None - ctx.is_rms_norm = is_rms_norm - ctx.has_beta = beta is not None - z_shape = (-1, *x0.shape[1:]) - if not return_dmask: - return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape)) - else: - z = zmat.view(z_shape) - dmask = ( - dmask.view(x0.shape) - if dropout_p > 0.0 - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - ctx.mark_non_differentiable(dmask) - return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask) - - @staticmethod - def backward(ctx, dz, *args): - # assert dz.is_contiguous() - dz = maybe_align(dz.contiguous(), 16) # this happens! - dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None - x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors - # x0 is None if colscale is None - dropout_p = ctx.dropout_p - has_residual = ctx.has_residual - dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - colscale, - x0_subset, - out_subset, - dropout_p, - ctx.rowscale_const, - ctx.x0_numrows, - has_residual, - ctx.is_rms_norm, - ) - dx0 = dx0mat.view(-1, *x.shape[1:]) - dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None - dcolscale = rest[0] if colscale is not None else None - return ( - dx0, - dresidual, - dgamma, - dbeta if ctx.has_beta else None, - dcolscale, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x0, - x1, - residual, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - residual_in_fp32=False, - prenorm=False, - is_rms_norm=False, - return_dmask=False, - ): - x0 = maybe_align(x0.contiguous(), 16) - x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None - residual = maybe_align(residual.contiguous(), 16) if residual is not None else None - gamma0 = maybe_align(gamma0.contiguous(), 16) - beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None - gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None - beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None - ( - z0mat, - z1mat, - xmat, - dmask0, - dmask1, - mu, - rsigma, - ) = _dropout_add_layer_norm_parallel_residual_forward( - x0, - x1, - residual, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - residual_in_fp32, - is_rms_norm, - ) - ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma) - ctx.prenorm = prenorm - ctx.dropout_p = dropout_p - ctx.has_x1 = x1 is not None - ctx.has_residual = residual is not None - ctx.is_rms_norm = is_rms_norm - ctx.has_beta = beta0 is not None - z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None) - if not return_dmask: - return z if not prenorm else (*z, xmat.view(x0.shape)) - else: - dmask0 = ( - dmask0.view(x0.shape) - if dropout_p > 0.0 - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - dmask1 = ( - dmask1.view(x0.shape) - if dropout_p > 0.0 and x1 is not None - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - ctx.mark_non_differentiable(dmask0) - ctx.mark_non_differentiable(dmask1) - return ( - (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1) - ) - - @staticmethod - def backward(ctx, dz0, dz1, *args): - dz0 = maybe_align(dz0.contiguous(), 16) # this happens! - dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None - dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None - x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors - dropout_p = ctx.dropout_p - has_x1 = ctx.has_x1 - has_residual = ctx.has_residual - ( - dx0mat, - dx1mat, - dresidualmat, - dgamma0, - dbeta0, - dgamma1, - dbeta1, - ) = _dropout_add_layer_norm_parallel_residual_backward( - dz0, - dz1, - dx, - x, - dmask0, - dmask1, - mu, - rsigma, - gamma0, - gamma1, - dropout_p, - has_x1, - has_residual, - ctx.is_rms_norm, - ) - dx0 = dx0mat.view(x.shape) - dx1 = dx1mat.view(x.shape) if dx1mat is not None else None - dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None - return ( - dx0, - dx1, - dresidual, - dgamma0, - dbeta0 if ctx.has_beta else None, - dgamma1, - dbeta1 if ctx.has_beta else None, - None, - None, - None, - None, - None, - None, - ) - - -def layer_norm(x, weight, bias, epsilon): - return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) - - -def dropout_add_layer_norm( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - rowscale=None, - layerscale=None, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormFn.apply( - x0, - residual, - weight, - bias, - rowscale, - layerscale, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - False, - return_dropout_mask, - ) - - -def dropout_add_layer_norm_subset( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - layerscale=None, - x0_subset=None, - out_subset=None, - rowscale_const=1.0, - out_numrows=0, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormSubsetFn.apply( - x0, - residual, - weight, - bias, - layerscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32, - prenorm, - False, - return_dropout_mask, - ) - - -def dropout_add_layer_norm_parallel_residual( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormParallelResidualFn.apply( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - False, - return_dropout_mask, - ) - - -class DropoutAddLayerNorm(torch.nn.Module): - def __init__( - self, - hidden_size, - prenorm=False, - p=0.0, - eps=1e-5, - residual_in_fp32=False, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.prenorm = prenorm - self.p = p - self.eps = eps - self.residual_in_fp32 = residual_in_fp32 - self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.reset_parameters() - - def reset_parameters(self): - init.ones_(self.weight) - init.zeros_(self.bias) - - def forward(self, x0, residual=None): - return dropout_add_layer_norm( - x0, - residual, - self.weight, - self.bias, - self.p if self.training else 0.0, - self.eps, - prenorm=self.prenorm, - residual_in_fp32=self.residual_in_fp32, - ) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..b65a1144bf7a355e11357dbe6c99ab21295576f7 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:790cd814bbfcaf7ff83b5c68bcb91091a67f34e92b9a2494e2856462e71a3141 +size 716945944 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so deleted file mode 100755 index 0953f1318fb68e86696a97b02b322fab559cacc8..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:fe0515daaf1bbfd1246d18bd5c1a5cd6f366059090a8b6e402955d06caaa6392 -size 716945976 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/layer_norm/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/layer_norm/_ops.py index 5c75ec08feea6f660d15b32a492e40a6ada28802..228f4a40387a221a0c80bb02e7493b7192a8641c 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/layer_norm/_ops.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/layer_norm/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _layer_norm_f622ea1_dirty -ops = torch.ops._layer_norm_f622ea1_dirty +from . import _layer_norm_f3fd6bf +ops = torch.ops._layer_norm_f3fd6bf def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_layer_norm_f622ea1_dirty::{op_name}" \ No newline at end of file + return f"_layer_norm_f3fd6bf::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..3df71f8d95ec614c3d3c12cc5d2216ae980bda2a --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b17984ef79fc9d6427c8efe0a8cc8f1f6e2777f9a8641b86556b7bb2359626ab +size 712024816 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so deleted file mode 100755 index 414995bf072eaad4caba0c076e28e88587c336ac..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:04095de2e4bf9cd03f9ec481084d0c9e9e0baa0bab17a0ec9715f22f69bdfd33 -size 712024848 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/layer_norm/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/layer_norm/_ops.py index 5c75ec08feea6f660d15b32a492e40a6ada28802..228f4a40387a221a0c80bb02e7493b7192a8641c 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/layer_norm/_ops.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/layer_norm/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _layer_norm_f622ea1_dirty -ops = torch.ops._layer_norm_f622ea1_dirty +from . import _layer_norm_f3fd6bf +ops = torch.ops._layer_norm_f3fd6bf def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_layer_norm_f622ea1_dirty::{op_name}" \ No newline at end of file + return f"_layer_norm_f3fd6bf::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..e445ce11dae02e2c0efc6a80fd674cc94efe7d7f --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7629b13b777a390df75374fc60d85311679a56a5bbd9969e138822e5c0fe2b1e +size 1231333360 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so deleted file mode 100755 index 9e06402e0e2ac7406d5d9b30600c4bd38cc8c2b4..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ae0d54be8ee4e3ae33f47f0b27243c9cbd5668ff7756b1dfb5dcd9e2430f5a35 -size 1231333392 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/layer_norm/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/layer_norm/_ops.py index 5c75ec08feea6f660d15b32a492e40a6ada28802..228f4a40387a221a0c80bb02e7493b7192a8641c 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/layer_norm/_ops.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/layer_norm/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _layer_norm_f622ea1_dirty -ops = torch.ops._layer_norm_f622ea1_dirty +from . import _layer_norm_f3fd6bf +ops = torch.ops._layer_norm_f3fd6bf def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_layer_norm_f622ea1_dirty::{op_name}" \ No newline at end of file + return f"_layer_norm_f3fd6bf::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..88458ff7e97a031e3862ca60aef9f034d490dce6 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d6ffc9d5651e8de6440f2d4f58018a5ded07634582ae03eec5b9edf428f613a6 +size 712024904 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so deleted file mode 100755 index f1ce54f7d12c0dc8073ba02a4b8f8c199c67c3b7..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:12b6de6cef24c5ee7a390d91ee2ea7069533e66440cf78ae5df7ae3beff5c1ca -size 712024936 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/layer_norm/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/layer_norm/_ops.py index 5c75ec08feea6f660d15b32a492e40a6ada28802..228f4a40387a221a0c80bb02e7493b7192a8641c 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/layer_norm/_ops.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/layer_norm/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _layer_norm_f622ea1_dirty -ops = torch.ops._layer_norm_f622ea1_dirty +from . import _layer_norm_f3fd6bf +ops = torch.ops._layer_norm_f3fd6bf def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_layer_norm_f622ea1_dirty::{op_name}" \ No newline at end of file + return f"_layer_norm_f3fd6bf::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..58492a533279aab2d1237e1a3791b470e30a1335 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df39795e047e962019cbecbb11f93d8ee1fcfb49ed8326f2edc267bc0d90da08 +size 1231337936 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so deleted file mode 100755 index 21a4b9c34deea4cbdc2e92358260b35b55f903de..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d51ec6b6da7095cf5fc18493eb4b0b1c20485f01dff4b38370979ea3d0a9dd60 -size 1231337968 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_ops.py index 5c75ec08feea6f660d15b32a492e40a6ada28802..228f4a40387a221a0c80bb02e7493b7192a8641c 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_ops.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _layer_norm_f622ea1_dirty -ops = torch.ops._layer_norm_f622ea1_dirty +from . import _layer_norm_f3fd6bf +ops = torch.ops._layer_norm_f3fd6bf def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_layer_norm_f622ea1_dirty::{op_name}" \ No newline at end of file + return f"_layer_norm_f3fd6bf::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..f351a91ee4f7a30ff052cb2432d440b9dc873fa5 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dfef6947945f8f126a284c6a8ab861e180a5e628992eeb0b4b7c7914c50a59c2 +size 1283037344 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so deleted file mode 100755 index 16f9e3c511feae6058c68d6cc24e35252ab0954a..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9080934ece3b5e09db6178b1baa15b8baf9f6873e234a951a2122071e1190fba -size 1283037376 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_ops.py index 5c75ec08feea6f660d15b32a492e40a6ada28802..228f4a40387a221a0c80bb02e7493b7192a8641c 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_ops.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _layer_norm_f622ea1_dirty -ops = torch.ops._layer_norm_f622ea1_dirty +from . import _layer_norm_f3fd6bf +ops = torch.ops._layer_norm_f3fd6bf def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_layer_norm_f622ea1_dirty::{op_name}" \ No newline at end of file + return f"_layer_norm_f3fd6bf::{op_name}" \ No newline at end of file diff --git a/torch-ext/layer_norm/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/layer_norm/__init__.py similarity index 100% rename from torch-ext/layer_norm/__init__.py rename to build/torch29-cxx11-cu126-x86_64-linux/layer_norm/__init__.py diff --git a/build/torch29-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..962ce37910eb8988192e95359f8d071d82a8eed0 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2bdb57c0889ade2fc574156873c1d4b543796f2e8ad6a894be82ee2785459c9b +size 712029160 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/layer_norm/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/layer_norm/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..228f4a40387a221a0c80bb02e7493b7192a8641c --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/layer_norm/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _layer_norm_f3fd6bf +ops = torch.ops._layer_norm_f3fd6bf + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_layer_norm_f3fd6bf::{op_name}" \ No newline at end of file diff --git a/torch-ext/layer_norm/layers.py b/build/torch29-cxx11-cu126-x86_64-linux/layer_norm/layers.py similarity index 100% rename from torch-ext/layer_norm/layers.py rename to build/torch29-cxx11-cu126-x86_64-linux/layer_norm/layers.py diff --git a/build/torch29-cxx11-cu128-x86_64-linux/layer_norm/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/layer_norm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..946160f16b9dc91fefeea037fb7ac84fd6afd802 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/layer_norm/__init__.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + +from ._ops import ops + +from . import layers + +def dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm): + return ops.dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm) + +def dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm): + return ops.dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm) + +def dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm): + return ops.dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm) + +def dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm): + return ops.dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm) + +__all__ = [ + "layers", + "dropout_add_ln_fwd", + "dropout_add_ln_bwd", + "dropout_add_ln_parallel_residual_fwd", + "dropout_add_ln_parallel_residual_bwd", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..5a1b11bc582182bc52a10a63f1c9d3a58952ba5e --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:03e6e7ecbf276b306d89607100f78f2ce8b3385a77594676dbf0daabdce26fc7 +size 1231338080 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/layer_norm/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/layer_norm/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..228f4a40387a221a0c80bb02e7493b7192a8641c --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/layer_norm/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _layer_norm_f3fd6bf +ops = torch.ops._layer_norm_f3fd6bf + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_layer_norm_f3fd6bf::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/layer_norm/layers.py b/build/torch29-cxx11-cu128-x86_64-linux/layer_norm/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..7ed883f42ead452f8b60f498ec11302c53d3cf74 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/layer_norm/layers.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn + +from ._ops import ops + + +class LayerNorm(nn.Module): + weight: torch.Tensor + variance_epsilon: float + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + output = ops.dropout_add_ln_fwd( + hidden_states.view(-1, hidden_states.shape[-1]), + gamma = self.weight, + beta = None, + rowscale = None, + colscale = None, + x0_subset = None, + z_subset = None, + dropout_p = 0, + epsilon = self.variance_epsilon, + rowscale_const = 1.0, + z_numrows = hidden_states.shape[1], + gen = None, + residual_in_fp32 = False, + is_rms_norm = False, + ) + return output[0].view(hidden_states.shape) + +class LlamaRMSNorm(nn.Module): + weight: torch.Tensor + variance_epsilon: float + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + output = ops.dropout_add_ln_fwd( + hidden_states.view(-1, hidden_states.shape[-1]), + gamma = self.weight, + beta = None, + rowscale = None, + colscale = None, + x0_subset = None, + z_subset = None, + dropout_p = 0, + epsilon = self.variance_epsilon, + rowscale_const = 1.0, + z_numrows = hidden_states.shape[1], + gen = None, + residual_in_fp32 = False, + is_rms_norm = True, + ) + return output[0].view(hidden_states.shape) \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/layer_norm/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/layer_norm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..946160f16b9dc91fefeea037fb7ac84fd6afd802 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/layer_norm/__init__.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + +from ._ops import ops + +from . import layers + +def dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm): + return ops.dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm) + +def dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm): + return ops.dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm) + +def dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm): + return ops.dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm) + +def dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm): + return ops.dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm) + +__all__ = [ + "layers", + "dropout_add_ln_fwd", + "dropout_add_ln_bwd", + "dropout_add_ln_parallel_residual_fwd", + "dropout_add_ln_parallel_residual_bwd", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..7d4138d38df27d77d33d643bb64e07756f7757e4 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:322e2d8fc69447be95ef7b6e85267e8769f1284419baa606732a77b1980a834d +size 1238333264 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/layer_norm/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/layer_norm/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..228f4a40387a221a0c80bb02e7493b7192a8641c --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/layer_norm/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _layer_norm_f3fd6bf +ops = torch.ops._layer_norm_f3fd6bf + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_layer_norm_f3fd6bf::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/layer_norm/layers.py b/build/torch29-cxx11-cu130-x86_64-linux/layer_norm/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..7ed883f42ead452f8b60f498ec11302c53d3cf74 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/layer_norm/layers.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn + +from ._ops import ops + + +class LayerNorm(nn.Module): + weight: torch.Tensor + variance_epsilon: float + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + output = ops.dropout_add_ln_fwd( + hidden_states.view(-1, hidden_states.shape[-1]), + gamma = self.weight, + beta = None, + rowscale = None, + colscale = None, + x0_subset = None, + z_subset = None, + dropout_p = 0, + epsilon = self.variance_epsilon, + rowscale_const = 1.0, + z_numrows = hidden_states.shape[1], + gen = None, + residual_in_fp32 = False, + is_rms_norm = False, + ) + return output[0].view(hidden_states.shape) + +class LlamaRMSNorm(nn.Module): + weight: torch.Tensor + variance_epsilon: float + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + output = ops.dropout_add_ln_fwd( + hidden_states.view(-1, hidden_states.shape[-1]), + gamma = self.weight, + beta = None, + rowscale = None, + colscale = None, + x0_subset = None, + z_subset = None, + dropout_p = 0, + epsilon = self.variance_epsilon, + rowscale_const = 1.0, + z_numrows = hidden_states.shape[1], + gen = None, + residual_in_fp32 = False, + is_rms_norm = True, + ) + return output[0].view(hidden_states.shape) \ No newline at end of file diff --git a/cmake/hipify.py b/cmake/hipify.py deleted file mode 100644 index a1539c02a297d2f9abe66700c18c168079c6987a..0000000000000000000000000000000000000000 --- a/cmake/hipify.py +++ /dev/null @@ -1,76 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: Apache-2.0 - -# From vLLM: https://github.com/vllm-project/vllm/blob/main/cmake/hipify.py - -# -# A command line tool for running pytorch's hipify preprocessor on CUDA -# source files. -# -# See https://github.com/ROCm/hipify_torch -# and /utils/hipify/hipify_python.py -# - -import argparse -import os -import shutil - -from torch.utils.hipify.hipify_python import hipify - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - - # Project directory where all the source + include files live. - parser.add_argument( - "-p", - "--project_dir", - help="The project directory.", - ) - - # Directory where hipified files are written. - parser.add_argument( - "-o", - "--output_dir", - help="The output directory.", - ) - - # Source files to convert. - parser.add_argument("sources", - help="Source files to hipify.", - nargs="*", - default=[]) - - args = parser.parse_args() - - # Limit include scope to project_dir only - includes = [os.path.join(args.project_dir, '*')] - - # Get absolute path for all source files. - extra_files = [os.path.abspath(s) for s in args.sources] - - # Copy sources from project directory to output directory. - # The directory might already exist to hold object files so we ignore that. - shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True) - - hipify_result = hipify(project_directory=args.project_dir, - output_directory=args.output_dir, - header_include_dirs=[], - includes=includes, - extra_files=extra_files, - show_detailed=True, - is_pytorch_extension=True, - hipify_extra_files_only=True) - - hipified_sources = [] - for source in args.sources: - s_abs = os.path.abspath(source) - hipified_s_abs = (hipify_result[s_abs].hipified_path if - (s_abs in hipify_result - and hipify_result[s_abs].hipified_path is not None) - else s_abs) - hipified_sources.append(hipified_s_abs) - - assert (len(hipified_sources) == len(args.sources)) - - # Print hipified source files. - print("\n".join(hipified_sources)) diff --git a/cmake/utils.cmake b/cmake/utils.cmake deleted file mode 100644 index 6c87d51ed6cb9b0cc07af011190b7e2cec6d8a58..0000000000000000000000000000000000000000 --- a/cmake/utils.cmake +++ /dev/null @@ -1,545 +0,0 @@ -# Vendored from vLLM: -# -# https://github.com/vllm-project/vllm/blob/main/cmake/utils.cmake -# -# Attempt to find the python package that uses the same python executable as -# `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`. -# -macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS) - file(REAL_PATH ${EXECUTABLE} EXECUTABLE) - set(Python_EXECUTABLE ${EXECUTABLE}) - find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule) - if (NOT Python_FOUND) - message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") - endif() - set(_VER "${Python_VERSION_MAJOR}.${Python_VERSION_MINOR}") - set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN}) - if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST) - message(FATAL_ERROR - "Python version (${_VER}) is not one of the supported versions: " - "${_SUPPORTED_VERSIONS_LIST}.") - endif() - message(STATUS "Found python matching: ${EXECUTABLE}.") -endmacro() - -# -# Run `EXPR` in python. The standard output of python is stored in `OUT` and -# has trailing whitespace stripped. If an error is encountered when running -# python, a fatal message `ERR_MSG` is issued. -# -function (run_python OUT EXPR ERR_MSG) - execute_process( - COMMAND - "${Python_EXECUTABLE}" "-c" "${EXPR}" - OUTPUT_VARIABLE PYTHON_OUT - RESULT_VARIABLE PYTHON_ERROR_CODE - ERROR_VARIABLE PYTHON_STDERR - OUTPUT_STRIP_TRAILING_WHITESPACE) - - if(NOT PYTHON_ERROR_CODE EQUAL 0) - message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}") - endif() - set(${OUT} ${PYTHON_OUT} PARENT_SCOPE) -endfunction() - -# Run `EXPR` in python after importing `PKG`. Use the result of this to extend -# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported. -macro (append_cmake_prefix_path PKG EXPR) - run_python(_PREFIX_PATH - "import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path") - list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH}) -endmacro() - -# -# Add a target named `hipify${NAME}` that runs the hipify preprocessor on a set -# of CUDA source files. The names of the corresponding "hipified" sources are -# stored in `OUT_SRCS`. -# -function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS) - # - # Split into C++ and non-C++ (i.e. CUDA) sources. - # - set(NODUP_SRCS ${ORIG_SRCS}) - list(REMOVE_DUPLICATES NODUP_SRCS) - set(SRCS ${NODUP_SRCS}) - set(CXX_SRCS ${NODUP_SRCS}) - list(FILTER SRCS INCLUDE REGEX "\.cu$") - list(FILTER CXX_SRCS EXCLUDE REGEX "\.cu$") - - # - # Generate ROCm/HIP source file names from CUDA file names. - # Since HIP files are generated code, they will appear in the build area - # `CMAKE_CURRENT_BINARY_DIR` directory rather than the original csrc dir. - # - set(HIP_SRCS) - foreach (SRC ${SRCS}) - get_source_file_property(include_dirs "${SRC}" INCLUDE_DIRECTORIES) - string(REGEX REPLACE "\.cu$" "\.hip" SRC ${SRC}) - string(REGEX REPLACE "cuda" "hip" SRC ${SRC}) - - if(include_dirs) - # Copy over include directories from the original CUDA file. - set_source_files_properties( - ${SRC} - PROPERTIES INCLUDE_DIRECTORIES "${include_dirs}") - endif() - - list(APPEND HIP_SRCS "${CMAKE_CURRENT_BINARY_DIR}/${SRC}") - endforeach() - - add_custom_target( - hipify${NAME} - COMMAND "${Python_EXECUTABLE}" ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR} -o ${CMAKE_CURRENT_BINARY_DIR} ${SRCS} - DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS} - BYPRODUCTS ${HIP_SRCS} - COMMENT "Running hipify on ${NAME} extension source files.") - - # Swap out original extension sources with hipified sources. - list(APPEND HIP_SRCS ${CXX_SRCS}) - set(${OUT_SRCS} ${HIP_SRCS} PARENT_SCOPE) -endfunction() - -# -# Get additional GPU compiler flags from torch. -# -function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) - if (${GPU_LANG} STREQUAL "CUDA") - # - # Get common NVCC flags from torch. - # - run_python(GPU_FLAGS - "from torch.utils.cpp_extension import COMMON_NVCC_FLAGS; print(';'.join(COMMON_NVCC_FLAGS))" - "Failed to determine torch nvcc compiler flags") - - if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) - list(APPEND GPU_FLAGS "-DENABLE_FP8") - list(REMOVE_ITEM GPU_FLAGS - "-D__CUDA_NO_HALF_OPERATORS__" - "-D__CUDA_NO_HALF_CONVERSIONS__" - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" - "-D__CUDA_NO_HALF2_OPERATORS__") - endif() - - elseif(${GPU_LANG} STREQUAL "HIP") - # - # Get common HIP/HIPCC flags from torch. - # - run_python(GPU_FLAGS - "import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))" - "Failed to determine torch nvcc compiler flags") - - list(APPEND GPU_FLAGS - "-DUSE_ROCM" - "-DENABLE_FP8" - "-U__HIP_NO_HALF_CONVERSIONS__" - "-U__HIP_NO_HALF_OPERATORS__" - "-fno-gpu-rdc") - - endif() - set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE) -endfunction() - -# Macro for converting a `gencode` version number to a cmake version number. -macro(string_to_ver OUT_VER IN_STR) - string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR}) -endmacro() - -# -# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in -# `CUDA_ARCH_FLAGS`. -# -# Example: -# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75" -# clear_cuda_arches(CUDA_ARCH_FLAGS) -# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75" -# CMAKE_CUDA_FLAGS="-Wall" -# -macro(clear_cuda_arches CUDA_ARCH_FLAGS) - # Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS` - string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS - ${CMAKE_CUDA_FLAGS}) - - # Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified - # and passed back via the `CUDA_ARCHITECTURES` property. - string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS - ${CMAKE_CUDA_FLAGS}) -endmacro() - -# -# Extract unique CUDA architectures from a list of compute capabilities codes in -# the form `[]`, convert them to the form sort -# `.`, dedupes them and then sorts them in ascending order and -# stores them in `OUT_ARCHES`. -# -# Example: -# CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a" -# extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS) -# OUT_ARCHES="7.5;...;9.0" -function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS) - set(_CUDA_ARCHES) - foreach(_ARCH ${CUDA_ARCH_FLAGS}) - string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH}) - if (_COMPUTE) - set(_COMPUTE ${CMAKE_MATCH_1}) - endif() - - string_to_ver(_COMPUTE_VER ${_COMPUTE}) - list(APPEND _CUDA_ARCHES ${_COMPUTE_VER}) - endforeach() - - list(REMOVE_DUPLICATES _CUDA_ARCHES) - list(SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING) - set(${OUT_ARCHES} ${_CUDA_ARCHES} PARENT_SCOPE) -endfunction() - -# -# For a specific file set the `-gencode` flag in compile options conditionally -# for the CUDA language. -# -# Example: -# set_gencode_flag_for_srcs( -# SRCS "foo.cu" -# ARCH "compute_75" -# CODE "sm_75") -# adds: "-gencode arch=compute_75,code=sm_75" to the compile options for -# `foo.cu` (only for the CUDA language). -# -macro(set_gencode_flag_for_srcs) - set(options) - set(oneValueArgs ARCH CODE) - set(multiValueArgs SRCS) - cmake_parse_arguments(arg "${options}" "${oneValueArgs}" - "${multiValueArgs}" ${ARGN} ) - set(_FLAG -gencode arch=${arg_ARCH},code=${arg_CODE}) - set_property( - SOURCE ${arg_SRCS} - APPEND PROPERTY - COMPILE_OPTIONS "$<$:${_FLAG}>" - ) - - message(DEBUG "Setting gencode flag for ${arg_SRCS}: ${_FLAG}") -endmacro(set_gencode_flag_for_srcs) - -# -# For a list of source files set the `-gencode` flags in the files specific -# compile options (specifically for the CUDA language). -# -# arguments are: -# SRCS: list of source files -# CUDA_ARCHS: list of CUDA architectures in the form `.[letter]` -# BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built -# for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS -# that is larger than BUILD_PTX_FOR_ARCH. -# -macro(set_gencode_flags_for_srcs) - set(options) - set(oneValueArgs BUILD_PTX_FOR_ARCH) - set(multiValueArgs SRCS CUDA_ARCHS) - cmake_parse_arguments(arg "${options}" "${oneValueArgs}" - "${multiValueArgs}" ${ARGN} ) - - foreach(_ARCH ${arg_CUDA_ARCHS}) - # handle +PTX suffix: generate both sm and ptx codes if requested - string(FIND "${_ARCH}" "+PTX" _HAS_PTX) - if(NOT _HAS_PTX EQUAL -1) - string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}") - string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}") - set_gencode_flag_for_srcs( - SRCS ${arg_SRCS} - ARCH "compute_${_STRIPPED_ARCH}" - CODE "sm_${_STRIPPED_ARCH}") - set_gencode_flag_for_srcs( - SRCS ${arg_SRCS} - ARCH "compute_${_STRIPPED_ARCH}" - CODE "compute_${_STRIPPED_ARCH}") - else() - string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}") - set_gencode_flag_for_srcs( - SRCS ${arg_SRCS} - ARCH "compute_${_STRIPPED_ARCH}" - CODE "sm_${_STRIPPED_ARCH}") - endif() - endforeach() - - if (${arg_BUILD_PTX_FOR_ARCH}) - list(SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) - list(GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH) - if (_HIGHEST_ARCH VERSION_GREATER_EQUAL ${arg_BUILD_PTX_FOR_ARCH}) - string(REPLACE "." "" _PTX_ARCH "${arg_BUILD_PTX_FOR_ARCH}") - set_gencode_flag_for_srcs( - SRCS ${arg_SRCS} - ARCH "compute_${_PTX_ARCH}" - CODE "compute_${_PTX_ARCH}") - endif() - endif() -endmacro() - -# -# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form -# `.[letter]` compute the "loose intersection" with the -# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in -# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there -# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the -# architecture in `SRC_CUDA_ARCHS`. -# The loose intersection is defined as: -# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} } -# where `<=` is the version comparison operator. -# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version -# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`. -# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is -# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add -# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS). -# The result is stored in `OUT_CUDA_ARCHS`. -# -# Example: -# SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a" -# TGT_CUDA_ARCHS="8.0;8.9;9.0" -# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) -# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a" -# -# Example With PTX: -# SRC_CUDA_ARCHS="8.0+PTX" -# TGT_CUDA_ARCHS="9.0" -# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) -# OUT_CUDA_ARCHS="8.0+PTX" -# -function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) - set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}") - set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS}) - - # handle +PTX suffix: separate base arch for matching, record PTX requests - set(_PTX_ARCHS) - foreach(_arch ${_SRC_CUDA_ARCHS}) - if(_arch MATCHES "\\+PTX$") - string(REPLACE "+PTX" "" _base "${_arch}") - list(APPEND _PTX_ARCHS "${_base}") - list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}") - list(APPEND _SRC_CUDA_ARCHS "${_base}") - endif() - endforeach() - list(REMOVE_DUPLICATES _PTX_ARCHS) - list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS) - - # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should - # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS - set(_CUDA_ARCHS) - foreach(_arch ${_SRC_CUDA_ARCHS}) - if(_arch MATCHES "\\a$") - list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}") - string(REPLACE "a" "" _base "${_arch}") - if ("${_base}" IN_LIST TGT_CUDA_ARCHS) - list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}") - list(APPEND _CUDA_ARCHS "${_arch}") - endif() - endif() - endforeach() - - list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) - - # for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that - # is less or equal to ARCH (but has the same major version since SASS binary - # compatibility is only forward compatible within the same major version). - foreach(_ARCH ${_TGT_CUDA_ARCHS}) - set(_TMP_ARCH) - # Extract the major version of the target arch - string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}") - foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS}) - # Extract the major version of the source arch - string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}") - # Check version-less-or-equal, and allow PTX arches to match across majors - if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH) - if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR) - set(_TMP_ARCH "${_SRC_ARCH}") - endif() - else() - # If we hit a version greater than the target, we can break - break() - endif() - endforeach() - - # If we found a matching _TMP_ARCH, append it to _CUDA_ARCHS - if (_TMP_ARCH) - list(APPEND _CUDA_ARCHS "${_TMP_ARCH}") - endif() - endforeach() - - list(REMOVE_DUPLICATES _CUDA_ARCHS) - - # reapply +PTX suffix to architectures that requested PTX - set(_FINAL_ARCHS) - foreach(_arch ${_CUDA_ARCHS}) - if(_arch IN_LIST _PTX_ARCHS) - list(APPEND _FINAL_ARCHS "${_arch}+PTX") - else() - list(APPEND _FINAL_ARCHS "${_arch}") - endif() - endforeach() - set(_CUDA_ARCHS ${_FINAL_ARCHS}) - - set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) -endfunction() - -# -# For the given `SRC_ROCM_ARCHS` list of architecture versions in the form -# `` compute the "loose intersection" with the `TGT_ROCM_ARCHS` list. -# The loose intersection is defined as: -# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} } -# where `<=` is the version comparison operator. -# In other words, for each version in `TGT_ROCM_ARCHS` find the highest version -# in `SRC_ROCM_ARCHS` that is less or equal to the version in `TGT_ROCM_ARCHS`. -# The result is stored in `OUT_ROCM_ARCHS`. -# -# Example: -# SRC_ROCM_ARCHS="gfx900;gfx906;gfx908;gfx90a" -# TGT_ROCM_ARCHS="gfx906;gfx908;gfx1030" -# hip_archs_loose_intersection(OUT_ROCM_ARCHS SRC_ROCM_ARCHS TGT_ROCM_ARCHS) -# OUT_ROCM_ARCHS="gfx906;gfx908" -# -function(hip_archs_loose_intersection OUT_ROCM_ARCHS SRC_ROCM_ARCHS TGT_ROCM_ARCHS) - list(REMOVE_DUPLICATES SRC_ROCM_ARCHS) - - # ROCm architectures are typically in format gfxNNN or gfxNNNx where N is a digit - # and x is a letter. We can sort them by string comparison which works for this format. - list(SORT SRC_ROCM_ARCHS COMPARE STRING ORDER ASCENDING) - - set(_ROCM_ARCHS) - - # Find the intersection of supported architectures - foreach(_SRC_ARCH ${SRC_ROCM_ARCHS}) - if(_SRC_ARCH IN_LIST TGT_ROCM_ARCHS) - list(APPEND _ROCM_ARCHS ${_SRC_ARCH}) - endif() - endforeach() - - list(REMOVE_DUPLICATES _ROCM_ARCHS) - set(${OUT_ROCM_ARCHS} ${_ROCM_ARCHS} PARENT_SCOPE) -endfunction() - -# -# Override the GPU architectures detected by cmake/torch and filter them by -# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in -# `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set -# the architectures on a per file basis. -# -# Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`. -# -macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) - set(_GPU_SUPPORTED_ARCHES_LIST ${GPU_SUPPORTED_ARCHES} ${ARGN}) - message(STATUS "${GPU_LANG} supported arches: ${_GPU_SUPPORTED_ARCHES_LIST}") - - if (${GPU_LANG} STREQUAL "HIP") - # - # `GPU_ARCHES` controls the `--offload-arch` flags. - # - # If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list, - # if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling - # "rocm_agent_enumerator" in "enable_language(HIP)" - # (in file Modules/CMakeDetermineHIPCompiler.cmake) - # - if(DEFINED ENV{PYTORCH_ROCM_ARCH}) - set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH}) - else() - set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES}) - endif() - # - # Find the intersection of the supported + detected architectures to - # set the module architecture flags. - # - set(${GPU_ARCHES}) - foreach (_ARCH ${HIP_ARCHITECTURES}) - if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST) - list(APPEND ${GPU_ARCHES} ${_ARCH}) - endif() - endforeach() - - if(NOT ${GPU_ARCHES}) - message(FATAL_ERROR - "None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is" - " supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.") - endif() - endif() -endmacro() - -# -# Define a target named `GPU_MOD_NAME` for a single extension. The -# arguments are: -# -# DESTINATION - Module destination directory. -# LANGUAGE - The GPU language for this module, e.g CUDA, HIP, -# etc. -# SOURCES - List of source files relative to CMakeLists.txt -# directory. -# -# Optional arguments: -# -# ARCHITECTURES - A list of target GPU architectures in cmake -# format. -# Refer `CMAKE_CUDA_ARCHITECTURES` documentation -# and `CMAKE_HIP_ARCHITECTURES` for more info. -# ARCHITECTURES will use cmake's defaults if -# not provided. -# COMPILE_FLAGS - Extra compiler flags passed to NVCC/hip. -# INCLUDE_DIRECTORIES - Extra include directories. -# LIBRARIES - Extra link libraries. -# WITH_SOABI - Generate library with python SOABI suffix name. -# USE_SABI - Use python stable api -# -# Note: optimization level/debug info is set via cmake build type. -# -function (define_gpu_extension_target GPU_MOD_NAME) - cmake_parse_arguments(PARSE_ARGV 1 - GPU - "WITH_SOABI" - "DESTINATION;LANGUAGE;USE_SABI" - "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES") - - # Add hipify preprocessing step when building with HIP/ROCm. - if (GPU_LANGUAGE STREQUAL "HIP") - hipify_sources_target(GPU_SOURCES ${GPU_MOD_NAME} "${GPU_SOURCES}") - endif() - - if (GPU_WITH_SOABI) - set(GPU_WITH_SOABI WITH_SOABI) - else() - set(GPU_WITH_SOABI) - endif() - - if (GPU_USE_SABI) - Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}") - else() - Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}") - endif() - - if (GPU_LANGUAGE STREQUAL "HIP") - # Make this target dependent on the hipify preprocessor step. - add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME}) - endif() - - if (GPU_ARCHITECTURES) - set_target_properties(${GPU_MOD_NAME} PROPERTIES - ${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}") - endif() - - set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17) - - target_compile_options(${GPU_MOD_NAME} PRIVATE - $<$:${GPU_COMPILE_FLAGS}>) - - target_compile_definitions(${GPU_MOD_NAME} PRIVATE - "-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}") - - target_include_directories(${GPU_MOD_NAME} PRIVATE csrc - ${GPU_INCLUDE_DIRECTORIES}) - - target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES}) - - # Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of - # dependencies that are not necessary and may not be installed. - if (GPU_LANGUAGE STREQUAL "CUDA") - target_link_libraries(${GPU_MOD_NAME} PRIVATE CUDA::cudart) - else() - target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES}) - endif() - - install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME}) -endfunction() diff --git a/layer_norm/ln.h b/layer_norm/ln.h deleted file mode 100644 index 9830c092d0aca9f3466154a18d1d3c32d651716e..0000000000000000000000000000000000000000 --- a/layer_norm/ln.h +++ /dev/null @@ -1,281 +0,0 @@ -#pragma once - -#include -#include -#include - -#ifdef OLD_GENERATOR_PATH -#include -#else -#include -#endif - -namespace layer_norm { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct LaunchParams{ - - size_t elts_per_thread; - size_t workspace_bytes; - size_t barrier_size; - - cudaDeviceProp * props; - - cudaStream_t stream; - - Params params; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct ParamsBase { - ParamsBase() - : ctas_per_col(0) - , rows(0) - , cols(0) - , x(nullptr) - , mu(nullptr) - , rs(nullptr) - , gamma(nullptr) - , gamma1(nullptr) - , rowscale(nullptr) - , colscale(nullptr) - , dropout_keep_p(1.f) - , dropout_scale(1.f) - , is_rms_norm(false) - , workspace(nullptr) - , barrier(nullptr) - { - } - - // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. - int ctas_per_col; - - // Input is interpreted as matrix. We normalize across columns. - int rows; - int cols; - - // Common data pointers. - void *x0; - void *x1; - void *residual; - void *x; - void *dmask; - void *dmask1; - void *mu; - void *rs; - void *gamma; - void *gamma1; - void *rowscale; - void *colscale; - void *x0_subset; - void *z_subset; - - float inverse_cols; - - float dropout_keep_p; - float dropout_scale; - float rowscale_const; - - bool is_rms_norm; - - // Multi-CTA workspace in gmem. - void *workspace; - - // Multi-CTA sync barriers in gmem. - int *barrier; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct FwdParams : public ParamsBase { - FwdParams() - : ParamsBase() - , z(nullptr) - , z1(nullptr) - , beta(nullptr) - , beta1(nullptr) - , epsilon(0.f) - { - } - - // Output of LN FWD. - void *z; - void *z1; - void *beta; - void *beta1; - float epsilon; - - // Random state. - at::PhiloxCudaState philox_args; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct BwdParams : public ParamsBase { - BwdParams() - : ParamsBase() - , dz(nullptr) - , dz1(nullptr) - , dx(nullptr) - , dbeta_part(nullptr) - , dgamma_part(nullptr) - , dbeta1_part(nullptr) - , dgamma1_part(nullptr) - , dcolscale_part(nullptr) - , dx0(nullptr) - , dx1(nullptr) - , dresidual(nullptr) - , dbeta(nullptr) - , dgamma(nullptr) - , dbeta1(nullptr) - , dgamma1(nullptr) - , dcolscale(nullptr) - { - } - - // Input: gradient wrt. LN FWD output. - void *dz; - void *dz1; - // Input: gradient wrt residual. - void *dx; - - // Workspace for Wgrad pre-reduction. - void *dbeta_part; - void *dgamma_part; - void *dbeta1_part; - void *dgamma1_part; - void *dcolscale_part; - - // Output: Dgrad. - void *dx0; - void *dx1; - void *dresidual; - // Output: Wgrad. - void *dbeta; - void *dgamma; - void *dbeta1; - void *dgamma1; - void *dcolscale; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -using FwdFunction = std::function&, const bool)>; -using BwdFunction = std::function&, const bool)>; -using FunctionKey = uint64_t; -using FwdRegistry = std::unordered_map; -using BwdRegistry = std::unordered_map; - -extern FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS; -extern BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -using fp32 = float; -using fp16 = half; -using bf16 = nv_bfloat16; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TypeId{}; - -template<> -struct TypeId{ - constexpr static uint32_t Value = 0; -}; - -template<> -struct TypeId{ - constexpr static uint32_t Value = 1; -}; - -template<> -struct TypeId{ - constexpr static uint32_t Value = 2; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Type2Key{ - constexpr static uint32_t Value = TypeId::Value << S; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct WeightType2Key : public Type2Key{}; - -template -struct InputType2Key : public Type2Key{}; - -template -struct ResidualType2Key : public Type2Key{}; - -template -struct OutputType2Key : public Type2Key{}; - -template -struct ComputeType2Key : public Type2Key{}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Types2Key{ - constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | ResidualType2Key::Value | OutputType2Key::Value | ComputeType2Key::Value; - constexpr static inline uint64_t get(const uint64_t hidden_size){ - constexpr uint64_t type_key = Value; - return (type_key << 32) | hidden_size; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdRegistrar{ - FwdRegistrar(FwdFunction f){ - uint64_t key = Types2Key::get(HIDDEN_SIZE); - FWD_FUNCS.insert({ key, f }); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdRegistrar{ - BwdRegistrar(BwdFunction f){ - uint64_t key = Types2Key::get(HIDDEN_SIZE); - BWD_FUNCS.insert({ key, f }); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdParallelRegistrar{ - FwdParallelRegistrar(FwdFunction f){ - uint64_t key = Types2Key::get(HIDDEN_SIZE); - PARALLEL_FWD_FUNCS.insert({ key, f }); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdParallelRegistrar{ - BwdParallelRegistrar(BwdFunction f){ - uint64_t key = Types2Key::get(HIDDEN_SIZE); - PARALLEL_BWD_FUNCS.insert({ key, f }); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace layer_norm diff --git a/layer_norm/ln_api.cpp b/layer_norm/ln_api.cpp deleted file mode 100644 index 39481bdd4be5b0457d2a6ab9341ad9b19a3f2fa7..0000000000000000000000000000000000000000 --- a/layer_norm/ln_api.cpp +++ /dev/null @@ -1,828 +0,0 @@ -#include -#include "ATen/cuda/CUDAContext.h" -#include - -#include "ln.h" - -/* - -Supported Type combinations: - -input residual compute weights output -============================================ -fp32 fp32 fp32 fp32 fp32 -fp16 fp32 fp32 fp32 fp16 -fp16 fp16 fp32 fp32 fp16 -bf16 fp32 fp32 fp32 bf16 -bf16 bf16 fp32 fp32 bf16 -fp16 fp16 fp32 fp16 fp16 -bf16 bf16 fp32 bf16 bf16 - -Remarks: -Output type = Input type -Compute always in FP32 - -*/ - -namespace layer_norm { - -// Create registries and provide runtime versions of config hash functions. - -FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS; -BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -uint32_t get_type_id(torch::Dtype dtype){ - if( dtype == torch::kFloat16 ) { - return TypeId::Value; - } else if( dtype == torch::kBFloat16 ) { - return TypeId::Value; - } else if( dtype == torch::kFloat32 ) { - return TypeId::Value; - } else { - TORCH_CHECK(false, "Type not supported: ", dtype); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) { - using namespace layer_norm; - uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(rtype) << 4) | (get_type_id(otype) << 6) | (get_type_id(ctype) << 8); - uint64_t launcher_key = (type_key << 32) | hidden_size; - return launcher_key; -} - -} // namespace layer_norm - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { - auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size)); - if( iter != layer_norm::FWD_FUNCS.end() ) { - return iter->second; - } else { - TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { - auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size)); - if( iter != layer_norm::BWD_FUNCS.end() ) { - return iter->second; - } else { - TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -layer_norm::FwdFunction & get_parallel_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { - auto iter = layer_norm::PARALLEL_FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size)); - if( iter != layer_norm::PARALLEL_FWD_FUNCS.end() ) { - return iter->second; - } else { - TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -layer_norm::BwdFunction & get_parallel_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { - auto iter = layer_norm::PARALLEL_BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size)); - if( iter != layer_norm::PARALLEL_BWD_FUNCS.end() ) { - return iter->second; - } else { - TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size - c10::optional &residual_, // Residual: BxSxhidden_size - const at::Tensor &gamma, // hidden_size - c10::optional &beta_, // hidden_size - c10::optional &rowscale_, // BxS - c10::optional &colscale_, // hidden_size - c10::optional &x0_subset_, // BxS - c10::optional &z_subset_, // BxS - const float dropout_p, - const float epsilon, - const float rowscale_const, - const int64_t z_numrows, - c10::optional gen_, - bool residual_in_fp32=false, - bool is_rms_norm=false -) { - auto itype = x0.scalar_type(); - auto rtype = residual_.has_value() - ? residual_.value().scalar_type() - : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type()); - auto wtype = gamma.scalar_type(); - auto otype = itype; - auto ctype = torch::kFloat32; - auto mtype = torch::kUInt8; - - TORCH_CHECK(x0.is_cuda()); - TORCH_CHECK(gamma.is_cuda()); - - TORCH_CHECK(x0.is_contiguous()); - // c10::IntArrayRef does not own the storage, so we need to construct a vector. - // Otherwise just constructing IntArrayRef({blah}) will cause uninitialized memory because - // blah is then deallocated. - std::vector sizes_vec {!x0_subset_.has_value() ? x0.size(0) : x0_subset_.value().size(0), x0.size(1)}; - auto sizes = c10::IntArrayRef(sizes_vec); - TORCH_CHECK(x0.dim() == 2); - TORCH_CHECK(sizes.size() == 2); - - const int rows = sizes[0]; - const int cols = sizes[1]; - auto hidden_size = gamma.numel(); - TORCH_CHECK(hidden_size == cols); - - if (beta_.has_value()) { - auto beta = beta_.value(); - TORCH_CHECK(beta.dtype() == wtype); - TORCH_CHECK(beta.is_cuda()); - TORCH_CHECK(beta.is_contiguous()); - TORCH_CHECK(beta.sizes() == gamma.sizes()); - } - - if (residual_.has_value()) { - auto residual = residual_.value(); - TORCH_CHECK(residual.is_cuda()); - TORCH_CHECK(residual.is_contiguous()); - TORCH_CHECK(residual.sizes() == sizes); - } - - if (rowscale_.has_value()) { - auto rowscale = rowscale_.value(); - TORCH_CHECK(rowscale.is_cuda()); - TORCH_CHECK(rowscale.is_contiguous()); - TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows}); - TORCH_CHECK(rowscale.dtype() == itype); - } - - if (colscale_.has_value()) { - auto colscale = colscale_.value(); - TORCH_CHECK(colscale.is_cuda()); - TORCH_CHECK(colscale.is_contiguous()); - TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols}); - TORCH_CHECK(colscale.dtype() == wtype); - } - - if (x0_subset_.has_value()) { - auto x0_subset = x0_subset_.value(); - TORCH_CHECK(x0_subset.is_cuda()); - TORCH_CHECK(x0_subset.is_contiguous()); - TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows}); - TORCH_CHECK(x0_subset.dtype() == torch::kInt32); - - TORCH_CHECK(z_subset_.has_value()); - auto z_subset = z_subset_.value(); - TORCH_CHECK(z_subset.is_cuda()); - TORCH_CHECK(z_subset.is_contiguous()); - TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows}); - TORCH_CHECK(z_subset.dtype() == torch::kInt32); - } - - TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192)); - TORCH_CHECK(epsilon >= 0.f); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)x0.get_device()}; - - auto opts = x0.options(); - - bool save_x = residual_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype); - at::Tensor x; - if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); } - at::Tensor dmask; - if (dropout_p > 0.f) { dmask = torch::empty(x0.sizes(), opts.dtype(mtype)); }; - auto z = torch::empty(z_subset_.has_value() ? c10::IntArrayRef{z_numrows, cols} : sizes, opts.dtype(otype)); - - auto mu = torch::empty({ rows }, opts.dtype(ctype)); - auto rsigma = torch::empty({ rows }, opts.dtype(ctype)); - - layer_norm::LaunchParams launch_params; - - launch_params.props = at::cuda::getCurrentDeviceProperties(); - launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); - TORCH_CHECK(dropout_p < 1.f); - launch_params.params.dropout_keep_p = 1.f - dropout_p; - launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr; - launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; - launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr; - launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr; - launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr; - - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024); - // Request the kernel launcher. - auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple)); - - // Set the kernel runtime parameters. - layer_norm::FwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x0 = x0.data_ptr(); - params.x = save_x ? x.data_ptr() : nullptr; - params.dmask = dropout_p > 0.f ? dmask.data_ptr() : nullptr; - params.mu = mu.data_ptr(); - params.rs = rsigma.data_ptr(); - params.gamma = gamma.data_ptr(); - params.beta = beta_.has_value() ? beta_.value().data_ptr() : nullptr; - params.z = z.data_ptr(); - params.epsilon = epsilon; - params.dropout_scale = 1.f / (1.f - dropout_p); - params.inverse_cols = 1.f / float(params.cols); - params.rowscale_const = rowscale_const; - params.is_rms_norm = is_rms_norm; - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - - at::Tensor workspace, barrier; - - if (dropout_p > 0.f) { - // number of times random will be generated per thread, to offset philox counter in thc random - // state - int64_t counter_offset = launch_params.elts_per_thread; - - // See Note [Acquire lock when using random generators] - { - std::lock_guard lock(gen->mutex_); - params.philox_args = gen->philox_cuda_state(counter_offset); - } - } - - if( launch_params.barrier_size > 0 ) { - auto options = x0.options(); - barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); - workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); - params.workspace = workspace.data_ptr(); - params.barrier = barrier.data_ptr(); - } - - // Launch the kernel. - launcher(launch_params, false); - - return { z, x, dmask, mu, rsigma }; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidden_size - c10::optional &dx_, // BxSxhidden_size - const at::Tensor &x, // BxSxhidden_size - c10::optional &x0_, // BxSxhidden_size - c10::optional &dmask_, // BxSxhidden_size - const at::Tensor &mu, // BxS, FP32! - const at::Tensor &rsigma, // BxS, FP32! - const at::Tensor &gamma, // hidden_size - c10::optional &rowscale_, // BxS - c10::optional &colscale_, // hidden_size - c10::optional &x0_subset_, // BxS - c10::optional &z_subset_, // BxS - const float dropout_p, - const float rowscale_const, - const int64_t x0_numrows, - const bool has_residual, - bool is_rms_norm=false -) { - - auto itype = dz.scalar_type(); - auto rtype = x.scalar_type(); - auto wtype = gamma.scalar_type(); - auto otype = itype; - auto ctype = torch::kFloat32; - auto mtype = torch::kUInt8; - - if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); } - - TORCH_CHECK(dz.dtype() == otype); - TORCH_CHECK(mu.dtype() == ctype); - TORCH_CHECK(rsigma.dtype() == ctype); - - TORCH_CHECK(x.is_cuda()); - TORCH_CHECK(dz.is_cuda()); - TORCH_CHECK(mu.is_cuda()); - TORCH_CHECK(rsigma.is_cuda()); - TORCH_CHECK(gamma.is_cuda()); - - TORCH_CHECK(x.is_contiguous()); - TORCH_CHECK(dz.is_contiguous()); - - auto sizes = x.sizes(); - TORCH_CHECK(sizes.size() == 2); - auto rows = sizes[0]; - auto cols = sizes[1]; - TORCH_CHECK(dz.dim() == 2); - TORCH_CHECK(dz.size(1) == cols); - auto hidden_size = gamma.numel(); - TORCH_CHECK(hidden_size == cols); - - // c10::IntArrayRef does not own the storage, so we need to construct a vector. - // Otherwise just constructing IntArrayRef({blah}) will cause uninitialized memory because - // blah is then deallocated. - std::vector x0_sizes_vec {!x0_subset_.has_value() ? rows : x0_numrows, cols}; - auto x0_sizes = c10::IntArrayRef(x0_sizes_vec); - - if (dx_.has_value()) { - auto dx = dx_.value(); - TORCH_CHECK(dx.dtype() == rtype); - TORCH_CHECK(dx.is_cuda()); - TORCH_CHECK(dx.is_contiguous()); - TORCH_CHECK(dx.sizes() == sizes); - } - - if (dmask_.has_value()) { - auto dmask = dmask_.value(); - TORCH_CHECK(dmask.dtype() == mtype); - TORCH_CHECK(dmask.is_cuda()); - TORCH_CHECK(dmask.is_contiguous()); - TORCH_CHECK(dmask.sizes() == x0_sizes); - } - - if (rowscale_.has_value()) { - auto rowscale = rowscale_.value(); - TORCH_CHECK(rowscale.is_cuda()); - TORCH_CHECK(rowscale.is_contiguous()); - TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows}); - TORCH_CHECK(rowscale.dtype() == itype); - } - - if (colscale_.has_value()) { - auto colscale = colscale_.value(); - TORCH_CHECK(colscale.is_cuda()); - TORCH_CHECK(colscale.is_contiguous()); - TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols}); - TORCH_CHECK(colscale.dtype() == wtype); - - TORCH_CHECK(x0_.has_value()); - auto x0 = x0_.value(); - TORCH_CHECK(x0.is_cuda()); - TORCH_CHECK(x0.is_contiguous()); - TORCH_CHECK(x0.sizes() == x0_sizes); - TORCH_CHECK(x0.dtype() == itype); - } - - if (x0_subset_.has_value()) { - auto x0_subset = x0_subset_.value(); - TORCH_CHECK(x0_subset.is_cuda()); - TORCH_CHECK(x0_subset.is_contiguous()); - TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows}); - TORCH_CHECK(x0_subset.dtype() == torch::kInt32); - - TORCH_CHECK(z_subset_.has_value()); - auto z_subset = z_subset_.value(); - TORCH_CHECK(z_subset.is_cuda()); - TORCH_CHECK(z_subset.is_contiguous()); - TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows}); - TORCH_CHECK(z_subset.dtype() == torch::kInt32); - } - - TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192)); - - TORCH_CHECK(mu.numel() == rows); - TORCH_CHECK(mu.sizes() == rsigma.sizes()); - - TORCH_CHECK(gamma.numel() == cols); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)dz.get_device()}; - - auto opts = x.options(); - - auto dx0 = torch::empty(x0_sizes, opts.dtype(itype)); - at::Tensor dresidual; - if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); } - auto dgamma = torch::empty_like(gamma); - auto dbeta = torch::empty_like(gamma); - at::Tensor dcolscale; - if (colscale_.has_value()) { - dcolscale = torch::empty_like(colscale_.value()); - } - - layer_norm::LaunchParams launch_params; - launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); - launch_params.props = at::cuda::getCurrentDeviceProperties(); - TORCH_CHECK(dropout_p < 1.f); - launch_params.params.dropout_keep_p = 1.f - dropout_p; - launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr; - launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; - launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr; - launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr; - launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr; - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024); - auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple)); - - launcher(launch_params, true); - - auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); - auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); - at::Tensor dcolscale_part; - if (colscale_.has_value()) { - dcolscale_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); - } - at::Tensor workspace, barrier; - - layer_norm::BwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data_ptr(); - params.x0 = x0_.has_value() ? x0_.value().data_ptr() : nullptr; - params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr; - params.mu = mu.data_ptr(); - params.rs = rsigma.data_ptr(); - params.gamma = gamma.data_ptr(); - params.dz = dz.data_ptr(); - params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr; - params.dx0 = dx0.data_ptr(); - params.dbeta = dbeta.data_ptr(); - params.dgamma = dgamma.data_ptr(); - params.dcolscale = colscale_.has_value() ? dcolscale.data_ptr() : nullptr; - params.dbeta_part = dbeta_part.data_ptr(); - params.dgamma_part = dgamma_part.data_ptr(); - params.dcolscale_part = colscale_.has_value() ? dcolscale_part.data_ptr() : nullptr; - params.dropout_scale = 1.f / (1.f - dropout_p); - params.inverse_cols = 1.f / float(params.cols); - params.rowscale_const = rowscale_const; - params.is_rms_norm = is_rms_norm; - - if( launch_params.barrier_size > 0 ) { - // TODO Any way to avoid this? - barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32)); - workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar)); - params.workspace = workspace.data_ptr(); - params.barrier = barrier.data_ptr(); - } - - launcher(launch_params, false); - - std::vector result = { dx0, dresidual, dgamma, dbeta, dgamma_part, dbeta_part }; - if (colscale_.has_value()) { - result.push_back(dcolscale); - result.push_back(dcolscale_part); - } - return result; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -std::vector dropout_add_ln_parallel_residual_fwd( - const at::Tensor &x0, // Input: BxSxhidden_size - c10::optional &x1_, // Input: BxSxhidden_size - c10::optional &residual_, // Residual: BxSxhidden_size - const at::Tensor &gamma0, // hidden_size - c10::optional &beta0_, // hidden_size - c10::optional &gamma1_, // hidden_size - c10::optional &beta1_, // hidden_size - const float dropout_p, - const float epsilon, - c10::optional gen_, - bool residual_in_fp32=false, - bool is_rms_norm=false -) { - auto itype = x0.scalar_type(); - auto rtype = residual_.has_value() - ? residual_.value().scalar_type() - : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type()); - auto wtype = gamma0.scalar_type(); - auto otype = itype; - auto ctype = torch::kFloat32; - auto mtype = torch::kUInt8; - - TORCH_CHECK(x0.is_cuda()); - TORCH_CHECK(gamma0.is_cuda()); - - TORCH_CHECK(x0.is_contiguous()); - const auto sizes = x0.sizes(); - TORCH_CHECK(x0.dim() == 2); - - const int rows = sizes[0]; - const int cols = sizes[1]; - auto hidden_size = gamma0.numel(); - TORCH_CHECK(hidden_size == cols); - - if (x1_.has_value()) { - auto x1 = x1_.value(); - TORCH_CHECK(x1.is_cuda()); - TORCH_CHECK(x1.is_contiguous()); - TORCH_CHECK(x1.sizes() == sizes); - } - - if (residual_.has_value()) { - auto residual = residual_.value(); - TORCH_CHECK(residual.is_cuda()); - TORCH_CHECK(residual.is_contiguous()); - TORCH_CHECK(residual.sizes() == sizes); - } - - if (beta0_.has_value()) { - auto beta0 = beta0_.value(); - TORCH_CHECK(beta0.dtype() == wtype); - TORCH_CHECK(beta0.is_cuda()); - TORCH_CHECK(beta0.is_contiguous()); - TORCH_CHECK(beta0.sizes() == gamma0.sizes()); - } - - if (gamma1_.has_value()) { - auto gamma1 = gamma1_.value(); - TORCH_CHECK(gamma1.dtype() == wtype); - TORCH_CHECK(gamma1.is_cuda()); - TORCH_CHECK(gamma1.is_contiguous()); - TORCH_CHECK(gamma1.sizes() == gamma0.sizes()); - } - - if (beta1_.has_value()) { - auto beta1 = beta1_.value(); - TORCH_CHECK(beta1.dtype() == wtype); - TORCH_CHECK(beta1.is_cuda()); - TORCH_CHECK(beta1.is_contiguous()); - TORCH_CHECK(beta1.sizes() == gamma0.sizes()); - } - - TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192)); - TORCH_CHECK(epsilon >= 0.f); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)x0.get_device()}; - - auto opts = x0.options(); - - bool save_x = residual_.has_value() || x1_.has_value() || (dropout_p > 0.f) || (itype != rtype); - at::Tensor x; - if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); } - at::Tensor dmask0, dmask1; - if (dropout_p > 0.f) { - dmask0 = torch::empty(x0.sizes(), opts.dtype(mtype)); - if (x1_.has_value()) { dmask1 = torch::empty(x0.sizes(), opts.dtype(mtype)); } - }; - auto z0 = torch::empty(sizes, opts.dtype(otype)); - at::Tensor z1; - if (gamma1_.has_value()) { z1 = torch::empty(sizes, opts.dtype(otype)); } - - auto mu = torch::empty({ rows }, opts.dtype(ctype)); - auto rsigma = torch::empty({ rows }, opts.dtype(ctype)); - - layer_norm::LaunchParams launch_params; - - launch_params.props = at::cuda::getCurrentDeviceProperties(); - launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); - TORCH_CHECK(dropout_p < 1.f); - launch_params.params.dropout_keep_p = 1.f - dropout_p; - launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr; - - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024); - // Request the kernel launcher. - auto launcher = get_parallel_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple)); - - // Set the kernel runtime parameters. - layer_norm::FwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x0 = x0.data_ptr(); - params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr; - params.x = save_x ? x.data_ptr() : nullptr; - params.dmask = dropout_p > 0.f ? dmask0.data_ptr() : nullptr; - params.dmask1 = (dropout_p > 0.f && x1_.has_value()) ? dmask1.data_ptr() : nullptr; - params.mu = mu.data_ptr(); - params.rs = rsigma.data_ptr(); - params.gamma = gamma0.data_ptr(); - params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr; - params.beta = beta0_.has_value() ? beta0_.value().data_ptr() : nullptr; - params.beta1 = beta1_.has_value() ? beta1_.value().data_ptr() : nullptr; - params.z = z0.data_ptr(); - params.z1 = gamma1_.has_value() ? z1.data_ptr() : nullptr; - params.epsilon = epsilon; - params.dropout_scale = 1.f / (1.f - dropout_p); - params.inverse_cols = 1.f / float(params.cols); - params.is_rms_norm = is_rms_norm; - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - - at::Tensor workspace, barrier; - - if (dropout_p > 0.f) { - // number of times random will be generated per thread, to offset philox counter in thc random - // state - int64_t counter_offset = 2 * launch_params.elts_per_thread; - - // See Note [Acquire lock when using random generators] - { - std::lock_guard lock(gen->mutex_); - params.philox_args = gen->philox_cuda_state(counter_offset); - } - } - - if( launch_params.barrier_size > 0 ) { - auto options = x0.options(); - barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); - workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); - params.workspace = workspace.data_ptr(); - params.barrier = barrier.data_ptr(); - } - - // Launch the kernel. - launcher(launch_params, false); - - return { z0, z1, x, dmask0, dmask1, mu, rsigma }; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -std::vector dropout_add_ln_parallel_residual_bwd( - const at::Tensor &dz0, // BxSxhidden_size - c10::optional &dz1_, // BxSxhidden_size - c10::optional &dx_, // BxSxhidden_size - const at::Tensor &x, // BxSxhidden_size - c10::optional &dmask0_, // BxSxhidden_size - c10::optional &dmask1_, // BxSxhidden_size - const at::Tensor &mu, // BxS, FP32! - const at::Tensor &rsigma, // BxS, FP32! - const at::Tensor &gamma0, // hidden_size - c10::optional &gamma1_, // hidden_size - const float dropout_p, - const bool has_x1, - const bool has_residual, - bool is_rms_norm=false -) { - - auto itype = dz0.scalar_type(); - auto rtype = x.scalar_type(); - auto wtype = gamma0.scalar_type(); - auto otype = itype; - auto ctype = torch::kFloat32; - auto mtype = torch::kUInt8; - - if (dropout_p > 0.f) { TORCH_CHECK(dmask0_.has_value()); } - - TORCH_CHECK(dz0.dtype() == otype); - TORCH_CHECK(dz0.dtype() == otype); - TORCH_CHECK(mu.dtype() == ctype); - TORCH_CHECK(rsigma.dtype() == ctype); - - TORCH_CHECK(x.is_cuda()); - TORCH_CHECK(dz0.is_cuda()); - TORCH_CHECK(mu.is_cuda()); - TORCH_CHECK(rsigma.is_cuda()); - TORCH_CHECK(gamma0.is_cuda()); - - TORCH_CHECK(x.is_contiguous()); - TORCH_CHECK(dz0.is_contiguous()); - - auto sizes = x.sizes(); - TORCH_CHECK(sizes.size() == 2); - auto rows = sizes[0]; - auto cols = sizes[1]; - TORCH_CHECK(dz0.dim() == 2); - TORCH_CHECK(dz0.size(1) == cols); - auto hidden_size = gamma0.numel(); - TORCH_CHECK(hidden_size == cols); - - if (dz1_.has_value()) { - auto dz1 = dz1_.value(); - TORCH_CHECK(dz1.dtype() == otype); - TORCH_CHECK(dz1.is_cuda()); - TORCH_CHECK(dz1.is_contiguous()); - TORCH_CHECK(dz1.sizes() == sizes); - - TORCH_CHECK(gamma1_.has_value()); - auto gamma1 = gamma1_.value(); - TORCH_CHECK(gamma1.dtype() == wtype); - TORCH_CHECK(gamma1.is_cuda()); - TORCH_CHECK(gamma1.is_contiguous()); - TORCH_CHECK(gamma1.sizes() == gamma0.sizes()); - } - - if (dx_.has_value()) { - auto dx = dx_.value(); - TORCH_CHECK(dx.dtype() == rtype); - TORCH_CHECK(dx.is_cuda()); - TORCH_CHECK(dx.is_contiguous()); - TORCH_CHECK(dx.sizes() == sizes); - } - - if (dmask0_.has_value()) { - auto dmask0 = dmask0_.value(); - TORCH_CHECK(dmask0.dtype() == mtype); - TORCH_CHECK(dmask0.is_cuda()); - TORCH_CHECK(dmask0.is_contiguous()); - TORCH_CHECK(dmask0.sizes() == sizes); - - if (has_x1) { - TORCH_CHECK(dmask1_.has_value()); - auto dmask1 = dmask1_.value(); - TORCH_CHECK(dmask1.dtype() == mtype); - TORCH_CHECK(dmask1.is_cuda()); - TORCH_CHECK(dmask1.is_contiguous()); - TORCH_CHECK(dmask1.sizes() == sizes); - } - } - - TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192)); - - TORCH_CHECK(mu.numel() == rows); - TORCH_CHECK(mu.sizes() == rsigma.sizes()); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)dz0.get_device()}; - - auto opts = x.options(); - - auto dx0 = torch::empty(sizes, opts.dtype(itype)); - at::Tensor dx1; - if (has_x1) { dx1 = torch::empty(sizes, opts.dtype(itype)); } - at::Tensor dresidual; - if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); } - auto dgamma0 = torch::empty_like(gamma0); - auto dbeta0 = torch::empty_like(gamma0); - at::Tensor dgamma1, dbeta1; - if (gamma1_.has_value()) { - dgamma1 = torch::empty_like(gamma0); - dbeta1 = torch::empty_like(gamma0); - } - - layer_norm::LaunchParams launch_params; - launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); - launch_params.props = at::cuda::getCurrentDeviceProperties(); - TORCH_CHECK(dropout_p < 1.f); - launch_params.params.dropout_keep_p = 1.f - dropout_p; - launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr; - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024); - auto launcher = get_parallel_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple)); - - launcher(launch_params, true); - - auto dgamma0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); - auto dbeta0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); - at::Tensor dgamma1_part, dbeta1_part; - if (gamma1_.has_value()) { - dgamma1_part = torch::zeros_like(dgamma0_part); - dbeta1_part = torch::zeros_like(dbeta0_part); - } - at::Tensor workspace, barrier; - - layer_norm::BwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data_ptr(); - params.dmask = dropout_p > 0.f ? dmask0_.value().data_ptr() : nullptr; - params.dmask1 = (dropout_p > 0.f && has_x1) ? dmask1_.value().data_ptr() : nullptr; - params.mu = mu.data_ptr(); - params.rs = rsigma.data_ptr(); - params.gamma = gamma0.data_ptr(); - params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr; - params.dz = dz0.data_ptr(); - params.dz1 = dz1_.has_value() ? dz1_.value().data_ptr() : nullptr; - params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr; - params.dx0 = dx0.data_ptr(); - params.dx1 = has_x1 ? dx1.data_ptr() : nullptr; - params.dbeta = dbeta0.data_ptr(); - params.dgamma = dgamma0.data_ptr(); - params.dbeta1 = gamma1_.has_value() ? dbeta1.data_ptr() : nullptr; - params.dgamma1 = gamma1_.has_value() ? dgamma1.data_ptr() : nullptr; - params.dbeta_part = dbeta0_part.data_ptr(); - params.dgamma_part = dgamma0_part.data_ptr(); - params.dbeta1_part = gamma1_.has_value() ? dbeta1_part.data_ptr() : nullptr; - params.dgamma1_part = gamma1_.has_value() ? dgamma1_part.data_ptr() : nullptr; - params.dropout_scale = 1.f / (1.f - dropout_p); - params.inverse_cols = 1.f / float(params.cols); - params.is_rms_norm = is_rms_norm; - - if( launch_params.barrier_size > 0 ) { - // TODO Any way to avoid this? - barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32)); - workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar)); - params.workspace = workspace.data_ptr(); - params.barrier = barrier.data_ptr(); - } - - launcher(launch_params, false); - - std::vector result = { dx0, dx1, dresidual, dgamma0, dbeta0, dgamma1, dbeta1, dgamma0_part, dbeta0_part, dgamma1_part, dbeta1_part }; - return result; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/layer_norm/ln_bwd_1024.cu b/layer_norm/ln_bwd_1024.cu deleted file mode 100644 index f7101f6450fcdb8baa4ff4e79379d913048696b6..0000000000000000000000000000000000000000 --- a/layer_norm/ln_bwd_1024.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/layer_norm/ln_bwd_1280.cu b/layer_norm/ln_bwd_1280.cu deleted file mode 100644 index a80a5762a178bd1fd1cd2ef4d0fb2010c1eea22e..0000000000000000000000000000000000000000 --- a/layer_norm/ln_bwd_1280.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/layer_norm/ln_bwd_1536.cu b/layer_norm/ln_bwd_1536.cu deleted file mode 100644 index 0c25c088494d52f3b68251235d29c23a46ffc430..0000000000000000000000000000000000000000 --- a/layer_norm/ln_bwd_1536.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); diff --git a/layer_norm/ln_bwd_2048.cu b/layer_norm/ln_bwd_2048.cu deleted file mode 100644 index 06c0e608a3e48ec7fad2081bc6ff82425ea1c56a..0000000000000000000000000000000000000000 --- a/layer_norm/ln_bwd_2048.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); \ No newline at end of file diff --git a/layer_norm/ln_bwd_256.cu b/layer_norm/ln_bwd_256.cu deleted file mode 100644 index 20945432b8e97be21d80ada73aa0b3e709733a5b..0000000000000000000000000000000000000000 --- a/layer_norm/ln_bwd_256.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/layer_norm/ln_bwd_2560.cu b/layer_norm/ln_bwd_2560.cu deleted file mode 100644 index 309184c37b93e1f90bc1020a47973dae84f0f0c8..0000000000000000000000000000000000000000 --- a/layer_norm/ln_bwd_2560.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); diff --git a/layer_norm/ln_bwd_3072.cu b/layer_norm/ln_bwd_3072.cu deleted file mode 100644 index e156b11cd92f450a6ce8e0c432487bd36d6f9847..0000000000000000000000000000000000000000 --- a/layer_norm/ln_bwd_3072.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); \ No newline at end of file diff --git a/layer_norm/ln_bwd_4096.cu b/layer_norm/ln_bwd_4096.cu deleted file mode 100644 index b715b0efe48c4111ae4301365018d19f537c7a81..0000000000000000000000000000000000000000 --- a/layer_norm/ln_bwd_4096.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); \ No newline at end of file diff --git a/layer_norm/ln_bwd_512.cu b/layer_norm/ln_bwd_512.cu deleted file mode 100644 index 2b472118f0a0025917edc4c706492ca5dc8fa205..0000000000000000000000000000000000000000 --- a/layer_norm/ln_bwd_512.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/layer_norm/ln_bwd_5120.cu b/layer_norm/ln_bwd_5120.cu deleted file mode 100644 index 38f3fbd406db8989f4a9806e64075bf52444c529..0000000000000000000000000000000000000000 --- a/layer_norm/ln_bwd_5120.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); \ No newline at end of file diff --git a/layer_norm/ln_bwd_6144.cu b/layer_norm/ln_bwd_6144.cu deleted file mode 100644 index 469ed4b6c7691c581bbd1db5b8587de860afcb16..0000000000000000000000000000000000000000 --- a/layer_norm/ln_bwd_6144.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); \ No newline at end of file diff --git a/layer_norm/ln_bwd_7168.cu b/layer_norm/ln_bwd_7168.cu deleted file mode 100644 index 549eab11aa3c770bea97bda727495f3e141ec24b..0000000000000000000000000000000000000000 --- a/layer_norm/ln_bwd_7168.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); -REGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); -REGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); -REGISTER_BWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); -REGISTER_BWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); -REGISTER_BWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); -REGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); -REGISTER_BWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); \ No newline at end of file diff --git a/layer_norm/ln_bwd_768.cu b/layer_norm/ln_bwd_768.cu deleted file mode 100644 index 5db64d3d7b184f6ffb01ae0e1a26e0acec3bbe3d..0000000000000000000000000000000000000000 --- a/layer_norm/ln_bwd_768.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/layer_norm/ln_bwd_8192.cu b/layer_norm/ln_bwd_8192.cu deleted file mode 100644 index e6514e613fe9cbf444ad4919a5acf9579b216c9e..0000000000000000000000000000000000000000 --- a/layer_norm/ln_bwd_8192.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); \ No newline at end of file diff --git a/layer_norm/ln_bwd_kernels.cuh b/layer_norm/ln_bwd_kernels.cuh deleted file mode 100644 index c7261d218442acbcf60b61ce2e8803556193d8cd..0000000000000000000000000000000000000000 --- a/layer_norm/ln_bwd_kernels.cuh +++ /dev/null @@ -1,534 +0,0 @@ -#pragma once - -#include "ln.h" -#include "ln_utils.cuh" -#include "ln_kernel_traits.h" -#include "static_switch.h" - -namespace layer_norm { - -template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) -void ln_bwd_kernel(layer_norm::BwdParams params) { - - enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; - enum { WARPS_M = Ktraits::WARPS_M }; - enum { WARPS_N = Ktraits::WARPS_N }; - enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; - enum { COLS = Ktraits::COLS }; - enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; - enum { LDGS = Ktraits::LDGS }; - enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; - enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; - enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; - - using input_t = typename Ktraits::input_t; - using compute_t = typename Ktraits::compute_t; - using index_t = typename Ktraits::index_t; - using mask_t = typename Ktraits::mask_t; - using Ivec = typename Ktraits::Ivec; - using Rvec = typename Ktraits::Rvec; - using Ovec = typename Ktraits::Ovec; - using Wvec = typename Ktraits::Wvec; - using Cvec = typename Ktraits::Cvec; - using Mvec = typename Ktraits::Mvec; - using Reducer = typename Ktraits::Reducer; - using reduce_t = typename Reducer::Type; - - extern __shared__ char smem_[]; - - const bool has_residual = params.dresidual != nullptr; - const bool prenorm = params.dx != nullptr; - - const index_t tidx = threadIdx.x; - const index_t bidn = blockIdx.x % CTAS_PER_ROW; - const index_t bidm = blockIdx.x / CTAS_PER_ROW; - const index_t lane = tidx % THREADS_PER_WARP; - const index_t warp = tidx / THREADS_PER_WARP; - const index_t warp_m = warp / Ktraits::WARPS_N; - const index_t warp_n = warp % Ktraits::WARPS_N; - const index_t tid_r = warp_n * THREADS_PER_WARP + lane; - - const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; - const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; - - static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); - - const input_t *rowscale = static_cast(params.rowscale); - const index_t *x0_subset = static_cast(params.x0_subset); - const index_t *z_subset = static_cast(params.z_subset); - - Cvec dzy_sum[LDGS]; - Cvec dz_sum[LDGS]; - Cvec dcolscale_sum[LDGS]; - - memset(dzy_sum, 0, sizeof(dzy_sum)); - memset(dz_sum, 0, sizeof(dz_sum)); - if (Has_colscale) { memset(dcolscale_sum, 0, sizeof(dcolscale_sum)); } - - compute_t * smem_wgrad = reinterpret_cast(smem_); - char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; - - Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad); - - Sum sum; - - const index_t num_valid_ldgs = - ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG; - - Wvec gamma[LDGS]; - Wvec colscale[LDGS]; - index_t idx = c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - gamma[it].load_from(params.gamma, idx); - if (Has_colscale) { colscale[it].load_from(params.colscale, idx); } - idx += Ktraits::VEC_COLS_PER_LDG; - } - } - // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the - // last blocks with syncthreads! - // grid stride over rows - #pragma unroll 1 - for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { - const compute_t mu_r = static_cast(params.mu)[row]; - const compute_t rs_r = static_cast(params.rs)[row]; - const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const; - const int row_z = !Has_subset ? row + 1 : z_subset[row]; - const int row_x0 = !Has_subset ? row + 1 : x0_subset[row]; - const bool load_dz = !Has_subset || row_z > 0; - const bool save_dx0 = !Has_subset || row_x0 > 0; - Mvec dmask[LDGS]; - Rvec dx[LDGS]; - compute_t dy[LDGS * NUM_ELTS]; - compute_t y[LDGS * NUM_ELTS]; - compute_t mdy_local = 0.f; - compute_t mdyy_local = 0.f; - // If dz is not loaded, then dy should be 0 and we don't care about the value of y. - if (load_dz) { - index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c; - index_t idx_z = !Has_subset ? idx_x : (load_dz ? (row_z - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); - index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - Rvec x; - Ovec dz; - dz.load_from(params.dz, !Has_subset ? idx_x : idx_z); - if (prenorm) { dx[it].load_from(params.dx, idx_x); } - x.load_from(params.x, idx_x); - if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); } - idx_x += Ktraits::VEC_COLS_PER_LDG; - idx_z += Ktraits::VEC_COLS_PER_LDG; - idx_x0 += Ktraits::VEC_COLS_PER_LDG; - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t x_tmp = x.data.elt[jt]; - compute_t y_tmp = rs_r * (x_tmp - (!params.is_rms_norm ? mu_r : 0.f)); - compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) * compute_t(dz.data.elt[jt]); - compute_t dz_tmp = dz.data.elt[jt]; - - mdy_local += dy_tmp; - mdyy_local += dy_tmp * y_tmp; - - dy[it * NUM_ELTS + jt] = dy_tmp; - y[it * NUM_ELTS + jt] = y_tmp; - - dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp; - dz_sum[it].data.elt[jt] += dz_tmp; - } - } - } - } else { - index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c; - index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - if (prenorm) { dx[it].load_from(params.dx, idx_x); } - if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); } - idx_x += Ktraits::VEC_COLS_PER_LDG; - idx_x0 += Ktraits::VEC_COLS_PER_LDG; - } - } - } - - reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum); - mdy_local = layer_norm::Get<0>::of(result) * params.inverse_cols; - mdyy_local = layer_norm::Get<1>::of(result) * params.inverse_cols; - - index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c; - index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - Ivec dx0; - Rvec dresidual; - Ivec x0; - if (Has_colscale && save_dx0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); } - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t dx_tmp_res; - if (load_dz) { - compute_t dy_tmp = dy[it * NUM_ELTS + jt]; - compute_t y_tmp = y[it * NUM_ELTS + jt]; - compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + (!params.is_rms_norm ? mdy_local : 0.f))); - dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp; - } else { - dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f; - } - if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; } - if (save_dx0) { - compute_t dx0_tmp_res = dx_tmp_res * rowscale_val; - if (Is_dropout) { - dx0_tmp_res *= params.dropout_scale; - if (Has_colscale) { - dcolscale_sum[it].data.elt[jt] += dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(x0.data.elt[jt]) : 0.f; - dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(colscale[it].data.elt[jt]) : 0.f; - } else { - dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res : 0.f; - } - } else { - if (Has_colscale) { - dcolscale_sum[it].data.elt[jt] += dx0_tmp_res * compute_t(x0.data.elt[jt]); - dx0.data.elt[jt] = dx0_tmp_res * compute_t(colscale[it].data.elt[jt]); - } else { - dx0.data.elt[jt] = dx0_tmp_res; - } - } - } - } - if (has_residual) { dresidual.store_to(params.dresidual, idx_x); } - if (save_dx0) { dx0.store_to(params.dx0, !Has_subset ? idx_x : idx_x0); } - idx_x += Ktraits::VEC_COLS_PER_LDG; - idx_x0 += Ktraits::VEC_COLS_PER_LDG; - } - } - - } // end: grid stride loop - - if( WARPS_M == 1 ) { - idx = r * params.cols / Ktraits::ELTS_PER_LDG + c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - dz_sum[it].store_to(params.dbeta_part, idx); - dzy_sum[it].store_to(params.dgamma_part, idx); - if (Has_colscale) { dcolscale_sum[it].store_to(params.dcolscale_part, idx); } - idx += Ktraits::VEC_COLS_PER_LDG; - } - } - } else { - static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA."); - // Finalize reduction of part dgamma and dbeta for this CTA - // by reducing over the rows held across the WARPS_M warps - - // Assumption: blockSize divides hidden size. - enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA }; - static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, ""); - - idx = warp_m * Ktraits::VEC_COLS + tid_r; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - dz_sum[it].store_to(smem_wgrad, idx); - idx += THREADS_PER_ROW; - } - __syncthreads(); - compute_t cta_dz_sum[NUM_RES]; - memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES); - for( int it = 0; it < ROWS_PER_CTA; it++ ) { - for( int jt = 0; jt < NUM_RES; jt++ ) { - cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; - } - } - __syncthreads(); - - idx = warp_m * Ktraits::VEC_COLS + tid_r; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - dzy_sum[it].store_to(smem_wgrad, idx); - idx += THREADS_PER_ROW; - } - __syncthreads(); - compute_t cta_dzy_sum[NUM_RES]; - memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES); - for( int it = 0; it < ROWS_PER_CTA; it++ ) { - for( int jt = 0; jt < NUM_RES; jt++ ) { - cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; - } - } - - compute_t cta_dcolscale_sum[NUM_RES]; - if (Has_colscale) { - __syncthreads(); - idx = warp_m * Ktraits::VEC_COLS + tid_r; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - dcolscale_sum[it].store_to(smem_wgrad, idx); - idx += THREADS_PER_ROW; - } - __syncthreads(); - memset(cta_dcolscale_sum, 0, sizeof(compute_t) * NUM_RES); - for( int it = 0; it < ROWS_PER_CTA; it++ ) { - for( int jt = 0; jt < NUM_RES; jt++ ) { - cta_dcolscale_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; - } - } - } - - const index_t num_valid_writes - = (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA; - compute_t *dgamma_part = static_cast(params.dgamma_part) + bidm * params.cols + tidx; - compute_t *dbeta_part = static_cast(params.dbeta_part) + bidm * params.cols + tidx; - compute_t *dcolscale_part = Has_colscale ? static_cast(params.dcolscale_part) + bidm * params.cols + tidx : nullptr; - for( int jt = 0; jt < NUM_RES; jt++ ) { - if (Is_even_cols || (jt < num_valid_writes)) { - *dgamma_part = cta_dzy_sum[jt]; - dgamma_part += Ktraits::THREADS_PER_CTA; - *dbeta_part = cta_dz_sum[jt]; - dbeta_part += Ktraits::THREADS_PER_CTA; - if (Has_colscale) { - *dcolscale_part = cta_dcolscale_sum[jt]; - dcolscale_part += Ktraits::THREADS_PER_CTA; - } - } - } - - } -} - -template -__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) -void ln_bwd_finalize_kernel(BwdParams params) -{ - - using compute_t = typename Kernel_traits::compute_t; - using weight_t = typename Kernel_traits::weight_t; - using index_t = typename Kernel_traits::index_t; - using Reducer = typename Kernel_traits::Reducer; - using reduce_t = typename Reducer::Type; - - Sum sum; - enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG }; - enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP }; - - __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA]; - - constexpr uint32_t bidm = 0; - - const uint32_t bidn = blockIdx.x; - const uint32_t tidx = threadIdx.x; - const uint32_t warp = tidx / THREADS_PER_WARP; - const uint32_t lane = tidx % THREADS_PER_WARP; - - Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_); - - const uint32_t c = bidn * THREADS_PER_WARP + lane; - const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; - constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; - for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) { - // Each thread sums over NUM_ELT columns. - Vec dbeta_local, dgamma_local, dcolscale_local; - memset(&dgamma_local, 0, sizeof(dgamma_local)); - memset(&dbeta_local, 0, sizeof(dbeta_local)); - if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); } - if (Is_even_cols || col < params.cols) { - for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) { - index_t idx = row * params.cols + col; - - Vec dbeta_part, dgamma_part, dcolscale_part; - dbeta_part.load_from(params.dbeta_part, idx); - dgamma_part.load_from(params.dgamma_part, idx); - if (Has_colscale) { dcolscale_part.load_from(params.dcolscale_part, idx); } - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - dgamma_local.data.elt[it] += dgamma_part.data.elt[it]; - dbeta_local.data.elt[it] += dbeta_part.data.elt[it]; - if (Has_colscale) { dcolscale_local.data.elt[it] += dcolscale_part.data.elt[it]; } - } - } - } - void * smem_gamma = smem_; - void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; - void * smem_colscale = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; - - const int write_row = warp; - const int write_col = lane ^ write_row; - const int write_idx = write_row * THREADS_PER_WARP + write_col; - - dgamma_local.store_to(smem_gamma, write_idx); - dbeta_local.store_to(smem_beta, write_idx); - if (Has_colscale) { dcolscale_local.store_to(smem_colscale, write_idx); } - - __syncthreads(); - - // It would be probably safe to reuse the first row of smem_beta and smem_gamma - void * smem_gamma_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE]; - void * smem_beta_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT]; - void * smem_colscale_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT]; - - - // More than one iter iff ROWS_PER_CTA < 32. - for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) { - const int read_row = lane; - const int read_col = w ^ read_row; - const int read_idx = read_row * THREADS_PER_WARP + read_col; - - memset(&dbeta_local, 0, sizeof(dbeta_local)); - memset(&dgamma_local, 0, sizeof(dgamma_local)); - if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); } - - // Load beta and gamma transposed - if(read_row < Kernel_traits::ROWS_PER_CTA){ - dbeta_local.load_from(smem_beta, read_idx); - dgamma_local.load_from(smem_gamma, read_idx); - if (Has_colscale) { dcolscale_local.load_from(smem_colscale, read_idx); } - } - - // Call reducer on the loaded value(s) and convert. - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - compute_t b_i = dbeta_local.data.elt[it]; - compute_t g_i = dgamma_local.data.elt[it]; - b_i = reducer.allreduce(b_i, sum); - g_i = reducer.allreduce(g_i, sum); - - dgamma_local.data.elt[it] = g_i; - dbeta_local.data.elt[it] = b_i; - if (Has_colscale) { - compute_t cs_i = dcolscale_local.data.elt[it]; - cs_i = reducer.allreduce(cs_i, sum); - dcolscale_local.data.elt[it] = cs_i; - } - } - - // Leader stores the result at the current column. - if(lane == 0){ - dgamma_local.store_to(smem_gamma_out, w); - dbeta_local.store_to(smem_beta_out, w); - if (Has_colscale) { dcolscale_local.store_to(smem_colscale_out, w); } - } - - } - - // All writes done. - __syncthreads(); - - // Pack and store: 2-wide stores with half the threads. - if (Is_even_cols || col_out * 2 < params.cols) { - if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) { - - using src_t = typename TypeToVec2::Type; - using dst_t = typename TypeToVec2::Type; - Vec dbeta_vec2, dgamma_vec2, dcolscale_vec2; - Vec dbeta_out2, dgamma_out2, dcolscale_out2; - - dgamma_vec2.load_from(smem_gamma_out, lane); - dbeta_vec2.load_from(smem_beta_out, lane); - if (Has_colscale) { dcolscale_vec2.load_from(smem_colscale_out, lane); } - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - dgamma_out2.data.elt[it] = Converter::convert(dgamma_vec2.data.elt[it]); - dbeta_out2.data.elt[it] = Converter::convert(dbeta_vec2.data.elt[it]); - if (Has_colscale) { dcolscale_out2.data.elt[it] = Converter::convert(dcolscale_vec2.data.elt[it]); } - } - dgamma_out2.store_to(params.dgamma, col_out); - dbeta_out2.store_to(params.dbeta, col_out); - if (Has_colscale) { dcolscale_out2.store_to(params.dcolscale, col_out); } - } - } - } -} -} // namespace layer_norm - -using namespace layer_norm; - -template< - typename weight_t, - typename input_t, - typename residual_t, - typename output_t, - typename compute_t, - typename index_t, - int HIDDEN_SIZE, - int CTAS_PER_ROW, - int WARPS_M, - int WARPS_N, - int BYTES_PER_LDG_MAIN, - int BYTES_PER_LDG_FINAL -> -void launch_(LaunchParams &launch_params, const bool configure_params){ - - using Kernel_traits = Kernel_traits; - bool is_dropout = launch_params.params.dropout_keep_p < 1.f; - bool has_colscale = launch_params.params.colscale != nullptr; - bool has_subset = launch_params.params.x0_subset != nullptr; - bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; - BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { - BOOL_SWITCH(has_colscale, HasColscaleConst, [&] { - BOOL_SWITCH(has_subset, HasSubsetConst, [&] { - BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { - auto kernel = &ln_bwd_kernel; - if( configure_params ) { - int ctas_per_sm; - CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES)); - launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if(Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col - * Kernel_traits::WARPS_M - * Kernel_traits::CTAS_PER_ROW - * sizeof(typename Kernel_traits::reduce_t) - * 2; - } - return; - } - - if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) { - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - - if( Kernel_traits::CTAS_PER_ROW == 1 ) { - kernel<<>>(launch_params.params); - } else { - dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = (void *)&launch_params.params; - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); - } - - using Kernel_traits_f = layer_norm::Kernel_traits_finalize; - - auto kernel_f = &layer_norm::ln_bwd_finalize_kernel; - kernel_f<<>>(launch_params.params); - }); - }); - }); - }); -} diff --git a/layer_norm/ln_fwd_1024.cu b/layer_norm/ln_fwd_1024.cu deleted file mode 100644 index 824d86e9fd05920d3e557b42356feec86c904f68..0000000000000000000000000000000000000000 --- a/layer_norm/ln_fwd_1024.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/layer_norm/ln_fwd_1280.cu b/layer_norm/ln_fwd_1280.cu deleted file mode 100644 index 1ff58cbc2889a2c06c51df560d2b35ca4e079201..0000000000000000000000000000000000000000 --- a/layer_norm/ln_fwd_1280.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/layer_norm/ln_fwd_1536.cu b/layer_norm/ln_fwd_1536.cu deleted file mode 100644 index a8e19d4dba97d91cd246e62ba80a2936ac05755c..0000000000000000000000000000000000000000 --- a/layer_norm/ln_fwd_1536.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/layer_norm/ln_fwd_2048.cu b/layer_norm/ln_fwd_2048.cu deleted file mode 100644 index 6f9794c1e77f91a333d64cc6e461560622b87e12..0000000000000000000000000000000000000000 --- a/layer_norm/ln_fwd_2048.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/layer_norm/ln_fwd_256.cu b/layer_norm/ln_fwd_256.cu deleted file mode 100644 index f3a541c6dbf20cd94bb56607bbb23e6a81059bdc..0000000000000000000000000000000000000000 --- a/layer_norm/ln_fwd_256.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/layer_norm/ln_fwd_2560.cu b/layer_norm/ln_fwd_2560.cu deleted file mode 100644 index 1650671e059ec358f8109c1d592694458e77d489..0000000000000000000000000000000000000000 --- a/layer_norm/ln_fwd_2560.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/layer_norm/ln_fwd_3072.cu b/layer_norm/ln_fwd_3072.cu deleted file mode 100644 index 25bb8691dc9f6a95297301efbd91567a5c22d1c2..0000000000000000000000000000000000000000 --- a/layer_norm/ln_fwd_3072.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); diff --git a/layer_norm/ln_fwd_4096.cu b/layer_norm/ln_fwd_4096.cu deleted file mode 100644 index b2bffb5831bf1b6eb18cd1e2cd2c4636a06f5736..0000000000000000000000000000000000000000 --- a/layer_norm/ln_fwd_4096.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); diff --git a/layer_norm/ln_fwd_512.cu b/layer_norm/ln_fwd_512.cu deleted file mode 100644 index a08fe34c55d61eecdbc74caa41dfbec10b3a8126..0000000000000000000000000000000000000000 --- a/layer_norm/ln_fwd_512.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/layer_norm/ln_fwd_5120.cu b/layer_norm/ln_fwd_5120.cu deleted file mode 100644 index bebbd69f05b38a5e3c0dae5d248de467118ef8c5..0000000000000000000000000000000000000000 --- a/layer_norm/ln_fwd_5120.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); diff --git a/layer_norm/ln_fwd_6144.cu b/layer_norm/ln_fwd_6144.cu deleted file mode 100644 index 4df01ead2f292e255221e6fb0b48e63941a22cab..0000000000000000000000000000000000000000 --- a/layer_norm/ln_fwd_6144.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16); diff --git a/layer_norm/ln_fwd_7168.cu b/layer_norm/ln_fwd_7168.cu deleted file mode 100644 index 8343666d10c2788cb2c19ba4f448eef2ccf2b956..0000000000000000000000000000000000000000 --- a/layer_norm/ln_fwd_7168.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); diff --git a/layer_norm/ln_fwd_768.cu b/layer_norm/ln_fwd_768.cu deleted file mode 100644 index 06d5a3b09cdd4941764885f5107bbbfa6b264eef..0000000000000000000000000000000000000000 --- a/layer_norm/ln_fwd_768.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/layer_norm/ln_fwd_8192.cu b/layer_norm/ln_fwd_8192.cu deleted file mode 100644 index bf7cb40252baf820c88dff1337c81dffd934087a..0000000000000000000000000000000000000000 --- a/layer_norm/ln_fwd_8192.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16); -REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16); diff --git a/layer_norm/ln_fwd_kernels.cuh b/layer_norm/ln_fwd_kernels.cuh deleted file mode 100644 index f6bccb8c28a2b3d967dddc3d8b21e1888ed2e29c..0000000000000000000000000000000000000000 --- a/layer_norm/ln_fwd_kernels.cuh +++ /dev/null @@ -1,272 +0,0 @@ -#pragma once - -#ifdef OLD_GENERATOR_PATH -#include -#else -#include -#endif - -#include // For at::cuda::philox::unpack -#include - -#include "ln.h" -#include "ln_utils.cuh" -#include "ln_kernel_traits.h" -#include "static_switch.h" - -namespace layer_norm { - -template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) -void ln_fwd_kernel(FwdParams params) { - - enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; - enum { WARPS_N = Ktraits::WARPS_N }; - enum { WARPS_M = Ktraits::WARPS_M }; - enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; - enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; - enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; - enum { LDGS = Ktraits::LDGS }; - enum { NUM_ELTS = Ktraits::NUM_ELTS }; - enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; - - using input_t = typename Ktraits::input_t; - using residual_t = typename Ktraits::residual_t; - using output_t = typename Ktraits::output_t; - using index_t = typename Ktraits::index_t; - using compute_t = typename Ktraits::compute_t; - using mask_t = typename Ktraits::mask_t; - using Ivec = typename Ktraits::Ivec; - using Rvec = typename Ktraits::Rvec; - using Ovec = typename Ktraits::Ovec; - using Wvec = typename Ktraits::Wvec; - using Cvec = typename Ktraits::Cvec; - using Mvec = typename Ktraits::Mvec; - - using Stats = typename Ktraits::Stats; - using stats_t = typename Stats::stats_t; - - const bool has_residual = params.residual != nullptr; - const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same::value); - - extern __shared__ char smem_[]; - - const index_t tidx = threadIdx.x; - const index_t bidn = blockIdx.x % CTAS_PER_ROW; - const index_t bidm = blockIdx.x / CTAS_PER_ROW; - const index_t lane = tidx % THREADS_PER_WARP; - const index_t warp = tidx / THREADS_PER_WARP; - const index_t warp_m = warp / WARPS_N; - const index_t warp_n = warp % WARPS_N; - - const index_t r = bidm * ROWS_PER_CTA + warp_m; - const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; - - Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); - - compute_t *mu_ptr = static_cast(params.mu); - compute_t *rs_ptr = static_cast(params.rs); - - const input_t *rowscale = static_cast(params.rowscale); - const index_t *x0_subset = static_cast(params.x0_subset); - const index_t *z_subset = static_cast(params.z_subset); - - // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu - curandStatePhilox4_32_10_t state; - if (Is_dropout) { - auto seeds = at::cuda::philox::unpack(params.philox_args); - const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x; - curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state); - } - - const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG; - - Wvec gamma[LDGS]; - Wvec beta[LDGS]; - Wvec colscale[LDGS]; - index_t idx = c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - gamma[it].load_from(params.gamma, idx); - if (params.beta != nullptr) { - beta[it].load_from(params.beta, idx); - } else { - beta[it].zero_(); - } - if (Has_colscale) { colscale[it].load_from(params.colscale, idx); } - idx += VEC_COLS_PER_LDG; - } - } - - for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { - const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const; - const int row_x0 = !Has_subset ? row + 1 : x0_subset[row]; - const int row_z = !Has_subset ? row + 1 : z_subset[row]; - const bool load_x0 = !Has_subset || row_x0 > 0; - index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c; - index_t idx_x0 = !Has_subset ? idx_x : (load_x0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); - compute_t xf[LDGS * NUM_ELTS]; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - Ivec x0; - Rvec residual; - Rvec x; - Mvec dmask; - if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); } - if (has_residual) { residual.load_from(params.residual, idx_x); } - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use - // the more efficient curand_uniform4. - compute_t x_ij; - if (load_x0) { - mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p; - if (Is_dropout) { dmask.data.elt[jt] = keep; } - compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val; - x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f; - if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); } - x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij; - } else { - x_ij = has_residual ? compute_t(residual.data.elt[jt]) : 0.f; - } - if (save_x) { x.data.elt[jt] = x_ij; } - xf[it * NUM_ELTS + jt] = x_ij; - } - if (save_x) { x.store_to(params.x, idx_x); } - if (Is_dropout && load_x0) { dmask.store_to(params.dmask, !Has_subset ? idx_x : idx_x0); } - idx_x += VEC_COLS_PER_LDG; - idx_x0 += VEC_COLS_PER_LDG; - } - } - - static_assert(CTAS_PER_ROW == 1, "Don't support multiple CTAs per row for now"); - const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG; - const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG; - const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG; - auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int { - // Need to convert to int, otherwise the subtraction will wrap around. - const index_t valid_partial_vecs_in_warp = - std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)), - int(THREADS_PER_WARP)); - return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS; - }; - stats_t s = stats.template compute( - xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS - ); - - compute_t mu = layer_norm::Get<0>::of(s); - compute_t m2 = layer_norm::Get<1>::of(s); - - if( bidn == 0 && warp_n == 0 && lane == 0 ) { - mu_ptr[row] = mu; - } - - compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu)); - - if( bidn == 0 && warp_n == 0 && lane == 0 ) { - rs_ptr[row] = rs; - } - - const bool save_z = !Has_subset || row_z > 0; - if (save_z) { - index_t idx_z = (!Has_subset ? row : (row_z - 1)) * params.cols / Ktraits::ELTS_PER_LDG + c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - Ovec z; - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f))); - compute_t g_ij = gamma[it].data.elt[jt]; - compute_t b_ij = beta[it].data.elt[jt]; - z.data.elt[jt] = output_t(g_ij * y_ij + b_ij); - } - z.store_to(params.z, idx_z); - idx_z += VEC_COLS_PER_LDG; - } - } - } - - } -} - -} // namespace layer_norm - -using namespace layer_norm; - -template< - typename weight_t, - typename input_t, - typename residual_t, - typename output_t, - typename compute_t, - typename index_t, - int HIDDEN_SIZE, - int CTAS_PER_ROW, - int WARPS_M, - int WARPS_N, - int BYTES_PER_LDG -> -void launch_(LaunchParams &launch_params, const bool configure_params){ - - using Kernel_traits = Kernel_traits; - bool has_colscale = launch_params.params.colscale != nullptr; - bool has_subset = launch_params.params.x0_subset != nullptr; - bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; - BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] { - BOOL_SWITCH(has_colscale, HasColscaleConst, [&] { - BOOL_SWITCH(has_subset, HasSubsetConst, [&] { - BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { - auto kernel = &ln_fwd_kernel; - if( configure_params ) { - int ctas_per_sm; - CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD)); - launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; - const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA; - launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if(Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col - * Kernel_traits::WARPS_M - * Kernel_traits::CTAS_PER_ROW - * sizeof(typename Kernel_traits::Stats::stats_t) - * 2; - } - return; - } - - if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - - if( Kernel_traits::CTAS_PER_ROW == 1 ) { - kernel<<>>(launch_params.params); - } else { - dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = (void *)&launch_params.params; - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); - } - }); - }); - }); - }); -} diff --git a/layer_norm/ln_kernel_traits.h b/layer_norm/ln_kernel_traits.h deleted file mode 100644 index 77de6bf9af60c9ae70427097db26cf4ed130b359..0000000000000000000000000000000000000000 --- a/layer_norm/ln_kernel_traits.h +++ /dev/null @@ -1,172 +0,0 @@ -#pragma once - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace layer_norm { -template< - uint32_t HIDDEN_SIZE_, - typename weight_t_, - typename input_t_, - typename residual_t_, - typename output_t_, - typename compute_t_, - typename index_t_, - uint32_t THREADS_PER_CTA_ -> -struct Kernel_traits_base { - - using weight_t = weight_t_; - using input_t = input_t_; - using residual_t = residual_t_; - using output_t = output_t_; - using compute_t = compute_t_; - using index_t = index_t_; - - enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; - enum { THREADS_PER_CTA = THREADS_PER_CTA_ }; - enum { THREADS_PER_WARP = 32 }; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - uint32_t HIDDEN_SIZE_, - typename weight_t_, - typename input_t_, - typename residual_t_, - typename output_t_, - typename compute_t_, - typename index_t_, - bool Has_colscale, - uint32_t THREADS_PER_CTA_, - uint32_t BYTES_PER_LDG_, - typename Base = Kernel_traits_base -> -struct Kernel_traits_finalize : public Base { - enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP }; - static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP); - // Bytes per global load from the input. - enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; - // Number of elements fetched by a global load. - enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) }; - // Bytes per global store of the weights. - enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) }; - static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!"); - static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!"); - // The total number of BYTES_PER_LDG-wide words in a hidden vector. - enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG }; - static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_)); - - // Shared memory size to transpose the CTA result. - enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG }; - // Shared memory size to coalsece the CTA result. - enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG }; - // Shared memory requirement per CTA. - static constexpr int NUM_FACTORS = Has_colscale ? 3 : 2; - enum { SMEM_BYTES_PER_CTA = NUM_FACTORS * SMEM_BYTES_TRANSPOSE + NUM_FACTORS * SMEM_BYTES_OUTPUT }; - - // The type of the reducer. - using Reducer = layer_norm::Reducer; - - // Condition for the whole CTA to participate in syncthreads. - static_assert(COLS % Base::THREADS_PER_WARP == 0); - enum { CTAS = COLS / Base::THREADS_PER_WARP }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template< - typename weight_t_, - typename input_t_, - typename residual_t_, - typename output_t_, - typename compute_t_, - typename index_t_, - uint32_t HIDDEN_SIZE_, - uint32_t CTAS_PER_ROW_, - uint32_t WARPS_M_, - uint32_t WARPS_N_, - uint32_t BYTES_PER_LDG_ = 16, - typename Base = Kernel_traits_base< - HIDDEN_SIZE_, - weight_t_, - input_t_, - residual_t_, - output_t_, - compute_t_, - index_t_, - WARPS_M_*WARPS_N_*THREADS_PER_WARP - > -> -struct Kernel_traits : public Base { - - using input_t = typename Base::input_t; - using residual_t = typename Base::residual_t; - using weight_t = typename Base::weight_t; - using compute_t = typename Base::compute_t; - using output_t = typename Base::output_t; - using index_t = typename Base::index_t; - // using mask_t = unsigned char; - using mask_t = bool; - - enum { CTAS_PER_ROW = CTAS_PER_ROW_ }; - enum { WARPS_M = WARPS_M_ }; - enum { WARPS_N = WARPS_N_ }; - enum { COLS = HIDDEN_SIZE_ }; - enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; - enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; - enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) }; - - enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP }; - enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW }; - enum { ROWS_PER_CTA = WARPS_M }; - - enum { BYTES_PER_ROW = COLS * sizeof(input_t) }; - enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG }; - // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed - enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) }; - static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1); - - using reduce_t = typename layer_norm::TypeToVec2::Type; - using Reducer = layer_norm::Reducer; - - enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES }; - enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD }; - - using Ivec = layer_norm::Vec; - using Rvec = layer_norm::Vec; - using Ovec = layer_norm::Vec; - using Wvec = layer_norm::Vec; - using Cvec = layer_norm::Vec; - using Mvec = layer_norm::Vec; - enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) }; - - // Assume that each thread can handle the same number of elements in the output and weights as in the input. - static_assert(sizeof(input_t) == sizeof(output_t)); - static_assert(sizeof(input_t) <= sizeof(residual_t)); - // The number of columns fetched per load from input: one per thread. - enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW }; - // The total number of vectorized loads/stores per hidden vector. - enum { VEC_COLS = COLS / ELTS_PER_LDG }; - // The number of loads per thread for the input. - enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG }; - static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS); - //static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, ""); - - using Stats = layer_norm::Stats; - enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES }; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace layer_norm diff --git a/layer_norm/ln_parallel_bwd_1024.cu b/layer_norm/ln_parallel_bwd_1024.cu deleted file mode 100644 index 6f4e77466c6c6d5a00275d54f4e68da062a5fc1a..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_bwd_1024.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/layer_norm/ln_parallel_bwd_1280.cu b/layer_norm/ln_parallel_bwd_1280.cu deleted file mode 100644 index 2dba3bebf26e99b853e7ef4b9b56421cf483e0bd..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_bwd_1280.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/layer_norm/ln_parallel_bwd_1536.cu b/layer_norm/ln_parallel_bwd_1536.cu deleted file mode 100644 index c2ac4b1b0998ca412dea02466f0d8fbe69f48216..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_bwd_1536.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); diff --git a/layer_norm/ln_parallel_bwd_2048.cu b/layer_norm/ln_parallel_bwd_2048.cu deleted file mode 100644 index f7f959e2fa785a4df3b6a32506f527e1723d83cc..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_bwd_2048.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); \ No newline at end of file diff --git a/layer_norm/ln_parallel_bwd_256.cu b/layer_norm/ln_parallel_bwd_256.cu deleted file mode 100644 index fa613cf45e1045d046cefc4afd55ded754bc20a4..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_bwd_256.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/layer_norm/ln_parallel_bwd_2560.cu b/layer_norm/ln_parallel_bwd_2560.cu deleted file mode 100644 index 5f5707612df09149885d7883728672dc3a2b751f..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_bwd_2560.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); diff --git a/layer_norm/ln_parallel_bwd_3072.cu b/layer_norm/ln_parallel_bwd_3072.cu deleted file mode 100644 index 8fdcb8ffb4d0f0e0fcae6aee930808bd0349ede5..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_bwd_3072.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); \ No newline at end of file diff --git a/layer_norm/ln_parallel_bwd_4096.cu b/layer_norm/ln_parallel_bwd_4096.cu deleted file mode 100644 index 8decfb085ac8ace1e3694a491bb66a83209027b8..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_bwd_4096.cu +++ /dev/null @@ -1,17 +0,0 @@ -#include "ln_parallel_residual_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -// Use 8 warps otherwise there's a lot of register spilling - -REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); \ No newline at end of file diff --git a/layer_norm/ln_parallel_bwd_512.cu b/layer_norm/ln_parallel_bwd_512.cu deleted file mode 100644 index 178453d3045bfefd95018320d357ea8662018782..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_bwd_512.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/layer_norm/ln_parallel_bwd_5120.cu b/layer_norm/ln_parallel_bwd_5120.cu deleted file mode 100644 index 815521973da7266534c7e8b167fa0b8baa47fa2c..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_bwd_5120.cu +++ /dev/null @@ -1,17 +0,0 @@ -#include "ln_parallel_residual_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -// Use 8 warps otherwise there's a lot of register spilling - -REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); \ No newline at end of file diff --git a/layer_norm/ln_parallel_bwd_6144.cu b/layer_norm/ln_parallel_bwd_6144.cu deleted file mode 100644 index eb8668d8a229d2ec24e5eac57db00f9d650615eb..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_bwd_6144.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); \ No newline at end of file diff --git a/layer_norm/ln_parallel_bwd_7168.cu b/layer_norm/ln_parallel_bwd_7168.cu deleted file mode 100644 index 0c12dc476678ce7b24c5fcd0b9408eb686bd6825..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_bwd_7168.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); \ No newline at end of file diff --git a/layer_norm/ln_parallel_bwd_768.cu b/layer_norm/ln_parallel_bwd_768.cu deleted file mode 100644 index 8beece8ab19cea2baefedd118f5d15c90a646526..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_bwd_768.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/layer_norm/ln_parallel_bwd_8192.cu b/layer_norm/ln_parallel_bwd_8192.cu deleted file mode 100644 index 5ad47c94fdff599dde62574d1c535c4bbacae551..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_bwd_8192.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_bwd_kernels.cuh" - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_PARALLEL_BWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); \ No newline at end of file diff --git a/layer_norm/ln_parallel_fwd_1024.cu b/layer_norm/ln_parallel_fwd_1024.cu deleted file mode 100644 index 3c64e169302eea0f94ff65641728c35689d7c4ba..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_fwd_1024.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/layer_norm/ln_parallel_fwd_1280.cu b/layer_norm/ln_parallel_fwd_1280.cu deleted file mode 100644 index 9bbfce5bc6c5e0303d70552bb36cf380601dcd38..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_fwd_1280.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/layer_norm/ln_parallel_fwd_1536.cu b/layer_norm/ln_parallel_fwd_1536.cu deleted file mode 100644 index b57f5edce8eb7b6779475f6eadb8aabba299c802..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_fwd_1536.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/layer_norm/ln_parallel_fwd_2048.cu b/layer_norm/ln_parallel_fwd_2048.cu deleted file mode 100644 index 6fa322d96b4e11aacf5722985672e141f929299b..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_fwd_2048.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/layer_norm/ln_parallel_fwd_256.cu b/layer_norm/ln_parallel_fwd_256.cu deleted file mode 100644 index 27445a6bc50c98935c7a5093ee5ffdddf52e2494..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_fwd_256.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); \ No newline at end of file diff --git a/layer_norm/ln_parallel_fwd_2560.cu b/layer_norm/ln_parallel_fwd_2560.cu deleted file mode 100644 index fdde470c267302adca3d63f2c6b736b67af7ee86..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_fwd_2560.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/layer_norm/ln_parallel_fwd_3072.cu b/layer_norm/ln_parallel_fwd_3072.cu deleted file mode 100644 index 992f71037607066fb4e4d0f1624669f21c2f53b1..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_fwd_3072.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); diff --git a/layer_norm/ln_parallel_fwd_4096.cu b/layer_norm/ln_parallel_fwd_4096.cu deleted file mode 100644 index 381837e60874e44aa5e0efccb8749b2ff41ac3fa..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_fwd_4096.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); diff --git a/layer_norm/ln_parallel_fwd_512.cu b/layer_norm/ln_parallel_fwd_512.cu deleted file mode 100644 index 4ba478b01fbdbc2ff5aab0a15fb698eba369f61a..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_fwd_512.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/layer_norm/ln_parallel_fwd_5120.cu b/layer_norm/ln_parallel_fwd_5120.cu deleted file mode 100644 index 7ada35228cb603ddd26b06e186989746a86926a8..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_fwd_5120.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); diff --git a/layer_norm/ln_parallel_fwd_6144.cu b/layer_norm/ln_parallel_fwd_6144.cu deleted file mode 100644 index 6f531c881f7f53651c56e3afd1f0f53c580815ec..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_fwd_6144.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16); diff --git a/layer_norm/ln_parallel_fwd_7168.cu b/layer_norm/ln_parallel_fwd_7168.cu deleted file mode 100644 index c99e752cd484a99e97f8bf7a92e433a817c54d64..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_fwd_7168.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); diff --git a/layer_norm/ln_parallel_fwd_768.cu b/layer_norm/ln_parallel_fwd_768.cu deleted file mode 100644 index f33f519c7fb2934b3b5aabf36a2d9046c4b51ee3..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_fwd_768.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/layer_norm/ln_parallel_fwd_8192.cu b/layer_norm/ln_parallel_fwd_8192.cu deleted file mode 100644 index 360e6d4471062cd40bf245ecff22b579f56d4020..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_fwd_8192.cu +++ /dev/null @@ -1,15 +0,0 @@ -#include "ln_parallel_residual_fwd_kernels.cuh" - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16); -REGISTER_PARALLEL_FWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16); diff --git a/layer_norm/ln_parallel_residual_bwd_kernels.cuh b/layer_norm/ln_parallel_residual_bwd_kernels.cuh deleted file mode 100644 index 521495724400fde6eaecb27e255154a51d8ddbb0..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_residual_bwd_kernels.cuh +++ /dev/null @@ -1,540 +0,0 @@ -#pragma once - -#include "ln.h" -#include "ln_utils.cuh" -#include "ln_kernel_traits.h" -#include "static_switch.h" -#include "ln_bwd_kernels.cuh" - -namespace layer_norm { - -template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) -void ln_parallel_residual_bwd_kernel(layer_norm::BwdParams params) { - - enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; - enum { WARPS_M = Ktraits::WARPS_M }; - enum { WARPS_N = Ktraits::WARPS_N }; - enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; - enum { COLS = Ktraits::COLS }; - enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; - enum { LDGS = Ktraits::LDGS }; - enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; - enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; - enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; - - using input_t = typename Ktraits::input_t; - using compute_t = typename Ktraits::compute_t; - using index_t = typename Ktraits::index_t; - using mask_t = typename Ktraits::mask_t; - using Ivec = typename Ktraits::Ivec; - using Rvec = typename Ktraits::Rvec; - using Ovec = typename Ktraits::Ovec; - using Wvec = typename Ktraits::Wvec; - using Cvec = typename Ktraits::Cvec; - using Mvec = typename Ktraits::Mvec; - using Reducer = typename Ktraits::Reducer; - using reduce_t = typename Reducer::Type; - - extern __shared__ char smem_[]; - - const bool has_residual = params.dresidual != nullptr; - const bool has_x1 = params.dx1 != nullptr; - const bool prenorm = params.dx != nullptr; - - const index_t tidx = threadIdx.x; - const index_t bidn = blockIdx.x % CTAS_PER_ROW; - const index_t bidm = blockIdx.x / CTAS_PER_ROW; - const index_t lane = tidx % THREADS_PER_WARP; - const index_t warp = tidx / THREADS_PER_WARP; - const index_t warp_m = warp / Ktraits::WARPS_N; - const index_t warp_n = warp % Ktraits::WARPS_N; - const index_t tid_r = warp_n * THREADS_PER_WARP + lane; - - const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; - const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; - - static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); - - Cvec dz0y_sum[LDGS]; - Cvec dz0_sum[LDGS]; - Cvec dz1y_sum[LDGS]; - Cvec dz1_sum[LDGS]; - - memset(dz0y_sum, 0, sizeof(dz0y_sum)); - memset(dz0_sum, 0, sizeof(dz0_sum)); - if (!Tied_norm) { - memset(dz1y_sum, 0, sizeof(dz1y_sum)); - memset(dz1_sum, 0, sizeof(dz1_sum)); - } - - compute_t * smem_wgrad = reinterpret_cast(smem_); - char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; - - Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad); - - Sum sum; - - const index_t num_valid_ldgs = - ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG; - - Wvec gamma0[LDGS]; - Wvec gamma1[LDGS]; - index_t idx = c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - gamma0[it].load_from(params.gamma, idx); - if (!Tied_norm) { gamma1[it].load_from(params.gamma1, idx); } - idx += Ktraits::VEC_COLS_PER_LDG; - } - } - // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the - // last blocks with syncthreads! - // grid stride over rows - #pragma unroll 1 - for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { - const compute_t mu_r = static_cast(params.mu)[row]; - const compute_t rs_r = static_cast(params.rs)[row]; - Mvec dmask0[LDGS], dmask1[LDGS]; - Rvec dx[LDGS]; - compute_t dy[LDGS * NUM_ELTS]; - compute_t y[LDGS * NUM_ELTS]; - compute_t mdy_local = 0.f; - compute_t mdyy_local = 0.f; - index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - Rvec x; - Ovec dz0, dz1; - dz0.load_from(params.dz, idx); - if (!Tied_norm) { dz1.load_from(params.dz1, idx); } - if (prenorm) { dx[it].load_from(params.dx, idx); } - x.load_from(params.x, idx); - if (Is_dropout) { - dmask0[it].load_from(params.dmask, idx); - if (has_x1) { dmask1[it].load_from(params.dmask1, idx); } - } - idx += Ktraits::VEC_COLS_PER_LDG; - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t x_tmp = x.data.elt[jt]; - compute_t y_tmp = rs_r * (x_tmp - (!params.is_rms_norm ? mu_r : 0.f)); - compute_t dy_tmp = compute_t(gamma0[it].data.elt[jt]) * compute_t(dz0.data.elt[jt]); - if (!Tied_norm) { - dy_tmp += compute_t(gamma1[it].data.elt[jt]) * compute_t(dz1.data.elt[jt]); - } - compute_t dz0_tmp = dz0.data.elt[jt]; - compute_t dz1_tmp; - if (!Tied_norm) { dz1_tmp = dz1.data.elt[jt]; } - - mdy_local += dy_tmp; - mdyy_local += dy_tmp * y_tmp; - - dy[it * NUM_ELTS + jt] = dy_tmp; - y[it * NUM_ELTS + jt] = y_tmp; - - dz0y_sum[it].data.elt[jt] += dz0_tmp * y_tmp; - dz0_sum[it].data.elt[jt] += dz0_tmp; - if (!Tied_norm) { - dz1y_sum[it].data.elt[jt] += dz1_tmp * y_tmp; - dz1_sum[it].data.elt[jt] += dz1_tmp; - } - } - } - } - - reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum); - mdy_local = layer_norm::Get<0>::of(result) * params.inverse_cols; - mdyy_local = layer_norm::Get<1>::of(result) * params.inverse_cols; - - idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - Ivec dx0, dx1; - Rvec dresidual; - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t dx_tmp_res; - compute_t dy_tmp = dy[it * NUM_ELTS + jt]; - compute_t y_tmp = y[it * NUM_ELTS + jt]; - compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + (!params.is_rms_norm ? mdy_local : 0.f))); - dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp; - if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; } - if (Is_dropout) { - dx0.data.elt[jt] = dmask0[it].data.elt[jt] ? dx_tmp_res * params.dropout_scale : 0.f; - if (has_x1) { dx1.data.elt[jt] = dmask1[it].data.elt[jt] ? dx_tmp_res * params.dropout_scale : 0.f; } - } else { - dx0.data.elt[jt] = dx_tmp_res; - if (has_x1) { dx1.data.elt[jt] = dx_tmp_res; } - } - } - if (has_residual) { dresidual.store_to(params.dresidual, idx); } - dx0.store_to(params.dx0, idx); - if (has_x1) { dx1.store_to(params.dx1, idx); } - idx += Ktraits::VEC_COLS_PER_LDG; - } - } - - } // end: grid stride loop - - if( WARPS_M == 1 ) { - idx = r * params.cols / Ktraits::ELTS_PER_LDG + c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - dz0_sum[it].store_to(params.dbeta_part, idx); - dz0y_sum[it].store_to(params.dgamma_part, idx); - if (!Tied_norm) { - dz1_sum[it].store_to(params.dbeta1_part, idx); - dz1y_sum[it].store_to(params.dgamma1_part, idx); - } - idx += Ktraits::VEC_COLS_PER_LDG; - } - } - } else { - static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA."); - // Finalize reduction of part dgamma and dbeta for this CTA - // by reducing over the rows held across the WARPS_M warps - - // Assumption: blockSize divides hidden size. - enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA }; - static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, ""); - - idx = warp_m * Ktraits::VEC_COLS + tid_r; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - dz0_sum[it].store_to(smem_wgrad, idx); - idx += THREADS_PER_ROW; - } - __syncthreads(); - compute_t cta_dz0_sum[NUM_RES]; - memset(cta_dz0_sum, 0, sizeof(compute_t) * NUM_RES); - for( int it = 0; it < ROWS_PER_CTA; it++ ) { - for( int jt = 0; jt < NUM_RES; jt++ ) { - cta_dz0_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; - } - } - __syncthreads(); - - idx = warp_m * Ktraits::VEC_COLS + tid_r; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - dz0y_sum[it].store_to(smem_wgrad, idx); - idx += THREADS_PER_ROW; - } - __syncthreads(); - compute_t cta_dz0y_sum[NUM_RES]; - memset(cta_dz0y_sum, 0, sizeof(compute_t) * NUM_RES); - for( int it = 0; it < ROWS_PER_CTA; it++ ) { - for( int jt = 0; jt < NUM_RES; jt++ ) { - cta_dz0y_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; - } - } - - compute_t cta_dz1_sum[NUM_RES], cta_dz1y_sum[NUM_RES]; - if (!Tied_norm) { - __syncthreads(); - idx = warp_m * Ktraits::VEC_COLS + tid_r; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - dz1_sum[it].store_to(smem_wgrad, idx); - idx += THREADS_PER_ROW; - } - __syncthreads(); - memset(cta_dz1_sum, 0, sizeof(compute_t) * NUM_RES); - for( int it = 0; it < ROWS_PER_CTA; it++ ) { - for( int jt = 0; jt < NUM_RES; jt++ ) { - cta_dz1_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; - } - } - __syncthreads(); - idx = warp_m * Ktraits::VEC_COLS + tid_r; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - dz1y_sum[it].store_to(smem_wgrad, idx); - idx += THREADS_PER_ROW; - } - __syncthreads(); - memset(cta_dz1y_sum, 0, sizeof(compute_t) * NUM_RES); - for( int it = 0; it < ROWS_PER_CTA; it++ ) { - for( int jt = 0; jt < NUM_RES; jt++ ) { - cta_dz1y_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; - } - } - } - - const index_t num_valid_writes - = (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA; - compute_t *dgamma0_part = static_cast(params.dgamma_part) + bidm * params.cols + tidx; - compute_t *dbeta0_part = static_cast(params.dbeta_part) + bidm * params.cols + tidx; - compute_t *dgamma1_part = !Tied_norm ? static_cast(params.dgamma1_part) + bidm * params.cols + tidx : nullptr; - compute_t *dbeta1_part = !Tied_norm ? static_cast(params.dbeta1_part) + bidm * params.cols + tidx : nullptr; - for( int jt = 0; jt < NUM_RES; jt++ ) { - if (Is_even_cols || (jt < num_valid_writes)) { - *dgamma0_part = cta_dz0y_sum[jt]; - dgamma0_part += Ktraits::THREADS_PER_CTA; - *dbeta0_part = cta_dz0_sum[jt]; - dbeta0_part += Ktraits::THREADS_PER_CTA; - if (!Tied_norm) { - *dgamma1_part = cta_dz1y_sum[jt]; - dgamma1_part += Ktraits::THREADS_PER_CTA; - *dbeta1_part = cta_dz1_sum[jt]; - dbeta1_part += Ktraits::THREADS_PER_CTA; - } - } - } - - } -} - -template -__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) -void ln_parallel_residual_bwd_finalize_kernel(BwdParams params) -{ - - using compute_t = typename Kernel_traits::compute_t; - using weight_t = typename Kernel_traits::weight_t; - using index_t = typename Kernel_traits::index_t; - using Reducer = typename Kernel_traits::Reducer; - using reduce_t = typename Reducer::Type; - - Sum sum; - enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG }; - enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP }; - - // Multiplying by 2 since we have both gamma0 and gamma1 - __shared__ char smem_[2 * Kernel_traits::SMEM_BYTES_PER_CTA]; - - constexpr uint32_t bidm = 0; - - const uint32_t bidn = blockIdx.x; - const uint32_t tidx = threadIdx.x; - const uint32_t warp = tidx / THREADS_PER_WARP; - const uint32_t lane = tidx % THREADS_PER_WARP; - - Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_); - - const uint32_t c = bidn * THREADS_PER_WARP + lane; - const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; - constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; - for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) { - // Each thread sums over NUM_ELT columns. - Vec dbeta0_local, dgamma0_local, dbeta1_local, dgamma1_local; - memset(&dgamma0_local, 0, sizeof(dgamma0_local)); - memset(&dbeta0_local, 0, sizeof(dbeta0_local)); - memset(&dgamma1_local, 0, sizeof(dgamma1_local)); - memset(&dbeta1_local, 0, sizeof(dbeta1_local)); - if (Is_even_cols || col < params.cols) { - for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) { - index_t idx = row * params.cols + col; - - Vec dbeta0_part, dgamma0_part, dbeta1_part, dgamma1_part; - dbeta0_part.load_from(params.dbeta_part, idx); - dgamma0_part.load_from(params.dgamma_part, idx); - dbeta1_part.load_from(params.dbeta1_part, idx); - dgamma1_part.load_from(params.dgamma1_part, idx); - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - dgamma0_local.data.elt[it] += dgamma0_part.data.elt[it]; - dbeta0_local.data.elt[it] += dbeta0_part.data.elt[it]; - dgamma1_local.data.elt[it] += dgamma1_part.data.elt[it]; - dbeta1_local.data.elt[it] += dbeta1_part.data.elt[it]; - } - } - } - void * smem_gamma0 = smem_; - void * smem_beta0 = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; - void * smem_gamma1 = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; - void * smem_beta1 = &smem_[3 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; - - const int write_row = warp; - const int write_col = lane ^ write_row; - const int write_idx = write_row * THREADS_PER_WARP + write_col; - - dgamma0_local.store_to(smem_gamma0, write_idx); - dbeta0_local.store_to(smem_beta0, write_idx); - dgamma1_local.store_to(smem_gamma1, write_idx); - dbeta1_local.store_to(smem_beta1, write_idx); - - __syncthreads(); - - // It would be probably safe to reuse the first row of smem_beta0 and smem_gamma0 - void * smem_gamma0_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; - void * smem_beta0_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT]; - void * smem_gamma1_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT]; - void * smem_beta1_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE + 3 * Kernel_traits::SMEM_BYTES_OUTPUT]; - - // More than one iter iff ROWS_PER_CTA < 32. - for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) { - const int read_row = lane; - const int read_col = w ^ read_row; - const int read_idx = read_row * THREADS_PER_WARP + read_col; - - memset(&dbeta0_local, 0, sizeof(dbeta0_local)); - memset(&dgamma0_local, 0, sizeof(dgamma0_local)); - memset(&dbeta1_local, 0, sizeof(dbeta1_local)); - memset(&dgamma1_local, 0, sizeof(dgamma1_local)); - - // Load beta and gamma transposed - if(read_row < Kernel_traits::ROWS_PER_CTA){ - dbeta0_local.load_from(smem_beta0, read_idx); - dgamma0_local.load_from(smem_gamma0, read_idx); - dbeta1_local.load_from(smem_beta1, read_idx); - dgamma1_local.load_from(smem_gamma1, read_idx); - } - - // Call reducer on the loaded value(s) and convert. - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - compute_t b0_i = dbeta0_local.data.elt[it]; - compute_t g0_i = dgamma0_local.data.elt[it]; - compute_t b1_i = dbeta1_local.data.elt[it]; - compute_t g1_i = dgamma1_local.data.elt[it]; - b0_i = reducer.allreduce(b0_i, sum); - g0_i = reducer.allreduce(g0_i, sum); - b1_i = reducer.allreduce(b1_i, sum); - g1_i = reducer.allreduce(g1_i, sum); - - dgamma0_local.data.elt[it] = g0_i; - dbeta0_local.data.elt[it] = b0_i; - dgamma1_local.data.elt[it] = g1_i; - dbeta1_local.data.elt[it] = b1_i; - } - - // Leader stores the result at the current column. - if(lane == 0){ - dgamma0_local.store_to(smem_gamma0_out, w); - dbeta0_local.store_to(smem_beta0_out, w); - dgamma1_local.store_to(smem_gamma1_out, w); - dbeta1_local.store_to(smem_beta1_out, w); - } - - } - - // All writes done. - __syncthreads(); - - // Pack and store: 2-wide stores with half the threads. - if (Is_even_cols || col_out * 2 < params.cols) { - if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) { - - using src_t = typename TypeToVec2::Type; - using dst_t = typename TypeToVec2::Type; - Vec dbeta0_vec2, dgamma0_vec2, dbeta1_vec2, dgamma1_vec2; - Vec dbeta0_out2, dgamma0_out2, dbeta1_out2, dgamma1_out2; - - dgamma0_vec2.load_from(smem_gamma0_out, lane); - dbeta0_vec2.load_from(smem_beta0_out, lane); - dgamma1_vec2.load_from(smem_gamma1_out, lane); - dbeta1_vec2.load_from(smem_beta1_out, lane); - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - dgamma0_out2.data.elt[it] = Converter::convert(dgamma0_vec2.data.elt[it]); - dbeta0_out2.data.elt[it] = Converter::convert(dbeta0_vec2.data.elt[it]); - dgamma1_out2.data.elt[it] = Converter::convert(dgamma1_vec2.data.elt[it]); - dbeta1_out2.data.elt[it] = Converter::convert(dbeta1_vec2.data.elt[it]); - } - dgamma0_out2.store_to(params.dgamma, col_out); - dbeta0_out2.store_to(params.dbeta, col_out); - dgamma1_out2.store_to(params.dgamma1, col_out); - dbeta1_out2.store_to(params.dbeta1, col_out); - } - } - } -} - -} // namespace layer_norm - -using namespace layer_norm; - -template< - typename weight_t, - typename input_t, - typename residual_t, - typename output_t, - typename compute_t, - typename index_t, - int HIDDEN_SIZE, - int CTAS_PER_ROW, - int WARPS_M, - int WARPS_N, - int BYTES_PER_LDG_MAIN, - int BYTES_PER_LDG_FINAL -> -void launch_parallel_residual_(LaunchParams &launch_params, const bool configure_params){ - - using Kernel_traits = Kernel_traits; - bool is_dropout = launch_params.params.dropout_keep_p < 1.f; - bool tied_norm = launch_params.params.gamma1 == nullptr; - bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; - BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { - BOOL_SWITCH(tied_norm, TiedNormConst, [&] { - BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { - auto kernel = &ln_parallel_residual_bwd_kernel; - if( configure_params ) { - int ctas_per_sm; - CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES)); - launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if(Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col - * Kernel_traits::WARPS_M - * Kernel_traits::CTAS_PER_ROW - * sizeof(typename Kernel_traits::reduce_t) - * 2; - } - return; - } - - if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) { - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - - if( Kernel_traits::CTAS_PER_ROW == 1 ) { - kernel<<>>(launch_params.params); - } else { - dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = (void *)&launch_params.params; - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); - } - - using Kernel_traits_f = layer_norm::Kernel_traits_finalize; - - auto kernel_f = !TiedNormConst - ? &layer_norm::ln_parallel_residual_bwd_finalize_kernel - : &layer_norm::ln_bwd_finalize_kernel; - kernel_f<<>>(launch_params.params); - - }); - }); - }); -} diff --git a/layer_norm/ln_parallel_residual_fwd_kernels.cuh b/layer_norm/ln_parallel_residual_fwd_kernels.cuh deleted file mode 100644 index 0e55cb4038b4dbe30d9eb47609df3afea4c4f5fb..0000000000000000000000000000000000000000 --- a/layer_norm/ln_parallel_residual_fwd_kernels.cuh +++ /dev/null @@ -1,281 +0,0 @@ -#pragma once - -#ifdef OLD_GENERATOR_PATH -#include -#else -#include -#endif - -#include // For at::cuda::philox::unpack -#include - -#include "ln.h" -#include "ln_utils.cuh" -#include "ln_kernel_traits.h" -#include "static_switch.h" - -namespace layer_norm { - -template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) -void ln_parallel_residual_fwd_kernel(FwdParams params) { - - enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; - enum { WARPS_N = Ktraits::WARPS_N }; - enum { WARPS_M = Ktraits::WARPS_M }; - enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; - enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; - enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; - enum { LDGS = Ktraits::LDGS }; - enum { NUM_ELTS = Ktraits::NUM_ELTS }; - enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; - - using input_t = typename Ktraits::input_t; - using residual_t = typename Ktraits::residual_t; - using output_t = typename Ktraits::output_t; - using index_t = typename Ktraits::index_t; - using compute_t = typename Ktraits::compute_t; - using mask_t = typename Ktraits::mask_t; - using Ivec = typename Ktraits::Ivec; - using Rvec = typename Ktraits::Rvec; - using Ovec = typename Ktraits::Ovec; - using Wvec = typename Ktraits::Wvec; - using Cvec = typename Ktraits::Cvec; - using Mvec = typename Ktraits::Mvec; - - using Stats = typename Ktraits::Stats; - using stats_t = typename Stats::stats_t; - - const bool has_residual = params.residual != nullptr; - const bool has_x1 = params.x1 != nullptr; - const bool save_x = has_residual || has_x1 || Is_dropout || !(std::is_same::value); - - extern __shared__ char smem_[]; - - const index_t tidx = threadIdx.x; - const index_t bidn = blockIdx.x % CTAS_PER_ROW; - const index_t bidm = blockIdx.x / CTAS_PER_ROW; - const index_t lane = tidx % THREADS_PER_WARP; - const index_t warp = tidx / THREADS_PER_WARP; - const index_t warp_m = warp / WARPS_N; - const index_t warp_n = warp % WARPS_N; - - const index_t r = bidm * ROWS_PER_CTA + warp_m; - const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; - - Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); - - compute_t *mu_ptr = static_cast(params.mu); - compute_t *rs_ptr = static_cast(params.rs); - - // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu - curandStatePhilox4_32_10_t state; - if (Is_dropout) { - auto seeds = at::cuda::philox::unpack(params.philox_args); - const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x; - curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state); - } - - const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG; - - Wvec gamma0[LDGS]; - Wvec beta0[LDGS]; - Wvec gamma1[LDGS]; - Wvec beta1[LDGS]; - index_t idx = c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - gamma0[it].load_from(params.gamma, idx); - if (params.beta != nullptr) { - beta0[it].load_from(params.beta, idx); - } else { - beta0[it].zero_(); - } - if (!Tied_norm) { - gamma1[it].load_from(params.gamma1, idx); - if (params.beta1 != nullptr) { - beta1[it].load_from(params.beta1, idx); - } else { - beta1[it].zero_(); - } - } - idx += VEC_COLS_PER_LDG; - } - } - - for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { - index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; - compute_t xf[LDGS * NUM_ELTS]; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - Ivec x0; - Ivec x1; - Rvec residual; - Rvec x; - Mvec dmask0; - Mvec dmask1; - x0.load_from(params.x0, idx); - if (has_x1) { x1.load_from(params.x1, idx); } - if (has_residual) { residual.load_from(params.residual, idx); } - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use - // the more efficient curand_uniform4. - compute_t x_ij; - mask_t keep0 = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p; - if (Is_dropout) { dmask0.data.elt[jt] = keep0; } - compute_t x0_ij = compute_t(x0.data.elt[jt]); - x0_ij = keep0 ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f; - if (has_x1) { - mask_t keep1 = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p; - if (Is_dropout) { dmask1.data.elt[jt] = keep1; } - compute_t x1_ij = compute_t(x1.data.elt[jt]); - x1_ij = keep1 ? (Is_dropout ? x1_ij * params.dropout_scale : x1_ij) : 0.0f; - x_ij = has_residual ? x0_ij + x1_ij + compute_t(residual.data.elt[jt]) : x0_ij + x1_ij; - } else { - x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij; - } - if (save_x) { x.data.elt[jt] = x_ij; } - xf[it * NUM_ELTS + jt] = x_ij; - } - if (save_x) { x.store_to(params.x, idx); } - if (Is_dropout) { - dmask0.store_to(params.dmask, idx); - if (has_x1) { dmask1.store_to(params.dmask1, idx); } - } - idx += VEC_COLS_PER_LDG; - } - } - - static_assert(CTAS_PER_ROW == 1, "Don't support multiple CTAs per row for now"); - const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG; - const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG; - const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG; - auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int { - // Need to convert to int, otherwise the subtraction will wrap around. - const index_t valid_partial_vecs_in_warp = - std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)), - int(THREADS_PER_WARP)); - return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS; - }; - stats_t s = stats.template compute( - xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS - ); - - compute_t mu = layer_norm::Get<0>::of(s); - compute_t m2 = layer_norm::Get<1>::of(s); - - if( bidn == 0 && warp_n == 0 && lane == 0 ) { - mu_ptr[row] = mu; - } - - compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu)); - - if( bidn == 0 && warp_n == 0 && lane == 0 ) { - rs_ptr[row] = rs; - } - - idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - Ovec z0; - Ovec z1; - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f))); - compute_t g0_ij = gamma0[it].data.elt[jt]; - compute_t b0_ij = beta0[it].data.elt[jt]; - z0.data.elt[jt] = output_t(g0_ij * y_ij + b0_ij); - if (!Tied_norm) { - compute_t g1_ij = gamma1[it].data.elt[jt]; - compute_t b1_ij = beta1[it].data.elt[jt]; - z1.data.elt[jt] = output_t(g1_ij * y_ij + b1_ij); - } - } - z0.store_to(params.z, idx); - if (!Tied_norm) { z1.store_to(params.z1, idx); } - idx += VEC_COLS_PER_LDG; - } - } - - } -} - -} // namespace layer_norm - -using namespace layer_norm; - -template< - typename weight_t, - typename input_t, - typename residual_t, - typename output_t, - typename compute_t, - typename index_t, - int HIDDEN_SIZE, - int CTAS_PER_ROW, - int WARPS_M, - int WARPS_N, - int BYTES_PER_LDG -> -void launch_parallel_residual_(LaunchParams &launch_params, const bool configure_params){ - - using Kernel_traits = Kernel_traits; - bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; - bool tied_norm = launch_params.params.gamma1 == nullptr; - BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] { - BOOL_SWITCH(tied_norm, TiedNormConst, [&] { - BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { - auto kernel = &ln_parallel_residual_fwd_kernel; - if( configure_params ) { - int ctas_per_sm; - CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD)); - launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; - const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA; - launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if(Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col - * Kernel_traits::WARPS_M - * Kernel_traits::CTAS_PER_ROW - * sizeof(typename Kernel_traits::Stats::stats_t) - * 2; - } - return; - } - - if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - - if( Kernel_traits::CTAS_PER_ROW == 1 ) { - kernel<<>>(launch_params.params); - } else { - dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = (void *)&launch_params.params; - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); - } - }); - }); - }); -} diff --git a/layer_norm/ln_utils.cuh b/layer_norm/ln_utils.cuh deleted file mode 100644 index 178d6fda895b478ac76e2a77a2b1b35115fcc279..0000000000000000000000000000000000000000 --- a/layer_norm/ln_utils.cuh +++ /dev/null @@ -1,783 +0,0 @@ -#pragma once - -#include - -#include -#include - -#include "ln.h" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -constexpr uint32_t THREADS_PER_WARP = 32; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline void check_cuda_(cudaError_t status, const char *file, int line) { - if( status != cudaSuccess ) { - fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line); - exit(status); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define CHECK_CUDA(ans) \ - { check_cuda_((ans), __FILE__, __LINE__); } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define DIVUP(x, y) (((x) + ((y)-1)) / (y)) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ - void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ - const bool configure_params) { \ - launch_( \ - launch_params, configure_params); \ - } \ - static FwdRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ - ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_BWD_LAUNCHER( \ - HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ - const bool configure_params) { \ - launch_(launch_params, configure_params); \ - } \ - static BwdRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ - ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_PARALLEL_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ - void ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ - const bool configure_params) { \ - launch_parallel_residual_( \ - launch_params, configure_params); \ - } \ - static FwdParallelRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ - ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_PARALLEL_BWD_LAUNCHER( \ - HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ - const bool configure_params) { \ - launch_parallel_residual_(launch_params, configure_params); \ - } \ - static BwdParallelRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ - ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 operator+(const float2 & a, const float2 & b){ - return {a.x + b.x, a.y + b.y}; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void operator+=(float2 & a, const float2 & b){ - a.x += b.x; - a.y += b.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Sum { - inline __device__ Sum(){} - inline __device__ T operator()(const T &a, const T &b){ - return a + b; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){ - return __shfl_xor_sync(uint32_t(-1), x, idx); -} - -template<> -inline __device__ float2 warp_shuffle_xor(const float2 & x, uint32_t idx){ - return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) }; -} - -template -inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){ - return __shfl_down_sync(uint32_t(-1), x, idx); -} - -template<> -inline __device__ float2 warp_shuffle_down(const float2 & x, uint32_t idx){ - return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) }; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace layer_norm { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct uint16 { - uint4 u; - uint4 v; - uint4 s; - uint4 t; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct uint8 { - uint4 u; - uint4 v; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BytesToType {}; - -template<> -struct BytesToType<64> { - using Type = uint16; - static_assert(sizeof(Type) == 64); -}; - -template<> -struct BytesToType<32> { - using Type = uint8; - static_assert(sizeof(Type) == 32); -}; - -template<> -struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); -}; - -template<> -struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); -}; - -template<> -struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); -}; - -template<> -struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); -}; - -template<> -struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TypeToVec2 {}; - -template<> -struct TypeToVec2 { - using Type = float2; -}; - -template<> -struct TypeToVec2 { - using Type = half2; -}; - -template<> -struct TypeToVec2 { - using Type = nv_bfloat162; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Get { - template - static inline __device__ R of(const T &vec); -}; - -template<> -template -inline __device__ R Get<0>::of(const T &vec) { - return vec.x; -} - -template<> -template -inline __device__ R Get<1>::of(const T &vec) { - return vec.y; -} - -template<> -template -inline __device__ R Get<2>::of(const T &vec) { - return vec.z; -} - -template<> -template -inline __device__ R Get<3>::of(const T &vec) { - return vec.w; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Converter{ - static inline __device__ Dst convert(const Src &from) { - return Dst(from); - } -}; - -template<> -struct Converter{ - static inline __device__ half2 convert(const float2 &x) { - return __float22half2_rn(x); - } -}; - -template<> -struct Converter{ - static inline __device__ nv_bfloat162 convert(const float2 &x) { -#if __CUDA_ARCH__ >= 800 - return __float22bfloat162_rn(x); -#else - union { - nv_bfloat162 raw; - nv_bfloat16 x; - nv_bfloat16 y; - } tmp; - tmp.x = __float2bfloat16_rn(x.x); - tmp.y = __float2bfloat16_rn(x.y); - return tmp.raw; -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Zeros{ - static inline __device__ T get() { - return T(0.f); - } -}; - -template<> -struct Zeros{ - static inline __device__ float2 get() { - return make_float2(0.f, 0.f); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Vec { - - enum { BYTES = NUM_ELT * sizeof(Elt_type) }; - - using Vec_type = typename BytesToType::Type; - - using Alias_type = union { - Vec_type vec; - Elt_type elt[NUM_ELT]; - }; - - Alias_type data; - - template - inline __device__ void to(Vec &other) { - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - other.data.elt[it] = S(this->data.elt[it]); - } - } - - template - inline __device__ void assign(const Op &op) { - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - this->data.elt[it] = op(it); - } - } - - inline __device__ void zero_() { - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - this->data.elt[it] = Elt_type(0.f); - } - } - - inline __device__ void load_from(const void *base_ptr, const size_t idx) { - this->data.vec = static_cast(base_ptr)[idx]; - } - - inline __device__ void store_to(void *base_ptr, const size_t idx) { - static_cast(base_ptr)[idx] = this->data.vec; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct InterCTASync { - - template - inline __device__ InterCTASync(Params & params, uint32_t bidm, uint32_t bidn) - : phase_counter_(0) - , b0_(params.barrier + bidm) // The barrier for this group of CTAs. - , b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs. - { - // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0! - } - - inline __device__ void spin_wait_(int *barrier, int step, int expected) { - asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); - for( int found = -1; found != expected; ) { - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); - } - } - - inline __device__ void sync(){ - // ALL THREADS MUST ENTER! - - // We switch barrier every iteration. - int *barrier = phase_counter_ & 0x1 ? b1_ : b0_; - // We decrement every other iteration. - bool dec = phase_counter_ & 0x2; - int step = dec ? -1 : 1; - int expected = dec ? 0 : CTAS_PER_ROW; - // There are only 4 phases: up/down for b0/b1. - phase_counter_ = (phase_counter_ + 1) & 0x3; - - if( threadIdx.x == 0 ) { - spin_wait_(barrier, step, expected); - } - // CTA waits for thread 0 - __syncthreads(); - } - - int phase_counter_; - int * b0_; - int * b1_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Reducer : public Reducer { - - using InterCTASync = InterCTASync; - using Base = Reducer; - using Type = typename Base::Type; - - enum { SMEM_BYTES = Base::SMEM_BYTES }; - - enum { WS_BARRIER_BYTES = 2 * sizeof(int) }; - enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) }; - - // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total) - enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES }; - - template - inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) - , inter_cta_(params, bidm, bidn) - , bidn_(bidn) // CTA id within the group. - , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) - , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) - { - } - - template - inline __device__ T allreduce(T data, Op &op) { - data = Base::reduce(data, op); - // We switch workspace every iteration. - T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; - - // Warp leaders 0 hold the CTA-local results. - if( this->warp_n_ == 0 && this->lane_ == 0 ) { - workspace[bidn_] = data; - } - inter_cta_.sync(); - static_assert(CTAS_PER_ROW <= 32); - T total = Zeros::get(); - if(this->lane_ < CTAS_PER_ROW){ - total = workspace[this->lane_]; - } - total = Reducer::allreduce_(total, op); - - return total; - } - - InterCTASync inter_cta_; - - T *w0_; - T *w1_; - int bidn_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Reducer { - - using Type = T; - enum { SMEM_BYTES = 0 }; - enum { WORKSPACE_BYTES_PER_GROUP = 0 }; - - enum { THREADS_PER_WARP = 32 }; - - template - inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : warp_n_(warp_n) - , lane_(lane) - { - } - - template - static inline __device__ T allreduce_(T data, Op &op) { - #pragma unroll - for( int it = 1; it < THREADS_PER_WARP; it *= 2 ) { - data = op(data, warp_shuffle_xor(data, it)); - } - return data; - } - - template - inline __device__ T allreduce(T data, Op &op) { - return allreduce_(data, op); - } - - template - inline __device__ T reduce(T data, Op &op){ - // only lane 0 holds the result! - #pragma unroll - for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) { - data = op(data, warp_shuffle_down(data, it)); - } - return data; - } - int warp_n_; - int lane_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Reducer : public Reducer { - - using Base = Reducer; - - using Type = T; - - enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 }; - enum { WORKSPACE_BYTES_PER_GROUP = 0 }; - - enum { THREADS_PER_WARP = 32 }; - - template - inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) - , use0_(true) - { - smem0_ = &static_cast(smem)[warp_m * WARPS_N]; - smem1_ = smem0_ + WARPS_M * WARPS_N; - } - - template - inline __device__ T allreduce(T data, Op & op) { - T * smem = use0_ ? smem0_ : smem1_; - use0_ = !use0_; - data = Base::reduce(data, op); - if( this->lane_ == 0 ) { - smem[this->warp_n_] = data; - } - __syncthreads(); - T out = Zeros::get(); - #pragma unroll - for( int it = 0; it < WARPS_N; it++ ) { - out = op(out, smem[it]); - } - return out; - } - - template - inline __device__ T reduce(T data, Op &op) { - T * smem = use0_ ? smem0_ : smem1_; - use0_ = !use0_; - // only intra-CTA group leader holds the result! - data = Base::reduce(data, op); - if( this->lane_ == 0 ) { - smem[this->warp_n_] = data; - } - __syncthreads(); - T out = Zeros::get(); - if( this->warp_n_ == 0 && this->lane_ == 0 ) { - #pragma unroll - for( int it = 0; it < WARPS_N; it++ ) { - out = op(out, smem[it]); - } - } - return out; - } - - T * smem0_; - T * smem1_; - bool use0_; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, int_t &n_a, int num_active){ - //Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise) - const int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); - - #pragma unroll - for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) { - // Exchange - int_t n_b = warp_shuffle_down(n_a, step); - T m_b = warp_shuffle_down(m_a, step); - T m2_b = warp_shuffle_down(m2_a, step); - - // Update - const int_t n_ab = n_a + n_b; // We can handle one of them being 0, not both. - const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :( - const T delta = m_a - m_b; - const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; - const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab; - - n_a = n_ab; - m_a = m_ab; - m2_a = m2_ab; - } - // Intra-warp broadcast (only lane 0 has valid stats). - m_a = __shfl_sync(uint32_t(-1), m_a, 0); - m2_a = __shfl_sync(uint32_t(-1), m2_a, 0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Stats { - // This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields. - - using InterCTASync = InterCTASync; - using BlockStats = Stats; - using stats_t = typename BlockStats::stats_t; - - enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; - - template - inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : inter_cta_(params, bidm, bidn) - , block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) - , bidn_(bidn) // CTA id within the group. - , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) - , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) - , warp_n_(warp_n) - , lane_(lane) - { - } - - template - inline __device__ stats_t compute(const T (&elts)[N], const T rn) { - constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP; - // TODO rn is not really needed here.. - constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA); - stats_t block_stats = block_stats_.compute(elts, block_rn); - - stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; - - if( warp_n_ == 0 && lane_ == 0 ) { - workspace[bidn_] = block_stats; - } - - // Wait for all CTAS_PER_ROW CTAS in the group to have written their result. - inter_cta_.sync(); - - T n = Zeros::get(); - T m = Zeros::get(); - T m2 = Zeros::get(); - - // Assume CTA group size in N less than 32, such that we can finalize with a single warp. - static_assert(CTAS_PER_ROW <= 32); - - // Every warp does the final reduction locally. - if( lane_ < CTAS_PER_ROW ) { - stats_t result = workspace[lane_]; - n = ELTS_PER_ROW_PER_CTA; - m = layer_norm::Get<0>::of(result); - m2 = layer_norm::Get<1>::of(result); - } - - warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW); - - return { m, m2 }; - } - - InterCTASync inter_cta_; - BlockStats block_stats_; - - stats_t *w0_; - stats_t *w1_; - int bidn_; - int warp_n_; - int lane_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Stats { - - using WarpStats = Stats; - using stats_t = typename WarpStats::stats_t; - - enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; - - template - inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) - , use0_(true) - { - smem0_ = static_cast(smem) + warp_m * WARPS_N; - smem1_ = smem0_ + WARPS_M * WARPS_N; - } - - template - inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor, - function_t valid_elts_in_warp_fn, const int num_valid_elts = N) { - stats_t * smem = use0_ ? smem0_ : smem1_; - use0_ = !use0_; - // Compute warp local for all WARPS_N - const auto warp_n = warp_stats_.reducer_.warp_n_; - const T warp_norm_factor = 1.f / T(Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(warp_n)); - stats_t warp_stats = warp_stats_.template compute( - elts, warp_norm_factor, valid_elts_in_warp_fn, num_valid_elts - ); - - //Each warp warp leader stores its stats - const auto lane = warp_stats_.reducer_.lane_; - if( lane == 0 ) { - smem[warp_n] = warp_stats; - } - __syncthreads(); - - int n = 0;; - T m = Zeros::get(); - T m2 = Zeros::get(); - - // Assume that there are less than 32 warps, such that we can finalize with a single warp - static_assert(WARPS_N <= 32); - if(lane < WARPS_N){ - stats_t result = smem[lane]; - n = Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(lane); - m = layer_norm::Get<0>::of(result); - m2 = layer_norm::Get<1>::of(result); - } - - warp_chan_upd_dynamic(m, m2, n, WARPS_N); - - return { m, m2 }; - } - WarpStats warp_stats_; - stats_t * smem0_; - stats_t * smem1_; - bool use0_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Stats { - - using stats_t = typename TypeToVec2::Type; - // The simple Warp reducer. - using Reducer = Reducer; - - enum { SMEM_BYTES = 0 }; - - template - inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) - { - } - - template - inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor, - // const int valid_elts_in_warp_ignored_, const int num_valid_elts = N) { - function_t valid_elts_in_warp_fn, const int num_valid_elts = N) { - - auto sum = Sum(); - - T m = Zeros::get(); - #pragma unroll - for( int it = 0; it < N; it++ ) { - if (Is_even_cols || (it < num_valid_elts)) { - m += elts[it]; - } - } - m = reducer_.allreduce(m, sum) * row_norm_factor; - - T m2 = Zeros::get(); - #pragma unroll - for( int it = 0; it < N; it++ ) { - if (Is_even_cols || (it < num_valid_elts)) { - T diff = (elts[it] - m); - m2 += diff * diff; - } - } - m2 = reducer_.allreduce(m2, sum); - - return {m, m2}; - } - - Reducer reducer_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace layer_norm diff --git a/layer_norm/static_switch.h b/layer_norm/static_switch.h deleted file mode 100644 index 7920ac045d0a2a1f4c4159ee3eebe51fe1e2c203..0000000000000000000000000000000000000000 --- a/layer_norm/static_switch.h +++ /dev/null @@ -1,25 +0,0 @@ -// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h -// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h - -#pragma once - -/// @param COND - a boolean expression to switch by -/// @param CONST_NAME - a name given for the constexpr bool variable. -/// @param ... - code to execute for true and false -/// -/// Usage: -/// ``` -/// BOOL_SWITCH(flag, BoolConst, [&] { -/// some_function(...); -/// }); -/// ``` -#define BOOL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - if (COND) { \ - constexpr bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index b9da95d03362688455d5ba0560560268b3ad6832..0000000000000000000000000000000000000000 --- a/pyproject.toml +++ /dev/null @@ -1,10 +0,0 @@ -[build-system] -requires = [ - "cmake>=3.26", - "ninja", - "packaging", - "setuptools>=61", - "torch", - "wheel", -] -build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/setup.py b/setup.py deleted file mode 100644 index 5c563b2266c9c2d011021d9eb9f1bd7254340f8e..0000000000000000000000000000000000000000 --- a/setup.py +++ /dev/null @@ -1,138 +0,0 @@ -import logging -import os -from shutil import which, move -import subprocess -import sys -from pathlib import Path - -from setuptools import Extension, find_packages, setup -from setuptools.command.build_ext import build_ext - -logger = logging.getLogger(__name__) - - -def is_sccache_available() -> bool: - return which("sccache") is not None - - -def is_ccache_available() -> bool: - return which("ccache") is not None - - -def is_ninja_available() -> bool: - return which("ninja") is not None - - -class CMakeExtension(Extension): - def __init__(self, name: str, sourcedir: str = "") -> None: - super().__init__(name, sources=[], py_limited_api=True) - self.sourcedir = os.fspath(Path(sourcedir).resolve()) - - -class CMakeBuild(build_ext): - def build_extension(self, ext: CMakeExtension) -> None: - ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) - extdir = ext_fullpath.parent.resolve() - - debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug - cfg = "Debug" if debug else "Release" - - cmake_generator = os.environ.get("CMAKE_GENERATOR", "") - - # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON - # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code - # from Python. - cmake_args = [ - f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}", - f"-DPython_EXECUTABLE={sys.executable}", - f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm - ] - build_args = [] - if "CMAKE_ARGS" in os.environ: - cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] - - if not cmake_generator or cmake_generator == "Ninja": - try: - import ninja - - ninja_executable_path = Path(ninja.BIN_DIR) / "ninja" - cmake_args += [ - "-GNinja", - f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}", - ] - except ImportError: - pass - - if is_sccache_available(): - cmake_args += [ - "-DCMAKE_C_COMPILER_LAUNCHER=sccache", - "-DCMAKE_CXX_COMPILER_LAUNCHER=sccache", - "-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache", - "-DCMAKE_HIP_COMPILER_LAUNCHER=sccache", - ] - elif is_ccache_available(): - cmake_args += [ - "-DCMAKE_C_COMPILER_LAUNCHER=ccache", - "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache", - "-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache", - "-DCMAKE_HIP_COMPILER_LAUNCHER=ccache", - ] - - num_jobs = os.getenv("MAX_JOBS", None) - if num_jobs is not None: - num_jobs = int(num_jobs) - logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs) - else: - try: - # os.sched_getaffinity() isn't universally available, so fall - # back to os.cpu_count() if we get an error here. - num_jobs = len(os.sched_getaffinity(0)) - except AttributeError: - num_jobs = os.cpu_count() - - nvcc_threads = os.getenv("NVCC_THREADS", None) - if nvcc_threads is not None: - nvcc_threads = int(nvcc_threads) - logger.info( - "Using NVCC_THREADS=%d as the number of nvcc threads.", nvcc_threads - ) - else: - nvcc_threads = 1 - num_jobs = max(1, num_jobs // nvcc_threads) - - build_args += [f"-j{num_jobs}"] - if sys.platform == "win32": - build_args += ["--config", cfg] - - if nvcc_threads: - cmake_args += ["-DNVCC_THREADS={}".format(nvcc_threads)] - - build_temp = Path(self.build_temp) / ext.name - if not build_temp.exists(): - build_temp.mkdir(parents=True) - - subprocess.run( - ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True - ) - subprocess.run( - ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True - ) - if sys.platform == "win32": - # Move the dylib one folder up for discovery. - for filename in os.listdir(extdir / cfg): - move(extdir / cfg / filename, extdir / filename) - - - -setup( - name="layer_norm", - # The version is just a stub, it's not used by the final build artefact. - version="0.1.0", - ext_modules=[CMakeExtension("layer_norm._layer_norm_711aa42_dirty")], - cmdclass={"build_ext": CMakeBuild}, - packages=find_packages(where="torch-ext", include=["layer_norm*"]), - package_dir={"": "torch-ext"}, - zip_safe=False, - install_requires=["torch"], - python_requires=">=3.9", -) \ No newline at end of file diff --git a/setup_backup.py b/setup_backup.py deleted file mode 100644 index b9337dab15de68aa8dc02c1b961cb0fe8e6e4feb..0000000000000000000000000000000000000000 --- a/setup_backup.py +++ /dev/null @@ -1,203 +0,0 @@ -# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import os -from packaging.version import parse, Version - -import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME -from setuptools import setup, find_packages -import subprocess - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) - torch_binary_version = parse(torch.version.cuda) - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_version != torch_binary_version): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pytorch binaries. " - "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.2"): - nvcc_threads = os.getenv("NVCC_THREADS") or "4" - return nvcc_extra_args + ["--threads", nvcc_threads] - return nvcc_extra_args - - -if not torch.cuda.is_available(): - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" - "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" - "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.8"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" - elif bare_metal_version >= Version("11.1"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" - elif bare_metal_version == Version("11.0"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" - - -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) - -cmdclass = {} -ext_modules = [] - -# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h -# See https://github.com/pytorch/pytorch/pull/70650 -generator_flag = [] -torch_dir = torch.__path__[0] -if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - -raise_if_cuda_home_none("--fast_layer_norm") -# Check, if CUDA11 is installed for compute capability 8.0 -cc_flag = [] -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) -if bare_metal_version < Version("11.0"): - raise RuntimeError("dropout_layer_norm is only supported on CUDA 11 and above") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_70,code=sm_70") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") -if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - -ext_modules.append( - CUDAExtension( - name="dropout_layer_norm", - sources=[ - "ln_api.cpp", - "ln_fwd_256.cu", - "ln_bwd_256.cu", - "ln_fwd_512.cu", - "ln_bwd_512.cu", - "ln_fwd_768.cu", - "ln_bwd_768.cu", - "ln_fwd_1024.cu", - "ln_bwd_1024.cu", - "ln_fwd_1280.cu", - "ln_bwd_1280.cu", - "ln_fwd_1536.cu", - "ln_bwd_1536.cu", - "ln_fwd_2048.cu", - "ln_bwd_2048.cu", - "ln_fwd_2560.cu", - "ln_bwd_2560.cu", - "ln_fwd_3072.cu", - "ln_bwd_3072.cu", - "ln_fwd_4096.cu", - "ln_bwd_4096.cu", - "ln_fwd_5120.cu", - "ln_bwd_5120.cu", - "ln_fwd_6144.cu", - "ln_bwd_6144.cu", - "ln_fwd_7168.cu", - "ln_bwd_7168.cu", - "ln_fwd_8192.cu", - "ln_bwd_8192.cu", - "ln_parallel_fwd_256.cu", - "ln_parallel_bwd_256.cu", - "ln_parallel_fwd_512.cu", - "ln_parallel_bwd_512.cu", - "ln_parallel_fwd_768.cu", - "ln_parallel_bwd_768.cu", - "ln_parallel_fwd_1024.cu", - "ln_parallel_bwd_1024.cu", - "ln_parallel_fwd_1280.cu", - "ln_parallel_bwd_1280.cu", - "ln_parallel_fwd_1536.cu", - "ln_parallel_bwd_1536.cu", - "ln_parallel_fwd_2048.cu", - "ln_parallel_bwd_2048.cu", - "ln_parallel_fwd_2560.cu", - "ln_parallel_bwd_2560.cu", - "ln_parallel_fwd_3072.cu", - "ln_parallel_bwd_3072.cu", - "ln_parallel_fwd_4096.cu", - "ln_parallel_bwd_4096.cu", - "ln_parallel_fwd_5120.cu", - "ln_parallel_bwd_5120.cu", - "ln_parallel_fwd_6144.cu", - "ln_parallel_bwd_6144.cu", - "ln_parallel_fwd_7168.cu", - "ln_parallel_bwd_7168.cu", - "ln_parallel_fwd_8192.cu", - "ln_parallel_bwd_8192.cu", - ], - extra_compile_args={ - "cxx": ["-O3"] + generator_flag, - "nvcc": append_nvcc_threads( - [ - "-O3", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - ] - + generator_flag - + cc_flag - ), - }, - include_dirs=[this_dir], - ) -) - -setup( - name="dropout_layer_norm", - version="0.1", - description="Fused dropout + add + layer norm", - ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension} if ext_modules else {}, -) diff --git a/torch-ext/torch_binding.cpp b/torch-ext/torch_binding.cpp deleted file mode 100644 index 308ce081ec5f8a9150c23a110fe19a85b283955c..0000000000000000000000000000000000000000 --- a/torch-ext/torch_binding.cpp +++ /dev/null @@ -1,154 +0,0 @@ -#include - -#include "registration.h" -#include "torch_binding.h" - -// Helper to turn Tensor? from schema (optional by value) into optional& args -template -static c10::optional as_const_opt(const c10::optional& v) { - if (v.has_value()) return c10::optional(v.value()); - return c10::optional(); -} - -// Wrappers with dispatcher-friendly types (double scalars, optional Generator) -// Forward -static std::vector dropout_add_ln_fwd_wrap( - const at::Tensor& input, - const at::Tensor& gamma, - c10::optional beta, - c10::optional rowscale, - c10::optional colscale, - c10::optional x0_subset, - c10::optional z_subset, - double dropout_p, - double epsilon, - double rowscale_const, - int64_t z_numrows, - c10::optional gen, - bool residual_in_fp32, - bool is_rms_norm) { - - // residual is not exposed in this schema (None) - auto residual_c = c10::optional(); - auto beta_c = as_const_opt(beta); - auto rowscale_c = as_const_opt(rowscale); - auto colscale_c = as_const_opt(colscale); - auto x0_subset_c = as_const_opt(x0_subset); - auto z_subset_c = as_const_opt(z_subset); - - return dropout_add_ln_fwd( - input, residual_c, gamma, beta_c, rowscale_c, colscale_c, x0_subset_c, z_subset_c, - static_cast(dropout_p), - static_cast(epsilon), - static_cast(rowscale_const), - z_numrows, gen, residual_in_fp32, is_rms_norm); -} - -// Backward -static std::vector dropout_add_ln_bwd_wrap( - const at::Tensor& dz, - c10::optional dx, - const at::Tensor& x, - c10::optional x0, - c10::optional dmask, - const at::Tensor& mu, - const at::Tensor& rsigma, - const at::Tensor& gamma, - c10::optional rowscale, - c10::optional colscale, - c10::optional x0_subset, - c10::optional z_subset, - double dropout_p, - double rowscale_const, - int64_t x0_numrows, - bool has_residual, - bool is_rms_norm) { - - auto dx_c = as_const_opt(dx); - auto x0_c = as_const_opt(x0); - auto dmask_c = as_const_opt(dmask); - auto rowscale_c = as_const_opt(rowscale); - auto colscale_c = as_const_opt(colscale); - auto x0_subset_c = as_const_opt(x0_subset); - auto z_subset_c = as_const_opt(z_subset); - - return dropout_add_ln_bwd( - dz, dx_c, x, x0_c, dmask_c, mu, rsigma, gamma, - rowscale_c, colscale_c, x0_subset_c, z_subset_c, - static_cast(dropout_p), - static_cast(rowscale_const), - x0_numrows, has_residual, is_rms_norm); -} - -// Parallel forward -static std::vector dropout_add_ln_parallel_residual_fwd_wrap( - const at::Tensor& input, - c10::optional x1, - c10::optional residual, - const at::Tensor& gamma0, - c10::optional beta0, - c10::optional gamma1, - c10::optional beta1, - double dropout_p, - double epsilon, - c10::optional gen, - bool residual_in_fp32, - bool is_rms_norm) { - - auto x1_c = as_const_opt(x1); - auto residual_c = as_const_opt(residual); - auto beta0_c = as_const_opt(beta0); - auto gamma1_c = as_const_opt(gamma1); - auto beta1_c = as_const_opt(beta1); - - return dropout_add_ln_parallel_residual_fwd( - input, x1_c, residual_c, gamma0, beta0_c, gamma1_c, beta1_c, - static_cast(dropout_p), - static_cast(epsilon), - gen, residual_in_fp32, is_rms_norm); -} - -// Parallel backward -static std::vector dropout_add_ln_parallel_residual_bwd_wrap( - const at::Tensor& dz0, - c10::optional dz1, - c10::optional dx, - const at::Tensor& x, - c10::optional dmask0, - c10::optional dmask1, - const at::Tensor& mu, - const at::Tensor& rsigma, - const at::Tensor& gamma0, - c10::optional gamma1, - double dropout_p, - bool has_x1, - bool has_residual, - bool is_rms_norm) { - - auto dz1_c = as_const_opt(dz1); - auto dx_c = as_const_opt(dx); - auto dmask0_c = as_const_opt(dmask0); - auto dmask1_c = as_const_opt(dmask1); - auto gamma1_c = as_const_opt(gamma1); - - return dropout_add_ln_parallel_residual_bwd( - dz0, dz1_c, dx_c, x, dmask0_c, dmask1_c, mu, rsigma, gamma0, gamma1_c, - static_cast(dropout_p), has_x1, has_residual, is_rms_norm); -} - -TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { - // Return lists to match std::vector from implementations - ops.def("dropout_add_ln_fwd(Tensor input, Tensor gamma, Tensor? beta, Tensor? rowscale, Tensor? colscale, Tensor? x0_subset, Tensor? z_subset, float dropout_p, float epsilon, float rowscale_const, int z_numrows, Generator? gen, bool residual_in_fp32, bool is_rms_norm) -> Tensor[]"); - ops.impl("dropout_add_ln_fwd", torch::kCUDA, &dropout_add_ln_fwd_wrap); - - ops.def("dropout_add_ln_bwd(Tensor dz, Tensor? dx, Tensor x, Tensor? x0, Tensor? dmask, Tensor mu, Tensor rsigma, Tensor gamma, Tensor? rowscale, Tensor? colscale, Tensor? x0_subset, Tensor? z_subset, float dropout_p, float rowscale_const, int x0_numrows, bool has_residual, bool is_rms_norm) -> Tensor[]"); - ops.impl("dropout_add_ln_bwd", torch::kCUDA, &dropout_add_ln_bwd_wrap); - - ops.def("dropout_add_ln_parallel_residual_fwd(Tensor input, Tensor? x1, Tensor? residual, Tensor gamma0, Tensor? beta0, Tensor? gamma1, Tensor? beta1, float dropout_p, float epsilon, Generator? gen, bool residual_in_fp32, bool is_rms_norm) -> Tensor[]"); - ops.impl("dropout_add_ln_parallel_residual_fwd", torch::kCUDA, &dropout_add_ln_parallel_residual_fwd_wrap); - - ops.def("dropout_add_ln_parallel_residual_bwd(Tensor dz0, Tensor? dz1, Tensor? dx, Tensor x, Tensor? dmask0, Tensor? dmask1, Tensor mu, Tensor rsigma, Tensor gamma0, Tensor? gamma1, float dropout_p, bool has_x1, bool has_residual, bool is_rms_norm) -> Tensor[]"); - ops.impl("dropout_add_ln_parallel_residual_bwd", torch::kCUDA, &dropout_add_ln_parallel_residual_bwd_wrap); -} - -REGISTER_EXTENSION(TORCH_EXTENSION_NAME) \ No newline at end of file diff --git a/torch-ext/torch_binding.h b/torch-ext/torch_binding.h deleted file mode 100644 index 4090367438bcec839585d5380a32df1b4c84f34d..0000000000000000000000000000000000000000 --- a/torch-ext/torch_binding.h +++ /dev/null @@ -1,70 +0,0 @@ -#pragma once - -#include - -// Declarations for implementations defined in layer_norm/ln_api.cpp -std::vector dropout_add_ln_fwd( - const at::Tensor &x0, - c10::optional &residual, - const at::Tensor &gamma, - c10::optional &beta, - c10::optional &rowscale, - c10::optional &colscale, - c10::optional &x0_subset, - c10::optional &z_subset, - const float dropout_p, - const float epsilon, - const float rowscale_const, - const int64_t z_numrows, - c10::optional gen, - bool residual_in_fp32, - bool is_rms_norm); - -std::vector dropout_add_ln_bwd( - const at::Tensor &dz, - c10::optional &dx, - const at::Tensor &x, - c10::optional &x0, - c10::optional &dmask, - const at::Tensor &mu, - const at::Tensor &rsigma, - const at::Tensor &gamma, - c10::optional &rowscale, - c10::optional &colscale, - c10::optional &x0_subset, - c10::optional &z_subset, - const float dropout_p, - const float rowscale_const, - const int64_t x0_numrows, - const bool has_residual, - bool is_rms_norm); - -std::vector dropout_add_ln_parallel_residual_fwd( - const at::Tensor &x0, - c10::optional &x1, - c10::optional &residual, - const at::Tensor &gamma0, - c10::optional &beta0, - c10::optional &gamma1, - c10::optional &beta1, - const float dropout_p, - const float epsilon, - c10::optional gen, - bool residual_in_fp32, - bool is_rms_norm); - -std::vector dropout_add_ln_parallel_residual_bwd( - const at::Tensor &dz0, - c10::optional &dz1, - c10::optional &dx, - const at::Tensor &x, - c10::optional &dmask0, - c10::optional &dmask1, - const at::Tensor &mu, - const at::Tensor &rsigma, - const at::Tensor &gamma0, - c10::optional &gamma1, - const float dropout_p, - const bool has_x1, - const bool has_residual, - bool is_rms_norm); \ No newline at end of file