Torch 2.9 builds
Browse filesThis view is limited to 50 files because it contains too many changes. Â
See raw diff
- .gitignore +1 -0
- CMakeLists.txt +0 -213
- README.md +1 -20
- api.py +0 -800
- build/{torch28-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so → torch27-cxx11-cu118-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so} +2 -2
- build/torch27-cxx11-cu118-x86_64-linux/layer_norm/_ops.py +3 -3
- build/{torch27-cxx11-cu118-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so} +2 -2
- build/torch27-cxx11-cu126-x86_64-linux/layer_norm/_ops.py +3 -3
- build/torch27-cxx11-cu128-x86_64-linux/layer_norm/{_layer_norm_f622ea1_dirty.abi3.so → _layer_norm_f3fd6bf.abi3.so} +2 -2
- build/torch27-cxx11-cu128-x86_64-linux/layer_norm/_ops.py +3 -3
- build/{torch27-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so → torch28-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so} +2 -2
- build/torch28-cxx11-cu126-x86_64-linux/layer_norm/_ops.py +3 -3
- build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so +3 -0
- build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so +0 -3
- build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_ops.py +3 -3
- build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so +3 -0
- build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so +0 -3
- build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_ops.py +3 -3
- {torch-ext → build/torch29-cxx11-cu126-x86_64-linux}/layer_norm/__init__.py +0 -0
- build/torch29-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so +3 -0
- build/torch29-cxx11-cu126-x86_64-linux/layer_norm/_ops.py +9 -0
- {torch-ext → build/torch29-cxx11-cu126-x86_64-linux}/layer_norm/layers.py +0 -0
- build/torch29-cxx11-cu128-x86_64-linux/layer_norm/__init__.py +26 -0
- build/torch29-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so +3 -0
- build/torch29-cxx11-cu128-x86_64-linux/layer_norm/_ops.py +9 -0
- build/torch29-cxx11-cu128-x86_64-linux/layer_norm/layers.py +51 -0
- build/torch29-cxx11-cu130-x86_64-linux/layer_norm/__init__.py +26 -0
- build/torch29-cxx11-cu130-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so +3 -0
- build/torch29-cxx11-cu130-x86_64-linux/layer_norm/_ops.py +9 -0
- build/torch29-cxx11-cu130-x86_64-linux/layer_norm/layers.py +51 -0
- cmake/hipify.py +0 -76
- cmake/utils.cmake +0 -545
- layer_norm/ln.h +0 -281
- layer_norm/ln_api.cpp +0 -828
- layer_norm/ln_bwd_1024.cu +0 -15
- layer_norm/ln_bwd_1280.cu +0 -15
- layer_norm/ln_bwd_1536.cu +0 -15
- layer_norm/ln_bwd_2048.cu +0 -15
- layer_norm/ln_bwd_256.cu +0 -15
- layer_norm/ln_bwd_2560.cu +0 -15
- layer_norm/ln_bwd_3072.cu +0 -15
- layer_norm/ln_bwd_4096.cu +0 -15
- layer_norm/ln_bwd_512.cu +0 -15
- layer_norm/ln_bwd_5120.cu +0 -15
- layer_norm/ln_bwd_6144.cu +0 -15
- layer_norm/ln_bwd_7168.cu +0 -15
- layer_norm/ln_bwd_768.cu +0 -15
- layer_norm/ln_bwd_8192.cu +0 -15
- layer_norm/ln_bwd_kernels.cuh +0 -534
- layer_norm/ln_fwd_1024.cu +0 -15
.gitignore
CHANGED
|
@@ -1,2 +1,3 @@
|
|
| 1 |
__pycache__/
|
|
|
|
| 2 |
*.pyc
|
|
|
|
| 1 |
__pycache__/
|
| 2 |
+
**/__pycache__/
|
| 3 |
*.pyc
|
CMakeLists.txt
DELETED
|
@@ -1,213 +0,0 @@
|
|
| 1 |
-
cmake_minimum_required(VERSION 3.26)
|
| 2 |
-
project(layer_norm LANGUAGES CXX)
|
| 3 |
-
|
| 4 |
-
set(TARGET_DEVICE "cuda" CACHE STRING "Target device backend for kernel")
|
| 5 |
-
|
| 6 |
-
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
|
| 7 |
-
|
| 8 |
-
include(FetchContent)
|
| 9 |
-
file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
|
| 10 |
-
message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
|
| 11 |
-
|
| 12 |
-
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
|
| 13 |
-
|
| 14 |
-
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101")
|
| 15 |
-
|
| 16 |
-
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
| 17 |
-
|
| 18 |
-
if(DEFINED Python_EXECUTABLE)
|
| 19 |
-
# Allow passing through the interpreter (e.g. from setup.py).
|
| 20 |
-
find_package(Python COMPONENTS Development Development.SABIModule Interpreter)
|
| 21 |
-
if (NOT Python_FOUND)
|
| 22 |
-
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
|
| 23 |
-
endif()
|
| 24 |
-
else()
|
| 25 |
-
find_package(Python REQUIRED COMPONENTS Development Development.SABIModule Interpreter)
|
| 26 |
-
endif()
|
| 27 |
-
|
| 28 |
-
append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
|
| 29 |
-
|
| 30 |
-
find_package(Torch REQUIRED)
|
| 31 |
-
|
| 32 |
-
if (NOT TARGET_DEVICE STREQUAL "cuda" AND
|
| 33 |
-
NOT TARGET_DEVICE STREQUAL "rocm")
|
| 34 |
-
return()
|
| 35 |
-
endif()
|
| 36 |
-
|
| 37 |
-
if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
|
| 38 |
-
CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
|
| 39 |
-
set(CUDA_DEFAULT_KERNEL_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0+PTX")
|
| 40 |
-
else()
|
| 41 |
-
set(CUDA_DEFAULT_KERNEL_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0+PTX")
|
| 42 |
-
endif()
|
| 43 |
-
|
| 44 |
-
if (NOT HIP_FOUND AND CUDA_FOUND)
|
| 45 |
-
set(GPU_LANG "CUDA")
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
elseif(HIP_FOUND)
|
| 50 |
-
set(GPU_LANG "HIP")
|
| 51 |
-
|
| 52 |
-
# Importing torch recognizes and sets up some HIP/ROCm configuration but does
|
| 53 |
-
# not let cmake recognize .hip files. In order to get cmake to understand the
|
| 54 |
-
# .hip extension automatically, HIP must be enabled explicitly.
|
| 55 |
-
enable_language(HIP)
|
| 56 |
-
else()
|
| 57 |
-
message(FATAL_ERROR "Can't find CUDA or HIP installation.")
|
| 58 |
-
endif()
|
| 59 |
-
|
| 60 |
-
if(GPU_LANG STREQUAL "CUDA")
|
| 61 |
-
clear_cuda_arches(CUDA_ARCH_FLAGS)
|
| 62 |
-
extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}")
|
| 63 |
-
message(STATUS "CUDA target architectures: ${CUDA_ARCHS}")
|
| 64 |
-
# Filter the target architectures by the supported supported archs
|
| 65 |
-
# since for some files we will build for all CUDA_ARCHS.
|
| 66 |
-
cuda_archs_loose_intersection(CUDA_ARCHS "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}")
|
| 67 |
-
message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}")
|
| 68 |
-
|
| 69 |
-
if(NVCC_THREADS AND GPU_LANG STREQUAL "CUDA")
|
| 70 |
-
list(APPEND GPU_FLAGS "--threads=${NVCC_THREADS}")
|
| 71 |
-
endif()
|
| 72 |
-
|
| 73 |
-
add_compile_definitions(CUDA_KERNEL)
|
| 74 |
-
elseif(GPU_LANG STREQUAL "HIP")
|
| 75 |
-
set(ROCM_ARCHS "${HIP_SUPPORTED_ARCHS}")
|
| 76 |
-
# TODO: remove this once we can set specific archs per source file set.
|
| 77 |
-
override_gpu_arches(GPU_ARCHES
|
| 78 |
-
${GPU_LANG}
|
| 79 |
-
"${${GPU_LANG}_SUPPORTED_ARCHS}")
|
| 80 |
-
|
| 81 |
-
add_compile_definitions(ROCM_KERNEL)
|
| 82 |
-
else()
|
| 83 |
-
override_gpu_arches(GPU_ARCHES
|
| 84 |
-
${GPU_LANG}
|
| 85 |
-
"${${GPU_LANG}_SUPPORTED_ARCHS}")
|
| 86 |
-
endif()
|
| 87 |
-
|
| 88 |
-
get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG})
|
| 89 |
-
list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS})
|
| 90 |
-
|
| 91 |
-
set(TORCH_layer_norm_SRC
|
| 92 |
-
torch-ext/torch_binding.cpp torch-ext/torch_binding.h
|
| 93 |
-
)
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
list(APPEND SRC "${TORCH_layer_norm_SRC}")
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
set(layer_norm_SRC
|
| 100 |
-
"layer_norm/ln.h"
|
| 101 |
-
"layer_norm/ln_api.cpp"
|
| 102 |
-
"layer_norm/ln_bwd_1024.cu"
|
| 103 |
-
"layer_norm/ln_bwd_1280.cu"
|
| 104 |
-
"layer_norm/ln_bwd_1536.cu"
|
| 105 |
-
"layer_norm/ln_bwd_2048.cu"
|
| 106 |
-
"layer_norm/ln_bwd_256.cu"
|
| 107 |
-
"layer_norm/ln_bwd_2560.cu"
|
| 108 |
-
"layer_norm/ln_bwd_3072.cu"
|
| 109 |
-
"layer_norm/ln_bwd_4096.cu"
|
| 110 |
-
"layer_norm/ln_bwd_512.cu"
|
| 111 |
-
"layer_norm/ln_bwd_5120.cu"
|
| 112 |
-
"layer_norm/ln_bwd_6144.cu"
|
| 113 |
-
"layer_norm/ln_bwd_7168.cu"
|
| 114 |
-
"layer_norm/ln_bwd_768.cu"
|
| 115 |
-
"layer_norm/ln_bwd_8192.cu"
|
| 116 |
-
"layer_norm/ln_bwd_kernels.cuh"
|
| 117 |
-
"layer_norm/ln_fwd_1024.cu"
|
| 118 |
-
"layer_norm/ln_fwd_1280.cu"
|
| 119 |
-
"layer_norm/ln_fwd_1536.cu"
|
| 120 |
-
"layer_norm/ln_fwd_2048.cu"
|
| 121 |
-
"layer_norm/ln_fwd_256.cu"
|
| 122 |
-
"layer_norm/ln_fwd_2560.cu"
|
| 123 |
-
"layer_norm/ln_fwd_3072.cu"
|
| 124 |
-
"layer_norm/ln_fwd_4096.cu"
|
| 125 |
-
"layer_norm/ln_fwd_512.cu"
|
| 126 |
-
"layer_norm/ln_fwd_5120.cu"
|
| 127 |
-
"layer_norm/ln_fwd_6144.cu"
|
| 128 |
-
"layer_norm/ln_fwd_7168.cu"
|
| 129 |
-
"layer_norm/ln_fwd_768.cu"
|
| 130 |
-
"layer_norm/ln_fwd_8192.cu"
|
| 131 |
-
"layer_norm/ln_fwd_kernels.cuh"
|
| 132 |
-
"layer_norm/ln_kernel_traits.h"
|
| 133 |
-
"layer_norm/ln_parallel_bwd_1024.cu"
|
| 134 |
-
"layer_norm/ln_parallel_bwd_1280.cu"
|
| 135 |
-
"layer_norm/ln_parallel_bwd_1536.cu"
|
| 136 |
-
"layer_norm/ln_parallel_bwd_2048.cu"
|
| 137 |
-
"layer_norm/ln_parallel_bwd_256.cu"
|
| 138 |
-
"layer_norm/ln_parallel_bwd_2560.cu"
|
| 139 |
-
"layer_norm/ln_parallel_bwd_3072.cu"
|
| 140 |
-
"layer_norm/ln_parallel_bwd_4096.cu"
|
| 141 |
-
"layer_norm/ln_parallel_bwd_512.cu"
|
| 142 |
-
"layer_norm/ln_parallel_bwd_5120.cu"
|
| 143 |
-
"layer_norm/ln_parallel_bwd_6144.cu"
|
| 144 |
-
"layer_norm/ln_parallel_bwd_7168.cu"
|
| 145 |
-
"layer_norm/ln_parallel_bwd_768.cu"
|
| 146 |
-
"layer_norm/ln_parallel_bwd_8192.cu"
|
| 147 |
-
"layer_norm/ln_parallel_fwd_1024.cu"
|
| 148 |
-
"layer_norm/ln_parallel_fwd_1280.cu"
|
| 149 |
-
"layer_norm/ln_parallel_fwd_1536.cu"
|
| 150 |
-
"layer_norm/ln_parallel_fwd_2048.cu"
|
| 151 |
-
"layer_norm/ln_parallel_fwd_256.cu"
|
| 152 |
-
"layer_norm/ln_parallel_fwd_2560.cu"
|
| 153 |
-
"layer_norm/ln_parallel_fwd_3072.cu"
|
| 154 |
-
"layer_norm/ln_parallel_fwd_4096.cu"
|
| 155 |
-
"layer_norm/ln_parallel_fwd_512.cu"
|
| 156 |
-
"layer_norm/ln_parallel_fwd_5120.cu"
|
| 157 |
-
"layer_norm/ln_parallel_fwd_6144.cu"
|
| 158 |
-
"layer_norm/ln_parallel_fwd_7168.cu"
|
| 159 |
-
"layer_norm/ln_parallel_fwd_768.cu"
|
| 160 |
-
"layer_norm/ln_parallel_fwd_8192.cu"
|
| 161 |
-
"layer_norm/ln_parallel_residual_bwd_kernels.cuh"
|
| 162 |
-
"layer_norm/ln_parallel_residual_fwd_kernels.cuh"
|
| 163 |
-
"layer_norm/ln_utils.cuh"
|
| 164 |
-
"layer_norm/static_switch.h"
|
| 165 |
-
)
|
| 166 |
-
|
| 167 |
-
# TODO: check if CLion support this:
|
| 168 |
-
# https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
|
| 169 |
-
set_source_files_properties(
|
| 170 |
-
${layer_norm_SRC}
|
| 171 |
-
PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/.")
|
| 172 |
-
|
| 173 |
-
if(GPU_LANG STREQUAL "CUDA")
|
| 174 |
-
cuda_archs_loose_intersection(layer_norm_ARCHS "${CUDA_DEFAULT_KERNEL_ARCHS}" "${CUDA_ARCHS}")
|
| 175 |
-
message(STATUS "Capabilities for kernel layer_norm: ${layer_norm_ARCHS}")
|
| 176 |
-
set_gencode_flags_for_srcs(SRCS "${layer_norm_SRC}" CUDA_ARCHS "${layer_norm_ARCHS}")
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
foreach(_KERNEL_SRC ${layer_norm_SRC})
|
| 180 |
-
if(_KERNEL_SRC MATCHES ".*\\.cu$")
|
| 181 |
-
set_property(
|
| 182 |
-
SOURCE ${_KERNEL_SRC}
|
| 183 |
-
APPEND PROPERTY
|
| 184 |
-
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-O3;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;-U__CUDA_NO_BFLOAT16_OPERATORS__;-U__CUDA_NO_BFLOAT16_CONVERSIONS__;-U__CUDA_NO_BFLOAT162_OPERATORS__;-U__CUDA_NO_BFLOAT162_CONVERSIONS__;--expt-relaxed-constexpr;--expt-extended-lambda;--use_fast_math>"
|
| 185 |
-
)
|
| 186 |
-
endif()
|
| 187 |
-
endforeach()
|
| 188 |
-
|
| 189 |
-
foreach(_KERNEL_SRC ${layer_norm_SRC})
|
| 190 |
-
set_property(
|
| 191 |
-
SOURCE ${_KERNEL_SRC}
|
| 192 |
-
APPEND PROPERTY
|
| 193 |
-
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CXX>:-DFLASHATTENTION_DISABLE_PYBIND>"
|
| 194 |
-
)
|
| 195 |
-
endforeach()
|
| 196 |
-
|
| 197 |
-
list(APPEND SRC "${layer_norm_SRC}")
|
| 198 |
-
endif()
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
define_gpu_extension_target(
|
| 202 |
-
_layer_norm_711aa42_dirty
|
| 203 |
-
DESTINATION _layer_norm_711aa42_dirty
|
| 204 |
-
LANGUAGE ${GPU_LANG}
|
| 205 |
-
SOURCES ${SRC}
|
| 206 |
-
COMPILE_FLAGS ${GPU_FLAGS}
|
| 207 |
-
ARCHITECTURES ${GPU_ARCHES}
|
| 208 |
-
#INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
|
| 209 |
-
USE_SABI 3
|
| 210 |
-
WITH_SOABI)
|
| 211 |
-
|
| 212 |
-
target_link_options(_layer_norm_711aa42_dirty PRIVATE -static-libstdc++)
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -2,23 +2,4 @@
|
|
| 2 |
tags:
|
| 3 |
- kernel
|
| 4 |
---
|
| 5 |
-
This CUDA extension implements fused dropout + residual + LayerNorm
|
| 6 |
-
Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm).
|
| 7 |
-
Major changes:
|
| 8 |
-
- Add dropout and residual.
|
| 9 |
-
- Make it work for both pre-norm and post-norm architecture.
|
| 10 |
-
- Support more hidden dimensions (all dimensions divisible by 8, up to 8192).
|
| 11 |
-
- Implement RMSNorm as an option.
|
| 12 |
-
- Support layer norm with parallel residual (e.g., GPT-J, GPT-NeoX, PaLM).
|
| 13 |
-
|
| 14 |
-
If you want to use it for dimensions larger than 8k, please file an issue.
|
| 15 |
-
|
| 16 |
-
This extension has only been tested on A100s.
|
| 17 |
-
|
| 18 |
-
```sh
|
| 19 |
-
cd csrc/layer_norm && pip install .
|
| 20 |
-
```
|
| 21 |
-
|
| 22 |
-
As of 2024-01-05, this extension is no longer used in the FlashAttention repo.
|
| 23 |
-
We've instead switched to a Triton-based
|
| 24 |
-
[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py).
|
|
|
|
| 2 |
tags:
|
| 3 |
- kernel
|
| 4 |
---
|
| 5 |
+
This CUDA extension implements fused dropout + residual + LayerNorm from the [flash-attention](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm) repo.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api.py
DELETED
|
@@ -1,800 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2022, Tri Dao.
|
| 2 |
-
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
|
| 3 |
-
|
| 4 |
-
import dropout_layer_norm
|
| 5 |
-
import torch
|
| 6 |
-
from torch.nn import init
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def maybe_align(x, alignment_in_bytes=16):
|
| 10 |
-
"""Assume that x already has last dim divisible by alignment_in_bytes"""
|
| 11 |
-
# TD [2023-07-04] I'm not 100% sure that clone will align the memory
|
| 12 |
-
# https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
|
| 13 |
-
return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def _dropout_add_layer_norm_forward(
|
| 17 |
-
x0,
|
| 18 |
-
residual,
|
| 19 |
-
gamma,
|
| 20 |
-
beta,
|
| 21 |
-
rowscale,
|
| 22 |
-
colscale,
|
| 23 |
-
dropout_p,
|
| 24 |
-
epsilon,
|
| 25 |
-
residual_in_fp32=False,
|
| 26 |
-
is_rms_norm=False,
|
| 27 |
-
):
|
| 28 |
-
"""Assume that arguments are contiguous and aligned to 16 bytes"""
|
| 29 |
-
hidden_size = gamma.numel()
|
| 30 |
-
x0mat = x0.view((-1, hidden_size))
|
| 31 |
-
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
|
| 32 |
-
rowscale = rowscale.view(-1) if rowscale is not None else None
|
| 33 |
-
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
|
| 34 |
-
x0mat,
|
| 35 |
-
residualmat,
|
| 36 |
-
gamma,
|
| 37 |
-
beta,
|
| 38 |
-
rowscale,
|
| 39 |
-
colscale,
|
| 40 |
-
None,
|
| 41 |
-
None,
|
| 42 |
-
dropout_p,
|
| 43 |
-
epsilon,
|
| 44 |
-
1.0,
|
| 45 |
-
0,
|
| 46 |
-
None,
|
| 47 |
-
residual_in_fp32,
|
| 48 |
-
is_rms_norm,
|
| 49 |
-
)
|
| 50 |
-
# dmask is None if dropout_p == 0.0
|
| 51 |
-
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
|
| 52 |
-
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def _dropout_add_layer_norm_backward(
|
| 56 |
-
dz,
|
| 57 |
-
dx,
|
| 58 |
-
x,
|
| 59 |
-
x0,
|
| 60 |
-
dmask,
|
| 61 |
-
mu,
|
| 62 |
-
rsigma,
|
| 63 |
-
gamma,
|
| 64 |
-
rowscale,
|
| 65 |
-
colscale,
|
| 66 |
-
dropout_p,
|
| 67 |
-
has_residual,
|
| 68 |
-
is_rms_norm=False,
|
| 69 |
-
):
|
| 70 |
-
"""Assume that arguments are contiguous and aligned to 16 bytes
|
| 71 |
-
dx == None means that it was a post-norm architecture
|
| 72 |
-
(x = drop(x0) + residual was not returned in the fwd).
|
| 73 |
-
x0 must not be None if we have colscale.
|
| 74 |
-
"""
|
| 75 |
-
hidden_size = gamma.numel()
|
| 76 |
-
xmat = x.view((-1, hidden_size))
|
| 77 |
-
dzmat = dz.view(xmat.shape)
|
| 78 |
-
dxmat = dx.view(xmat.shape) if dx is not None else None
|
| 79 |
-
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
|
| 80 |
-
rowscale = rowscale.view(-1) if rowscale is not None else None
|
| 81 |
-
if colscale is not None:
|
| 82 |
-
assert x0 is not None, "x0 is required to compute the gradient of colscale"
|
| 83 |
-
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
|
| 84 |
-
dzmat,
|
| 85 |
-
dxmat,
|
| 86 |
-
xmat,
|
| 87 |
-
x0mat,
|
| 88 |
-
dmask,
|
| 89 |
-
mu,
|
| 90 |
-
rsigma,
|
| 91 |
-
gamma,
|
| 92 |
-
rowscale,
|
| 93 |
-
colscale,
|
| 94 |
-
None,
|
| 95 |
-
None,
|
| 96 |
-
dropout_p,
|
| 97 |
-
1.0,
|
| 98 |
-
0,
|
| 99 |
-
has_residual,
|
| 100 |
-
is_rms_norm,
|
| 101 |
-
)
|
| 102 |
-
# dresidualmat is None if not has_residual
|
| 103 |
-
if colscale is None:
|
| 104 |
-
return dx0mat, dresidualmat, dgamma, dbeta
|
| 105 |
-
else:
|
| 106 |
-
dcolscale = rest[0]
|
| 107 |
-
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
def _dropout_add_layer_norm_subset_forward(
|
| 111 |
-
x0,
|
| 112 |
-
residual,
|
| 113 |
-
gamma,
|
| 114 |
-
beta,
|
| 115 |
-
colscale,
|
| 116 |
-
x0_subset,
|
| 117 |
-
out_subset,
|
| 118 |
-
dropout_p,
|
| 119 |
-
epsilon,
|
| 120 |
-
rowscale_const,
|
| 121 |
-
out_numrows,
|
| 122 |
-
residual_in_fp32=False,
|
| 123 |
-
is_rms_norm=False,
|
| 124 |
-
):
|
| 125 |
-
"""Assume that arguments are contiguous and aligned to 16 bytes"""
|
| 126 |
-
hidden_size = gamma.numel()
|
| 127 |
-
x0mat = x0.view((-1, hidden_size))
|
| 128 |
-
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
|
| 129 |
-
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
|
| 130 |
-
out_subset = out_subset.view(-1) if out_subset is not None else None
|
| 131 |
-
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
|
| 132 |
-
x0mat,
|
| 133 |
-
residualmat,
|
| 134 |
-
gamma,
|
| 135 |
-
beta,
|
| 136 |
-
None,
|
| 137 |
-
colscale,
|
| 138 |
-
x0_subset,
|
| 139 |
-
out_subset,
|
| 140 |
-
dropout_p,
|
| 141 |
-
epsilon,
|
| 142 |
-
rowscale_const,
|
| 143 |
-
out_numrows,
|
| 144 |
-
None,
|
| 145 |
-
residual_in_fp32,
|
| 146 |
-
is_rms_norm,
|
| 147 |
-
)
|
| 148 |
-
# dmask is None if dropout_p == 0.0
|
| 149 |
-
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
|
| 150 |
-
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
def _dropout_add_layer_norm_subset_backward(
|
| 154 |
-
dz,
|
| 155 |
-
dx,
|
| 156 |
-
x,
|
| 157 |
-
x0,
|
| 158 |
-
dmask,
|
| 159 |
-
mu,
|
| 160 |
-
rsigma,
|
| 161 |
-
gamma,
|
| 162 |
-
colscale,
|
| 163 |
-
x0_subset,
|
| 164 |
-
out_subset,
|
| 165 |
-
dropout_p,
|
| 166 |
-
rowscale_const,
|
| 167 |
-
x0_numrows,
|
| 168 |
-
has_residual,
|
| 169 |
-
is_rms_norm=False,
|
| 170 |
-
):
|
| 171 |
-
"""Assume that arguments are contiguous and aligned to 16 bytes
|
| 172 |
-
dx == None means that it was a post-norm architecture
|
| 173 |
-
(x = drop(x0) + residual was not returned in the fwd).
|
| 174 |
-
x0 must not be None if we have colscale.
|
| 175 |
-
"""
|
| 176 |
-
hidden_size = gamma.numel()
|
| 177 |
-
xmat = x.view((-1, hidden_size))
|
| 178 |
-
dzmat = dz.view(-1, hidden_size)
|
| 179 |
-
dxmat = dx.view(xmat.shape) if dx is not None else None
|
| 180 |
-
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
|
| 181 |
-
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
|
| 182 |
-
out_subset = out_subset.view(-1) if out_subset is not None else None
|
| 183 |
-
if colscale is not None:
|
| 184 |
-
assert x0 is not None, "x0 is required to compute the gradient of colscale"
|
| 185 |
-
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
|
| 186 |
-
dzmat,
|
| 187 |
-
dxmat,
|
| 188 |
-
xmat,
|
| 189 |
-
x0mat,
|
| 190 |
-
dmask,
|
| 191 |
-
mu,
|
| 192 |
-
rsigma,
|
| 193 |
-
gamma,
|
| 194 |
-
None,
|
| 195 |
-
colscale,
|
| 196 |
-
x0_subset,
|
| 197 |
-
out_subset,
|
| 198 |
-
dropout_p,
|
| 199 |
-
rowscale_const,
|
| 200 |
-
x0_numrows,
|
| 201 |
-
has_residual,
|
| 202 |
-
is_rms_norm,
|
| 203 |
-
)
|
| 204 |
-
# dresidualmat is None if not has_residual
|
| 205 |
-
if colscale is None:
|
| 206 |
-
return dx0mat, dresidualmat, dgamma, dbeta
|
| 207 |
-
else:
|
| 208 |
-
dcolscale = rest[0]
|
| 209 |
-
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
def _dropout_add_layer_norm_parallel_residual_forward(
|
| 213 |
-
x0,
|
| 214 |
-
x1,
|
| 215 |
-
residual,
|
| 216 |
-
gamma0,
|
| 217 |
-
beta0,
|
| 218 |
-
gamma1,
|
| 219 |
-
beta1,
|
| 220 |
-
dropout_p,
|
| 221 |
-
epsilon,
|
| 222 |
-
residual_in_fp32=False,
|
| 223 |
-
is_rms_norm=False,
|
| 224 |
-
):
|
| 225 |
-
"""Assume that arguments are contiguous and aligned to 16 bytes"""
|
| 226 |
-
hidden_size = gamma0.numel()
|
| 227 |
-
x0mat = x0.view((-1, hidden_size))
|
| 228 |
-
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
|
| 229 |
-
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
|
| 230 |
-
(
|
| 231 |
-
z0mat,
|
| 232 |
-
z1mat,
|
| 233 |
-
xmat,
|
| 234 |
-
dmask0,
|
| 235 |
-
dmask1,
|
| 236 |
-
mu,
|
| 237 |
-
rsigma,
|
| 238 |
-
) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
|
| 239 |
-
x0mat,
|
| 240 |
-
x1mat,
|
| 241 |
-
residualmat,
|
| 242 |
-
gamma0,
|
| 243 |
-
beta0,
|
| 244 |
-
gamma1,
|
| 245 |
-
beta1,
|
| 246 |
-
dropout_p,
|
| 247 |
-
epsilon,
|
| 248 |
-
None,
|
| 249 |
-
residual_in_fp32,
|
| 250 |
-
is_rms_norm,
|
| 251 |
-
)
|
| 252 |
-
# dmask0 and dmask1 are None if dropout_p == 0.0
|
| 253 |
-
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
|
| 254 |
-
return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
def _dropout_add_layer_norm_parallel_residual_backward(
|
| 258 |
-
dz0,
|
| 259 |
-
dz1,
|
| 260 |
-
dx,
|
| 261 |
-
x,
|
| 262 |
-
dmask0,
|
| 263 |
-
dmask1,
|
| 264 |
-
mu,
|
| 265 |
-
rsigma,
|
| 266 |
-
gamma0,
|
| 267 |
-
gamma1,
|
| 268 |
-
dropout_p,
|
| 269 |
-
has_x1,
|
| 270 |
-
has_residual,
|
| 271 |
-
is_rms_norm=False,
|
| 272 |
-
):
|
| 273 |
-
"""Assume that arguments are contiguous and aligned to 16 bytes
|
| 274 |
-
dx == None means that it was a post-norm architecture
|
| 275 |
-
(x = drop(x0) + residual was not returned in the fwd).
|
| 276 |
-
"""
|
| 277 |
-
hidden_size = gamma0.numel()
|
| 278 |
-
xmat = x.view((-1, hidden_size))
|
| 279 |
-
dz0mat = dz0.view(xmat.shape)
|
| 280 |
-
dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
|
| 281 |
-
dxmat = dx.view(xmat.shape) if dx is not None else None
|
| 282 |
-
(
|
| 283 |
-
dx0mat,
|
| 284 |
-
dx1mat,
|
| 285 |
-
dresidualmat,
|
| 286 |
-
dgamma0,
|
| 287 |
-
dbeta0,
|
| 288 |
-
dgamma1,
|
| 289 |
-
dbeta1,
|
| 290 |
-
*rest,
|
| 291 |
-
) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
|
| 292 |
-
dz0mat,
|
| 293 |
-
dz1mat,
|
| 294 |
-
dxmat,
|
| 295 |
-
xmat,
|
| 296 |
-
dmask0,
|
| 297 |
-
dmask1,
|
| 298 |
-
mu,
|
| 299 |
-
rsigma,
|
| 300 |
-
gamma0,
|
| 301 |
-
gamma1,
|
| 302 |
-
dropout_p,
|
| 303 |
-
has_x1,
|
| 304 |
-
has_residual,
|
| 305 |
-
is_rms_norm,
|
| 306 |
-
)
|
| 307 |
-
# dresidualmat is None if not has_residual
|
| 308 |
-
return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
class DropoutAddLayerNormFn(torch.autograd.Function):
|
| 312 |
-
@staticmethod
|
| 313 |
-
def forward(
|
| 314 |
-
ctx,
|
| 315 |
-
x0,
|
| 316 |
-
residual,
|
| 317 |
-
gamma,
|
| 318 |
-
beta,
|
| 319 |
-
rowscale,
|
| 320 |
-
colscale,
|
| 321 |
-
dropout_p,
|
| 322 |
-
epsilon,
|
| 323 |
-
residual_in_fp32=False,
|
| 324 |
-
prenorm=False,
|
| 325 |
-
is_rms_norm=False,
|
| 326 |
-
return_dmask=False,
|
| 327 |
-
):
|
| 328 |
-
x0 = maybe_align(x0.contiguous(), 16)
|
| 329 |
-
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
| 330 |
-
gamma = maybe_align(gamma.contiguous(), 16)
|
| 331 |
-
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
|
| 332 |
-
rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
|
| 333 |
-
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
|
| 334 |
-
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
|
| 335 |
-
x0,
|
| 336 |
-
residual,
|
| 337 |
-
gamma,
|
| 338 |
-
beta,
|
| 339 |
-
rowscale,
|
| 340 |
-
colscale,
|
| 341 |
-
dropout_p,
|
| 342 |
-
epsilon,
|
| 343 |
-
residual_in_fp32,
|
| 344 |
-
is_rms_norm,
|
| 345 |
-
)
|
| 346 |
-
# Only need to save x0 if we need to compute gradient wrt colscale
|
| 347 |
-
x0_saved = x0 if colscale is not None else None
|
| 348 |
-
ctx.save_for_backward(
|
| 349 |
-
xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale
|
| 350 |
-
)
|
| 351 |
-
ctx.prenorm = prenorm
|
| 352 |
-
ctx.dropout_p = dropout_p
|
| 353 |
-
ctx.has_residual = residual is not None
|
| 354 |
-
ctx.is_rms_norm = is_rms_norm
|
| 355 |
-
ctx.has_beta = beta is not None
|
| 356 |
-
if not return_dmask:
|
| 357 |
-
return (
|
| 358 |
-
zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape))
|
| 359 |
-
)
|
| 360 |
-
else:
|
| 361 |
-
dmask = (
|
| 362 |
-
dmask.view(x0.shape)
|
| 363 |
-
if dropout_p > 0.0
|
| 364 |
-
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
|
| 365 |
-
)
|
| 366 |
-
ctx.mark_non_differentiable(dmask)
|
| 367 |
-
return (
|
| 368 |
-
(zmat.view(x0.shape), dmask)
|
| 369 |
-
if not prenorm
|
| 370 |
-
else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)
|
| 371 |
-
)
|
| 372 |
-
|
| 373 |
-
@staticmethod
|
| 374 |
-
def backward(ctx, dz, *args):
|
| 375 |
-
# assert dz.is_contiguous()
|
| 376 |
-
dz = maybe_align(dz.contiguous(), 16) # this happens!
|
| 377 |
-
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
|
| 378 |
-
x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
|
| 379 |
-
# x0 is None if colscale is None
|
| 380 |
-
dropout_p = ctx.dropout_p
|
| 381 |
-
has_residual = ctx.has_residual
|
| 382 |
-
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
|
| 383 |
-
dz,
|
| 384 |
-
dx,
|
| 385 |
-
x,
|
| 386 |
-
x0,
|
| 387 |
-
dmask,
|
| 388 |
-
mu,
|
| 389 |
-
rsigma,
|
| 390 |
-
gamma,
|
| 391 |
-
rowscale,
|
| 392 |
-
colscale,
|
| 393 |
-
dropout_p,
|
| 394 |
-
has_residual,
|
| 395 |
-
ctx.is_rms_norm,
|
| 396 |
-
)
|
| 397 |
-
dx0 = dx0mat.view(x.shape)
|
| 398 |
-
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
|
| 399 |
-
dcolscale = rest[0] if colscale is not None else None
|
| 400 |
-
return (
|
| 401 |
-
dx0,
|
| 402 |
-
dresidual,
|
| 403 |
-
dgamma,
|
| 404 |
-
dbeta if ctx.has_beta else None,
|
| 405 |
-
None,
|
| 406 |
-
dcolscale,
|
| 407 |
-
None,
|
| 408 |
-
None,
|
| 409 |
-
None,
|
| 410 |
-
None,
|
| 411 |
-
None,
|
| 412 |
-
None,
|
| 413 |
-
)
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
|
| 417 |
-
@staticmethod
|
| 418 |
-
def forward(
|
| 419 |
-
ctx,
|
| 420 |
-
x0,
|
| 421 |
-
residual,
|
| 422 |
-
gamma,
|
| 423 |
-
beta,
|
| 424 |
-
colscale,
|
| 425 |
-
x0_subset,
|
| 426 |
-
out_subset,
|
| 427 |
-
dropout_p,
|
| 428 |
-
epsilon,
|
| 429 |
-
rowscale_const,
|
| 430 |
-
out_numrows,
|
| 431 |
-
residual_in_fp32=False,
|
| 432 |
-
prenorm=False,
|
| 433 |
-
is_rms_norm=False,
|
| 434 |
-
return_dmask=False,
|
| 435 |
-
):
|
| 436 |
-
x0 = maybe_align(x0.contiguous(), 16)
|
| 437 |
-
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
| 438 |
-
gamma = maybe_align(gamma.contiguous(), 16)
|
| 439 |
-
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
|
| 440 |
-
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
|
| 441 |
-
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
|
| 442 |
-
x0,
|
| 443 |
-
residual,
|
| 444 |
-
gamma,
|
| 445 |
-
beta,
|
| 446 |
-
colscale,
|
| 447 |
-
x0_subset,
|
| 448 |
-
out_subset,
|
| 449 |
-
dropout_p,
|
| 450 |
-
epsilon,
|
| 451 |
-
rowscale_const,
|
| 452 |
-
out_numrows,
|
| 453 |
-
residual_in_fp32,
|
| 454 |
-
is_rms_norm,
|
| 455 |
-
)
|
| 456 |
-
# Only need to save x0 if we need to compute gradient wrt colscale
|
| 457 |
-
x0_saved = x0 if colscale is not None else None
|
| 458 |
-
x_shape = (-1, *x0.shape[1:])
|
| 459 |
-
ctx.save_for_backward(
|
| 460 |
-
xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset
|
| 461 |
-
)
|
| 462 |
-
ctx.prenorm = prenorm
|
| 463 |
-
ctx.dropout_p = dropout_p
|
| 464 |
-
ctx.rowscale_const = rowscale_const
|
| 465 |
-
ctx.x0_numrows = x0.shape[:-1].numel()
|
| 466 |
-
ctx.has_residual = residual is not None
|
| 467 |
-
ctx.is_rms_norm = is_rms_norm
|
| 468 |
-
ctx.has_beta = beta is not None
|
| 469 |
-
z_shape = (-1, *x0.shape[1:])
|
| 470 |
-
if not return_dmask:
|
| 471 |
-
return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape))
|
| 472 |
-
else:
|
| 473 |
-
z = zmat.view(z_shape)
|
| 474 |
-
dmask = (
|
| 475 |
-
dmask.view(x0.shape)
|
| 476 |
-
if dropout_p > 0.0
|
| 477 |
-
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
|
| 478 |
-
)
|
| 479 |
-
ctx.mark_non_differentiable(dmask)
|
| 480 |
-
return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask)
|
| 481 |
-
|
| 482 |
-
@staticmethod
|
| 483 |
-
def backward(ctx, dz, *args):
|
| 484 |
-
# assert dz.is_contiguous()
|
| 485 |
-
dz = maybe_align(dz.contiguous(), 16) # this happens!
|
| 486 |
-
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
|
| 487 |
-
x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
|
| 488 |
-
# x0 is None if colscale is None
|
| 489 |
-
dropout_p = ctx.dropout_p
|
| 490 |
-
has_residual = ctx.has_residual
|
| 491 |
-
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
|
| 492 |
-
dz,
|
| 493 |
-
dx,
|
| 494 |
-
x,
|
| 495 |
-
x0,
|
| 496 |
-
dmask,
|
| 497 |
-
mu,
|
| 498 |
-
rsigma,
|
| 499 |
-
gamma,
|
| 500 |
-
colscale,
|
| 501 |
-
x0_subset,
|
| 502 |
-
out_subset,
|
| 503 |
-
dropout_p,
|
| 504 |
-
ctx.rowscale_const,
|
| 505 |
-
ctx.x0_numrows,
|
| 506 |
-
has_residual,
|
| 507 |
-
ctx.is_rms_norm,
|
| 508 |
-
)
|
| 509 |
-
dx0 = dx0mat.view(-1, *x.shape[1:])
|
| 510 |
-
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
|
| 511 |
-
dcolscale = rest[0] if colscale is not None else None
|
| 512 |
-
return (
|
| 513 |
-
dx0,
|
| 514 |
-
dresidual,
|
| 515 |
-
dgamma,
|
| 516 |
-
dbeta if ctx.has_beta else None,
|
| 517 |
-
dcolscale,
|
| 518 |
-
None,
|
| 519 |
-
None,
|
| 520 |
-
None,
|
| 521 |
-
None,
|
| 522 |
-
None,
|
| 523 |
-
None,
|
| 524 |
-
None,
|
| 525 |
-
None,
|
| 526 |
-
None,
|
| 527 |
-
None,
|
| 528 |
-
)
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
|
| 532 |
-
@staticmethod
|
| 533 |
-
def forward(
|
| 534 |
-
ctx,
|
| 535 |
-
x0,
|
| 536 |
-
x1,
|
| 537 |
-
residual,
|
| 538 |
-
gamma0,
|
| 539 |
-
beta0,
|
| 540 |
-
gamma1,
|
| 541 |
-
beta1,
|
| 542 |
-
dropout_p,
|
| 543 |
-
epsilon,
|
| 544 |
-
residual_in_fp32=False,
|
| 545 |
-
prenorm=False,
|
| 546 |
-
is_rms_norm=False,
|
| 547 |
-
return_dmask=False,
|
| 548 |
-
):
|
| 549 |
-
x0 = maybe_align(x0.contiguous(), 16)
|
| 550 |
-
x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
|
| 551 |
-
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
| 552 |
-
gamma0 = maybe_align(gamma0.contiguous(), 16)
|
| 553 |
-
beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
|
| 554 |
-
gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
|
| 555 |
-
beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
|
| 556 |
-
(
|
| 557 |
-
z0mat,
|
| 558 |
-
z1mat,
|
| 559 |
-
xmat,
|
| 560 |
-
dmask0,
|
| 561 |
-
dmask1,
|
| 562 |
-
mu,
|
| 563 |
-
rsigma,
|
| 564 |
-
) = _dropout_add_layer_norm_parallel_residual_forward(
|
| 565 |
-
x0,
|
| 566 |
-
x1,
|
| 567 |
-
residual,
|
| 568 |
-
gamma0,
|
| 569 |
-
beta0,
|
| 570 |
-
gamma1,
|
| 571 |
-
beta1,
|
| 572 |
-
dropout_p,
|
| 573 |
-
epsilon,
|
| 574 |
-
residual_in_fp32,
|
| 575 |
-
is_rms_norm,
|
| 576 |
-
)
|
| 577 |
-
ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
|
| 578 |
-
ctx.prenorm = prenorm
|
| 579 |
-
ctx.dropout_p = dropout_p
|
| 580 |
-
ctx.has_x1 = x1 is not None
|
| 581 |
-
ctx.has_residual = residual is not None
|
| 582 |
-
ctx.is_rms_norm = is_rms_norm
|
| 583 |
-
ctx.has_beta = beta0 is not None
|
| 584 |
-
z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
|
| 585 |
-
if not return_dmask:
|
| 586 |
-
return z if not prenorm else (*z, xmat.view(x0.shape))
|
| 587 |
-
else:
|
| 588 |
-
dmask0 = (
|
| 589 |
-
dmask0.view(x0.shape)
|
| 590 |
-
if dropout_p > 0.0
|
| 591 |
-
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
|
| 592 |
-
)
|
| 593 |
-
dmask1 = (
|
| 594 |
-
dmask1.view(x0.shape)
|
| 595 |
-
if dropout_p > 0.0 and x1 is not None
|
| 596 |
-
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
|
| 597 |
-
)
|
| 598 |
-
ctx.mark_non_differentiable(dmask0)
|
| 599 |
-
ctx.mark_non_differentiable(dmask1)
|
| 600 |
-
return (
|
| 601 |
-
(*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
|
| 602 |
-
)
|
| 603 |
-
|
| 604 |
-
@staticmethod
|
| 605 |
-
def backward(ctx, dz0, dz1, *args):
|
| 606 |
-
dz0 = maybe_align(dz0.contiguous(), 16) # this happens!
|
| 607 |
-
dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None
|
| 608 |
-
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
|
| 609 |
-
x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
|
| 610 |
-
dropout_p = ctx.dropout_p
|
| 611 |
-
has_x1 = ctx.has_x1
|
| 612 |
-
has_residual = ctx.has_residual
|
| 613 |
-
(
|
| 614 |
-
dx0mat,
|
| 615 |
-
dx1mat,
|
| 616 |
-
dresidualmat,
|
| 617 |
-
dgamma0,
|
| 618 |
-
dbeta0,
|
| 619 |
-
dgamma1,
|
| 620 |
-
dbeta1,
|
| 621 |
-
) = _dropout_add_layer_norm_parallel_residual_backward(
|
| 622 |
-
dz0,
|
| 623 |
-
dz1,
|
| 624 |
-
dx,
|
| 625 |
-
x,
|
| 626 |
-
dmask0,
|
| 627 |
-
dmask1,
|
| 628 |
-
mu,
|
| 629 |
-
rsigma,
|
| 630 |
-
gamma0,
|
| 631 |
-
gamma1,
|
| 632 |
-
dropout_p,
|
| 633 |
-
has_x1,
|
| 634 |
-
has_residual,
|
| 635 |
-
ctx.is_rms_norm,
|
| 636 |
-
)
|
| 637 |
-
dx0 = dx0mat.view(x.shape)
|
| 638 |
-
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
|
| 639 |
-
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
|
| 640 |
-
return (
|
| 641 |
-
dx0,
|
| 642 |
-
dx1,
|
| 643 |
-
dresidual,
|
| 644 |
-
dgamma0,
|
| 645 |
-
dbeta0 if ctx.has_beta else None,
|
| 646 |
-
dgamma1,
|
| 647 |
-
dbeta1 if ctx.has_beta else None,
|
| 648 |
-
None,
|
| 649 |
-
None,
|
| 650 |
-
None,
|
| 651 |
-
None,
|
| 652 |
-
None,
|
| 653 |
-
None,
|
| 654 |
-
)
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
def layer_norm(x, weight, bias, epsilon):
|
| 658 |
-
return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
def dropout_add_layer_norm(
|
| 662 |
-
x0,
|
| 663 |
-
residual,
|
| 664 |
-
weight,
|
| 665 |
-
bias,
|
| 666 |
-
dropout_p,
|
| 667 |
-
epsilon,
|
| 668 |
-
rowscale=None,
|
| 669 |
-
layerscale=None,
|
| 670 |
-
prenorm=False,
|
| 671 |
-
residual_in_fp32=False,
|
| 672 |
-
return_dropout_mask=False,
|
| 673 |
-
):
|
| 674 |
-
"""residual_in_fp32 only has an effect if residual is None.
|
| 675 |
-
Otherwise residual dtype is residual.dtype.
|
| 676 |
-
"""
|
| 677 |
-
return DropoutAddLayerNormFn.apply(
|
| 678 |
-
x0,
|
| 679 |
-
residual,
|
| 680 |
-
weight,
|
| 681 |
-
bias,
|
| 682 |
-
rowscale,
|
| 683 |
-
layerscale,
|
| 684 |
-
dropout_p,
|
| 685 |
-
epsilon,
|
| 686 |
-
residual_in_fp32,
|
| 687 |
-
prenorm,
|
| 688 |
-
False,
|
| 689 |
-
return_dropout_mask,
|
| 690 |
-
)
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
def dropout_add_layer_norm_subset(
|
| 694 |
-
x0,
|
| 695 |
-
residual,
|
| 696 |
-
weight,
|
| 697 |
-
bias,
|
| 698 |
-
dropout_p,
|
| 699 |
-
epsilon,
|
| 700 |
-
layerscale=None,
|
| 701 |
-
x0_subset=None,
|
| 702 |
-
out_subset=None,
|
| 703 |
-
rowscale_const=1.0,
|
| 704 |
-
out_numrows=0,
|
| 705 |
-
prenorm=False,
|
| 706 |
-
residual_in_fp32=False,
|
| 707 |
-
return_dropout_mask=False,
|
| 708 |
-
):
|
| 709 |
-
"""residual_in_fp32 only has an effect if residual is None.
|
| 710 |
-
Otherwise residual dtype is residual.dtype.
|
| 711 |
-
"""
|
| 712 |
-
return DropoutAddLayerNormSubsetFn.apply(
|
| 713 |
-
x0,
|
| 714 |
-
residual,
|
| 715 |
-
weight,
|
| 716 |
-
bias,
|
| 717 |
-
layerscale,
|
| 718 |
-
x0_subset,
|
| 719 |
-
out_subset,
|
| 720 |
-
dropout_p,
|
| 721 |
-
epsilon,
|
| 722 |
-
rowscale_const,
|
| 723 |
-
out_numrows,
|
| 724 |
-
residual_in_fp32,
|
| 725 |
-
prenorm,
|
| 726 |
-
False,
|
| 727 |
-
return_dropout_mask,
|
| 728 |
-
)
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
def dropout_add_layer_norm_parallel_residual(
|
| 732 |
-
x0,
|
| 733 |
-
x1,
|
| 734 |
-
residual,
|
| 735 |
-
weight0,
|
| 736 |
-
bias0,
|
| 737 |
-
weight1,
|
| 738 |
-
bias1,
|
| 739 |
-
dropout_p,
|
| 740 |
-
epsilon,
|
| 741 |
-
prenorm=False,
|
| 742 |
-
residual_in_fp32=False,
|
| 743 |
-
return_dropout_mask=False,
|
| 744 |
-
):
|
| 745 |
-
"""residual_in_fp32 only has an effect if residual is None.
|
| 746 |
-
Otherwise residual dtype is residual.dtype.
|
| 747 |
-
"""
|
| 748 |
-
return DropoutAddLayerNormParallelResidualFn.apply(
|
| 749 |
-
x0,
|
| 750 |
-
x1,
|
| 751 |
-
residual,
|
| 752 |
-
weight0,
|
| 753 |
-
bias0,
|
| 754 |
-
weight1,
|
| 755 |
-
bias1,
|
| 756 |
-
dropout_p,
|
| 757 |
-
epsilon,
|
| 758 |
-
residual_in_fp32,
|
| 759 |
-
prenorm,
|
| 760 |
-
False,
|
| 761 |
-
return_dropout_mask,
|
| 762 |
-
)
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
class DropoutAddLayerNorm(torch.nn.Module):
|
| 766 |
-
def __init__(
|
| 767 |
-
self,
|
| 768 |
-
hidden_size,
|
| 769 |
-
prenorm=False,
|
| 770 |
-
p=0.0,
|
| 771 |
-
eps=1e-5,
|
| 772 |
-
residual_in_fp32=False,
|
| 773 |
-
device=None,
|
| 774 |
-
dtype=None,
|
| 775 |
-
):
|
| 776 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
| 777 |
-
super().__init__()
|
| 778 |
-
self.prenorm = prenorm
|
| 779 |
-
self.p = p
|
| 780 |
-
self.eps = eps
|
| 781 |
-
self.residual_in_fp32 = residual_in_fp32
|
| 782 |
-
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
| 783 |
-
self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
| 784 |
-
self.reset_parameters()
|
| 785 |
-
|
| 786 |
-
def reset_parameters(self):
|
| 787 |
-
init.ones_(self.weight)
|
| 788 |
-
init.zeros_(self.bias)
|
| 789 |
-
|
| 790 |
-
def forward(self, x0, residual=None):
|
| 791 |
-
return dropout_add_layer_norm(
|
| 792 |
-
x0,
|
| 793 |
-
residual,
|
| 794 |
-
self.weight,
|
| 795 |
-
self.bias,
|
| 796 |
-
self.p if self.training else 0.0,
|
| 797 |
-
self.eps,
|
| 798 |
-
prenorm=self.prenorm,
|
| 799 |
-
residual_in_fp32=self.residual_in_fp32,
|
| 800 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/{torch28-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so → torch27-cxx11-cu118-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:790cd814bbfcaf7ff83b5c68bcb91091a67f34e92b9a2494e2856462e71a3141
|
| 3 |
+
size 716945944
|
build/torch27-cxx11-cu118-x86_64-linux/layer_norm/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _layer_norm_f3fd6bf
|
| 3 |
+
ops = torch.ops._layer_norm_f3fd6bf
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_layer_norm_f3fd6bf::{op_name}"
|
build/{torch27-cxx11-cu118-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b17984ef79fc9d6427c8efe0a8cc8f1f6e2777f9a8641b86556b7bb2359626ab
|
| 3 |
+
size 712024816
|
build/torch27-cxx11-cu126-x86_64-linux/layer_norm/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _layer_norm_f3fd6bf
|
| 3 |
+
ops = torch.ops._layer_norm_f3fd6bf
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_layer_norm_f3fd6bf::{op_name}"
|
build/torch27-cxx11-cu128-x86_64-linux/layer_norm/{_layer_norm_f622ea1_dirty.abi3.so → _layer_norm_f3fd6bf.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7629b13b777a390df75374fc60d85311679a56a5bbd9969e138822e5c0fe2b1e
|
| 3 |
+
size 1231333360
|
build/torch27-cxx11-cu128-x86_64-linux/layer_norm/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _layer_norm_f3fd6bf
|
| 3 |
+
ops = torch.ops._layer_norm_f3fd6bf
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_layer_norm_f3fd6bf::{op_name}"
|
build/{torch27-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so → torch28-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d6ffc9d5651e8de6440f2d4f58018a5ded07634582ae03eec5b9edf428f613a6
|
| 3 |
+
size 712024904
|
build/torch28-cxx11-cu126-x86_64-linux/layer_norm/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _layer_norm_f3fd6bf
|
| 3 |
+
ops = torch.ops._layer_norm_f3fd6bf
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_layer_norm_f3fd6bf::{op_name}"
|
build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:df39795e047e962019cbecbb11f93d8ee1fcfb49ed8326f2edc267bc0d90da08
|
| 3 |
+
size 1231337936
|
build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:d51ec6b6da7095cf5fc18493eb4b0b1c20485f01dff4b38370979ea3d0a9dd60
|
| 3 |
-
size 1231337968
|
|
|
|
|
|
|
|
|
|
|
|
build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _layer_norm_f3fd6bf
|
| 3 |
+
ops = torch.ops._layer_norm_f3fd6bf
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_layer_norm_f3fd6bf::{op_name}"
|
build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dfef6947945f8f126a284c6a8ab861e180a5e628992eeb0b4b7c7914c50a59c2
|
| 3 |
+
size 1283037344
|
build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:9080934ece3b5e09db6178b1baa15b8baf9f6873e234a951a2122071e1190fba
|
| 3 |
-
size 1283037376
|
|
|
|
|
|
|
|
|
|
|
|
build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _layer_norm_f3fd6bf
|
| 3 |
+
ops = torch.ops._layer_norm_f3fd6bf
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_layer_norm_f3fd6bf::{op_name}"
|
{torch-ext → build/torch29-cxx11-cu126-x86_64-linux}/layer_norm/__init__.py
RENAMED
|
File without changes
|
build/torch29-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2bdb57c0889ade2fc574156873c1d4b543796f2e8ad6a894be82ee2785459c9b
|
| 3 |
+
size 712029160
|
build/torch29-cxx11-cu126-x86_64-linux/layer_norm/_ops.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from . import _layer_norm_f3fd6bf
|
| 3 |
+
ops = torch.ops._layer_norm_f3fd6bf
|
| 4 |
+
|
| 5 |
+
def add_op_namespace_prefix(op_name: str):
|
| 6 |
+
"""
|
| 7 |
+
Prefix op by namespace.
|
| 8 |
+
"""
|
| 9 |
+
return f"_layer_norm_f3fd6bf::{op_name}"
|
{torch-ext → build/torch29-cxx11-cu126-x86_64-linux}/layer_norm/layers.py
RENAMED
|
File without changes
|
build/torch29-cxx11-cu128-x86_64-linux/layer_norm/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from ._ops import ops
|
| 5 |
+
|
| 6 |
+
from . import layers
|
| 7 |
+
|
| 8 |
+
def dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm):
|
| 9 |
+
return ops.dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm)
|
| 10 |
+
|
| 11 |
+
def dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm):
|
| 12 |
+
return ops.dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm)
|
| 13 |
+
|
| 14 |
+
def dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm):
|
| 15 |
+
return ops.dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm)
|
| 16 |
+
|
| 17 |
+
def dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm):
|
| 18 |
+
return ops.dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm)
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"layers",
|
| 22 |
+
"dropout_add_ln_fwd",
|
| 23 |
+
"dropout_add_ln_bwd",
|
| 24 |
+
"dropout_add_ln_parallel_residual_fwd",
|
| 25 |
+
"dropout_add_ln_parallel_residual_bwd",
|
| 26 |
+
]
|
build/torch29-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:03e6e7ecbf276b306d89607100f78f2ce8b3385a77594676dbf0daabdce26fc7
|
| 3 |
+
size 1231338080
|
build/torch29-cxx11-cu128-x86_64-linux/layer_norm/_ops.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from . import _layer_norm_f3fd6bf
|
| 3 |
+
ops = torch.ops._layer_norm_f3fd6bf
|
| 4 |
+
|
| 5 |
+
def add_op_namespace_prefix(op_name: str):
|
| 6 |
+
"""
|
| 7 |
+
Prefix op by namespace.
|
| 8 |
+
"""
|
| 9 |
+
return f"_layer_norm_f3fd6bf::{op_name}"
|
build/torch29-cxx11-cu128-x86_64-linux/layer_norm/layers.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from ._ops import ops
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LayerNorm(nn.Module):
|
| 8 |
+
weight: torch.Tensor
|
| 9 |
+
variance_epsilon: float
|
| 10 |
+
|
| 11 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 12 |
+
output = ops.dropout_add_ln_fwd(
|
| 13 |
+
hidden_states.view(-1, hidden_states.shape[-1]),
|
| 14 |
+
gamma = self.weight,
|
| 15 |
+
beta = None,
|
| 16 |
+
rowscale = None,
|
| 17 |
+
colscale = None,
|
| 18 |
+
x0_subset = None,
|
| 19 |
+
z_subset = None,
|
| 20 |
+
dropout_p = 0,
|
| 21 |
+
epsilon = self.variance_epsilon,
|
| 22 |
+
rowscale_const = 1.0,
|
| 23 |
+
z_numrows = hidden_states.shape[1],
|
| 24 |
+
gen = None,
|
| 25 |
+
residual_in_fp32 = False,
|
| 26 |
+
is_rms_norm = False,
|
| 27 |
+
)
|
| 28 |
+
return output[0].view(hidden_states.shape)
|
| 29 |
+
|
| 30 |
+
class LlamaRMSNorm(nn.Module):
|
| 31 |
+
weight: torch.Tensor
|
| 32 |
+
variance_epsilon: float
|
| 33 |
+
|
| 34 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 35 |
+
output = ops.dropout_add_ln_fwd(
|
| 36 |
+
hidden_states.view(-1, hidden_states.shape[-1]),
|
| 37 |
+
gamma = self.weight,
|
| 38 |
+
beta = None,
|
| 39 |
+
rowscale = None,
|
| 40 |
+
colscale = None,
|
| 41 |
+
x0_subset = None,
|
| 42 |
+
z_subset = None,
|
| 43 |
+
dropout_p = 0,
|
| 44 |
+
epsilon = self.variance_epsilon,
|
| 45 |
+
rowscale_const = 1.0,
|
| 46 |
+
z_numrows = hidden_states.shape[1],
|
| 47 |
+
gen = None,
|
| 48 |
+
residual_in_fp32 = False,
|
| 49 |
+
is_rms_norm = True,
|
| 50 |
+
)
|
| 51 |
+
return output[0].view(hidden_states.shape)
|
build/torch29-cxx11-cu130-x86_64-linux/layer_norm/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from ._ops import ops
|
| 5 |
+
|
| 6 |
+
from . import layers
|
| 7 |
+
|
| 8 |
+
def dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm):
|
| 9 |
+
return ops.dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm)
|
| 10 |
+
|
| 11 |
+
def dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm):
|
| 12 |
+
return ops.dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm)
|
| 13 |
+
|
| 14 |
+
def dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm):
|
| 15 |
+
return ops.dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm)
|
| 16 |
+
|
| 17 |
+
def dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm):
|
| 18 |
+
return ops.dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm)
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"layers",
|
| 22 |
+
"dropout_add_ln_fwd",
|
| 23 |
+
"dropout_add_ln_bwd",
|
| 24 |
+
"dropout_add_ln_parallel_residual_fwd",
|
| 25 |
+
"dropout_add_ln_parallel_residual_bwd",
|
| 26 |
+
]
|
build/torch29-cxx11-cu130-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:322e2d8fc69447be95ef7b6e85267e8769f1284419baa606732a77b1980a834d
|
| 3 |
+
size 1238333264
|
build/torch29-cxx11-cu130-x86_64-linux/layer_norm/_ops.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from . import _layer_norm_f3fd6bf
|
| 3 |
+
ops = torch.ops._layer_norm_f3fd6bf
|
| 4 |
+
|
| 5 |
+
def add_op_namespace_prefix(op_name: str):
|
| 6 |
+
"""
|
| 7 |
+
Prefix op by namespace.
|
| 8 |
+
"""
|
| 9 |
+
return f"_layer_norm_f3fd6bf::{op_name}"
|
build/torch29-cxx11-cu130-x86_64-linux/layer_norm/layers.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from ._ops import ops
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LayerNorm(nn.Module):
|
| 8 |
+
weight: torch.Tensor
|
| 9 |
+
variance_epsilon: float
|
| 10 |
+
|
| 11 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 12 |
+
output = ops.dropout_add_ln_fwd(
|
| 13 |
+
hidden_states.view(-1, hidden_states.shape[-1]),
|
| 14 |
+
gamma = self.weight,
|
| 15 |
+
beta = None,
|
| 16 |
+
rowscale = None,
|
| 17 |
+
colscale = None,
|
| 18 |
+
x0_subset = None,
|
| 19 |
+
z_subset = None,
|
| 20 |
+
dropout_p = 0,
|
| 21 |
+
epsilon = self.variance_epsilon,
|
| 22 |
+
rowscale_const = 1.0,
|
| 23 |
+
z_numrows = hidden_states.shape[1],
|
| 24 |
+
gen = None,
|
| 25 |
+
residual_in_fp32 = False,
|
| 26 |
+
is_rms_norm = False,
|
| 27 |
+
)
|
| 28 |
+
return output[0].view(hidden_states.shape)
|
| 29 |
+
|
| 30 |
+
class LlamaRMSNorm(nn.Module):
|
| 31 |
+
weight: torch.Tensor
|
| 32 |
+
variance_epsilon: float
|
| 33 |
+
|
| 34 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 35 |
+
output = ops.dropout_add_ln_fwd(
|
| 36 |
+
hidden_states.view(-1, hidden_states.shape[-1]),
|
| 37 |
+
gamma = self.weight,
|
| 38 |
+
beta = None,
|
| 39 |
+
rowscale = None,
|
| 40 |
+
colscale = None,
|
| 41 |
+
x0_subset = None,
|
| 42 |
+
z_subset = None,
|
| 43 |
+
dropout_p = 0,
|
| 44 |
+
epsilon = self.variance_epsilon,
|
| 45 |
+
rowscale_const = 1.0,
|
| 46 |
+
z_numrows = hidden_states.shape[1],
|
| 47 |
+
gen = None,
|
| 48 |
+
residual_in_fp32 = False,
|
| 49 |
+
is_rms_norm = True,
|
| 50 |
+
)
|
| 51 |
+
return output[0].view(hidden_states.shape)
|
cmake/hipify.py
DELETED
|
@@ -1,76 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
|
| 4 |
-
# From vLLM: https://github.com/vllm-project/vllm/blob/main/cmake/hipify.py
|
| 5 |
-
|
| 6 |
-
#
|
| 7 |
-
# A command line tool for running pytorch's hipify preprocessor on CUDA
|
| 8 |
-
# source files.
|
| 9 |
-
#
|
| 10 |
-
# See https://github.com/ROCm/hipify_torch
|
| 11 |
-
# and <torch install dir>/utils/hipify/hipify_python.py
|
| 12 |
-
#
|
| 13 |
-
|
| 14 |
-
import argparse
|
| 15 |
-
import os
|
| 16 |
-
import shutil
|
| 17 |
-
|
| 18 |
-
from torch.utils.hipify.hipify_python import hipify
|
| 19 |
-
|
| 20 |
-
if __name__ == '__main__':
|
| 21 |
-
parser = argparse.ArgumentParser()
|
| 22 |
-
|
| 23 |
-
# Project directory where all the source + include files live.
|
| 24 |
-
parser.add_argument(
|
| 25 |
-
"-p",
|
| 26 |
-
"--project_dir",
|
| 27 |
-
help="The project directory.",
|
| 28 |
-
)
|
| 29 |
-
|
| 30 |
-
# Directory where hipified files are written.
|
| 31 |
-
parser.add_argument(
|
| 32 |
-
"-o",
|
| 33 |
-
"--output_dir",
|
| 34 |
-
help="The output directory.",
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
# Source files to convert.
|
| 38 |
-
parser.add_argument("sources",
|
| 39 |
-
help="Source files to hipify.",
|
| 40 |
-
nargs="*",
|
| 41 |
-
default=[])
|
| 42 |
-
|
| 43 |
-
args = parser.parse_args()
|
| 44 |
-
|
| 45 |
-
# Limit include scope to project_dir only
|
| 46 |
-
includes = [os.path.join(args.project_dir, '*')]
|
| 47 |
-
|
| 48 |
-
# Get absolute path for all source files.
|
| 49 |
-
extra_files = [os.path.abspath(s) for s in args.sources]
|
| 50 |
-
|
| 51 |
-
# Copy sources from project directory to output directory.
|
| 52 |
-
# The directory might already exist to hold object files so we ignore that.
|
| 53 |
-
shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True)
|
| 54 |
-
|
| 55 |
-
hipify_result = hipify(project_directory=args.project_dir,
|
| 56 |
-
output_directory=args.output_dir,
|
| 57 |
-
header_include_dirs=[],
|
| 58 |
-
includes=includes,
|
| 59 |
-
extra_files=extra_files,
|
| 60 |
-
show_detailed=True,
|
| 61 |
-
is_pytorch_extension=True,
|
| 62 |
-
hipify_extra_files_only=True)
|
| 63 |
-
|
| 64 |
-
hipified_sources = []
|
| 65 |
-
for source in args.sources:
|
| 66 |
-
s_abs = os.path.abspath(source)
|
| 67 |
-
hipified_s_abs = (hipify_result[s_abs].hipified_path if
|
| 68 |
-
(s_abs in hipify_result
|
| 69 |
-
and hipify_result[s_abs].hipified_path is not None)
|
| 70 |
-
else s_abs)
|
| 71 |
-
hipified_sources.append(hipified_s_abs)
|
| 72 |
-
|
| 73 |
-
assert (len(hipified_sources) == len(args.sources))
|
| 74 |
-
|
| 75 |
-
# Print hipified source files.
|
| 76 |
-
print("\n".join(hipified_sources))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cmake/utils.cmake
DELETED
|
@@ -1,545 +0,0 @@
|
|
| 1 |
-
# Vendored from vLLM:
|
| 2 |
-
#
|
| 3 |
-
# https://github.com/vllm-project/vllm/blob/main/cmake/utils.cmake
|
| 4 |
-
#
|
| 5 |
-
# Attempt to find the python package that uses the same python executable as
|
| 6 |
-
# `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`.
|
| 7 |
-
#
|
| 8 |
-
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
|
| 9 |
-
file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
|
| 10 |
-
set(Python_EXECUTABLE ${EXECUTABLE})
|
| 11 |
-
find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
|
| 12 |
-
if (NOT Python_FOUND)
|
| 13 |
-
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
|
| 14 |
-
endif()
|
| 15 |
-
set(_VER "${Python_VERSION_MAJOR}.${Python_VERSION_MINOR}")
|
| 16 |
-
set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN})
|
| 17 |
-
if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST)
|
| 18 |
-
message(FATAL_ERROR
|
| 19 |
-
"Python version (${_VER}) is not one of the supported versions: "
|
| 20 |
-
"${_SUPPORTED_VERSIONS_LIST}.")
|
| 21 |
-
endif()
|
| 22 |
-
message(STATUS "Found python matching: ${EXECUTABLE}.")
|
| 23 |
-
endmacro()
|
| 24 |
-
|
| 25 |
-
#
|
| 26 |
-
# Run `EXPR` in python. The standard output of python is stored in `OUT` and
|
| 27 |
-
# has trailing whitespace stripped. If an error is encountered when running
|
| 28 |
-
# python, a fatal message `ERR_MSG` is issued.
|
| 29 |
-
#
|
| 30 |
-
function (run_python OUT EXPR ERR_MSG)
|
| 31 |
-
execute_process(
|
| 32 |
-
COMMAND
|
| 33 |
-
"${Python_EXECUTABLE}" "-c" "${EXPR}"
|
| 34 |
-
OUTPUT_VARIABLE PYTHON_OUT
|
| 35 |
-
RESULT_VARIABLE PYTHON_ERROR_CODE
|
| 36 |
-
ERROR_VARIABLE PYTHON_STDERR
|
| 37 |
-
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
| 38 |
-
|
| 39 |
-
if(NOT PYTHON_ERROR_CODE EQUAL 0)
|
| 40 |
-
message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}")
|
| 41 |
-
endif()
|
| 42 |
-
set(${OUT} ${PYTHON_OUT} PARENT_SCOPE)
|
| 43 |
-
endfunction()
|
| 44 |
-
|
| 45 |
-
# Run `EXPR` in python after importing `PKG`. Use the result of this to extend
|
| 46 |
-
# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported.
|
| 47 |
-
macro (append_cmake_prefix_path PKG EXPR)
|
| 48 |
-
run_python(_PREFIX_PATH
|
| 49 |
-
"import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path")
|
| 50 |
-
list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH})
|
| 51 |
-
endmacro()
|
| 52 |
-
|
| 53 |
-
#
|
| 54 |
-
# Add a target named `hipify${NAME}` that runs the hipify preprocessor on a set
|
| 55 |
-
# of CUDA source files. The names of the corresponding "hipified" sources are
|
| 56 |
-
# stored in `OUT_SRCS`.
|
| 57 |
-
#
|
| 58 |
-
function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
|
| 59 |
-
#
|
| 60 |
-
# Split into C++ and non-C++ (i.e. CUDA) sources.
|
| 61 |
-
#
|
| 62 |
-
set(NODUP_SRCS ${ORIG_SRCS})
|
| 63 |
-
list(REMOVE_DUPLICATES NODUP_SRCS)
|
| 64 |
-
set(SRCS ${NODUP_SRCS})
|
| 65 |
-
set(CXX_SRCS ${NODUP_SRCS})
|
| 66 |
-
list(FILTER SRCS INCLUDE REGEX "\.cu$")
|
| 67 |
-
list(FILTER CXX_SRCS EXCLUDE REGEX "\.cu$")
|
| 68 |
-
|
| 69 |
-
#
|
| 70 |
-
# Generate ROCm/HIP source file names from CUDA file names.
|
| 71 |
-
# Since HIP files are generated code, they will appear in the build area
|
| 72 |
-
# `CMAKE_CURRENT_BINARY_DIR` directory rather than the original csrc dir.
|
| 73 |
-
#
|
| 74 |
-
set(HIP_SRCS)
|
| 75 |
-
foreach (SRC ${SRCS})
|
| 76 |
-
get_source_file_property(include_dirs "${SRC}" INCLUDE_DIRECTORIES)
|
| 77 |
-
string(REGEX REPLACE "\.cu$" "\.hip" SRC ${SRC})
|
| 78 |
-
string(REGEX REPLACE "cuda" "hip" SRC ${SRC})
|
| 79 |
-
|
| 80 |
-
if(include_dirs)
|
| 81 |
-
# Copy over include directories from the original CUDA file.
|
| 82 |
-
set_source_files_properties(
|
| 83 |
-
${SRC}
|
| 84 |
-
PROPERTIES INCLUDE_DIRECTORIES "${include_dirs}")
|
| 85 |
-
endif()
|
| 86 |
-
|
| 87 |
-
list(APPEND HIP_SRCS "${CMAKE_CURRENT_BINARY_DIR}/${SRC}")
|
| 88 |
-
endforeach()
|
| 89 |
-
|
| 90 |
-
add_custom_target(
|
| 91 |
-
hipify${NAME}
|
| 92 |
-
COMMAND "${Python_EXECUTABLE}" ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR} -o ${CMAKE_CURRENT_BINARY_DIR} ${SRCS}
|
| 93 |
-
DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS}
|
| 94 |
-
BYPRODUCTS ${HIP_SRCS}
|
| 95 |
-
COMMENT "Running hipify on ${NAME} extension source files.")
|
| 96 |
-
|
| 97 |
-
# Swap out original extension sources with hipified sources.
|
| 98 |
-
list(APPEND HIP_SRCS ${CXX_SRCS})
|
| 99 |
-
set(${OUT_SRCS} ${HIP_SRCS} PARENT_SCOPE)
|
| 100 |
-
endfunction()
|
| 101 |
-
|
| 102 |
-
#
|
| 103 |
-
# Get additional GPU compiler flags from torch.
|
| 104 |
-
#
|
| 105 |
-
function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
|
| 106 |
-
if (${GPU_LANG} STREQUAL "CUDA")
|
| 107 |
-
#
|
| 108 |
-
# Get common NVCC flags from torch.
|
| 109 |
-
#
|
| 110 |
-
run_python(GPU_FLAGS
|
| 111 |
-
"from torch.utils.cpp_extension import COMMON_NVCC_FLAGS; print(';'.join(COMMON_NVCC_FLAGS))"
|
| 112 |
-
"Failed to determine torch nvcc compiler flags")
|
| 113 |
-
|
| 114 |
-
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
|
| 115 |
-
list(APPEND GPU_FLAGS "-DENABLE_FP8")
|
| 116 |
-
list(REMOVE_ITEM GPU_FLAGS
|
| 117 |
-
"-D__CUDA_NO_HALF_OPERATORS__"
|
| 118 |
-
"-D__CUDA_NO_HALF_CONVERSIONS__"
|
| 119 |
-
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
|
| 120 |
-
"-D__CUDA_NO_HALF2_OPERATORS__")
|
| 121 |
-
endif()
|
| 122 |
-
|
| 123 |
-
elseif(${GPU_LANG} STREQUAL "HIP")
|
| 124 |
-
#
|
| 125 |
-
# Get common HIP/HIPCC flags from torch.
|
| 126 |
-
#
|
| 127 |
-
run_python(GPU_FLAGS
|
| 128 |
-
"import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))"
|
| 129 |
-
"Failed to determine torch nvcc compiler flags")
|
| 130 |
-
|
| 131 |
-
list(APPEND GPU_FLAGS
|
| 132 |
-
"-DUSE_ROCM"
|
| 133 |
-
"-DENABLE_FP8"
|
| 134 |
-
"-U__HIP_NO_HALF_CONVERSIONS__"
|
| 135 |
-
"-U__HIP_NO_HALF_OPERATORS__"
|
| 136 |
-
"-fno-gpu-rdc")
|
| 137 |
-
|
| 138 |
-
endif()
|
| 139 |
-
set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE)
|
| 140 |
-
endfunction()
|
| 141 |
-
|
| 142 |
-
# Macro for converting a `gencode` version number to a cmake version number.
|
| 143 |
-
macro(string_to_ver OUT_VER IN_STR)
|
| 144 |
-
string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
|
| 145 |
-
endmacro()
|
| 146 |
-
|
| 147 |
-
#
|
| 148 |
-
# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
|
| 149 |
-
# `CUDA_ARCH_FLAGS`.
|
| 150 |
-
#
|
| 151 |
-
# Example:
|
| 152 |
-
# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
|
| 153 |
-
# clear_cuda_arches(CUDA_ARCH_FLAGS)
|
| 154 |
-
# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
|
| 155 |
-
# CMAKE_CUDA_FLAGS="-Wall"
|
| 156 |
-
#
|
| 157 |
-
macro(clear_cuda_arches CUDA_ARCH_FLAGS)
|
| 158 |
-
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
|
| 159 |
-
string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
|
| 160 |
-
${CMAKE_CUDA_FLAGS})
|
| 161 |
-
|
| 162 |
-
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
|
| 163 |
-
# and passed back via the `CUDA_ARCHITECTURES` property.
|
| 164 |
-
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
|
| 165 |
-
${CMAKE_CUDA_FLAGS})
|
| 166 |
-
endmacro()
|
| 167 |
-
|
| 168 |
-
#
|
| 169 |
-
# Extract unique CUDA architectures from a list of compute capabilities codes in
|
| 170 |
-
# the form `<major><minor>[<letter>]`, convert them to the form sort
|
| 171 |
-
# `<major>.<minor>`, dedupes them and then sorts them in ascending order and
|
| 172 |
-
# stores them in `OUT_ARCHES`.
|
| 173 |
-
#
|
| 174 |
-
# Example:
|
| 175 |
-
# CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a"
|
| 176 |
-
# extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS)
|
| 177 |
-
# OUT_ARCHES="7.5;...;9.0"
|
| 178 |
-
function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS)
|
| 179 |
-
set(_CUDA_ARCHES)
|
| 180 |
-
foreach(_ARCH ${CUDA_ARCH_FLAGS})
|
| 181 |
-
string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
|
| 182 |
-
if (_COMPUTE)
|
| 183 |
-
set(_COMPUTE ${CMAKE_MATCH_1})
|
| 184 |
-
endif()
|
| 185 |
-
|
| 186 |
-
string_to_ver(_COMPUTE_VER ${_COMPUTE})
|
| 187 |
-
list(APPEND _CUDA_ARCHES ${_COMPUTE_VER})
|
| 188 |
-
endforeach()
|
| 189 |
-
|
| 190 |
-
list(REMOVE_DUPLICATES _CUDA_ARCHES)
|
| 191 |
-
list(SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING)
|
| 192 |
-
set(${OUT_ARCHES} ${_CUDA_ARCHES} PARENT_SCOPE)
|
| 193 |
-
endfunction()
|
| 194 |
-
|
| 195 |
-
#
|
| 196 |
-
# For a specific file set the `-gencode` flag in compile options conditionally
|
| 197 |
-
# for the CUDA language.
|
| 198 |
-
#
|
| 199 |
-
# Example:
|
| 200 |
-
# set_gencode_flag_for_srcs(
|
| 201 |
-
# SRCS "foo.cu"
|
| 202 |
-
# ARCH "compute_75"
|
| 203 |
-
# CODE "sm_75")
|
| 204 |
-
# adds: "-gencode arch=compute_75,code=sm_75" to the compile options for
|
| 205 |
-
# `foo.cu` (only for the CUDA language).
|
| 206 |
-
#
|
| 207 |
-
macro(set_gencode_flag_for_srcs)
|
| 208 |
-
set(options)
|
| 209 |
-
set(oneValueArgs ARCH CODE)
|
| 210 |
-
set(multiValueArgs SRCS)
|
| 211 |
-
cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
|
| 212 |
-
"${multiValueArgs}" ${ARGN} )
|
| 213 |
-
set(_FLAG -gencode arch=${arg_ARCH},code=${arg_CODE})
|
| 214 |
-
set_property(
|
| 215 |
-
SOURCE ${arg_SRCS}
|
| 216 |
-
APPEND PROPERTY
|
| 217 |
-
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:${_FLAG}>"
|
| 218 |
-
)
|
| 219 |
-
|
| 220 |
-
message(DEBUG "Setting gencode flag for ${arg_SRCS}: ${_FLAG}")
|
| 221 |
-
endmacro(set_gencode_flag_for_srcs)
|
| 222 |
-
|
| 223 |
-
#
|
| 224 |
-
# For a list of source files set the `-gencode` flags in the files specific
|
| 225 |
-
# compile options (specifically for the CUDA language).
|
| 226 |
-
#
|
| 227 |
-
# arguments are:
|
| 228 |
-
# SRCS: list of source files
|
| 229 |
-
# CUDA_ARCHS: list of CUDA architectures in the form `<major>.<minor>[letter]`
|
| 230 |
-
# BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built
|
| 231 |
-
# for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS
|
| 232 |
-
# that is larger than BUILD_PTX_FOR_ARCH.
|
| 233 |
-
#
|
| 234 |
-
macro(set_gencode_flags_for_srcs)
|
| 235 |
-
set(options)
|
| 236 |
-
set(oneValueArgs BUILD_PTX_FOR_ARCH)
|
| 237 |
-
set(multiValueArgs SRCS CUDA_ARCHS)
|
| 238 |
-
cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
|
| 239 |
-
"${multiValueArgs}" ${ARGN} )
|
| 240 |
-
|
| 241 |
-
foreach(_ARCH ${arg_CUDA_ARCHS})
|
| 242 |
-
# handle +PTX suffix: generate both sm and ptx codes if requested
|
| 243 |
-
string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
|
| 244 |
-
if(NOT _HAS_PTX EQUAL -1)
|
| 245 |
-
string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
|
| 246 |
-
string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
|
| 247 |
-
set_gencode_flag_for_srcs(
|
| 248 |
-
SRCS ${arg_SRCS}
|
| 249 |
-
ARCH "compute_${_STRIPPED_ARCH}"
|
| 250 |
-
CODE "sm_${_STRIPPED_ARCH}")
|
| 251 |
-
set_gencode_flag_for_srcs(
|
| 252 |
-
SRCS ${arg_SRCS}
|
| 253 |
-
ARCH "compute_${_STRIPPED_ARCH}"
|
| 254 |
-
CODE "compute_${_STRIPPED_ARCH}")
|
| 255 |
-
else()
|
| 256 |
-
string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
|
| 257 |
-
set_gencode_flag_for_srcs(
|
| 258 |
-
SRCS ${arg_SRCS}
|
| 259 |
-
ARCH "compute_${_STRIPPED_ARCH}"
|
| 260 |
-
CODE "sm_${_STRIPPED_ARCH}")
|
| 261 |
-
endif()
|
| 262 |
-
endforeach()
|
| 263 |
-
|
| 264 |
-
if (${arg_BUILD_PTX_FOR_ARCH})
|
| 265 |
-
list(SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
| 266 |
-
list(GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH)
|
| 267 |
-
if (_HIGHEST_ARCH VERSION_GREATER_EQUAL ${arg_BUILD_PTX_FOR_ARCH})
|
| 268 |
-
string(REPLACE "." "" _PTX_ARCH "${arg_BUILD_PTX_FOR_ARCH}")
|
| 269 |
-
set_gencode_flag_for_srcs(
|
| 270 |
-
SRCS ${arg_SRCS}
|
| 271 |
-
ARCH "compute_${_PTX_ARCH}"
|
| 272 |
-
CODE "compute_${_PTX_ARCH}")
|
| 273 |
-
endif()
|
| 274 |
-
endif()
|
| 275 |
-
endmacro()
|
| 276 |
-
|
| 277 |
-
#
|
| 278 |
-
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
| 279 |
-
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
| 280 |
-
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
|
| 281 |
-
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
|
| 282 |
-
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
|
| 283 |
-
# architecture in `SRC_CUDA_ARCHS`.
|
| 284 |
-
# The loose intersection is defined as:
|
| 285 |
-
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
|
| 286 |
-
# where `<=` is the version comparison operator.
|
| 287 |
-
# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
|
| 288 |
-
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
|
| 289 |
-
# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is
|
| 290 |
-
# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add
|
| 291 |
-
# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
|
| 292 |
-
# The result is stored in `OUT_CUDA_ARCHS`.
|
| 293 |
-
#
|
| 294 |
-
# Example:
|
| 295 |
-
# SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a"
|
| 296 |
-
# TGT_CUDA_ARCHS="8.0;8.9;9.0"
|
| 297 |
-
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
| 298 |
-
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
|
| 299 |
-
#
|
| 300 |
-
# Example With PTX:
|
| 301 |
-
# SRC_CUDA_ARCHS="8.0+PTX"
|
| 302 |
-
# TGT_CUDA_ARCHS="9.0"
|
| 303 |
-
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
| 304 |
-
# OUT_CUDA_ARCHS="8.0+PTX"
|
| 305 |
-
#
|
| 306 |
-
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
| 307 |
-
set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
|
| 308 |
-
set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})
|
| 309 |
-
|
| 310 |
-
# handle +PTX suffix: separate base arch for matching, record PTX requests
|
| 311 |
-
set(_PTX_ARCHS)
|
| 312 |
-
foreach(_arch ${_SRC_CUDA_ARCHS})
|
| 313 |
-
if(_arch MATCHES "\\+PTX$")
|
| 314 |
-
string(REPLACE "+PTX" "" _base "${_arch}")
|
| 315 |
-
list(APPEND _PTX_ARCHS "${_base}")
|
| 316 |
-
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
|
| 317 |
-
list(APPEND _SRC_CUDA_ARCHS "${_base}")
|
| 318 |
-
endif()
|
| 319 |
-
endforeach()
|
| 320 |
-
list(REMOVE_DUPLICATES _PTX_ARCHS)
|
| 321 |
-
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
|
| 322 |
-
|
| 323 |
-
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
|
| 324 |
-
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
|
| 325 |
-
set(_CUDA_ARCHS)
|
| 326 |
-
foreach(_arch ${_SRC_CUDA_ARCHS})
|
| 327 |
-
if(_arch MATCHES "\\a$")
|
| 328 |
-
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
|
| 329 |
-
string(REPLACE "a" "" _base "${_arch}")
|
| 330 |
-
if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
|
| 331 |
-
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
|
| 332 |
-
list(APPEND _CUDA_ARCHS "${_arch}")
|
| 333 |
-
endif()
|
| 334 |
-
endif()
|
| 335 |
-
endforeach()
|
| 336 |
-
|
| 337 |
-
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
| 338 |
-
|
| 339 |
-
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
|
| 340 |
-
# is less or equal to ARCH (but has the same major version since SASS binary
|
| 341 |
-
# compatibility is only forward compatible within the same major version).
|
| 342 |
-
foreach(_ARCH ${_TGT_CUDA_ARCHS})
|
| 343 |
-
set(_TMP_ARCH)
|
| 344 |
-
# Extract the major version of the target arch
|
| 345 |
-
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
|
| 346 |
-
foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
|
| 347 |
-
# Extract the major version of the source arch
|
| 348 |
-
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
|
| 349 |
-
# Check version-less-or-equal, and allow PTX arches to match across majors
|
| 350 |
-
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
| 351 |
-
if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
|
| 352 |
-
set(_TMP_ARCH "${_SRC_ARCH}")
|
| 353 |
-
endif()
|
| 354 |
-
else()
|
| 355 |
-
# If we hit a version greater than the target, we can break
|
| 356 |
-
break()
|
| 357 |
-
endif()
|
| 358 |
-
endforeach()
|
| 359 |
-
|
| 360 |
-
# If we found a matching _TMP_ARCH, append it to _CUDA_ARCHS
|
| 361 |
-
if (_TMP_ARCH)
|
| 362 |
-
list(APPEND _CUDA_ARCHS "${_TMP_ARCH}")
|
| 363 |
-
endif()
|
| 364 |
-
endforeach()
|
| 365 |
-
|
| 366 |
-
list(REMOVE_DUPLICATES _CUDA_ARCHS)
|
| 367 |
-
|
| 368 |
-
# reapply +PTX suffix to architectures that requested PTX
|
| 369 |
-
set(_FINAL_ARCHS)
|
| 370 |
-
foreach(_arch ${_CUDA_ARCHS})
|
| 371 |
-
if(_arch IN_LIST _PTX_ARCHS)
|
| 372 |
-
list(APPEND _FINAL_ARCHS "${_arch}+PTX")
|
| 373 |
-
else()
|
| 374 |
-
list(APPEND _FINAL_ARCHS "${_arch}")
|
| 375 |
-
endif()
|
| 376 |
-
endforeach()
|
| 377 |
-
set(_CUDA_ARCHS ${_FINAL_ARCHS})
|
| 378 |
-
|
| 379 |
-
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
| 380 |
-
endfunction()
|
| 381 |
-
|
| 382 |
-
#
|
| 383 |
-
# For the given `SRC_ROCM_ARCHS` list of architecture versions in the form
|
| 384 |
-
# `<name>` compute the "loose intersection" with the `TGT_ROCM_ARCHS` list.
|
| 385 |
-
# The loose intersection is defined as:
|
| 386 |
-
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
|
| 387 |
-
# where `<=` is the version comparison operator.
|
| 388 |
-
# In other words, for each version in `TGT_ROCM_ARCHS` find the highest version
|
| 389 |
-
# in `SRC_ROCM_ARCHS` that is less or equal to the version in `TGT_ROCM_ARCHS`.
|
| 390 |
-
# The result is stored in `OUT_ROCM_ARCHS`.
|
| 391 |
-
#
|
| 392 |
-
# Example:
|
| 393 |
-
# SRC_ROCM_ARCHS="gfx900;gfx906;gfx908;gfx90a"
|
| 394 |
-
# TGT_ROCM_ARCHS="gfx906;gfx908;gfx1030"
|
| 395 |
-
# hip_archs_loose_intersection(OUT_ROCM_ARCHS SRC_ROCM_ARCHS TGT_ROCM_ARCHS)
|
| 396 |
-
# OUT_ROCM_ARCHS="gfx906;gfx908"
|
| 397 |
-
#
|
| 398 |
-
function(hip_archs_loose_intersection OUT_ROCM_ARCHS SRC_ROCM_ARCHS TGT_ROCM_ARCHS)
|
| 399 |
-
list(REMOVE_DUPLICATES SRC_ROCM_ARCHS)
|
| 400 |
-
|
| 401 |
-
# ROCm architectures are typically in format gfxNNN or gfxNNNx where N is a digit
|
| 402 |
-
# and x is a letter. We can sort them by string comparison which works for this format.
|
| 403 |
-
list(SORT SRC_ROCM_ARCHS COMPARE STRING ORDER ASCENDING)
|
| 404 |
-
|
| 405 |
-
set(_ROCM_ARCHS)
|
| 406 |
-
|
| 407 |
-
# Find the intersection of supported architectures
|
| 408 |
-
foreach(_SRC_ARCH ${SRC_ROCM_ARCHS})
|
| 409 |
-
if(_SRC_ARCH IN_LIST TGT_ROCM_ARCHS)
|
| 410 |
-
list(APPEND _ROCM_ARCHS ${_SRC_ARCH})
|
| 411 |
-
endif()
|
| 412 |
-
endforeach()
|
| 413 |
-
|
| 414 |
-
list(REMOVE_DUPLICATES _ROCM_ARCHS)
|
| 415 |
-
set(${OUT_ROCM_ARCHS} ${_ROCM_ARCHS} PARENT_SCOPE)
|
| 416 |
-
endfunction()
|
| 417 |
-
|
| 418 |
-
#
|
| 419 |
-
# Override the GPU architectures detected by cmake/torch and filter them by
|
| 420 |
-
# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
|
| 421 |
-
# `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set
|
| 422 |
-
# the architectures on a per file basis.
|
| 423 |
-
#
|
| 424 |
-
# Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.
|
| 425 |
-
#
|
| 426 |
-
macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
|
| 427 |
-
set(_GPU_SUPPORTED_ARCHES_LIST ${GPU_SUPPORTED_ARCHES} ${ARGN})
|
| 428 |
-
message(STATUS "${GPU_LANG} supported arches: ${_GPU_SUPPORTED_ARCHES_LIST}")
|
| 429 |
-
|
| 430 |
-
if (${GPU_LANG} STREQUAL "HIP")
|
| 431 |
-
#
|
| 432 |
-
# `GPU_ARCHES` controls the `--offload-arch` flags.
|
| 433 |
-
#
|
| 434 |
-
# If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list,
|
| 435 |
-
# if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling
|
| 436 |
-
# "rocm_agent_enumerator" in "enable_language(HIP)"
|
| 437 |
-
# (in file Modules/CMakeDetermineHIPCompiler.cmake)
|
| 438 |
-
#
|
| 439 |
-
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
|
| 440 |
-
set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH})
|
| 441 |
-
else()
|
| 442 |
-
set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES})
|
| 443 |
-
endif()
|
| 444 |
-
#
|
| 445 |
-
# Find the intersection of the supported + detected architectures to
|
| 446 |
-
# set the module architecture flags.
|
| 447 |
-
#
|
| 448 |
-
set(${GPU_ARCHES})
|
| 449 |
-
foreach (_ARCH ${HIP_ARCHITECTURES})
|
| 450 |
-
if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
|
| 451 |
-
list(APPEND ${GPU_ARCHES} ${_ARCH})
|
| 452 |
-
endif()
|
| 453 |
-
endforeach()
|
| 454 |
-
|
| 455 |
-
if(NOT ${GPU_ARCHES})
|
| 456 |
-
message(FATAL_ERROR
|
| 457 |
-
"None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
|
| 458 |
-
" supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
|
| 459 |
-
endif()
|
| 460 |
-
endif()
|
| 461 |
-
endmacro()
|
| 462 |
-
|
| 463 |
-
#
|
| 464 |
-
# Define a target named `GPU_MOD_NAME` for a single extension. The
|
| 465 |
-
# arguments are:
|
| 466 |
-
#
|
| 467 |
-
# DESTINATION <dest> - Module destination directory.
|
| 468 |
-
# LANGUAGE <lang> - The GPU language for this module, e.g CUDA, HIP,
|
| 469 |
-
# etc.
|
| 470 |
-
# SOURCES <sources> - List of source files relative to CMakeLists.txt
|
| 471 |
-
# directory.
|
| 472 |
-
#
|
| 473 |
-
# Optional arguments:
|
| 474 |
-
#
|
| 475 |
-
# ARCHITECTURES <arches> - A list of target GPU architectures in cmake
|
| 476 |
-
# format.
|
| 477 |
-
# Refer `CMAKE_CUDA_ARCHITECTURES` documentation
|
| 478 |
-
# and `CMAKE_HIP_ARCHITECTURES` for more info.
|
| 479 |
-
# ARCHITECTURES will use cmake's defaults if
|
| 480 |
-
# not provided.
|
| 481 |
-
# COMPILE_FLAGS <flags> - Extra compiler flags passed to NVCC/hip.
|
| 482 |
-
# INCLUDE_DIRECTORIES <dirs> - Extra include directories.
|
| 483 |
-
# LIBRARIES <libraries> - Extra link libraries.
|
| 484 |
-
# WITH_SOABI - Generate library with python SOABI suffix name.
|
| 485 |
-
# USE_SABI <version> - Use python stable api <version>
|
| 486 |
-
#
|
| 487 |
-
# Note: optimization level/debug info is set via cmake build type.
|
| 488 |
-
#
|
| 489 |
-
function (define_gpu_extension_target GPU_MOD_NAME)
|
| 490 |
-
cmake_parse_arguments(PARSE_ARGV 1
|
| 491 |
-
GPU
|
| 492 |
-
"WITH_SOABI"
|
| 493 |
-
"DESTINATION;LANGUAGE;USE_SABI"
|
| 494 |
-
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
|
| 495 |
-
|
| 496 |
-
# Add hipify preprocessing step when building with HIP/ROCm.
|
| 497 |
-
if (GPU_LANGUAGE STREQUAL "HIP")
|
| 498 |
-
hipify_sources_target(GPU_SOURCES ${GPU_MOD_NAME} "${GPU_SOURCES}")
|
| 499 |
-
endif()
|
| 500 |
-
|
| 501 |
-
if (GPU_WITH_SOABI)
|
| 502 |
-
set(GPU_WITH_SOABI WITH_SOABI)
|
| 503 |
-
else()
|
| 504 |
-
set(GPU_WITH_SOABI)
|
| 505 |
-
endif()
|
| 506 |
-
|
| 507 |
-
if (GPU_USE_SABI)
|
| 508 |
-
Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
|
| 509 |
-
else()
|
| 510 |
-
Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
|
| 511 |
-
endif()
|
| 512 |
-
|
| 513 |
-
if (GPU_LANGUAGE STREQUAL "HIP")
|
| 514 |
-
# Make this target dependent on the hipify preprocessor step.
|
| 515 |
-
add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME})
|
| 516 |
-
endif()
|
| 517 |
-
|
| 518 |
-
if (GPU_ARCHITECTURES)
|
| 519 |
-
set_target_properties(${GPU_MOD_NAME} PROPERTIES
|
| 520 |
-
${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}")
|
| 521 |
-
endif()
|
| 522 |
-
|
| 523 |
-
set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17)
|
| 524 |
-
|
| 525 |
-
target_compile_options(${GPU_MOD_NAME} PRIVATE
|
| 526 |
-
$<$<COMPILE_LANGUAGE:${GPU_LANGUAGE}>:${GPU_COMPILE_FLAGS}>)
|
| 527 |
-
|
| 528 |
-
target_compile_definitions(${GPU_MOD_NAME} PRIVATE
|
| 529 |
-
"-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}")
|
| 530 |
-
|
| 531 |
-
target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
|
| 532 |
-
${GPU_INCLUDE_DIRECTORIES})
|
| 533 |
-
|
| 534 |
-
target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES})
|
| 535 |
-
|
| 536 |
-
# Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of
|
| 537 |
-
# dependencies that are not necessary and may not be installed.
|
| 538 |
-
if (GPU_LANGUAGE STREQUAL "CUDA")
|
| 539 |
-
target_link_libraries(${GPU_MOD_NAME} PRIVATE CUDA::cudart)
|
| 540 |
-
else()
|
| 541 |
-
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
|
| 542 |
-
endif()
|
| 543 |
-
|
| 544 |
-
install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME})
|
| 545 |
-
endfunction()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_norm/ln.h
DELETED
|
@@ -1,281 +0,0 @@
|
|
| 1 |
-
#pragma once
|
| 2 |
-
|
| 3 |
-
#include <unordered_map>
|
| 4 |
-
#include <cuda_fp16.h>
|
| 5 |
-
#include <cuda_bf16.h>
|
| 6 |
-
|
| 7 |
-
#ifdef OLD_GENERATOR_PATH
|
| 8 |
-
#include <ATen/CUDAGeneratorImpl.h>
|
| 9 |
-
#else
|
| 10 |
-
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
| 11 |
-
#endif
|
| 12 |
-
|
| 13 |
-
namespace layer_norm {
|
| 14 |
-
|
| 15 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 16 |
-
|
| 17 |
-
template<typename Params>
|
| 18 |
-
struct LaunchParams{
|
| 19 |
-
|
| 20 |
-
size_t elts_per_thread;
|
| 21 |
-
size_t workspace_bytes;
|
| 22 |
-
size_t barrier_size;
|
| 23 |
-
|
| 24 |
-
cudaDeviceProp * props;
|
| 25 |
-
|
| 26 |
-
cudaStream_t stream;
|
| 27 |
-
|
| 28 |
-
Params params;
|
| 29 |
-
|
| 30 |
-
};
|
| 31 |
-
|
| 32 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 33 |
-
|
| 34 |
-
struct ParamsBase {
|
| 35 |
-
ParamsBase()
|
| 36 |
-
: ctas_per_col(0)
|
| 37 |
-
, rows(0)
|
| 38 |
-
, cols(0)
|
| 39 |
-
, x(nullptr)
|
| 40 |
-
, mu(nullptr)
|
| 41 |
-
, rs(nullptr)
|
| 42 |
-
, gamma(nullptr)
|
| 43 |
-
, gamma1(nullptr)
|
| 44 |
-
, rowscale(nullptr)
|
| 45 |
-
, colscale(nullptr)
|
| 46 |
-
, dropout_keep_p(1.f)
|
| 47 |
-
, dropout_scale(1.f)
|
| 48 |
-
, is_rms_norm(false)
|
| 49 |
-
, workspace(nullptr)
|
| 50 |
-
, barrier(nullptr)
|
| 51 |
-
{
|
| 52 |
-
}
|
| 53 |
-
|
| 54 |
-
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
|
| 55 |
-
int ctas_per_col;
|
| 56 |
-
|
| 57 |
-
// Input is interpreted as matrix. We normalize across columns.
|
| 58 |
-
int rows;
|
| 59 |
-
int cols;
|
| 60 |
-
|
| 61 |
-
// Common data pointers.
|
| 62 |
-
void *x0;
|
| 63 |
-
void *x1;
|
| 64 |
-
void *residual;
|
| 65 |
-
void *x;
|
| 66 |
-
void *dmask;
|
| 67 |
-
void *dmask1;
|
| 68 |
-
void *mu;
|
| 69 |
-
void *rs;
|
| 70 |
-
void *gamma;
|
| 71 |
-
void *gamma1;
|
| 72 |
-
void *rowscale;
|
| 73 |
-
void *colscale;
|
| 74 |
-
void *x0_subset;
|
| 75 |
-
void *z_subset;
|
| 76 |
-
|
| 77 |
-
float inverse_cols;
|
| 78 |
-
|
| 79 |
-
float dropout_keep_p;
|
| 80 |
-
float dropout_scale;
|
| 81 |
-
float rowscale_const;
|
| 82 |
-
|
| 83 |
-
bool is_rms_norm;
|
| 84 |
-
|
| 85 |
-
// Multi-CTA workspace in gmem.
|
| 86 |
-
void *workspace;
|
| 87 |
-
|
| 88 |
-
// Multi-CTA sync barriers in gmem.
|
| 89 |
-
int *barrier;
|
| 90 |
-
|
| 91 |
-
};
|
| 92 |
-
|
| 93 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 94 |
-
|
| 95 |
-
struct FwdParams : public ParamsBase {
|
| 96 |
-
FwdParams()
|
| 97 |
-
: ParamsBase()
|
| 98 |
-
, z(nullptr)
|
| 99 |
-
, z1(nullptr)
|
| 100 |
-
, beta(nullptr)
|
| 101 |
-
, beta1(nullptr)
|
| 102 |
-
, epsilon(0.f)
|
| 103 |
-
{
|
| 104 |
-
}
|
| 105 |
-
|
| 106 |
-
// Output of LN FWD.
|
| 107 |
-
void *z;
|
| 108 |
-
void *z1;
|
| 109 |
-
void *beta;
|
| 110 |
-
void *beta1;
|
| 111 |
-
float epsilon;
|
| 112 |
-
|
| 113 |
-
// Random state.
|
| 114 |
-
at::PhiloxCudaState philox_args;
|
| 115 |
-
};
|
| 116 |
-
|
| 117 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 118 |
-
|
| 119 |
-
struct BwdParams : public ParamsBase {
|
| 120 |
-
BwdParams()
|
| 121 |
-
: ParamsBase()
|
| 122 |
-
, dz(nullptr)
|
| 123 |
-
, dz1(nullptr)
|
| 124 |
-
, dx(nullptr)
|
| 125 |
-
, dbeta_part(nullptr)
|
| 126 |
-
, dgamma_part(nullptr)
|
| 127 |
-
, dbeta1_part(nullptr)
|
| 128 |
-
, dgamma1_part(nullptr)
|
| 129 |
-
, dcolscale_part(nullptr)
|
| 130 |
-
, dx0(nullptr)
|
| 131 |
-
, dx1(nullptr)
|
| 132 |
-
, dresidual(nullptr)
|
| 133 |
-
, dbeta(nullptr)
|
| 134 |
-
, dgamma(nullptr)
|
| 135 |
-
, dbeta1(nullptr)
|
| 136 |
-
, dgamma1(nullptr)
|
| 137 |
-
, dcolscale(nullptr)
|
| 138 |
-
{
|
| 139 |
-
}
|
| 140 |
-
|
| 141 |
-
// Input: gradient wrt. LN FWD output.
|
| 142 |
-
void *dz;
|
| 143 |
-
void *dz1;
|
| 144 |
-
// Input: gradient wrt residual.
|
| 145 |
-
void *dx;
|
| 146 |
-
|
| 147 |
-
// Workspace for Wgrad pre-reduction.
|
| 148 |
-
void *dbeta_part;
|
| 149 |
-
void *dgamma_part;
|
| 150 |
-
void *dbeta1_part;
|
| 151 |
-
void *dgamma1_part;
|
| 152 |
-
void *dcolscale_part;
|
| 153 |
-
|
| 154 |
-
// Output: Dgrad.
|
| 155 |
-
void *dx0;
|
| 156 |
-
void *dx1;
|
| 157 |
-
void *dresidual;
|
| 158 |
-
// Output: Wgrad.
|
| 159 |
-
void *dbeta;
|
| 160 |
-
void *dgamma;
|
| 161 |
-
void *dbeta1;
|
| 162 |
-
void *dgamma1;
|
| 163 |
-
void *dcolscale;
|
| 164 |
-
|
| 165 |
-
};
|
| 166 |
-
|
| 167 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 168 |
-
|
| 169 |
-
using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
|
| 170 |
-
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;
|
| 171 |
-
using FunctionKey = uint64_t;
|
| 172 |
-
using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
|
| 173 |
-
using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;
|
| 174 |
-
|
| 175 |
-
extern FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS;
|
| 176 |
-
extern BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS;
|
| 177 |
-
|
| 178 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 179 |
-
|
| 180 |
-
using fp32 = float;
|
| 181 |
-
using fp16 = half;
|
| 182 |
-
using bf16 = nv_bfloat16;
|
| 183 |
-
|
| 184 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 185 |
-
|
| 186 |
-
template<typename T>
|
| 187 |
-
struct TypeId{};
|
| 188 |
-
|
| 189 |
-
template<>
|
| 190 |
-
struct TypeId<fp16>{
|
| 191 |
-
constexpr static uint32_t Value = 0;
|
| 192 |
-
};
|
| 193 |
-
|
| 194 |
-
template<>
|
| 195 |
-
struct TypeId<bf16>{
|
| 196 |
-
constexpr static uint32_t Value = 1;
|
| 197 |
-
};
|
| 198 |
-
|
| 199 |
-
template<>
|
| 200 |
-
struct TypeId<fp32>{
|
| 201 |
-
constexpr static uint32_t Value = 2;
|
| 202 |
-
};
|
| 203 |
-
|
| 204 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 205 |
-
|
| 206 |
-
template<typename T, int S>
|
| 207 |
-
struct Type2Key{
|
| 208 |
-
constexpr static uint32_t Value = TypeId<T>::Value << S;
|
| 209 |
-
};
|
| 210 |
-
|
| 211 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 212 |
-
|
| 213 |
-
template<typename T>
|
| 214 |
-
struct WeightType2Key : public Type2Key<T, 0>{};
|
| 215 |
-
|
| 216 |
-
template<typename T>
|
| 217 |
-
struct InputType2Key : public Type2Key<T, 2>{};
|
| 218 |
-
|
| 219 |
-
template<typename T>
|
| 220 |
-
struct ResidualType2Key : public Type2Key<T, 4>{};
|
| 221 |
-
|
| 222 |
-
template<typename T>
|
| 223 |
-
struct OutputType2Key : public Type2Key<T, 6>{};
|
| 224 |
-
|
| 225 |
-
template<typename T>
|
| 226 |
-
struct ComputeType2Key : public Type2Key<T, 8>{};
|
| 227 |
-
|
| 228 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 229 |
-
|
| 230 |
-
template<typename W, typename I, typename R, typename O, typename C>
|
| 231 |
-
struct Types2Key{
|
| 232 |
-
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | ResidualType2Key<R>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
|
| 233 |
-
constexpr static inline uint64_t get(const uint64_t hidden_size){
|
| 234 |
-
constexpr uint64_t type_key = Value;
|
| 235 |
-
return (type_key << 32) | hidden_size;
|
| 236 |
-
}
|
| 237 |
-
};
|
| 238 |
-
|
| 239 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 240 |
-
|
| 241 |
-
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
|
| 242 |
-
struct FwdRegistrar{
|
| 243 |
-
FwdRegistrar(FwdFunction f){
|
| 244 |
-
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
|
| 245 |
-
FWD_FUNCS.insert({ key, f });
|
| 246 |
-
}
|
| 247 |
-
};
|
| 248 |
-
|
| 249 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 250 |
-
|
| 251 |
-
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
|
| 252 |
-
struct BwdRegistrar{
|
| 253 |
-
BwdRegistrar(BwdFunction f){
|
| 254 |
-
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
|
| 255 |
-
BWD_FUNCS.insert({ key, f });
|
| 256 |
-
}
|
| 257 |
-
};
|
| 258 |
-
|
| 259 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 260 |
-
|
| 261 |
-
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
|
| 262 |
-
struct FwdParallelRegistrar{
|
| 263 |
-
FwdParallelRegistrar(FwdFunction f){
|
| 264 |
-
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
|
| 265 |
-
PARALLEL_FWD_FUNCS.insert({ key, f });
|
| 266 |
-
}
|
| 267 |
-
};
|
| 268 |
-
|
| 269 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 270 |
-
|
| 271 |
-
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
|
| 272 |
-
struct BwdParallelRegistrar{
|
| 273 |
-
BwdParallelRegistrar(BwdFunction f){
|
| 274 |
-
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
|
| 275 |
-
PARALLEL_BWD_FUNCS.insert({ key, f });
|
| 276 |
-
}
|
| 277 |
-
};
|
| 278 |
-
|
| 279 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 280 |
-
|
| 281 |
-
} // namespace layer_norm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_norm/ln_api.cpp
DELETED
|
@@ -1,828 +0,0 @@
|
|
| 1 |
-
#include <torch/torch.h>
|
| 2 |
-
#include "ATen/cuda/CUDAContext.h"
|
| 3 |
-
#include <c10/cuda/CUDAGuard.h>
|
| 4 |
-
|
| 5 |
-
#include "ln.h"
|
| 6 |
-
|
| 7 |
-
/*
|
| 8 |
-
|
| 9 |
-
Supported Type combinations:
|
| 10 |
-
|
| 11 |
-
input residual compute weights output
|
| 12 |
-
============================================
|
| 13 |
-
fp32 fp32 fp32 fp32 fp32
|
| 14 |
-
fp16 fp32 fp32 fp32 fp16
|
| 15 |
-
fp16 fp16 fp32 fp32 fp16
|
| 16 |
-
bf16 fp32 fp32 fp32 bf16
|
| 17 |
-
bf16 bf16 fp32 fp32 bf16
|
| 18 |
-
fp16 fp16 fp32 fp16 fp16
|
| 19 |
-
bf16 bf16 fp32 bf16 bf16
|
| 20 |
-
|
| 21 |
-
Remarks:
|
| 22 |
-
Output type = Input type
|
| 23 |
-
Compute always in FP32
|
| 24 |
-
|
| 25 |
-
*/
|
| 26 |
-
|
| 27 |
-
namespace layer_norm {
|
| 28 |
-
|
| 29 |
-
// Create registries and provide runtime versions of config hash functions.
|
| 30 |
-
|
| 31 |
-
FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS;
|
| 32 |
-
BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS;
|
| 33 |
-
|
| 34 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 35 |
-
|
| 36 |
-
uint32_t get_type_id(torch::Dtype dtype){
|
| 37 |
-
if( dtype == torch::kFloat16 ) {
|
| 38 |
-
return TypeId<fp16>::Value;
|
| 39 |
-
} else if( dtype == torch::kBFloat16 ) {
|
| 40 |
-
return TypeId<bf16>::Value;
|
| 41 |
-
} else if( dtype == torch::kFloat32 ) {
|
| 42 |
-
return TypeId<fp32>::Value;
|
| 43 |
-
} else {
|
| 44 |
-
TORCH_CHECK(false, "Type not supported: ", dtype);
|
| 45 |
-
}
|
| 46 |
-
}
|
| 47 |
-
|
| 48 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 49 |
-
|
| 50 |
-
uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) {
|
| 51 |
-
using namespace layer_norm;
|
| 52 |
-
uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(rtype) << 4) | (get_type_id(otype) << 6) | (get_type_id(ctype) << 8);
|
| 53 |
-
uint64_t launcher_key = (type_key << 32) | hidden_size;
|
| 54 |
-
return launcher_key;
|
| 55 |
-
}
|
| 56 |
-
|
| 57 |
-
} // namespace layer_norm
|
| 58 |
-
|
| 59 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 60 |
-
|
| 61 |
-
layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
|
| 62 |
-
auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
|
| 63 |
-
if( iter != layer_norm::FWD_FUNCS.end() ) {
|
| 64 |
-
return iter->second;
|
| 65 |
-
} else {
|
| 66 |
-
TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
|
| 67 |
-
}
|
| 68 |
-
}
|
| 69 |
-
|
| 70 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 71 |
-
|
| 72 |
-
layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
|
| 73 |
-
auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
|
| 74 |
-
if( iter != layer_norm::BWD_FUNCS.end() ) {
|
| 75 |
-
return iter->second;
|
| 76 |
-
} else {
|
| 77 |
-
TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
|
| 78 |
-
}
|
| 79 |
-
}
|
| 80 |
-
|
| 81 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 82 |
-
|
| 83 |
-
layer_norm::FwdFunction & get_parallel_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
|
| 84 |
-
auto iter = layer_norm::PARALLEL_FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
|
| 85 |
-
if( iter != layer_norm::PARALLEL_FWD_FUNCS.end() ) {
|
| 86 |
-
return iter->second;
|
| 87 |
-
} else {
|
| 88 |
-
TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
|
| 89 |
-
}
|
| 90 |
-
}
|
| 91 |
-
|
| 92 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 93 |
-
|
| 94 |
-
layer_norm::BwdFunction & get_parallel_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
|
| 95 |
-
auto iter = layer_norm::PARALLEL_BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
|
| 96 |
-
if( iter != layer_norm::PARALLEL_BWD_FUNCS.end() ) {
|
| 97 |
-
return iter->second;
|
| 98 |
-
} else {
|
| 99 |
-
TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
|
| 100 |
-
}
|
| 101 |
-
}
|
| 102 |
-
|
| 103 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 104 |
-
|
| 105 |
-
std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size
|
| 106 |
-
c10::optional<const at::Tensor> &residual_, // Residual: BxSxhidden_size
|
| 107 |
-
const at::Tensor &gamma, // hidden_size
|
| 108 |
-
c10::optional<const at::Tensor> &beta_, // hidden_size
|
| 109 |
-
c10::optional<const at::Tensor> &rowscale_, // BxS
|
| 110 |
-
c10::optional<const at::Tensor> &colscale_, // hidden_size
|
| 111 |
-
c10::optional<const at::Tensor> &x0_subset_, // BxS
|
| 112 |
-
c10::optional<const at::Tensor> &z_subset_, // BxS
|
| 113 |
-
const float dropout_p,
|
| 114 |
-
const float epsilon,
|
| 115 |
-
const float rowscale_const,
|
| 116 |
-
const int64_t z_numrows,
|
| 117 |
-
c10::optional<at::Generator> gen_,
|
| 118 |
-
bool residual_in_fp32=false,
|
| 119 |
-
bool is_rms_norm=false
|
| 120 |
-
) {
|
| 121 |
-
auto itype = x0.scalar_type();
|
| 122 |
-
auto rtype = residual_.has_value()
|
| 123 |
-
? residual_.value().scalar_type()
|
| 124 |
-
: (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
|
| 125 |
-
auto wtype = gamma.scalar_type();
|
| 126 |
-
auto otype = itype;
|
| 127 |
-
auto ctype = torch::kFloat32;
|
| 128 |
-
auto mtype = torch::kUInt8;
|
| 129 |
-
|
| 130 |
-
TORCH_CHECK(x0.is_cuda());
|
| 131 |
-
TORCH_CHECK(gamma.is_cuda());
|
| 132 |
-
|
| 133 |
-
TORCH_CHECK(x0.is_contiguous());
|
| 134 |
-
// c10::IntArrayRef does not own the storage, so we need to construct a vector.
|
| 135 |
-
// Otherwise just constructing IntArrayRef({blah}) will cause uninitialized memory because
|
| 136 |
-
// blah is then deallocated.
|
| 137 |
-
std::vector<int64_t> sizes_vec {!x0_subset_.has_value() ? x0.size(0) : x0_subset_.value().size(0), x0.size(1)};
|
| 138 |
-
auto sizes = c10::IntArrayRef(sizes_vec);
|
| 139 |
-
TORCH_CHECK(x0.dim() == 2);
|
| 140 |
-
TORCH_CHECK(sizes.size() == 2);
|
| 141 |
-
|
| 142 |
-
const int rows = sizes[0];
|
| 143 |
-
const int cols = sizes[1];
|
| 144 |
-
auto hidden_size = gamma.numel();
|
| 145 |
-
TORCH_CHECK(hidden_size == cols);
|
| 146 |
-
|
| 147 |
-
if (beta_.has_value()) {
|
| 148 |
-
auto beta = beta_.value();
|
| 149 |
-
TORCH_CHECK(beta.dtype() == wtype);
|
| 150 |
-
TORCH_CHECK(beta.is_cuda());
|
| 151 |
-
TORCH_CHECK(beta.is_contiguous());
|
| 152 |
-
TORCH_CHECK(beta.sizes() == gamma.sizes());
|
| 153 |
-
}
|
| 154 |
-
|
| 155 |
-
if (residual_.has_value()) {
|
| 156 |
-
auto residual = residual_.value();
|
| 157 |
-
TORCH_CHECK(residual.is_cuda());
|
| 158 |
-
TORCH_CHECK(residual.is_contiguous());
|
| 159 |
-
TORCH_CHECK(residual.sizes() == sizes);
|
| 160 |
-
}
|
| 161 |
-
|
| 162 |
-
if (rowscale_.has_value()) {
|
| 163 |
-
auto rowscale = rowscale_.value();
|
| 164 |
-
TORCH_CHECK(rowscale.is_cuda());
|
| 165 |
-
TORCH_CHECK(rowscale.is_contiguous());
|
| 166 |
-
TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});
|
| 167 |
-
TORCH_CHECK(rowscale.dtype() == itype);
|
| 168 |
-
}
|
| 169 |
-
|
| 170 |
-
if (colscale_.has_value()) {
|
| 171 |
-
auto colscale = colscale_.value();
|
| 172 |
-
TORCH_CHECK(colscale.is_cuda());
|
| 173 |
-
TORCH_CHECK(colscale.is_contiguous());
|
| 174 |
-
TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});
|
| 175 |
-
TORCH_CHECK(colscale.dtype() == wtype);
|
| 176 |
-
}
|
| 177 |
-
|
| 178 |
-
if (x0_subset_.has_value()) {
|
| 179 |
-
auto x0_subset = x0_subset_.value();
|
| 180 |
-
TORCH_CHECK(x0_subset.is_cuda());
|
| 181 |
-
TORCH_CHECK(x0_subset.is_contiguous());
|
| 182 |
-
TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});
|
| 183 |
-
TORCH_CHECK(x0_subset.dtype() == torch::kInt32);
|
| 184 |
-
|
| 185 |
-
TORCH_CHECK(z_subset_.has_value());
|
| 186 |
-
auto z_subset = z_subset_.value();
|
| 187 |
-
TORCH_CHECK(z_subset.is_cuda());
|
| 188 |
-
TORCH_CHECK(z_subset.is_contiguous());
|
| 189 |
-
TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows});
|
| 190 |
-
TORCH_CHECK(z_subset.dtype() == torch::kInt32);
|
| 191 |
-
}
|
| 192 |
-
|
| 193 |
-
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
|
| 194 |
-
TORCH_CHECK(epsilon >= 0.f);
|
| 195 |
-
|
| 196 |
-
// Otherwise the kernel will be launched from cuda:0 device
|
| 197 |
-
// Cast to char to avoid compiler warning about narrowing
|
| 198 |
-
at::cuda::CUDAGuard device_guard{(char)x0.get_device()};
|
| 199 |
-
|
| 200 |
-
auto opts = x0.options();
|
| 201 |
-
|
| 202 |
-
bool save_x = residual_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype);
|
| 203 |
-
at::Tensor x;
|
| 204 |
-
if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
|
| 205 |
-
at::Tensor dmask;
|
| 206 |
-
if (dropout_p > 0.f) { dmask = torch::empty(x0.sizes(), opts.dtype(mtype)); };
|
| 207 |
-
auto z = torch::empty(z_subset_.has_value() ? c10::IntArrayRef{z_numrows, cols} : sizes, opts.dtype(otype));
|
| 208 |
-
|
| 209 |
-
auto mu = torch::empty({ rows }, opts.dtype(ctype));
|
| 210 |
-
auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
|
| 211 |
-
|
| 212 |
-
layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
|
| 213 |
-
|
| 214 |
-
launch_params.props = at::cuda::getCurrentDeviceProperties();
|
| 215 |
-
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
|
| 216 |
-
TORCH_CHECK(dropout_p < 1.f);
|
| 217 |
-
launch_params.params.dropout_keep_p = 1.f - dropout_p;
|
| 218 |
-
launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr;
|
| 219 |
-
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
|
| 220 |
-
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
|
| 221 |
-
launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
|
| 222 |
-
launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr;
|
| 223 |
-
|
| 224 |
-
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
| 225 |
-
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
| 226 |
-
|
| 227 |
-
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 228 |
-
const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
|
| 229 |
-
// Request the kernel launcher.
|
| 230 |
-
auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
|
| 231 |
-
|
| 232 |
-
// Set the kernel runtime parameters.
|
| 233 |
-
layer_norm::FwdParams ¶ms = launch_params.params;
|
| 234 |
-
params.rows = rows;
|
| 235 |
-
params.cols = cols;
|
| 236 |
-
params.x0 = x0.data_ptr();
|
| 237 |
-
params.x = save_x ? x.data_ptr() : nullptr;
|
| 238 |
-
params.dmask = dropout_p > 0.f ? dmask.data_ptr() : nullptr;
|
| 239 |
-
params.mu = mu.data_ptr();
|
| 240 |
-
params.rs = rsigma.data_ptr();
|
| 241 |
-
params.gamma = gamma.data_ptr();
|
| 242 |
-
params.beta = beta_.has_value() ? beta_.value().data_ptr() : nullptr;
|
| 243 |
-
params.z = z.data_ptr();
|
| 244 |
-
params.epsilon = epsilon;
|
| 245 |
-
params.dropout_scale = 1.f / (1.f - dropout_p);
|
| 246 |
-
params.inverse_cols = 1.f / float(params.cols);
|
| 247 |
-
params.rowscale_const = rowscale_const;
|
| 248 |
-
params.is_rms_norm = is_rms_norm;
|
| 249 |
-
|
| 250 |
-
// Query the kernel-specific launch parameters.
|
| 251 |
-
launcher(launch_params, true);
|
| 252 |
-
|
| 253 |
-
at::Tensor workspace, barrier;
|
| 254 |
-
|
| 255 |
-
if (dropout_p > 0.f) {
|
| 256 |
-
// number of times random will be generated per thread, to offset philox counter in thc random
|
| 257 |
-
// state
|
| 258 |
-
int64_t counter_offset = launch_params.elts_per_thread;
|
| 259 |
-
|
| 260 |
-
// See Note [Acquire lock when using random generators]
|
| 261 |
-
{
|
| 262 |
-
std::lock_guard<std::mutex> lock(gen->mutex_);
|
| 263 |
-
params.philox_args = gen->philox_cuda_state(counter_offset);
|
| 264 |
-
}
|
| 265 |
-
}
|
| 266 |
-
|
| 267 |
-
if( launch_params.barrier_size > 0 ) {
|
| 268 |
-
auto options = x0.options();
|
| 269 |
-
barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
|
| 270 |
-
workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
|
| 271 |
-
params.workspace = workspace.data_ptr();
|
| 272 |
-
params.barrier = barrier.data_ptr<int>();
|
| 273 |
-
}
|
| 274 |
-
|
| 275 |
-
// Launch the kernel.
|
| 276 |
-
launcher(launch_params, false);
|
| 277 |
-
|
| 278 |
-
return { z, x, dmask, mu, rsigma };
|
| 279 |
-
}
|
| 280 |
-
|
| 281 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 282 |
-
|
| 283 |
-
std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidden_size
|
| 284 |
-
c10::optional<const at::Tensor> &dx_, // BxSxhidden_size
|
| 285 |
-
const at::Tensor &x, // BxSxhidden_size
|
| 286 |
-
c10::optional<const at::Tensor> &x0_, // BxSxhidden_size
|
| 287 |
-
c10::optional<const at::Tensor> &dmask_, // BxSxhidden_size
|
| 288 |
-
const at::Tensor &mu, // BxS, FP32!
|
| 289 |
-
const at::Tensor &rsigma, // BxS, FP32!
|
| 290 |
-
const at::Tensor &gamma, // hidden_size
|
| 291 |
-
c10::optional<const at::Tensor> &rowscale_, // BxS
|
| 292 |
-
c10::optional<const at::Tensor> &colscale_, // hidden_size
|
| 293 |
-
c10::optional<const at::Tensor> &x0_subset_, // BxS
|
| 294 |
-
c10::optional<const at::Tensor> &z_subset_, // BxS
|
| 295 |
-
const float dropout_p,
|
| 296 |
-
const float rowscale_const,
|
| 297 |
-
const int64_t x0_numrows,
|
| 298 |
-
const bool has_residual,
|
| 299 |
-
bool is_rms_norm=false
|
| 300 |
-
) {
|
| 301 |
-
|
| 302 |
-
auto itype = dz.scalar_type();
|
| 303 |
-
auto rtype = x.scalar_type();
|
| 304 |
-
auto wtype = gamma.scalar_type();
|
| 305 |
-
auto otype = itype;
|
| 306 |
-
auto ctype = torch::kFloat32;
|
| 307 |
-
auto mtype = torch::kUInt8;
|
| 308 |
-
|
| 309 |
-
if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); }
|
| 310 |
-
|
| 311 |
-
TORCH_CHECK(dz.dtype() == otype);
|
| 312 |
-
TORCH_CHECK(mu.dtype() == ctype);
|
| 313 |
-
TORCH_CHECK(rsigma.dtype() == ctype);
|
| 314 |
-
|
| 315 |
-
TORCH_CHECK(x.is_cuda());
|
| 316 |
-
TORCH_CHECK(dz.is_cuda());
|
| 317 |
-
TORCH_CHECK(mu.is_cuda());
|
| 318 |
-
TORCH_CHECK(rsigma.is_cuda());
|
| 319 |
-
TORCH_CHECK(gamma.is_cuda());
|
| 320 |
-
|
| 321 |
-
TORCH_CHECK(x.is_contiguous());
|
| 322 |
-
TORCH_CHECK(dz.is_contiguous());
|
| 323 |
-
|
| 324 |
-
auto sizes = x.sizes();
|
| 325 |
-
TORCH_CHECK(sizes.size() == 2);
|
| 326 |
-
auto rows = sizes[0];
|
| 327 |
-
auto cols = sizes[1];
|
| 328 |
-
TORCH_CHECK(dz.dim() == 2);
|
| 329 |
-
TORCH_CHECK(dz.size(1) == cols);
|
| 330 |
-
auto hidden_size = gamma.numel();
|
| 331 |
-
TORCH_CHECK(hidden_size == cols);
|
| 332 |
-
|
| 333 |
-
// c10::IntArrayRef does not own the storage, so we need to construct a vector.
|
| 334 |
-
// Otherwise just constructing IntArrayRef({blah}) will cause uninitialized memory because
|
| 335 |
-
// blah is then deallocated.
|
| 336 |
-
std::vector<int64_t> x0_sizes_vec {!x0_subset_.has_value() ? rows : x0_numrows, cols};
|
| 337 |
-
auto x0_sizes = c10::IntArrayRef(x0_sizes_vec);
|
| 338 |
-
|
| 339 |
-
if (dx_.has_value()) {
|
| 340 |
-
auto dx = dx_.value();
|
| 341 |
-
TORCH_CHECK(dx.dtype() == rtype);
|
| 342 |
-
TORCH_CHECK(dx.is_cuda());
|
| 343 |
-
TORCH_CHECK(dx.is_contiguous());
|
| 344 |
-
TORCH_CHECK(dx.sizes() == sizes);
|
| 345 |
-
}
|
| 346 |
-
|
| 347 |
-
if (dmask_.has_value()) {
|
| 348 |
-
auto dmask = dmask_.value();
|
| 349 |
-
TORCH_CHECK(dmask.dtype() == mtype);
|
| 350 |
-
TORCH_CHECK(dmask.is_cuda());
|
| 351 |
-
TORCH_CHECK(dmask.is_contiguous());
|
| 352 |
-
TORCH_CHECK(dmask.sizes() == x0_sizes);
|
| 353 |
-
}
|
| 354 |
-
|
| 355 |
-
if (rowscale_.has_value()) {
|
| 356 |
-
auto rowscale = rowscale_.value();
|
| 357 |
-
TORCH_CHECK(rowscale.is_cuda());
|
| 358 |
-
TORCH_CHECK(rowscale.is_contiguous());
|
| 359 |
-
TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});
|
| 360 |
-
TORCH_CHECK(rowscale.dtype() == itype);
|
| 361 |
-
}
|
| 362 |
-
|
| 363 |
-
if (colscale_.has_value()) {
|
| 364 |
-
auto colscale = colscale_.value();
|
| 365 |
-
TORCH_CHECK(colscale.is_cuda());
|
| 366 |
-
TORCH_CHECK(colscale.is_contiguous());
|
| 367 |
-
TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});
|
| 368 |
-
TORCH_CHECK(colscale.dtype() == wtype);
|
| 369 |
-
|
| 370 |
-
TORCH_CHECK(x0_.has_value());
|
| 371 |
-
auto x0 = x0_.value();
|
| 372 |
-
TORCH_CHECK(x0.is_cuda());
|
| 373 |
-
TORCH_CHECK(x0.is_contiguous());
|
| 374 |
-
TORCH_CHECK(x0.sizes() == x0_sizes);
|
| 375 |
-
TORCH_CHECK(x0.dtype() == itype);
|
| 376 |
-
}
|
| 377 |
-
|
| 378 |
-
if (x0_subset_.has_value()) {
|
| 379 |
-
auto x0_subset = x0_subset_.value();
|
| 380 |
-
TORCH_CHECK(x0_subset.is_cuda());
|
| 381 |
-
TORCH_CHECK(x0_subset.is_contiguous());
|
| 382 |
-
TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});
|
| 383 |
-
TORCH_CHECK(x0_subset.dtype() == torch::kInt32);
|
| 384 |
-
|
| 385 |
-
TORCH_CHECK(z_subset_.has_value());
|
| 386 |
-
auto z_subset = z_subset_.value();
|
| 387 |
-
TORCH_CHECK(z_subset.is_cuda());
|
| 388 |
-
TORCH_CHECK(z_subset.is_contiguous());
|
| 389 |
-
TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows});
|
| 390 |
-
TORCH_CHECK(z_subset.dtype() == torch::kInt32);
|
| 391 |
-
}
|
| 392 |
-
|
| 393 |
-
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
|
| 394 |
-
|
| 395 |
-
TORCH_CHECK(mu.numel() == rows);
|
| 396 |
-
TORCH_CHECK(mu.sizes() == rsigma.sizes());
|
| 397 |
-
|
| 398 |
-
TORCH_CHECK(gamma.numel() == cols);
|
| 399 |
-
|
| 400 |
-
// Otherwise the kernel will be launched from cuda:0 device
|
| 401 |
-
// Cast to char to avoid compiler warning about narrowing
|
| 402 |
-
at::cuda::CUDAGuard device_guard{(char)dz.get_device()};
|
| 403 |
-
|
| 404 |
-
auto opts = x.options();
|
| 405 |
-
|
| 406 |
-
auto dx0 = torch::empty(x0_sizes, opts.dtype(itype));
|
| 407 |
-
at::Tensor dresidual;
|
| 408 |
-
if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); }
|
| 409 |
-
auto dgamma = torch::empty_like(gamma);
|
| 410 |
-
auto dbeta = torch::empty_like(gamma);
|
| 411 |
-
at::Tensor dcolscale;
|
| 412 |
-
if (colscale_.has_value()) {
|
| 413 |
-
dcolscale = torch::empty_like(colscale_.value());
|
| 414 |
-
}
|
| 415 |
-
|
| 416 |
-
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
|
| 417 |
-
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
|
| 418 |
-
launch_params.props = at::cuda::getCurrentDeviceProperties();
|
| 419 |
-
TORCH_CHECK(dropout_p < 1.f);
|
| 420 |
-
launch_params.params.dropout_keep_p = 1.f - dropout_p;
|
| 421 |
-
launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr;
|
| 422 |
-
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
|
| 423 |
-
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
|
| 424 |
-
launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
|
| 425 |
-
launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr;
|
| 426 |
-
|
| 427 |
-
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 428 |
-
const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
|
| 429 |
-
auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
|
| 430 |
-
|
| 431 |
-
launcher(launch_params, true);
|
| 432 |
-
|
| 433 |
-
auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
|
| 434 |
-
auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
|
| 435 |
-
at::Tensor dcolscale_part;
|
| 436 |
-
if (colscale_.has_value()) {
|
| 437 |
-
dcolscale_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
|
| 438 |
-
}
|
| 439 |
-
at::Tensor workspace, barrier;
|
| 440 |
-
|
| 441 |
-
layer_norm::BwdParams ¶ms = launch_params.params;
|
| 442 |
-
params.rows = rows;
|
| 443 |
-
params.cols = cols;
|
| 444 |
-
params.x = x.data_ptr();
|
| 445 |
-
params.x0 = x0_.has_value() ? x0_.value().data_ptr() : nullptr;
|
| 446 |
-
params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
|
| 447 |
-
params.mu = mu.data_ptr();
|
| 448 |
-
params.rs = rsigma.data_ptr();
|
| 449 |
-
params.gamma = gamma.data_ptr();
|
| 450 |
-
params.dz = dz.data_ptr();
|
| 451 |
-
params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr;
|
| 452 |
-
params.dx0 = dx0.data_ptr();
|
| 453 |
-
params.dbeta = dbeta.data_ptr();
|
| 454 |
-
params.dgamma = dgamma.data_ptr();
|
| 455 |
-
params.dcolscale = colscale_.has_value() ? dcolscale.data_ptr() : nullptr;
|
| 456 |
-
params.dbeta_part = dbeta_part.data_ptr();
|
| 457 |
-
params.dgamma_part = dgamma_part.data_ptr();
|
| 458 |
-
params.dcolscale_part = colscale_.has_value() ? dcolscale_part.data_ptr() : nullptr;
|
| 459 |
-
params.dropout_scale = 1.f / (1.f - dropout_p);
|
| 460 |
-
params.inverse_cols = 1.f / float(params.cols);
|
| 461 |
-
params.rowscale_const = rowscale_const;
|
| 462 |
-
params.is_rms_norm = is_rms_norm;
|
| 463 |
-
|
| 464 |
-
if( launch_params.barrier_size > 0 ) {
|
| 465 |
-
// TODO Any way to avoid this?
|
| 466 |
-
barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
|
| 467 |
-
workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
|
| 468 |
-
params.workspace = workspace.data_ptr();
|
| 469 |
-
params.barrier = barrier.data_ptr<int>();
|
| 470 |
-
}
|
| 471 |
-
|
| 472 |
-
launcher(launch_params, false);
|
| 473 |
-
|
| 474 |
-
std::vector<at::Tensor> result = { dx0, dresidual, dgamma, dbeta, dgamma_part, dbeta_part };
|
| 475 |
-
if (colscale_.has_value()) {
|
| 476 |
-
result.push_back(dcolscale);
|
| 477 |
-
result.push_back(dcolscale_part);
|
| 478 |
-
}
|
| 479 |
-
return result;
|
| 480 |
-
}
|
| 481 |
-
|
| 482 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 483 |
-
|
| 484 |
-
std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(
|
| 485 |
-
const at::Tensor &x0, // Input: BxSxhidden_size
|
| 486 |
-
c10::optional<const at::Tensor> &x1_, // Input: BxSxhidden_size
|
| 487 |
-
c10::optional<const at::Tensor> &residual_, // Residual: BxSxhidden_size
|
| 488 |
-
const at::Tensor &gamma0, // hidden_size
|
| 489 |
-
c10::optional<const at::Tensor> &beta0_, // hidden_size
|
| 490 |
-
c10::optional<const at::Tensor> &gamma1_, // hidden_size
|
| 491 |
-
c10::optional<const at::Tensor> &beta1_, // hidden_size
|
| 492 |
-
const float dropout_p,
|
| 493 |
-
const float epsilon,
|
| 494 |
-
c10::optional<at::Generator> gen_,
|
| 495 |
-
bool residual_in_fp32=false,
|
| 496 |
-
bool is_rms_norm=false
|
| 497 |
-
) {
|
| 498 |
-
auto itype = x0.scalar_type();
|
| 499 |
-
auto rtype = residual_.has_value()
|
| 500 |
-
? residual_.value().scalar_type()
|
| 501 |
-
: (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
|
| 502 |
-
auto wtype = gamma0.scalar_type();
|
| 503 |
-
auto otype = itype;
|
| 504 |
-
auto ctype = torch::kFloat32;
|
| 505 |
-
auto mtype = torch::kUInt8;
|
| 506 |
-
|
| 507 |
-
TORCH_CHECK(x0.is_cuda());
|
| 508 |
-
TORCH_CHECK(gamma0.is_cuda());
|
| 509 |
-
|
| 510 |
-
TORCH_CHECK(x0.is_contiguous());
|
| 511 |
-
const auto sizes = x0.sizes();
|
| 512 |
-
TORCH_CHECK(x0.dim() == 2);
|
| 513 |
-
|
| 514 |
-
const int rows = sizes[0];
|
| 515 |
-
const int cols = sizes[1];
|
| 516 |
-
auto hidden_size = gamma0.numel();
|
| 517 |
-
TORCH_CHECK(hidden_size == cols);
|
| 518 |
-
|
| 519 |
-
if (x1_.has_value()) {
|
| 520 |
-
auto x1 = x1_.value();
|
| 521 |
-
TORCH_CHECK(x1.is_cuda());
|
| 522 |
-
TORCH_CHECK(x1.is_contiguous());
|
| 523 |
-
TORCH_CHECK(x1.sizes() == sizes);
|
| 524 |
-
}
|
| 525 |
-
|
| 526 |
-
if (residual_.has_value()) {
|
| 527 |
-
auto residual = residual_.value();
|
| 528 |
-
TORCH_CHECK(residual.is_cuda());
|
| 529 |
-
TORCH_CHECK(residual.is_contiguous());
|
| 530 |
-
TORCH_CHECK(residual.sizes() == sizes);
|
| 531 |
-
}
|
| 532 |
-
|
| 533 |
-
if (beta0_.has_value()) {
|
| 534 |
-
auto beta0 = beta0_.value();
|
| 535 |
-
TORCH_CHECK(beta0.dtype() == wtype);
|
| 536 |
-
TORCH_CHECK(beta0.is_cuda());
|
| 537 |
-
TORCH_CHECK(beta0.is_contiguous());
|
| 538 |
-
TORCH_CHECK(beta0.sizes() == gamma0.sizes());
|
| 539 |
-
}
|
| 540 |
-
|
| 541 |
-
if (gamma1_.has_value()) {
|
| 542 |
-
auto gamma1 = gamma1_.value();
|
| 543 |
-
TORCH_CHECK(gamma1.dtype() == wtype);
|
| 544 |
-
TORCH_CHECK(gamma1.is_cuda());
|
| 545 |
-
TORCH_CHECK(gamma1.is_contiguous());
|
| 546 |
-
TORCH_CHECK(gamma1.sizes() == gamma0.sizes());
|
| 547 |
-
}
|
| 548 |
-
|
| 549 |
-
if (beta1_.has_value()) {
|
| 550 |
-
auto beta1 = beta1_.value();
|
| 551 |
-
TORCH_CHECK(beta1.dtype() == wtype);
|
| 552 |
-
TORCH_CHECK(beta1.is_cuda());
|
| 553 |
-
TORCH_CHECK(beta1.is_contiguous());
|
| 554 |
-
TORCH_CHECK(beta1.sizes() == gamma0.sizes());
|
| 555 |
-
}
|
| 556 |
-
|
| 557 |
-
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
|
| 558 |
-
TORCH_CHECK(epsilon >= 0.f);
|
| 559 |
-
|
| 560 |
-
// Otherwise the kernel will be launched from cuda:0 device
|
| 561 |
-
// Cast to char to avoid compiler warning about narrowing
|
| 562 |
-
at::cuda::CUDAGuard device_guard{(char)x0.get_device()};
|
| 563 |
-
|
| 564 |
-
auto opts = x0.options();
|
| 565 |
-
|
| 566 |
-
bool save_x = residual_.has_value() || x1_.has_value() || (dropout_p > 0.f) || (itype != rtype);
|
| 567 |
-
at::Tensor x;
|
| 568 |
-
if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
|
| 569 |
-
at::Tensor dmask0, dmask1;
|
| 570 |
-
if (dropout_p > 0.f) {
|
| 571 |
-
dmask0 = torch::empty(x0.sizes(), opts.dtype(mtype));
|
| 572 |
-
if (x1_.has_value()) { dmask1 = torch::empty(x0.sizes(), opts.dtype(mtype)); }
|
| 573 |
-
};
|
| 574 |
-
auto z0 = torch::empty(sizes, opts.dtype(otype));
|
| 575 |
-
at::Tensor z1;
|
| 576 |
-
if (gamma1_.has_value()) { z1 = torch::empty(sizes, opts.dtype(otype)); }
|
| 577 |
-
|
| 578 |
-
auto mu = torch::empty({ rows }, opts.dtype(ctype));
|
| 579 |
-
auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
|
| 580 |
-
|
| 581 |
-
layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
|
| 582 |
-
|
| 583 |
-
launch_params.props = at::cuda::getCurrentDeviceProperties();
|
| 584 |
-
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
|
| 585 |
-
TORCH_CHECK(dropout_p < 1.f);
|
| 586 |
-
launch_params.params.dropout_keep_p = 1.f - dropout_p;
|
| 587 |
-
launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr;
|
| 588 |
-
|
| 589 |
-
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
| 590 |
-
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
| 591 |
-
|
| 592 |
-
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 593 |
-
const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
|
| 594 |
-
// Request the kernel launcher.
|
| 595 |
-
auto launcher = get_parallel_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
|
| 596 |
-
|
| 597 |
-
// Set the kernel runtime parameters.
|
| 598 |
-
layer_norm::FwdParams ¶ms = launch_params.params;
|
| 599 |
-
params.rows = rows;
|
| 600 |
-
params.cols = cols;
|
| 601 |
-
params.x0 = x0.data_ptr();
|
| 602 |
-
params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
|
| 603 |
-
params.x = save_x ? x.data_ptr() : nullptr;
|
| 604 |
-
params.dmask = dropout_p > 0.f ? dmask0.data_ptr() : nullptr;
|
| 605 |
-
params.dmask1 = (dropout_p > 0.f && x1_.has_value()) ? dmask1.data_ptr() : nullptr;
|
| 606 |
-
params.mu = mu.data_ptr();
|
| 607 |
-
params.rs = rsigma.data_ptr();
|
| 608 |
-
params.gamma = gamma0.data_ptr();
|
| 609 |
-
params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr;
|
| 610 |
-
params.beta = beta0_.has_value() ? beta0_.value().data_ptr() : nullptr;
|
| 611 |
-
params.beta1 = beta1_.has_value() ? beta1_.value().data_ptr() : nullptr;
|
| 612 |
-
params.z = z0.data_ptr();
|
| 613 |
-
params.z1 = gamma1_.has_value() ? z1.data_ptr() : nullptr;
|
| 614 |
-
params.epsilon = epsilon;
|
| 615 |
-
params.dropout_scale = 1.f / (1.f - dropout_p);
|
| 616 |
-
params.inverse_cols = 1.f / float(params.cols);
|
| 617 |
-
params.is_rms_norm = is_rms_norm;
|
| 618 |
-
|
| 619 |
-
// Query the kernel-specific launch parameters.
|
| 620 |
-
launcher(launch_params, true);
|
| 621 |
-
|
| 622 |
-
at::Tensor workspace, barrier;
|
| 623 |
-
|
| 624 |
-
if (dropout_p > 0.f) {
|
| 625 |
-
// number of times random will be generated per thread, to offset philox counter in thc random
|
| 626 |
-
// state
|
| 627 |
-
int64_t counter_offset = 2 * launch_params.elts_per_thread;
|
| 628 |
-
|
| 629 |
-
// See Note [Acquire lock when using random generators]
|
| 630 |
-
{
|
| 631 |
-
std::lock_guard<std::mutex> lock(gen->mutex_);
|
| 632 |
-
params.philox_args = gen->philox_cuda_state(counter_offset);
|
| 633 |
-
}
|
| 634 |
-
}
|
| 635 |
-
|
| 636 |
-
if( launch_params.barrier_size > 0 ) {
|
| 637 |
-
auto options = x0.options();
|
| 638 |
-
barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
|
| 639 |
-
workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
|
| 640 |
-
params.workspace = workspace.data_ptr();
|
| 641 |
-
params.barrier = barrier.data_ptr<int>();
|
| 642 |
-
}
|
| 643 |
-
|
| 644 |
-
// Launch the kernel.
|
| 645 |
-
launcher(launch_params, false);
|
| 646 |
-
|
| 647 |
-
return { z0, z1, x, dmask0, dmask1, mu, rsigma };
|
| 648 |
-
}
|
| 649 |
-
|
| 650 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 651 |
-
|
| 652 |
-
std::vector<at::Tensor> dropout_add_ln_parallel_residual_bwd(
|
| 653 |
-
const at::Tensor &dz0, // BxSxhidden_size
|
| 654 |
-
c10::optional<const at::Tensor> &dz1_, // BxSxhidden_size
|
| 655 |
-
c10::optional<const at::Tensor> &dx_, // BxSxhidden_size
|
| 656 |
-
const at::Tensor &x, // BxSxhidden_size
|
| 657 |
-
c10::optional<const at::Tensor> &dmask0_, // BxSxhidden_size
|
| 658 |
-
c10::optional<const at::Tensor> &dmask1_, // BxSxhidden_size
|
| 659 |
-
const at::Tensor &mu, // BxS, FP32!
|
| 660 |
-
const at::Tensor &rsigma, // BxS, FP32!
|
| 661 |
-
const at::Tensor &gamma0, // hidden_size
|
| 662 |
-
c10::optional<const at::Tensor> &gamma1_, // hidden_size
|
| 663 |
-
const float dropout_p,
|
| 664 |
-
const bool has_x1,
|
| 665 |
-
const bool has_residual,
|
| 666 |
-
bool is_rms_norm=false
|
| 667 |
-
) {
|
| 668 |
-
|
| 669 |
-
auto itype = dz0.scalar_type();
|
| 670 |
-
auto rtype = x.scalar_type();
|
| 671 |
-
auto wtype = gamma0.scalar_type();
|
| 672 |
-
auto otype = itype;
|
| 673 |
-
auto ctype = torch::kFloat32;
|
| 674 |
-
auto mtype = torch::kUInt8;
|
| 675 |
-
|
| 676 |
-
if (dropout_p > 0.f) { TORCH_CHECK(dmask0_.has_value()); }
|
| 677 |
-
|
| 678 |
-
TORCH_CHECK(dz0.dtype() == otype);
|
| 679 |
-
TORCH_CHECK(dz0.dtype() == otype);
|
| 680 |
-
TORCH_CHECK(mu.dtype() == ctype);
|
| 681 |
-
TORCH_CHECK(rsigma.dtype() == ctype);
|
| 682 |
-
|
| 683 |
-
TORCH_CHECK(x.is_cuda());
|
| 684 |
-
TORCH_CHECK(dz0.is_cuda());
|
| 685 |
-
TORCH_CHECK(mu.is_cuda());
|
| 686 |
-
TORCH_CHECK(rsigma.is_cuda());
|
| 687 |
-
TORCH_CHECK(gamma0.is_cuda());
|
| 688 |
-
|
| 689 |
-
TORCH_CHECK(x.is_contiguous());
|
| 690 |
-
TORCH_CHECK(dz0.is_contiguous());
|
| 691 |
-
|
| 692 |
-
auto sizes = x.sizes();
|
| 693 |
-
TORCH_CHECK(sizes.size() == 2);
|
| 694 |
-
auto rows = sizes[0];
|
| 695 |
-
auto cols = sizes[1];
|
| 696 |
-
TORCH_CHECK(dz0.dim() == 2);
|
| 697 |
-
TORCH_CHECK(dz0.size(1) == cols);
|
| 698 |
-
auto hidden_size = gamma0.numel();
|
| 699 |
-
TORCH_CHECK(hidden_size == cols);
|
| 700 |
-
|
| 701 |
-
if (dz1_.has_value()) {
|
| 702 |
-
auto dz1 = dz1_.value();
|
| 703 |
-
TORCH_CHECK(dz1.dtype() == otype);
|
| 704 |
-
TORCH_CHECK(dz1.is_cuda());
|
| 705 |
-
TORCH_CHECK(dz1.is_contiguous());
|
| 706 |
-
TORCH_CHECK(dz1.sizes() == sizes);
|
| 707 |
-
|
| 708 |
-
TORCH_CHECK(gamma1_.has_value());
|
| 709 |
-
auto gamma1 = gamma1_.value();
|
| 710 |
-
TORCH_CHECK(gamma1.dtype() == wtype);
|
| 711 |
-
TORCH_CHECK(gamma1.is_cuda());
|
| 712 |
-
TORCH_CHECK(gamma1.is_contiguous());
|
| 713 |
-
TORCH_CHECK(gamma1.sizes() == gamma0.sizes());
|
| 714 |
-
}
|
| 715 |
-
|
| 716 |
-
if (dx_.has_value()) {
|
| 717 |
-
auto dx = dx_.value();
|
| 718 |
-
TORCH_CHECK(dx.dtype() == rtype);
|
| 719 |
-
TORCH_CHECK(dx.is_cuda());
|
| 720 |
-
TORCH_CHECK(dx.is_contiguous());
|
| 721 |
-
TORCH_CHECK(dx.sizes() == sizes);
|
| 722 |
-
}
|
| 723 |
-
|
| 724 |
-
if (dmask0_.has_value()) {
|
| 725 |
-
auto dmask0 = dmask0_.value();
|
| 726 |
-
TORCH_CHECK(dmask0.dtype() == mtype);
|
| 727 |
-
TORCH_CHECK(dmask0.is_cuda());
|
| 728 |
-
TORCH_CHECK(dmask0.is_contiguous());
|
| 729 |
-
TORCH_CHECK(dmask0.sizes() == sizes);
|
| 730 |
-
|
| 731 |
-
if (has_x1) {
|
| 732 |
-
TORCH_CHECK(dmask1_.has_value());
|
| 733 |
-
auto dmask1 = dmask1_.value();
|
| 734 |
-
TORCH_CHECK(dmask1.dtype() == mtype);
|
| 735 |
-
TORCH_CHECK(dmask1.is_cuda());
|
| 736 |
-
TORCH_CHECK(dmask1.is_contiguous());
|
| 737 |
-
TORCH_CHECK(dmask1.sizes() == sizes);
|
| 738 |
-
}
|
| 739 |
-
}
|
| 740 |
-
|
| 741 |
-
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
|
| 742 |
-
|
| 743 |
-
TORCH_CHECK(mu.numel() == rows);
|
| 744 |
-
TORCH_CHECK(mu.sizes() == rsigma.sizes());
|
| 745 |
-
|
| 746 |
-
// Otherwise the kernel will be launched from cuda:0 device
|
| 747 |
-
// Cast to char to avoid compiler warning about narrowing
|
| 748 |
-
at::cuda::CUDAGuard device_guard{(char)dz0.get_device()};
|
| 749 |
-
|
| 750 |
-
auto opts = x.options();
|
| 751 |
-
|
| 752 |
-
auto dx0 = torch::empty(sizes, opts.dtype(itype));
|
| 753 |
-
at::Tensor dx1;
|
| 754 |
-
if (has_x1) { dx1 = torch::empty(sizes, opts.dtype(itype)); }
|
| 755 |
-
at::Tensor dresidual;
|
| 756 |
-
if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); }
|
| 757 |
-
auto dgamma0 = torch::empty_like(gamma0);
|
| 758 |
-
auto dbeta0 = torch::empty_like(gamma0);
|
| 759 |
-
at::Tensor dgamma1, dbeta1;
|
| 760 |
-
if (gamma1_.has_value()) {
|
| 761 |
-
dgamma1 = torch::empty_like(gamma0);
|
| 762 |
-
dbeta1 = torch::empty_like(gamma0);
|
| 763 |
-
}
|
| 764 |
-
|
| 765 |
-
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
|
| 766 |
-
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
|
| 767 |
-
launch_params.props = at::cuda::getCurrentDeviceProperties();
|
| 768 |
-
TORCH_CHECK(dropout_p < 1.f);
|
| 769 |
-
launch_params.params.dropout_keep_p = 1.f - dropout_p;
|
| 770 |
-
launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr;
|
| 771 |
-
|
| 772 |
-
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 773 |
-
const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
|
| 774 |
-
auto launcher = get_parallel_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
|
| 775 |
-
|
| 776 |
-
launcher(launch_params, true);
|
| 777 |
-
|
| 778 |
-
auto dgamma0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
|
| 779 |
-
auto dbeta0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
|
| 780 |
-
at::Tensor dgamma1_part, dbeta1_part;
|
| 781 |
-
if (gamma1_.has_value()) {
|
| 782 |
-
dgamma1_part = torch::zeros_like(dgamma0_part);
|
| 783 |
-
dbeta1_part = torch::zeros_like(dbeta0_part);
|
| 784 |
-
}
|
| 785 |
-
at::Tensor workspace, barrier;
|
| 786 |
-
|
| 787 |
-
layer_norm::BwdParams ¶ms = launch_params.params;
|
| 788 |
-
params.rows = rows;
|
| 789 |
-
params.cols = cols;
|
| 790 |
-
params.x = x.data_ptr();
|
| 791 |
-
params.dmask = dropout_p > 0.f ? dmask0_.value().data_ptr() : nullptr;
|
| 792 |
-
params.dmask1 = (dropout_p > 0.f && has_x1) ? dmask1_.value().data_ptr() : nullptr;
|
| 793 |
-
params.mu = mu.data_ptr();
|
| 794 |
-
params.rs = rsigma.data_ptr();
|
| 795 |
-
params.gamma = gamma0.data_ptr();
|
| 796 |
-
params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr;
|
| 797 |
-
params.dz = dz0.data_ptr();
|
| 798 |
-
params.dz1 = dz1_.has_value() ? dz1_.value().data_ptr() : nullptr;
|
| 799 |
-
params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr;
|
| 800 |
-
params.dx0 = dx0.data_ptr();
|
| 801 |
-
params.dx1 = has_x1 ? dx1.data_ptr() : nullptr;
|
| 802 |
-
params.dbeta = dbeta0.data_ptr();
|
| 803 |
-
params.dgamma = dgamma0.data_ptr();
|
| 804 |
-
params.dbeta1 = gamma1_.has_value() ? dbeta1.data_ptr() : nullptr;
|
| 805 |
-
params.dgamma1 = gamma1_.has_value() ? dgamma1.data_ptr() : nullptr;
|
| 806 |
-
params.dbeta_part = dbeta0_part.data_ptr();
|
| 807 |
-
params.dgamma_part = dgamma0_part.data_ptr();
|
| 808 |
-
params.dbeta1_part = gamma1_.has_value() ? dbeta1_part.data_ptr() : nullptr;
|
| 809 |
-
params.dgamma1_part = gamma1_.has_value() ? dgamma1_part.data_ptr() : nullptr;
|
| 810 |
-
params.dropout_scale = 1.f / (1.f - dropout_p);
|
| 811 |
-
params.inverse_cols = 1.f / float(params.cols);
|
| 812 |
-
params.is_rms_norm = is_rms_norm;
|
| 813 |
-
|
| 814 |
-
if( launch_params.barrier_size > 0 ) {
|
| 815 |
-
// TODO Any way to avoid this?
|
| 816 |
-
barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
|
| 817 |
-
workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
|
| 818 |
-
params.workspace = workspace.data_ptr();
|
| 819 |
-
params.barrier = barrier.data_ptr<int>();
|
| 820 |
-
}
|
| 821 |
-
|
| 822 |
-
launcher(launch_params, false);
|
| 823 |
-
|
| 824 |
-
std::vector<at::Tensor> result = { dx0, dx1, dresidual, dgamma0, dbeta0, dgamma1, dbeta1, dgamma0_part, dbeta0_part, dgamma1_part, dbeta1_part };
|
| 825 |
-
return result;
|
| 826 |
-
}
|
| 827 |
-
|
| 828 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_norm/ln_bwd_1024.cu
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
#include "ln_bwd_kernels.cuh"
|
| 2 |
-
|
| 3 |
-
// Create backward launch function and register. Macro signature:
|
| 4 |
-
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
-
|
| 6 |
-
REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 7 |
-
REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 8 |
-
REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 9 |
-
REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 10 |
-
REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 11 |
-
REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 12 |
-
REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 13 |
-
REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
| 14 |
-
REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 15 |
-
REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_norm/ln_bwd_1280.cu
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
#include "ln_bwd_kernels.cuh"
|
| 2 |
-
|
| 3 |
-
// Create backward launch function and register. Macro signature:
|
| 4 |
-
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
-
|
| 6 |
-
REGISTER_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 7 |
-
REGISTER_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 8 |
-
REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 9 |
-
REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 10 |
-
REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 11 |
-
REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 12 |
-
REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 13 |
-
REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
| 14 |
-
REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 15 |
-
REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_norm/ln_bwd_1536.cu
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
#include "ln_bwd_kernels.cuh"
|
| 2 |
-
|
| 3 |
-
// Create backward launch function and register. Macro signature:
|
| 4 |
-
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
-
|
| 6 |
-
REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 7 |
-
REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 8 |
-
REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
|
| 9 |
-
REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
|
| 10 |
-
REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
|
| 11 |
-
REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
|
| 12 |
-
REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
|
| 13 |
-
REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
|
| 14 |
-
REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
|
| 15 |
-
REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_norm/ln_bwd_2048.cu
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
#include "ln_bwd_kernels.cuh"
|
| 2 |
-
|
| 3 |
-
// Create backward launch function and register. Macro signature:
|
| 4 |
-
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
-
|
| 6 |
-
REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 7 |
-
REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 8 |
-
REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 9 |
-
REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 10 |
-
REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 11 |
-
REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 12 |
-
REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 13 |
-
REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
| 14 |
-
REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 15 |
-
REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_norm/ln_bwd_256.cu
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
#include "ln_bwd_kernels.cuh"
|
| 2 |
-
|
| 3 |
-
// Create backward launch function and register. Macro signature:
|
| 4 |
-
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
-
|
| 6 |
-
REGISTER_BWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 7 |
-
REGISTER_BWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 8 |
-
REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 9 |
-
REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 10 |
-
REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 11 |
-
REGISTER_BWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 12 |
-
REGISTER_BWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 13 |
-
REGISTER_BWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
| 14 |
-
REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 15 |
-
REGISTER_BWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_norm/ln_bwd_2560.cu
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
#include "ln_bwd_kernels.cuh"
|
| 2 |
-
|
| 3 |
-
// Create backward launch function and register. Macro signature:
|
| 4 |
-
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
-
|
| 6 |
-
REGISTER_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 7 |
-
REGISTER_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 8 |
-
REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
|
| 9 |
-
REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
|
| 10 |
-
REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
|
| 11 |
-
REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
|
| 12 |
-
REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
|
| 13 |
-
REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
|
| 14 |
-
REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
|
| 15 |
-
REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_norm/ln_bwd_3072.cu
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
#include "ln_bwd_kernels.cuh"
|
| 2 |
-
|
| 3 |
-
// Create backward launch function and register. Macro signature:
|
| 4 |
-
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
-
|
| 6 |
-
REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 7 |
-
REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 8 |
-
REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 9 |
-
REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 10 |
-
REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 11 |
-
REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 12 |
-
REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 13 |
-
REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
| 14 |
-
REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 15 |
-
REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_norm/ln_bwd_4096.cu
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
#include "ln_bwd_kernels.cuh"
|
| 2 |
-
|
| 3 |
-
// Create backward launch function and register. Macro signature:
|
| 4 |
-
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
-
|
| 6 |
-
REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 7 |
-
REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 8 |
-
REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 9 |
-
REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 10 |
-
REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 11 |
-
REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 12 |
-
REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 13 |
-
REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
| 14 |
-
REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 15 |
-
REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_norm/ln_bwd_512.cu
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
#include "ln_bwd_kernels.cuh"
|
| 2 |
-
|
| 3 |
-
// Create backward launch function and register. Macro signature:
|
| 4 |
-
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
-
|
| 6 |
-
REGISTER_BWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 7 |
-
REGISTER_BWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 8 |
-
REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 9 |
-
REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 10 |
-
REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 11 |
-
REGISTER_BWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 12 |
-
REGISTER_BWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 13 |
-
REGISTER_BWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
| 14 |
-
REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 15 |
-
REGISTER_BWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_norm/ln_bwd_5120.cu
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
#include "ln_bwd_kernels.cuh"
|
| 2 |
-
|
| 3 |
-
// Create backward launch function and register. Macro signature:
|
| 4 |
-
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
-
|
| 6 |
-
REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 7 |
-
REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 8 |
-
REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 9 |
-
REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 10 |
-
REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 11 |
-
REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 12 |
-
REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 13 |
-
REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
| 14 |
-
REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 15 |
-
REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_norm/ln_bwd_6144.cu
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
#include "ln_bwd_kernels.cuh"
|
| 2 |
-
|
| 3 |
-
// Create backward launch function and register. Macro signature:
|
| 4 |
-
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
-
|
| 6 |
-
REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 7 |
-
REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 8 |
-
REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
|
| 9 |
-
REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
|
| 10 |
-
REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
|
| 11 |
-
REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
|
| 12 |
-
REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
|
| 13 |
-
REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
|
| 14 |
-
REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
|
| 15 |
-
REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_norm/ln_bwd_7168.cu
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
#include "ln_bwd_kernels.cuh"
|
| 2 |
-
|
| 3 |
-
// Create backward launch function and register. Macro signature:
|
| 4 |
-
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
-
|
| 6 |
-
REGISTER_BWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 7 |
-
REGISTER_BWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 8 |
-
REGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4);
|
| 9 |
-
REGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4);
|
| 10 |
-
REGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4);
|
| 11 |
-
REGISTER_BWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4);
|
| 12 |
-
REGISTER_BWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4);
|
| 13 |
-
REGISTER_BWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4);
|
| 14 |
-
REGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4);
|
| 15 |
-
REGISTER_BWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_norm/ln_bwd_768.cu
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
#include "ln_bwd_kernels.cuh"
|
| 2 |
-
|
| 3 |
-
// Create backward launch function and register. Macro signature:
|
| 4 |
-
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
-
|
| 6 |
-
REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 7 |
-
REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 8 |
-
REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 9 |
-
REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 10 |
-
REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 11 |
-
REGISTER_BWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 12 |
-
REGISTER_BWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 13 |
-
REGISTER_BWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
| 14 |
-
REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 15 |
-
REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_norm/ln_bwd_8192.cu
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
#include "ln_bwd_kernels.cuh"
|
| 2 |
-
|
| 3 |
-
// Create backward launch function and register. Macro signature:
|
| 4 |
-
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
-
|
| 6 |
-
REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 7 |
-
REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 8 |
-
REGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
|
| 9 |
-
REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
|
| 10 |
-
REGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
|
| 11 |
-
REGISTER_BWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
|
| 12 |
-
REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
|
| 13 |
-
REGISTER_BWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
|
| 14 |
-
REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
|
| 15 |
-
REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_norm/ln_bwd_kernels.cuh
DELETED
|
@@ -1,534 +0,0 @@
|
|
| 1 |
-
#pragma once
|
| 2 |
-
|
| 3 |
-
#include "ln.h"
|
| 4 |
-
#include "ln_utils.cuh"
|
| 5 |
-
#include "ln_kernel_traits.h"
|
| 6 |
-
#include "static_switch.h"
|
| 7 |
-
|
| 8 |
-
namespace layer_norm {
|
| 9 |
-
|
| 10 |
-
template<typename Ktraits, bool Is_dropout, bool Has_colscale, bool Has_subset, bool Is_even_cols>
|
| 11 |
-
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
|
| 12 |
-
void ln_bwd_kernel(layer_norm::BwdParams params) {
|
| 13 |
-
|
| 14 |
-
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
|
| 15 |
-
enum { WARPS_M = Ktraits::WARPS_M };
|
| 16 |
-
enum { WARPS_N = Ktraits::WARPS_N };
|
| 17 |
-
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
|
| 18 |
-
enum { COLS = Ktraits::COLS };
|
| 19 |
-
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
|
| 20 |
-
enum { LDGS = Ktraits::LDGS };
|
| 21 |
-
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
|
| 22 |
-
enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
|
| 23 |
-
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
|
| 24 |
-
|
| 25 |
-
using input_t = typename Ktraits::input_t;
|
| 26 |
-
using compute_t = typename Ktraits::compute_t;
|
| 27 |
-
using index_t = typename Ktraits::index_t;
|
| 28 |
-
using mask_t = typename Ktraits::mask_t;
|
| 29 |
-
using Ivec = typename Ktraits::Ivec;
|
| 30 |
-
using Rvec = typename Ktraits::Rvec;
|
| 31 |
-
using Ovec = typename Ktraits::Ovec;
|
| 32 |
-
using Wvec = typename Ktraits::Wvec;
|
| 33 |
-
using Cvec = typename Ktraits::Cvec;
|
| 34 |
-
using Mvec = typename Ktraits::Mvec;
|
| 35 |
-
using Reducer = typename Ktraits::Reducer;
|
| 36 |
-
using reduce_t = typename Reducer::Type;
|
| 37 |
-
|
| 38 |
-
extern __shared__ char smem_[];
|
| 39 |
-
|
| 40 |
-
const bool has_residual = params.dresidual != nullptr;
|
| 41 |
-
const bool prenorm = params.dx != nullptr;
|
| 42 |
-
|
| 43 |
-
const index_t tidx = threadIdx.x;
|
| 44 |
-
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
|
| 45 |
-
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
|
| 46 |
-
const index_t lane = tidx % THREADS_PER_WARP;
|
| 47 |
-
const index_t warp = tidx / THREADS_PER_WARP;
|
| 48 |
-
const index_t warp_m = warp / Ktraits::WARPS_N;
|
| 49 |
-
const index_t warp_n = warp % Ktraits::WARPS_N;
|
| 50 |
-
const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
|
| 51 |
-
|
| 52 |
-
const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
|
| 53 |
-
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
|
| 54 |
-
|
| 55 |
-
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
|
| 56 |
-
|
| 57 |
-
const input_t *rowscale = static_cast<input_t *>(params.rowscale);
|
| 58 |
-
const index_t *x0_subset = static_cast<index_t *>(params.x0_subset);
|
| 59 |
-
const index_t *z_subset = static_cast<index_t *>(params.z_subset);
|
| 60 |
-
|
| 61 |
-
Cvec dzy_sum[LDGS];
|
| 62 |
-
Cvec dz_sum[LDGS];
|
| 63 |
-
Cvec dcolscale_sum[LDGS];
|
| 64 |
-
|
| 65 |
-
memset(dzy_sum, 0, sizeof(dzy_sum));
|
| 66 |
-
memset(dz_sum, 0, sizeof(dz_sum));
|
| 67 |
-
if (Has_colscale) { memset(dcolscale_sum, 0, sizeof(dcolscale_sum)); }
|
| 68 |
-
|
| 69 |
-
compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
|
| 70 |
-
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
|
| 71 |
-
|
| 72 |
-
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
|
| 73 |
-
|
| 74 |
-
Sum<reduce_t> sum;
|
| 75 |
-
|
| 76 |
-
const index_t num_valid_ldgs =
|
| 77 |
-
((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG;
|
| 78 |
-
|
| 79 |
-
Wvec gamma[LDGS];
|
| 80 |
-
Wvec colscale[LDGS];
|
| 81 |
-
index_t idx = c;
|
| 82 |
-
#pragma unroll
|
| 83 |
-
for( int it = 0; it < LDGS; it++ ) {
|
| 84 |
-
if (Is_even_cols || (it < num_valid_ldgs)) {
|
| 85 |
-
gamma[it].load_from(params.gamma, idx);
|
| 86 |
-
if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
|
| 87 |
-
idx += Ktraits::VEC_COLS_PER_LDG;
|
| 88 |
-
}
|
| 89 |
-
}
|
| 90 |
-
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
|
| 91 |
-
// last blocks with syncthreads!
|
| 92 |
-
// grid stride over rows
|
| 93 |
-
#pragma unroll 1
|
| 94 |
-
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
|
| 95 |
-
const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
|
| 96 |
-
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
|
| 97 |
-
const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const;
|
| 98 |
-
const int row_z = !Has_subset ? row + 1 : z_subset[row];
|
| 99 |
-
const int row_x0 = !Has_subset ? row + 1 : x0_subset[row];
|
| 100 |
-
const bool load_dz = !Has_subset || row_z > 0;
|
| 101 |
-
const bool save_dx0 = !Has_subset || row_x0 > 0;
|
| 102 |
-
Mvec dmask[LDGS];
|
| 103 |
-
Rvec dx[LDGS];
|
| 104 |
-
compute_t dy[LDGS * NUM_ELTS];
|
| 105 |
-
compute_t y[LDGS * NUM_ELTS];
|
| 106 |
-
compute_t mdy_local = 0.f;
|
| 107 |
-
compute_t mdyy_local = 0.f;
|
| 108 |
-
// If dz is not loaded, then dy should be 0 and we don't care about the value of y.
|
| 109 |
-
if (load_dz) {
|
| 110 |
-
index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
|
| 111 |
-
index_t idx_z = !Has_subset ? idx_x : (load_dz ? (row_z - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
|
| 112 |
-
index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
|
| 113 |
-
#pragma unroll
|
| 114 |
-
for( int it = 0; it < LDGS; it++ ) {
|
| 115 |
-
if (Is_even_cols || (it < num_valid_ldgs)) {
|
| 116 |
-
Rvec x;
|
| 117 |
-
Ovec dz;
|
| 118 |
-
dz.load_from(params.dz, !Has_subset ? idx_x : idx_z);
|
| 119 |
-
if (prenorm) { dx[it].load_from(params.dx, idx_x); }
|
| 120 |
-
x.load_from(params.x, idx_x);
|
| 121 |
-
if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); }
|
| 122 |
-
idx_x += Ktraits::VEC_COLS_PER_LDG;
|
| 123 |
-
idx_z += Ktraits::VEC_COLS_PER_LDG;
|
| 124 |
-
idx_x0 += Ktraits::VEC_COLS_PER_LDG;
|
| 125 |
-
#pragma unroll
|
| 126 |
-
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
|
| 127 |
-
compute_t x_tmp = x.data.elt[jt];
|
| 128 |
-
compute_t y_tmp = rs_r * (x_tmp - (!params.is_rms_norm ? mu_r : 0.f));
|
| 129 |
-
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) * compute_t(dz.data.elt[jt]);
|
| 130 |
-
compute_t dz_tmp = dz.data.elt[jt];
|
| 131 |
-
|
| 132 |
-
mdy_local += dy_tmp;
|
| 133 |
-
mdyy_local += dy_tmp * y_tmp;
|
| 134 |
-
|
| 135 |
-
dy[it * NUM_ELTS + jt] = dy_tmp;
|
| 136 |
-
y[it * NUM_ELTS + jt] = y_tmp;
|
| 137 |
-
|
| 138 |
-
dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
|
| 139 |
-
dz_sum[it].data.elt[jt] += dz_tmp;
|
| 140 |
-
}
|
| 141 |
-
}
|
| 142 |
-
}
|
| 143 |
-
} else {
|
| 144 |
-
index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
|
| 145 |
-
index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
|
| 146 |
-
#pragma unroll
|
| 147 |
-
for( int it = 0; it < LDGS; it++ ) {
|
| 148 |
-
if (Is_even_cols || (it < num_valid_ldgs)) {
|
| 149 |
-
if (prenorm) { dx[it].load_from(params.dx, idx_x); }
|
| 150 |
-
if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); }
|
| 151 |
-
idx_x += Ktraits::VEC_COLS_PER_LDG;
|
| 152 |
-
idx_x0 += Ktraits::VEC_COLS_PER_LDG;
|
| 153 |
-
}
|
| 154 |
-
}
|
| 155 |
-
}
|
| 156 |
-
|
| 157 |
-
reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
|
| 158 |
-
mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * params.inverse_cols;
|
| 159 |
-
mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * params.inverse_cols;
|
| 160 |
-
|
| 161 |
-
index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
|
| 162 |
-
index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
|
| 163 |
-
#pragma unroll
|
| 164 |
-
for( int it = 0; it < LDGS; it++ ) {
|
| 165 |
-
if (Is_even_cols || (it < num_valid_ldgs)) {
|
| 166 |
-
Ivec dx0;
|
| 167 |
-
Rvec dresidual;
|
| 168 |
-
Ivec x0;
|
| 169 |
-
if (Has_colscale && save_dx0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
|
| 170 |
-
#pragma unroll
|
| 171 |
-
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
|
| 172 |
-
compute_t dx_tmp_res;
|
| 173 |
-
if (load_dz) {
|
| 174 |
-
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
|
| 175 |
-
compute_t y_tmp = y[it * NUM_ELTS + jt];
|
| 176 |
-
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + (!params.is_rms_norm ? mdy_local : 0.f)));
|
| 177 |
-
dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;
|
| 178 |
-
} else {
|
| 179 |
-
dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f;
|
| 180 |
-
}
|
| 181 |
-
if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; }
|
| 182 |
-
if (save_dx0) {
|
| 183 |
-
compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
|
| 184 |
-
if (Is_dropout) {
|
| 185 |
-
dx0_tmp_res *= params.dropout_scale;
|
| 186 |
-
if (Has_colscale) {
|
| 187 |
-
dcolscale_sum[it].data.elt[jt] += dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(x0.data.elt[jt]) : 0.f;
|
| 188 |
-
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(colscale[it].data.elt[jt]) : 0.f;
|
| 189 |
-
} else {
|
| 190 |
-
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res : 0.f;
|
| 191 |
-
}
|
| 192 |
-
} else {
|
| 193 |
-
if (Has_colscale) {
|
| 194 |
-
dcolscale_sum[it].data.elt[jt] += dx0_tmp_res * compute_t(x0.data.elt[jt]);
|
| 195 |
-
dx0.data.elt[jt] = dx0_tmp_res * compute_t(colscale[it].data.elt[jt]);
|
| 196 |
-
} else {
|
| 197 |
-
dx0.data.elt[jt] = dx0_tmp_res;
|
| 198 |
-
}
|
| 199 |
-
}
|
| 200 |
-
}
|
| 201 |
-
}
|
| 202 |
-
if (has_residual) { dresidual.store_to(params.dresidual, idx_x); }
|
| 203 |
-
if (save_dx0) { dx0.store_to(params.dx0, !Has_subset ? idx_x : idx_x0); }
|
| 204 |
-
idx_x += Ktraits::VEC_COLS_PER_LDG;
|
| 205 |
-
idx_x0 += Ktraits::VEC_COLS_PER_LDG;
|
| 206 |
-
}
|
| 207 |
-
}
|
| 208 |
-
|
| 209 |
-
} // end: grid stride loop
|
| 210 |
-
|
| 211 |
-
if( WARPS_M == 1 ) {
|
| 212 |
-
idx = r * params.cols / Ktraits::ELTS_PER_LDG + c;
|
| 213 |
-
#pragma unroll
|
| 214 |
-
for( int it = 0; it < LDGS; it++ ) {
|
| 215 |
-
if (Is_even_cols || (it < num_valid_ldgs)) {
|
| 216 |
-
dz_sum[it].store_to(params.dbeta_part, idx);
|
| 217 |
-
dzy_sum[it].store_to(params.dgamma_part, idx);
|
| 218 |
-
if (Has_colscale) { dcolscale_sum[it].store_to(params.dcolscale_part, idx); }
|
| 219 |
-
idx += Ktraits::VEC_COLS_PER_LDG;
|
| 220 |
-
}
|
| 221 |
-
}
|
| 222 |
-
} else {
|
| 223 |
-
static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA.");
|
| 224 |
-
// Finalize reduction of part dgamma and dbeta for this CTA
|
| 225 |
-
// by reducing over the rows held across the WARPS_M warps
|
| 226 |
-
|
| 227 |
-
// Assumption: blockSize divides hidden size.
|
| 228 |
-
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
|
| 229 |
-
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
|
| 230 |
-
|
| 231 |
-
idx = warp_m * Ktraits::VEC_COLS + tid_r;
|
| 232 |
-
#pragma unroll
|
| 233 |
-
for( int it = 0; it < LDGS; it++ ) {
|
| 234 |
-
dz_sum[it].store_to(smem_wgrad, idx);
|
| 235 |
-
idx += THREADS_PER_ROW;
|
| 236 |
-
}
|
| 237 |
-
__syncthreads();
|
| 238 |
-
compute_t cta_dz_sum[NUM_RES];
|
| 239 |
-
memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES);
|
| 240 |
-
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
|
| 241 |
-
for( int jt = 0; jt < NUM_RES; jt++ ) {
|
| 242 |
-
cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
|
| 243 |
-
}
|
| 244 |
-
}
|
| 245 |
-
__syncthreads();
|
| 246 |
-
|
| 247 |
-
idx = warp_m * Ktraits::VEC_COLS + tid_r;
|
| 248 |
-
#pragma unroll
|
| 249 |
-
for( int it = 0; it < LDGS; it++ ) {
|
| 250 |
-
dzy_sum[it].store_to(smem_wgrad, idx);
|
| 251 |
-
idx += THREADS_PER_ROW;
|
| 252 |
-
}
|
| 253 |
-
__syncthreads();
|
| 254 |
-
compute_t cta_dzy_sum[NUM_RES];
|
| 255 |
-
memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
|
| 256 |
-
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
|
| 257 |
-
for( int jt = 0; jt < NUM_RES; jt++ ) {
|
| 258 |
-
cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
|
| 259 |
-
}
|
| 260 |
-
}
|
| 261 |
-
|
| 262 |
-
compute_t cta_dcolscale_sum[NUM_RES];
|
| 263 |
-
if (Has_colscale) {
|
| 264 |
-
__syncthreads();
|
| 265 |
-
idx = warp_m * Ktraits::VEC_COLS + tid_r;
|
| 266 |
-
#pragma unroll
|
| 267 |
-
for( int it = 0; it < LDGS; it++ ) {
|
| 268 |
-
dcolscale_sum[it].store_to(smem_wgrad, idx);
|
| 269 |
-
idx += THREADS_PER_ROW;
|
| 270 |
-
}
|
| 271 |
-
__syncthreads();
|
| 272 |
-
memset(cta_dcolscale_sum, 0, sizeof(compute_t) * NUM_RES);
|
| 273 |
-
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
|
| 274 |
-
for( int jt = 0; jt < NUM_RES; jt++ ) {
|
| 275 |
-
cta_dcolscale_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
|
| 276 |
-
}
|
| 277 |
-
}
|
| 278 |
-
}
|
| 279 |
-
|
| 280 |
-
const index_t num_valid_writes
|
| 281 |
-
= (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA;
|
| 282 |
-
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * params.cols + tidx;
|
| 283 |
-
compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * params.cols + tidx;
|
| 284 |
-
compute_t *dcolscale_part = Has_colscale ? static_cast<compute_t *>(params.dcolscale_part) + bidm * params.cols + tidx : nullptr;
|
| 285 |
-
for( int jt = 0; jt < NUM_RES; jt++ ) {
|
| 286 |
-
if (Is_even_cols || (jt < num_valid_writes)) {
|
| 287 |
-
*dgamma_part = cta_dzy_sum[jt];
|
| 288 |
-
dgamma_part += Ktraits::THREADS_PER_CTA;
|
| 289 |
-
*dbeta_part = cta_dz_sum[jt];
|
| 290 |
-
dbeta_part += Ktraits::THREADS_PER_CTA;
|
| 291 |
-
if (Has_colscale) {
|
| 292 |
-
*dcolscale_part = cta_dcolscale_sum[jt];
|
| 293 |
-
dcolscale_part += Ktraits::THREADS_PER_CTA;
|
| 294 |
-
}
|
| 295 |
-
}
|
| 296 |
-
}
|
| 297 |
-
|
| 298 |
-
}
|
| 299 |
-
}
|
| 300 |
-
|
| 301 |
-
template<typename Kernel_traits, bool Has_colscale, bool Is_even_cols>
|
| 302 |
-
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
|
| 303 |
-
void ln_bwd_finalize_kernel(BwdParams params)
|
| 304 |
-
{
|
| 305 |
-
|
| 306 |
-
using compute_t = typename Kernel_traits::compute_t;
|
| 307 |
-
using weight_t = typename Kernel_traits::weight_t;
|
| 308 |
-
using index_t = typename Kernel_traits::index_t;
|
| 309 |
-
using Reducer = typename Kernel_traits::Reducer;
|
| 310 |
-
using reduce_t = typename Reducer::Type;
|
| 311 |
-
|
| 312 |
-
Sum<reduce_t> sum;
|
| 313 |
-
enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
|
| 314 |
-
enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
|
| 315 |
-
|
| 316 |
-
__shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
|
| 317 |
-
|
| 318 |
-
constexpr uint32_t bidm = 0;
|
| 319 |
-
|
| 320 |
-
const uint32_t bidn = blockIdx.x;
|
| 321 |
-
const uint32_t tidx = threadIdx.x;
|
| 322 |
-
const uint32_t warp = tidx / THREADS_PER_WARP;
|
| 323 |
-
const uint32_t lane = tidx % THREADS_PER_WARP;
|
| 324 |
-
|
| 325 |
-
Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
|
| 326 |
-
|
| 327 |
-
const uint32_t c = bidn * THREADS_PER_WARP + lane;
|
| 328 |
-
const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
|
| 329 |
-
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
|
| 330 |
-
for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
|
| 331 |
-
// Each thread sums over NUM_ELT columns.
|
| 332 |
-
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local, dcolscale_local;
|
| 333 |
-
memset(&dgamma_local, 0, sizeof(dgamma_local));
|
| 334 |
-
memset(&dbeta_local, 0, sizeof(dbeta_local));
|
| 335 |
-
if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
|
| 336 |
-
if (Is_even_cols || col < params.cols) {
|
| 337 |
-
for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
|
| 338 |
-
index_t idx = row * params.cols + col;
|
| 339 |
-
|
| 340 |
-
Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part, dcolscale_part;
|
| 341 |
-
dbeta_part.load_from(params.dbeta_part, idx);
|
| 342 |
-
dgamma_part.load_from(params.dgamma_part, idx);
|
| 343 |
-
if (Has_colscale) { dcolscale_part.load_from(params.dcolscale_part, idx); }
|
| 344 |
-
#pragma unroll
|
| 345 |
-
for( int it = 0; it < NUM_ELT; it++ ) {
|
| 346 |
-
dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
|
| 347 |
-
dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
|
| 348 |
-
if (Has_colscale) { dcolscale_local.data.elt[it] += dcolscale_part.data.elt[it]; }
|
| 349 |
-
}
|
| 350 |
-
}
|
| 351 |
-
}
|
| 352 |
-
void * smem_gamma = smem_;
|
| 353 |
-
void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
|
| 354 |
-
void * smem_colscale = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
|
| 355 |
-
|
| 356 |
-
const int write_row = warp;
|
| 357 |
-
const int write_col = lane ^ write_row;
|
| 358 |
-
const int write_idx = write_row * THREADS_PER_WARP + write_col;
|
| 359 |
-
|
| 360 |
-
dgamma_local.store_to(smem_gamma, write_idx);
|
| 361 |
-
dbeta_local.store_to(smem_beta, write_idx);
|
| 362 |
-
if (Has_colscale) { dcolscale_local.store_to(smem_colscale, write_idx); }
|
| 363 |
-
|
| 364 |
-
__syncthreads();
|
| 365 |
-
|
| 366 |
-
// It would be probably safe to reuse the first row of smem_beta and smem_gamma
|
| 367 |
-
void * smem_gamma_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE];
|
| 368 |
-
void * smem_beta_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
|
| 369 |
-
void * smem_colscale_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT];
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
// More than one iter iff ROWS_PER_CTA < 32.
|
| 373 |
-
for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) {
|
| 374 |
-
const int read_row = lane;
|
| 375 |
-
const int read_col = w ^ read_row;
|
| 376 |
-
const int read_idx = read_row * THREADS_PER_WARP + read_col;
|
| 377 |
-
|
| 378 |
-
memset(&dbeta_local, 0, sizeof(dbeta_local));
|
| 379 |
-
memset(&dgamma_local, 0, sizeof(dgamma_local));
|
| 380 |
-
if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
|
| 381 |
-
|
| 382 |
-
// Load beta and gamma transposed
|
| 383 |
-
if(read_row < Kernel_traits::ROWS_PER_CTA){
|
| 384 |
-
dbeta_local.load_from(smem_beta, read_idx);
|
| 385 |
-
dgamma_local.load_from(smem_gamma, read_idx);
|
| 386 |
-
if (Has_colscale) { dcolscale_local.load_from(smem_colscale, read_idx); }
|
| 387 |
-
}
|
| 388 |
-
|
| 389 |
-
// Call reducer on the loaded value(s) and convert.
|
| 390 |
-
#pragma unroll
|
| 391 |
-
for( int it = 0; it < NUM_ELT; it++ ) {
|
| 392 |
-
compute_t b_i = dbeta_local.data.elt[it];
|
| 393 |
-
compute_t g_i = dgamma_local.data.elt[it];
|
| 394 |
-
b_i = reducer.allreduce(b_i, sum);
|
| 395 |
-
g_i = reducer.allreduce(g_i, sum);
|
| 396 |
-
|
| 397 |
-
dgamma_local.data.elt[it] = g_i;
|
| 398 |
-
dbeta_local.data.elt[it] = b_i;
|
| 399 |
-
if (Has_colscale) {
|
| 400 |
-
compute_t cs_i = dcolscale_local.data.elt[it];
|
| 401 |
-
cs_i = reducer.allreduce(cs_i, sum);
|
| 402 |
-
dcolscale_local.data.elt[it] = cs_i;
|
| 403 |
-
}
|
| 404 |
-
}
|
| 405 |
-
|
| 406 |
-
// Leader stores the result at the current column.
|
| 407 |
-
if(lane == 0){
|
| 408 |
-
dgamma_local.store_to(smem_gamma_out, w);
|
| 409 |
-
dbeta_local.store_to(smem_beta_out, w);
|
| 410 |
-
if (Has_colscale) { dcolscale_local.store_to(smem_colscale_out, w); }
|
| 411 |
-
}
|
| 412 |
-
|
| 413 |
-
}
|
| 414 |
-
|
| 415 |
-
// All writes done.
|
| 416 |
-
__syncthreads();
|
| 417 |
-
|
| 418 |
-
// Pack and store: 2-wide stores with half the threads.
|
| 419 |
-
if (Is_even_cols || col_out * 2 < params.cols) {
|
| 420 |
-
if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
|
| 421 |
-
|
| 422 |
-
using src_t = typename TypeToVec2<compute_t>::Type;
|
| 423 |
-
using dst_t = typename TypeToVec2<weight_t>::Type;
|
| 424 |
-
Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2, dcolscale_vec2;
|
| 425 |
-
Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2, dcolscale_out2;
|
| 426 |
-
|
| 427 |
-
dgamma_vec2.load_from(smem_gamma_out, lane);
|
| 428 |
-
dbeta_vec2.load_from(smem_beta_out, lane);
|
| 429 |
-
if (Has_colscale) { dcolscale_vec2.load_from(smem_colscale_out, lane); }
|
| 430 |
-
#pragma unroll
|
| 431 |
-
for( int it = 0; it < NUM_ELT; it++ ) {
|
| 432 |
-
dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);
|
| 433 |
-
dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);
|
| 434 |
-
if (Has_colscale) { dcolscale_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dcolscale_vec2.data.elt[it]); }
|
| 435 |
-
}
|
| 436 |
-
dgamma_out2.store_to(params.dgamma, col_out);
|
| 437 |
-
dbeta_out2.store_to(params.dbeta, col_out);
|
| 438 |
-
if (Has_colscale) { dcolscale_out2.store_to(params.dcolscale, col_out); }
|
| 439 |
-
}
|
| 440 |
-
}
|
| 441 |
-
}
|
| 442 |
-
}
|
| 443 |
-
} // namespace layer_norm
|
| 444 |
-
|
| 445 |
-
using namespace layer_norm;
|
| 446 |
-
|
| 447 |
-
template<
|
| 448 |
-
typename weight_t,
|
| 449 |
-
typename input_t,
|
| 450 |
-
typename residual_t,
|
| 451 |
-
typename output_t,
|
| 452 |
-
typename compute_t,
|
| 453 |
-
typename index_t,
|
| 454 |
-
int HIDDEN_SIZE,
|
| 455 |
-
int CTAS_PER_ROW,
|
| 456 |
-
int WARPS_M,
|
| 457 |
-
int WARPS_N,
|
| 458 |
-
int BYTES_PER_LDG_MAIN,
|
| 459 |
-
int BYTES_PER_LDG_FINAL
|
| 460 |
-
>
|
| 461 |
-
void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params){
|
| 462 |
-
|
| 463 |
-
using Kernel_traits = Kernel_traits<weight_t,
|
| 464 |
-
input_t,
|
| 465 |
-
residual_t,
|
| 466 |
-
output_t,
|
| 467 |
-
compute_t,
|
| 468 |
-
index_t,
|
| 469 |
-
HIDDEN_SIZE,
|
| 470 |
-
CTAS_PER_ROW,
|
| 471 |
-
WARPS_M,
|
| 472 |
-
WARPS_N,
|
| 473 |
-
BYTES_PER_LDG_MAIN
|
| 474 |
-
>;
|
| 475 |
-
bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
|
| 476 |
-
bool has_colscale = launch_params.params.colscale != nullptr;
|
| 477 |
-
bool has_subset = launch_params.params.x0_subset != nullptr;
|
| 478 |
-
bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
|
| 479 |
-
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
|
| 480 |
-
BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
|
| 481 |
-
BOOL_SWITCH(has_subset, HasSubsetConst, [&] {
|
| 482 |
-
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
|
| 483 |
-
auto kernel = &ln_bwd_kernel<Kernel_traits, IsDropoutConst, HasColscaleConst, HasSubsetConst, IsEvenColsConst>;
|
| 484 |
-
if( configure_params ) {
|
| 485 |
-
int ctas_per_sm;
|
| 486 |
-
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
| 487 |
-
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
|
| 488 |
-
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
|
| 489 |
-
launch_params.barrier_size = 0;
|
| 490 |
-
launch_params.workspace_bytes = 0;
|
| 491 |
-
if(Kernel_traits::CTAS_PER_ROW > 1) {
|
| 492 |
-
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
|
| 493 |
-
launch_params.workspace_bytes = launch_params.params.ctas_per_col
|
| 494 |
-
* Kernel_traits::WARPS_M
|
| 495 |
-
* Kernel_traits::CTAS_PER_ROW
|
| 496 |
-
* sizeof(typename Kernel_traits::reduce_t)
|
| 497 |
-
* 2;
|
| 498 |
-
}
|
| 499 |
-
return;
|
| 500 |
-
}
|
| 501 |
-
|
| 502 |
-
if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
|
| 503 |
-
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
|
| 504 |
-
}
|
| 505 |
-
auto stream = launch_params.stream;
|
| 506 |
-
auto ctas_per_col = launch_params.params.ctas_per_col;
|
| 507 |
-
|
| 508 |
-
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
|
| 509 |
-
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
|
| 510 |
-
} else {
|
| 511 |
-
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
|
| 512 |
-
dim3 block(Kernel_traits::THREADS_PER_CTA);
|
| 513 |
-
void *params_ = (void *)&launch_params.params;
|
| 514 |
-
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream);
|
| 515 |
-
}
|
| 516 |
-
|
| 517 |
-
using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
|
| 518 |
-
weight_t,
|
| 519 |
-
input_t,
|
| 520 |
-
residual_t,
|
| 521 |
-
output_t,
|
| 522 |
-
compute_t,
|
| 523 |
-
index_t,
|
| 524 |
-
HasColscaleConst,
|
| 525 |
-
32 * 32, // THREADS_PER_CTA
|
| 526 |
-
BYTES_PER_LDG_FINAL>;
|
| 527 |
-
|
| 528 |
-
auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, HasColscaleConst, IsEvenColsConst>;
|
| 529 |
-
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
|
| 530 |
-
});
|
| 531 |
-
});
|
| 532 |
-
});
|
| 533 |
-
});
|
| 534 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_norm/ln_fwd_1024.cu
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
#include "ln_fwd_kernels.cuh"
|
| 2 |
-
|
| 3 |
-
// Create forward launch function and register. Macro signature:
|
| 4 |
-
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
|
| 5 |
-
|
| 6 |
-
REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
| 7 |
-
REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
| 8 |
-
REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
| 9 |
-
REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
| 10 |
-
REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
| 11 |
-
REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
| 12 |
-
REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
| 13 |
-
REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
| 14 |
-
REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
| 15 |
-
REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|