medmekk HF Staff commited on
Commit
854c683
·
1 Parent(s): b9597c9

Torch 2.9 builds

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +1 -0
  2. CMakeLists.txt +0 -213
  3. README.md +1 -20
  4. api.py +0 -800
  5. build/{torch28-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so → torch27-cxx11-cu118-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so} +2 -2
  6. build/torch27-cxx11-cu118-x86_64-linux/layer_norm/_ops.py +3 -3
  7. build/{torch27-cxx11-cu118-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so} +2 -2
  8. build/torch27-cxx11-cu126-x86_64-linux/layer_norm/_ops.py +3 -3
  9. build/torch27-cxx11-cu128-x86_64-linux/layer_norm/{_layer_norm_f622ea1_dirty.abi3.so → _layer_norm_f3fd6bf.abi3.so} +2 -2
  10. build/torch27-cxx11-cu128-x86_64-linux/layer_norm/_ops.py +3 -3
  11. build/{torch27-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so → torch28-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so} +2 -2
  12. build/torch28-cxx11-cu126-x86_64-linux/layer_norm/_ops.py +3 -3
  13. build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so +3 -0
  14. build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so +0 -3
  15. build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_ops.py +3 -3
  16. build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so +3 -0
  17. build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so +0 -3
  18. build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_ops.py +3 -3
  19. {torch-ext → build/torch29-cxx11-cu126-x86_64-linux}/layer_norm/__init__.py +0 -0
  20. build/torch29-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so +3 -0
  21. build/torch29-cxx11-cu126-x86_64-linux/layer_norm/_ops.py +9 -0
  22. {torch-ext → build/torch29-cxx11-cu126-x86_64-linux}/layer_norm/layers.py +0 -0
  23. build/torch29-cxx11-cu128-x86_64-linux/layer_norm/__init__.py +26 -0
  24. build/torch29-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so +3 -0
  25. build/torch29-cxx11-cu128-x86_64-linux/layer_norm/_ops.py +9 -0
  26. build/torch29-cxx11-cu128-x86_64-linux/layer_norm/layers.py +51 -0
  27. build/torch29-cxx11-cu130-x86_64-linux/layer_norm/__init__.py +26 -0
  28. build/torch29-cxx11-cu130-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so +3 -0
  29. build/torch29-cxx11-cu130-x86_64-linux/layer_norm/_ops.py +9 -0
  30. build/torch29-cxx11-cu130-x86_64-linux/layer_norm/layers.py +51 -0
  31. cmake/hipify.py +0 -76
  32. cmake/utils.cmake +0 -545
  33. layer_norm/ln.h +0 -281
  34. layer_norm/ln_api.cpp +0 -828
  35. layer_norm/ln_bwd_1024.cu +0 -15
  36. layer_norm/ln_bwd_1280.cu +0 -15
  37. layer_norm/ln_bwd_1536.cu +0 -15
  38. layer_norm/ln_bwd_2048.cu +0 -15
  39. layer_norm/ln_bwd_256.cu +0 -15
  40. layer_norm/ln_bwd_2560.cu +0 -15
  41. layer_norm/ln_bwd_3072.cu +0 -15
  42. layer_norm/ln_bwd_4096.cu +0 -15
  43. layer_norm/ln_bwd_512.cu +0 -15
  44. layer_norm/ln_bwd_5120.cu +0 -15
  45. layer_norm/ln_bwd_6144.cu +0 -15
  46. layer_norm/ln_bwd_7168.cu +0 -15
  47. layer_norm/ln_bwd_768.cu +0 -15
  48. layer_norm/ln_bwd_8192.cu +0 -15
  49. layer_norm/ln_bwd_kernels.cuh +0 -534
  50. layer_norm/ln_fwd_1024.cu +0 -15
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  __pycache__/
 
2
  *.pyc
 
1
  __pycache__/
2
+ **/__pycache__/
3
  *.pyc
CMakeLists.txt DELETED
@@ -1,213 +0,0 @@
1
- cmake_minimum_required(VERSION 3.26)
2
- project(layer_norm LANGUAGES CXX)
3
-
4
- set(TARGET_DEVICE "cuda" CACHE STRING "Target device backend for kernel")
5
-
6
- install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
7
-
8
- include(FetchContent)
9
- file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
10
- message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
11
-
12
- set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
13
-
14
- set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101")
15
-
16
- include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
17
-
18
- if(DEFINED Python_EXECUTABLE)
19
- # Allow passing through the interpreter (e.g. from setup.py).
20
- find_package(Python COMPONENTS Development Development.SABIModule Interpreter)
21
- if (NOT Python_FOUND)
22
- message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
23
- endif()
24
- else()
25
- find_package(Python REQUIRED COMPONENTS Development Development.SABIModule Interpreter)
26
- endif()
27
-
28
- append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
29
-
30
- find_package(Torch REQUIRED)
31
-
32
- if (NOT TARGET_DEVICE STREQUAL "cuda" AND
33
- NOT TARGET_DEVICE STREQUAL "rocm")
34
- return()
35
- endif()
36
-
37
- if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
38
- CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
39
- set(CUDA_DEFAULT_KERNEL_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0+PTX")
40
- else()
41
- set(CUDA_DEFAULT_KERNEL_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0+PTX")
42
- endif()
43
-
44
- if (NOT HIP_FOUND AND CUDA_FOUND)
45
- set(GPU_LANG "CUDA")
46
-
47
-
48
-
49
- elseif(HIP_FOUND)
50
- set(GPU_LANG "HIP")
51
-
52
- # Importing torch recognizes and sets up some HIP/ROCm configuration but does
53
- # not let cmake recognize .hip files. In order to get cmake to understand the
54
- # .hip extension automatically, HIP must be enabled explicitly.
55
- enable_language(HIP)
56
- else()
57
- message(FATAL_ERROR "Can't find CUDA or HIP installation.")
58
- endif()
59
-
60
- if(GPU_LANG STREQUAL "CUDA")
61
- clear_cuda_arches(CUDA_ARCH_FLAGS)
62
- extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}")
63
- message(STATUS "CUDA target architectures: ${CUDA_ARCHS}")
64
- # Filter the target architectures by the supported supported archs
65
- # since for some files we will build for all CUDA_ARCHS.
66
- cuda_archs_loose_intersection(CUDA_ARCHS "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}")
67
- message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}")
68
-
69
- if(NVCC_THREADS AND GPU_LANG STREQUAL "CUDA")
70
- list(APPEND GPU_FLAGS "--threads=${NVCC_THREADS}")
71
- endif()
72
-
73
- add_compile_definitions(CUDA_KERNEL)
74
- elseif(GPU_LANG STREQUAL "HIP")
75
- set(ROCM_ARCHS "${HIP_SUPPORTED_ARCHS}")
76
- # TODO: remove this once we can set specific archs per source file set.
77
- override_gpu_arches(GPU_ARCHES
78
- ${GPU_LANG}
79
- "${${GPU_LANG}_SUPPORTED_ARCHS}")
80
-
81
- add_compile_definitions(ROCM_KERNEL)
82
- else()
83
- override_gpu_arches(GPU_ARCHES
84
- ${GPU_LANG}
85
- "${${GPU_LANG}_SUPPORTED_ARCHS}")
86
- endif()
87
-
88
- get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG})
89
- list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS})
90
-
91
- set(TORCH_layer_norm_SRC
92
- torch-ext/torch_binding.cpp torch-ext/torch_binding.h
93
- )
94
-
95
-
96
- list(APPEND SRC "${TORCH_layer_norm_SRC}")
97
-
98
-
99
- set(layer_norm_SRC
100
- "layer_norm/ln.h"
101
- "layer_norm/ln_api.cpp"
102
- "layer_norm/ln_bwd_1024.cu"
103
- "layer_norm/ln_bwd_1280.cu"
104
- "layer_norm/ln_bwd_1536.cu"
105
- "layer_norm/ln_bwd_2048.cu"
106
- "layer_norm/ln_bwd_256.cu"
107
- "layer_norm/ln_bwd_2560.cu"
108
- "layer_norm/ln_bwd_3072.cu"
109
- "layer_norm/ln_bwd_4096.cu"
110
- "layer_norm/ln_bwd_512.cu"
111
- "layer_norm/ln_bwd_5120.cu"
112
- "layer_norm/ln_bwd_6144.cu"
113
- "layer_norm/ln_bwd_7168.cu"
114
- "layer_norm/ln_bwd_768.cu"
115
- "layer_norm/ln_bwd_8192.cu"
116
- "layer_norm/ln_bwd_kernels.cuh"
117
- "layer_norm/ln_fwd_1024.cu"
118
- "layer_norm/ln_fwd_1280.cu"
119
- "layer_norm/ln_fwd_1536.cu"
120
- "layer_norm/ln_fwd_2048.cu"
121
- "layer_norm/ln_fwd_256.cu"
122
- "layer_norm/ln_fwd_2560.cu"
123
- "layer_norm/ln_fwd_3072.cu"
124
- "layer_norm/ln_fwd_4096.cu"
125
- "layer_norm/ln_fwd_512.cu"
126
- "layer_norm/ln_fwd_5120.cu"
127
- "layer_norm/ln_fwd_6144.cu"
128
- "layer_norm/ln_fwd_7168.cu"
129
- "layer_norm/ln_fwd_768.cu"
130
- "layer_norm/ln_fwd_8192.cu"
131
- "layer_norm/ln_fwd_kernels.cuh"
132
- "layer_norm/ln_kernel_traits.h"
133
- "layer_norm/ln_parallel_bwd_1024.cu"
134
- "layer_norm/ln_parallel_bwd_1280.cu"
135
- "layer_norm/ln_parallel_bwd_1536.cu"
136
- "layer_norm/ln_parallel_bwd_2048.cu"
137
- "layer_norm/ln_parallel_bwd_256.cu"
138
- "layer_norm/ln_parallel_bwd_2560.cu"
139
- "layer_norm/ln_parallel_bwd_3072.cu"
140
- "layer_norm/ln_parallel_bwd_4096.cu"
141
- "layer_norm/ln_parallel_bwd_512.cu"
142
- "layer_norm/ln_parallel_bwd_5120.cu"
143
- "layer_norm/ln_parallel_bwd_6144.cu"
144
- "layer_norm/ln_parallel_bwd_7168.cu"
145
- "layer_norm/ln_parallel_bwd_768.cu"
146
- "layer_norm/ln_parallel_bwd_8192.cu"
147
- "layer_norm/ln_parallel_fwd_1024.cu"
148
- "layer_norm/ln_parallel_fwd_1280.cu"
149
- "layer_norm/ln_parallel_fwd_1536.cu"
150
- "layer_norm/ln_parallel_fwd_2048.cu"
151
- "layer_norm/ln_parallel_fwd_256.cu"
152
- "layer_norm/ln_parallel_fwd_2560.cu"
153
- "layer_norm/ln_parallel_fwd_3072.cu"
154
- "layer_norm/ln_parallel_fwd_4096.cu"
155
- "layer_norm/ln_parallel_fwd_512.cu"
156
- "layer_norm/ln_parallel_fwd_5120.cu"
157
- "layer_norm/ln_parallel_fwd_6144.cu"
158
- "layer_norm/ln_parallel_fwd_7168.cu"
159
- "layer_norm/ln_parallel_fwd_768.cu"
160
- "layer_norm/ln_parallel_fwd_8192.cu"
161
- "layer_norm/ln_parallel_residual_bwd_kernels.cuh"
162
- "layer_norm/ln_parallel_residual_fwd_kernels.cuh"
163
- "layer_norm/ln_utils.cuh"
164
- "layer_norm/static_switch.h"
165
- )
166
-
167
- # TODO: check if CLion support this:
168
- # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
169
- set_source_files_properties(
170
- ${layer_norm_SRC}
171
- PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/.")
172
-
173
- if(GPU_LANG STREQUAL "CUDA")
174
- cuda_archs_loose_intersection(layer_norm_ARCHS "${CUDA_DEFAULT_KERNEL_ARCHS}" "${CUDA_ARCHS}")
175
- message(STATUS "Capabilities for kernel layer_norm: ${layer_norm_ARCHS}")
176
- set_gencode_flags_for_srcs(SRCS "${layer_norm_SRC}" CUDA_ARCHS "${layer_norm_ARCHS}")
177
-
178
-
179
- foreach(_KERNEL_SRC ${layer_norm_SRC})
180
- if(_KERNEL_SRC MATCHES ".*\\.cu$")
181
- set_property(
182
- SOURCE ${_KERNEL_SRC}
183
- APPEND PROPERTY
184
- COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-O3;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;-U__CUDA_NO_BFLOAT16_OPERATORS__;-U__CUDA_NO_BFLOAT16_CONVERSIONS__;-U__CUDA_NO_BFLOAT162_OPERATORS__;-U__CUDA_NO_BFLOAT162_CONVERSIONS__;--expt-relaxed-constexpr;--expt-extended-lambda;--use_fast_math>"
185
- )
186
- endif()
187
- endforeach()
188
-
189
- foreach(_KERNEL_SRC ${layer_norm_SRC})
190
- set_property(
191
- SOURCE ${_KERNEL_SRC}
192
- APPEND PROPERTY
193
- COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CXX>:-DFLASHATTENTION_DISABLE_PYBIND>"
194
- )
195
- endforeach()
196
-
197
- list(APPEND SRC "${layer_norm_SRC}")
198
- endif()
199
-
200
-
201
- define_gpu_extension_target(
202
- _layer_norm_711aa42_dirty
203
- DESTINATION _layer_norm_711aa42_dirty
204
- LANGUAGE ${GPU_LANG}
205
- SOURCES ${SRC}
206
- COMPILE_FLAGS ${GPU_FLAGS}
207
- ARCHITECTURES ${GPU_ARCHES}
208
- #INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
209
- USE_SABI 3
210
- WITH_SOABI)
211
-
212
- target_link_options(_layer_norm_711aa42_dirty PRIVATE -static-libstdc++)
213
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -2,23 +2,4 @@
2
  tags:
3
  - kernel
4
  ---
5
- This CUDA extension implements fused dropout + residual + LayerNorm, building on
6
- Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm).
7
- Major changes:
8
- - Add dropout and residual.
9
- - Make it work for both pre-norm and post-norm architecture.
10
- - Support more hidden dimensions (all dimensions divisible by 8, up to 8192).
11
- - Implement RMSNorm as an option.
12
- - Support layer norm with parallel residual (e.g., GPT-J, GPT-NeoX, PaLM).
13
-
14
- If you want to use it for dimensions larger than 8k, please file an issue.
15
-
16
- This extension has only been tested on A100s.
17
-
18
- ```sh
19
- cd csrc/layer_norm && pip install .
20
- ```
21
-
22
- As of 2024-01-05, this extension is no longer used in the FlashAttention repo.
23
- We've instead switched to a Triton-based
24
- [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py).
 
2
  tags:
3
  - kernel
4
  ---
5
+ This CUDA extension implements fused dropout + residual + LayerNorm from the [flash-attention](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm) repo.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api.py DELETED
@@ -1,800 +0,0 @@
1
- # Copyright (c) 2022, Tri Dao.
2
- # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
3
-
4
- import dropout_layer_norm
5
- import torch
6
- from torch.nn import init
7
-
8
-
9
- def maybe_align(x, alignment_in_bytes=16):
10
- """Assume that x already has last dim divisible by alignment_in_bytes"""
11
- # TD [2023-07-04] I'm not 100% sure that clone will align the memory
12
- # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
13
- return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
14
-
15
-
16
- def _dropout_add_layer_norm_forward(
17
- x0,
18
- residual,
19
- gamma,
20
- beta,
21
- rowscale,
22
- colscale,
23
- dropout_p,
24
- epsilon,
25
- residual_in_fp32=False,
26
- is_rms_norm=False,
27
- ):
28
- """Assume that arguments are contiguous and aligned to 16 bytes"""
29
- hidden_size = gamma.numel()
30
- x0mat = x0.view((-1, hidden_size))
31
- residualmat = residual.view((-1, hidden_size)) if residual is not None else None
32
- rowscale = rowscale.view(-1) if rowscale is not None else None
33
- zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
34
- x0mat,
35
- residualmat,
36
- gamma,
37
- beta,
38
- rowscale,
39
- colscale,
40
- None,
41
- None,
42
- dropout_p,
43
- epsilon,
44
- 1.0,
45
- 0,
46
- None,
47
- residual_in_fp32,
48
- is_rms_norm,
49
- )
50
- # dmask is None if dropout_p == 0.0
51
- # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
52
- return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
53
-
54
-
55
- def _dropout_add_layer_norm_backward(
56
- dz,
57
- dx,
58
- x,
59
- x0,
60
- dmask,
61
- mu,
62
- rsigma,
63
- gamma,
64
- rowscale,
65
- colscale,
66
- dropout_p,
67
- has_residual,
68
- is_rms_norm=False,
69
- ):
70
- """Assume that arguments are contiguous and aligned to 16 bytes
71
- dx == None means that it was a post-norm architecture
72
- (x = drop(x0) + residual was not returned in the fwd).
73
- x0 must not be None if we have colscale.
74
- """
75
- hidden_size = gamma.numel()
76
- xmat = x.view((-1, hidden_size))
77
- dzmat = dz.view(xmat.shape)
78
- dxmat = dx.view(xmat.shape) if dx is not None else None
79
- x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
80
- rowscale = rowscale.view(-1) if rowscale is not None else None
81
- if colscale is not None:
82
- assert x0 is not None, "x0 is required to compute the gradient of colscale"
83
- dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
84
- dzmat,
85
- dxmat,
86
- xmat,
87
- x0mat,
88
- dmask,
89
- mu,
90
- rsigma,
91
- gamma,
92
- rowscale,
93
- colscale,
94
- None,
95
- None,
96
- dropout_p,
97
- 1.0,
98
- 0,
99
- has_residual,
100
- is_rms_norm,
101
- )
102
- # dresidualmat is None if not has_residual
103
- if colscale is None:
104
- return dx0mat, dresidualmat, dgamma, dbeta
105
- else:
106
- dcolscale = rest[0]
107
- return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
108
-
109
-
110
- def _dropout_add_layer_norm_subset_forward(
111
- x0,
112
- residual,
113
- gamma,
114
- beta,
115
- colscale,
116
- x0_subset,
117
- out_subset,
118
- dropout_p,
119
- epsilon,
120
- rowscale_const,
121
- out_numrows,
122
- residual_in_fp32=False,
123
- is_rms_norm=False,
124
- ):
125
- """Assume that arguments are contiguous and aligned to 16 bytes"""
126
- hidden_size = gamma.numel()
127
- x0mat = x0.view((-1, hidden_size))
128
- residualmat = residual.view((-1, hidden_size)) if residual is not None else None
129
- x0_subset = x0_subset.view(-1) if x0_subset is not None else None
130
- out_subset = out_subset.view(-1) if out_subset is not None else None
131
- zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
132
- x0mat,
133
- residualmat,
134
- gamma,
135
- beta,
136
- None,
137
- colscale,
138
- x0_subset,
139
- out_subset,
140
- dropout_p,
141
- epsilon,
142
- rowscale_const,
143
- out_numrows,
144
- None,
145
- residual_in_fp32,
146
- is_rms_norm,
147
- )
148
- # dmask is None if dropout_p == 0.0
149
- # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
150
- return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
151
-
152
-
153
- def _dropout_add_layer_norm_subset_backward(
154
- dz,
155
- dx,
156
- x,
157
- x0,
158
- dmask,
159
- mu,
160
- rsigma,
161
- gamma,
162
- colscale,
163
- x0_subset,
164
- out_subset,
165
- dropout_p,
166
- rowscale_const,
167
- x0_numrows,
168
- has_residual,
169
- is_rms_norm=False,
170
- ):
171
- """Assume that arguments are contiguous and aligned to 16 bytes
172
- dx == None means that it was a post-norm architecture
173
- (x = drop(x0) + residual was not returned in the fwd).
174
- x0 must not be None if we have colscale.
175
- """
176
- hidden_size = gamma.numel()
177
- xmat = x.view((-1, hidden_size))
178
- dzmat = dz.view(-1, hidden_size)
179
- dxmat = dx.view(xmat.shape) if dx is not None else None
180
- x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
181
- x0_subset = x0_subset.view(-1) if x0_subset is not None else None
182
- out_subset = out_subset.view(-1) if out_subset is not None else None
183
- if colscale is not None:
184
- assert x0 is not None, "x0 is required to compute the gradient of colscale"
185
- dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
186
- dzmat,
187
- dxmat,
188
- xmat,
189
- x0mat,
190
- dmask,
191
- mu,
192
- rsigma,
193
- gamma,
194
- None,
195
- colscale,
196
- x0_subset,
197
- out_subset,
198
- dropout_p,
199
- rowscale_const,
200
- x0_numrows,
201
- has_residual,
202
- is_rms_norm,
203
- )
204
- # dresidualmat is None if not has_residual
205
- if colscale is None:
206
- return dx0mat, dresidualmat, dgamma, dbeta
207
- else:
208
- dcolscale = rest[0]
209
- return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
210
-
211
-
212
- def _dropout_add_layer_norm_parallel_residual_forward(
213
- x0,
214
- x1,
215
- residual,
216
- gamma0,
217
- beta0,
218
- gamma1,
219
- beta1,
220
- dropout_p,
221
- epsilon,
222
- residual_in_fp32=False,
223
- is_rms_norm=False,
224
- ):
225
- """Assume that arguments are contiguous and aligned to 16 bytes"""
226
- hidden_size = gamma0.numel()
227
- x0mat = x0.view((-1, hidden_size))
228
- x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
229
- residualmat = residual.view((-1, hidden_size)) if residual is not None else None
230
- (
231
- z0mat,
232
- z1mat,
233
- xmat,
234
- dmask0,
235
- dmask1,
236
- mu,
237
- rsigma,
238
- ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
239
- x0mat,
240
- x1mat,
241
- residualmat,
242
- gamma0,
243
- beta0,
244
- gamma1,
245
- beta1,
246
- dropout_p,
247
- epsilon,
248
- None,
249
- residual_in_fp32,
250
- is_rms_norm,
251
- )
252
- # dmask0 and dmask1 are None if dropout_p == 0.0
253
- # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
254
- return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma
255
-
256
-
257
- def _dropout_add_layer_norm_parallel_residual_backward(
258
- dz0,
259
- dz1,
260
- dx,
261
- x,
262
- dmask0,
263
- dmask1,
264
- mu,
265
- rsigma,
266
- gamma0,
267
- gamma1,
268
- dropout_p,
269
- has_x1,
270
- has_residual,
271
- is_rms_norm=False,
272
- ):
273
- """Assume that arguments are contiguous and aligned to 16 bytes
274
- dx == None means that it was a post-norm architecture
275
- (x = drop(x0) + residual was not returned in the fwd).
276
- """
277
- hidden_size = gamma0.numel()
278
- xmat = x.view((-1, hidden_size))
279
- dz0mat = dz0.view(xmat.shape)
280
- dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
281
- dxmat = dx.view(xmat.shape) if dx is not None else None
282
- (
283
- dx0mat,
284
- dx1mat,
285
- dresidualmat,
286
- dgamma0,
287
- dbeta0,
288
- dgamma1,
289
- dbeta1,
290
- *rest,
291
- ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
292
- dz0mat,
293
- dz1mat,
294
- dxmat,
295
- xmat,
296
- dmask0,
297
- dmask1,
298
- mu,
299
- rsigma,
300
- gamma0,
301
- gamma1,
302
- dropout_p,
303
- has_x1,
304
- has_residual,
305
- is_rms_norm,
306
- )
307
- # dresidualmat is None if not has_residual
308
- return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1
309
-
310
-
311
- class DropoutAddLayerNormFn(torch.autograd.Function):
312
- @staticmethod
313
- def forward(
314
- ctx,
315
- x0,
316
- residual,
317
- gamma,
318
- beta,
319
- rowscale,
320
- colscale,
321
- dropout_p,
322
- epsilon,
323
- residual_in_fp32=False,
324
- prenorm=False,
325
- is_rms_norm=False,
326
- return_dmask=False,
327
- ):
328
- x0 = maybe_align(x0.contiguous(), 16)
329
- residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
330
- gamma = maybe_align(gamma.contiguous(), 16)
331
- beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
332
- rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
333
- colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
334
- zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
335
- x0,
336
- residual,
337
- gamma,
338
- beta,
339
- rowscale,
340
- colscale,
341
- dropout_p,
342
- epsilon,
343
- residual_in_fp32,
344
- is_rms_norm,
345
- )
346
- # Only need to save x0 if we need to compute gradient wrt colscale
347
- x0_saved = x0 if colscale is not None else None
348
- ctx.save_for_backward(
349
- xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale
350
- )
351
- ctx.prenorm = prenorm
352
- ctx.dropout_p = dropout_p
353
- ctx.has_residual = residual is not None
354
- ctx.is_rms_norm = is_rms_norm
355
- ctx.has_beta = beta is not None
356
- if not return_dmask:
357
- return (
358
- zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape))
359
- )
360
- else:
361
- dmask = (
362
- dmask.view(x0.shape)
363
- if dropout_p > 0.0
364
- else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
365
- )
366
- ctx.mark_non_differentiable(dmask)
367
- return (
368
- (zmat.view(x0.shape), dmask)
369
- if not prenorm
370
- else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)
371
- )
372
-
373
- @staticmethod
374
- def backward(ctx, dz, *args):
375
- # assert dz.is_contiguous()
376
- dz = maybe_align(dz.contiguous(), 16) # this happens!
377
- dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
378
- x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
379
- # x0 is None if colscale is None
380
- dropout_p = ctx.dropout_p
381
- has_residual = ctx.has_residual
382
- dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
383
- dz,
384
- dx,
385
- x,
386
- x0,
387
- dmask,
388
- mu,
389
- rsigma,
390
- gamma,
391
- rowscale,
392
- colscale,
393
- dropout_p,
394
- has_residual,
395
- ctx.is_rms_norm,
396
- )
397
- dx0 = dx0mat.view(x.shape)
398
- dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
399
- dcolscale = rest[0] if colscale is not None else None
400
- return (
401
- dx0,
402
- dresidual,
403
- dgamma,
404
- dbeta if ctx.has_beta else None,
405
- None,
406
- dcolscale,
407
- None,
408
- None,
409
- None,
410
- None,
411
- None,
412
- None,
413
- )
414
-
415
-
416
- class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
417
- @staticmethod
418
- def forward(
419
- ctx,
420
- x0,
421
- residual,
422
- gamma,
423
- beta,
424
- colscale,
425
- x0_subset,
426
- out_subset,
427
- dropout_p,
428
- epsilon,
429
- rowscale_const,
430
- out_numrows,
431
- residual_in_fp32=False,
432
- prenorm=False,
433
- is_rms_norm=False,
434
- return_dmask=False,
435
- ):
436
- x0 = maybe_align(x0.contiguous(), 16)
437
- residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
438
- gamma = maybe_align(gamma.contiguous(), 16)
439
- beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
440
- colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
441
- zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
442
- x0,
443
- residual,
444
- gamma,
445
- beta,
446
- colscale,
447
- x0_subset,
448
- out_subset,
449
- dropout_p,
450
- epsilon,
451
- rowscale_const,
452
- out_numrows,
453
- residual_in_fp32,
454
- is_rms_norm,
455
- )
456
- # Only need to save x0 if we need to compute gradient wrt colscale
457
- x0_saved = x0 if colscale is not None else None
458
- x_shape = (-1, *x0.shape[1:])
459
- ctx.save_for_backward(
460
- xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset
461
- )
462
- ctx.prenorm = prenorm
463
- ctx.dropout_p = dropout_p
464
- ctx.rowscale_const = rowscale_const
465
- ctx.x0_numrows = x0.shape[:-1].numel()
466
- ctx.has_residual = residual is not None
467
- ctx.is_rms_norm = is_rms_norm
468
- ctx.has_beta = beta is not None
469
- z_shape = (-1, *x0.shape[1:])
470
- if not return_dmask:
471
- return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape))
472
- else:
473
- z = zmat.view(z_shape)
474
- dmask = (
475
- dmask.view(x0.shape)
476
- if dropout_p > 0.0
477
- else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
478
- )
479
- ctx.mark_non_differentiable(dmask)
480
- return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask)
481
-
482
- @staticmethod
483
- def backward(ctx, dz, *args):
484
- # assert dz.is_contiguous()
485
- dz = maybe_align(dz.contiguous(), 16) # this happens!
486
- dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
487
- x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
488
- # x0 is None if colscale is None
489
- dropout_p = ctx.dropout_p
490
- has_residual = ctx.has_residual
491
- dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
492
- dz,
493
- dx,
494
- x,
495
- x0,
496
- dmask,
497
- mu,
498
- rsigma,
499
- gamma,
500
- colscale,
501
- x0_subset,
502
- out_subset,
503
- dropout_p,
504
- ctx.rowscale_const,
505
- ctx.x0_numrows,
506
- has_residual,
507
- ctx.is_rms_norm,
508
- )
509
- dx0 = dx0mat.view(-1, *x.shape[1:])
510
- dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
511
- dcolscale = rest[0] if colscale is not None else None
512
- return (
513
- dx0,
514
- dresidual,
515
- dgamma,
516
- dbeta if ctx.has_beta else None,
517
- dcolscale,
518
- None,
519
- None,
520
- None,
521
- None,
522
- None,
523
- None,
524
- None,
525
- None,
526
- None,
527
- None,
528
- )
529
-
530
-
531
- class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
532
- @staticmethod
533
- def forward(
534
- ctx,
535
- x0,
536
- x1,
537
- residual,
538
- gamma0,
539
- beta0,
540
- gamma1,
541
- beta1,
542
- dropout_p,
543
- epsilon,
544
- residual_in_fp32=False,
545
- prenorm=False,
546
- is_rms_norm=False,
547
- return_dmask=False,
548
- ):
549
- x0 = maybe_align(x0.contiguous(), 16)
550
- x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
551
- residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
552
- gamma0 = maybe_align(gamma0.contiguous(), 16)
553
- beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
554
- gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
555
- beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
556
- (
557
- z0mat,
558
- z1mat,
559
- xmat,
560
- dmask0,
561
- dmask1,
562
- mu,
563
- rsigma,
564
- ) = _dropout_add_layer_norm_parallel_residual_forward(
565
- x0,
566
- x1,
567
- residual,
568
- gamma0,
569
- beta0,
570
- gamma1,
571
- beta1,
572
- dropout_p,
573
- epsilon,
574
- residual_in_fp32,
575
- is_rms_norm,
576
- )
577
- ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
578
- ctx.prenorm = prenorm
579
- ctx.dropout_p = dropout_p
580
- ctx.has_x1 = x1 is not None
581
- ctx.has_residual = residual is not None
582
- ctx.is_rms_norm = is_rms_norm
583
- ctx.has_beta = beta0 is not None
584
- z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
585
- if not return_dmask:
586
- return z if not prenorm else (*z, xmat.view(x0.shape))
587
- else:
588
- dmask0 = (
589
- dmask0.view(x0.shape)
590
- if dropout_p > 0.0
591
- else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
592
- )
593
- dmask1 = (
594
- dmask1.view(x0.shape)
595
- if dropout_p > 0.0 and x1 is not None
596
- else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
597
- )
598
- ctx.mark_non_differentiable(dmask0)
599
- ctx.mark_non_differentiable(dmask1)
600
- return (
601
- (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
602
- )
603
-
604
- @staticmethod
605
- def backward(ctx, dz0, dz1, *args):
606
- dz0 = maybe_align(dz0.contiguous(), 16) # this happens!
607
- dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None
608
- dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
609
- x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
610
- dropout_p = ctx.dropout_p
611
- has_x1 = ctx.has_x1
612
- has_residual = ctx.has_residual
613
- (
614
- dx0mat,
615
- dx1mat,
616
- dresidualmat,
617
- dgamma0,
618
- dbeta0,
619
- dgamma1,
620
- dbeta1,
621
- ) = _dropout_add_layer_norm_parallel_residual_backward(
622
- dz0,
623
- dz1,
624
- dx,
625
- x,
626
- dmask0,
627
- dmask1,
628
- mu,
629
- rsigma,
630
- gamma0,
631
- gamma1,
632
- dropout_p,
633
- has_x1,
634
- has_residual,
635
- ctx.is_rms_norm,
636
- )
637
- dx0 = dx0mat.view(x.shape)
638
- dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
639
- dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
640
- return (
641
- dx0,
642
- dx1,
643
- dresidual,
644
- dgamma0,
645
- dbeta0 if ctx.has_beta else None,
646
- dgamma1,
647
- dbeta1 if ctx.has_beta else None,
648
- None,
649
- None,
650
- None,
651
- None,
652
- None,
653
- None,
654
- )
655
-
656
-
657
- def layer_norm(x, weight, bias, epsilon):
658
- return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
659
-
660
-
661
- def dropout_add_layer_norm(
662
- x0,
663
- residual,
664
- weight,
665
- bias,
666
- dropout_p,
667
- epsilon,
668
- rowscale=None,
669
- layerscale=None,
670
- prenorm=False,
671
- residual_in_fp32=False,
672
- return_dropout_mask=False,
673
- ):
674
- """residual_in_fp32 only has an effect if residual is None.
675
- Otherwise residual dtype is residual.dtype.
676
- """
677
- return DropoutAddLayerNormFn.apply(
678
- x0,
679
- residual,
680
- weight,
681
- bias,
682
- rowscale,
683
- layerscale,
684
- dropout_p,
685
- epsilon,
686
- residual_in_fp32,
687
- prenorm,
688
- False,
689
- return_dropout_mask,
690
- )
691
-
692
-
693
- def dropout_add_layer_norm_subset(
694
- x0,
695
- residual,
696
- weight,
697
- bias,
698
- dropout_p,
699
- epsilon,
700
- layerscale=None,
701
- x0_subset=None,
702
- out_subset=None,
703
- rowscale_const=1.0,
704
- out_numrows=0,
705
- prenorm=False,
706
- residual_in_fp32=False,
707
- return_dropout_mask=False,
708
- ):
709
- """residual_in_fp32 only has an effect if residual is None.
710
- Otherwise residual dtype is residual.dtype.
711
- """
712
- return DropoutAddLayerNormSubsetFn.apply(
713
- x0,
714
- residual,
715
- weight,
716
- bias,
717
- layerscale,
718
- x0_subset,
719
- out_subset,
720
- dropout_p,
721
- epsilon,
722
- rowscale_const,
723
- out_numrows,
724
- residual_in_fp32,
725
- prenorm,
726
- False,
727
- return_dropout_mask,
728
- )
729
-
730
-
731
- def dropout_add_layer_norm_parallel_residual(
732
- x0,
733
- x1,
734
- residual,
735
- weight0,
736
- bias0,
737
- weight1,
738
- bias1,
739
- dropout_p,
740
- epsilon,
741
- prenorm=False,
742
- residual_in_fp32=False,
743
- return_dropout_mask=False,
744
- ):
745
- """residual_in_fp32 only has an effect if residual is None.
746
- Otherwise residual dtype is residual.dtype.
747
- """
748
- return DropoutAddLayerNormParallelResidualFn.apply(
749
- x0,
750
- x1,
751
- residual,
752
- weight0,
753
- bias0,
754
- weight1,
755
- bias1,
756
- dropout_p,
757
- epsilon,
758
- residual_in_fp32,
759
- prenorm,
760
- False,
761
- return_dropout_mask,
762
- )
763
-
764
-
765
- class DropoutAddLayerNorm(torch.nn.Module):
766
- def __init__(
767
- self,
768
- hidden_size,
769
- prenorm=False,
770
- p=0.0,
771
- eps=1e-5,
772
- residual_in_fp32=False,
773
- device=None,
774
- dtype=None,
775
- ):
776
- factory_kwargs = {"device": device, "dtype": dtype}
777
- super().__init__()
778
- self.prenorm = prenorm
779
- self.p = p
780
- self.eps = eps
781
- self.residual_in_fp32 = residual_in_fp32
782
- self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
783
- self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
784
- self.reset_parameters()
785
-
786
- def reset_parameters(self):
787
- init.ones_(self.weight)
788
- init.zeros_(self.bias)
789
-
790
- def forward(self, x0, residual=None):
791
- return dropout_add_layer_norm(
792
- x0,
793
- residual,
794
- self.weight,
795
- self.bias,
796
- self.p if self.training else 0.0,
797
- self.eps,
798
- prenorm=self.prenorm,
799
- residual_in_fp32=self.residual_in_fp32,
800
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/{torch28-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so → torch27-cxx11-cu118-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:12b6de6cef24c5ee7a390d91ee2ea7069533e66440cf78ae5df7ae3beff5c1ca
3
- size 712024936
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:790cd814bbfcaf7ff83b5c68bcb91091a67f34e92b9a2494e2856462e71a3141
3
+ size 716945944
build/torch27-cxx11-cu118-x86_64-linux/layer_norm/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _layer_norm_f622ea1_dirty
3
- ops = torch.ops._layer_norm_f622ea1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_layer_norm_f622ea1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _layer_norm_f3fd6bf
3
+ ops = torch.ops._layer_norm_f3fd6bf
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_layer_norm_f3fd6bf::{op_name}"
build/{torch27-cxx11-cu118-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fe0515daaf1bbfd1246d18bd5c1a5cd6f366059090a8b6e402955d06caaa6392
3
- size 716945976
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b17984ef79fc9d6427c8efe0a8cc8f1f6e2777f9a8641b86556b7bb2359626ab
3
+ size 712024816
build/torch27-cxx11-cu126-x86_64-linux/layer_norm/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _layer_norm_f622ea1_dirty
3
- ops = torch.ops._layer_norm_f622ea1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_layer_norm_f622ea1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _layer_norm_f3fd6bf
3
+ ops = torch.ops._layer_norm_f3fd6bf
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_layer_norm_f3fd6bf::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/layer_norm/{_layer_norm_f622ea1_dirty.abi3.so → _layer_norm_f3fd6bf.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ae0d54be8ee4e3ae33f47f0b27243c9cbd5668ff7756b1dfb5dcd9e2430f5a35
3
- size 1231333392
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7629b13b777a390df75374fc60d85311679a56a5bbd9969e138822e5c0fe2b1e
3
+ size 1231333360
build/torch27-cxx11-cu128-x86_64-linux/layer_norm/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _layer_norm_f622ea1_dirty
3
- ops = torch.ops._layer_norm_f622ea1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_layer_norm_f622ea1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _layer_norm_f3fd6bf
3
+ ops = torch.ops._layer_norm_f3fd6bf
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_layer_norm_f3fd6bf::{op_name}"
build/{torch27-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so → torch28-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:04095de2e4bf9cd03f9ec481084d0c9e9e0baa0bab17a0ec9715f22f69bdfd33
3
- size 712024848
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6ffc9d5651e8de6440f2d4f58018a5ded07634582ae03eec5b9edf428f613a6
3
+ size 712024904
build/torch28-cxx11-cu126-x86_64-linux/layer_norm/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _layer_norm_f622ea1_dirty
3
- ops = torch.ops._layer_norm_f622ea1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_layer_norm_f622ea1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _layer_norm_f3fd6bf
3
+ ops = torch.ops._layer_norm_f3fd6bf
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_layer_norm_f3fd6bf::{op_name}"
build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df39795e047e962019cbecbb11f93d8ee1fcfb49ed8326f2edc267bc0d90da08
3
+ size 1231337936
build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d51ec6b6da7095cf5fc18493eb4b0b1c20485f01dff4b38370979ea3d0a9dd60
3
- size 1231337968
 
 
 
 
build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _layer_norm_f622ea1_dirty
3
- ops = torch.ops._layer_norm_f622ea1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_layer_norm_f622ea1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _layer_norm_f3fd6bf
3
+ ops = torch.ops._layer_norm_f3fd6bf
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_layer_norm_f3fd6bf::{op_name}"
build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfef6947945f8f126a284c6a8ab861e180a5e628992eeb0b4b7c7914c50a59c2
3
+ size 1283037344
build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_layer_norm_f622ea1_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9080934ece3b5e09db6178b1baa15b8baf9f6873e234a951a2122071e1190fba
3
- size 1283037376
 
 
 
 
build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _layer_norm_f622ea1_dirty
3
- ops = torch.ops._layer_norm_f622ea1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_layer_norm_f622ea1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _layer_norm_f3fd6bf
3
+ ops = torch.ops._layer_norm_f3fd6bf
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_layer_norm_f3fd6bf::{op_name}"
{torch-ext → build/torch29-cxx11-cu126-x86_64-linux}/layer_norm/__init__.py RENAMED
File without changes
build/torch29-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2bdb57c0889ade2fc574156873c1d4b543796f2e8ad6a894be82ee2785459c9b
3
+ size 712029160
build/torch29-cxx11-cu126-x86_64-linux/layer_norm/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _layer_norm_f3fd6bf
3
+ ops = torch.ops._layer_norm_f3fd6bf
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_layer_norm_f3fd6bf::{op_name}"
{torch-ext → build/torch29-cxx11-cu126-x86_64-linux}/layer_norm/layers.py RENAMED
File without changes
build/torch29-cxx11-cu128-x86_64-linux/layer_norm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ._ops import ops
5
+
6
+ from . import layers
7
+
8
+ def dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm):
9
+ return ops.dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm)
10
+
11
+ def dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm):
12
+ return ops.dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm)
13
+
14
+ def dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm):
15
+ return ops.dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm)
16
+
17
+ def dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm):
18
+ return ops.dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm)
19
+
20
+ __all__ = [
21
+ "layers",
22
+ "dropout_add_ln_fwd",
23
+ "dropout_add_ln_bwd",
24
+ "dropout_add_ln_parallel_residual_fwd",
25
+ "dropout_add_ln_parallel_residual_bwd",
26
+ ]
build/torch29-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03e6e7ecbf276b306d89607100f78f2ce8b3385a77594676dbf0daabdce26fc7
3
+ size 1231338080
build/torch29-cxx11-cu128-x86_64-linux/layer_norm/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _layer_norm_f3fd6bf
3
+ ops = torch.ops._layer_norm_f3fd6bf
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_layer_norm_f3fd6bf::{op_name}"
build/torch29-cxx11-cu128-x86_64-linux/layer_norm/layers.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ._ops import ops
5
+
6
+
7
+ class LayerNorm(nn.Module):
8
+ weight: torch.Tensor
9
+ variance_epsilon: float
10
+
11
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
12
+ output = ops.dropout_add_ln_fwd(
13
+ hidden_states.view(-1, hidden_states.shape[-1]),
14
+ gamma = self.weight,
15
+ beta = None,
16
+ rowscale = None,
17
+ colscale = None,
18
+ x0_subset = None,
19
+ z_subset = None,
20
+ dropout_p = 0,
21
+ epsilon = self.variance_epsilon,
22
+ rowscale_const = 1.0,
23
+ z_numrows = hidden_states.shape[1],
24
+ gen = None,
25
+ residual_in_fp32 = False,
26
+ is_rms_norm = False,
27
+ )
28
+ return output[0].view(hidden_states.shape)
29
+
30
+ class LlamaRMSNorm(nn.Module):
31
+ weight: torch.Tensor
32
+ variance_epsilon: float
33
+
34
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
35
+ output = ops.dropout_add_ln_fwd(
36
+ hidden_states.view(-1, hidden_states.shape[-1]),
37
+ gamma = self.weight,
38
+ beta = None,
39
+ rowscale = None,
40
+ colscale = None,
41
+ x0_subset = None,
42
+ z_subset = None,
43
+ dropout_p = 0,
44
+ epsilon = self.variance_epsilon,
45
+ rowscale_const = 1.0,
46
+ z_numrows = hidden_states.shape[1],
47
+ gen = None,
48
+ residual_in_fp32 = False,
49
+ is_rms_norm = True,
50
+ )
51
+ return output[0].view(hidden_states.shape)
build/torch29-cxx11-cu130-x86_64-linux/layer_norm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ._ops import ops
5
+
6
+ from . import layers
7
+
8
+ def dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm):
9
+ return ops.dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm)
10
+
11
+ def dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm):
12
+ return ops.dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm)
13
+
14
+ def dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm):
15
+ return ops.dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm)
16
+
17
+ def dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm):
18
+ return ops.dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm)
19
+
20
+ __all__ = [
21
+ "layers",
22
+ "dropout_add_ln_fwd",
23
+ "dropout_add_ln_bwd",
24
+ "dropout_add_ln_parallel_residual_fwd",
25
+ "dropout_add_ln_parallel_residual_bwd",
26
+ ]
build/torch29-cxx11-cu130-x86_64-linux/layer_norm/_layer_norm_f3fd6bf.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:322e2d8fc69447be95ef7b6e85267e8769f1284419baa606732a77b1980a834d
3
+ size 1238333264
build/torch29-cxx11-cu130-x86_64-linux/layer_norm/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _layer_norm_f3fd6bf
3
+ ops = torch.ops._layer_norm_f3fd6bf
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_layer_norm_f3fd6bf::{op_name}"
build/torch29-cxx11-cu130-x86_64-linux/layer_norm/layers.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ._ops import ops
5
+
6
+
7
+ class LayerNorm(nn.Module):
8
+ weight: torch.Tensor
9
+ variance_epsilon: float
10
+
11
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
12
+ output = ops.dropout_add_ln_fwd(
13
+ hidden_states.view(-1, hidden_states.shape[-1]),
14
+ gamma = self.weight,
15
+ beta = None,
16
+ rowscale = None,
17
+ colscale = None,
18
+ x0_subset = None,
19
+ z_subset = None,
20
+ dropout_p = 0,
21
+ epsilon = self.variance_epsilon,
22
+ rowscale_const = 1.0,
23
+ z_numrows = hidden_states.shape[1],
24
+ gen = None,
25
+ residual_in_fp32 = False,
26
+ is_rms_norm = False,
27
+ )
28
+ return output[0].view(hidden_states.shape)
29
+
30
+ class LlamaRMSNorm(nn.Module):
31
+ weight: torch.Tensor
32
+ variance_epsilon: float
33
+
34
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
35
+ output = ops.dropout_add_ln_fwd(
36
+ hidden_states.view(-1, hidden_states.shape[-1]),
37
+ gamma = self.weight,
38
+ beta = None,
39
+ rowscale = None,
40
+ colscale = None,
41
+ x0_subset = None,
42
+ z_subset = None,
43
+ dropout_p = 0,
44
+ epsilon = self.variance_epsilon,
45
+ rowscale_const = 1.0,
46
+ z_numrows = hidden_states.shape[1],
47
+ gen = None,
48
+ residual_in_fp32 = False,
49
+ is_rms_norm = True,
50
+ )
51
+ return output[0].view(hidden_states.shape)
cmake/hipify.py DELETED
@@ -1,76 +0,0 @@
1
- #!/usr/bin/env python3
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- # From vLLM: https://github.com/vllm-project/vllm/blob/main/cmake/hipify.py
5
-
6
- #
7
- # A command line tool for running pytorch's hipify preprocessor on CUDA
8
- # source files.
9
- #
10
- # See https://github.com/ROCm/hipify_torch
11
- # and <torch install dir>/utils/hipify/hipify_python.py
12
- #
13
-
14
- import argparse
15
- import os
16
- import shutil
17
-
18
- from torch.utils.hipify.hipify_python import hipify
19
-
20
- if __name__ == '__main__':
21
- parser = argparse.ArgumentParser()
22
-
23
- # Project directory where all the source + include files live.
24
- parser.add_argument(
25
- "-p",
26
- "--project_dir",
27
- help="The project directory.",
28
- )
29
-
30
- # Directory where hipified files are written.
31
- parser.add_argument(
32
- "-o",
33
- "--output_dir",
34
- help="The output directory.",
35
- )
36
-
37
- # Source files to convert.
38
- parser.add_argument("sources",
39
- help="Source files to hipify.",
40
- nargs="*",
41
- default=[])
42
-
43
- args = parser.parse_args()
44
-
45
- # Limit include scope to project_dir only
46
- includes = [os.path.join(args.project_dir, '*')]
47
-
48
- # Get absolute path for all source files.
49
- extra_files = [os.path.abspath(s) for s in args.sources]
50
-
51
- # Copy sources from project directory to output directory.
52
- # The directory might already exist to hold object files so we ignore that.
53
- shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True)
54
-
55
- hipify_result = hipify(project_directory=args.project_dir,
56
- output_directory=args.output_dir,
57
- header_include_dirs=[],
58
- includes=includes,
59
- extra_files=extra_files,
60
- show_detailed=True,
61
- is_pytorch_extension=True,
62
- hipify_extra_files_only=True)
63
-
64
- hipified_sources = []
65
- for source in args.sources:
66
- s_abs = os.path.abspath(source)
67
- hipified_s_abs = (hipify_result[s_abs].hipified_path if
68
- (s_abs in hipify_result
69
- and hipify_result[s_abs].hipified_path is not None)
70
- else s_abs)
71
- hipified_sources.append(hipified_s_abs)
72
-
73
- assert (len(hipified_sources) == len(args.sources))
74
-
75
- # Print hipified source files.
76
- print("\n".join(hipified_sources))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cmake/utils.cmake DELETED
@@ -1,545 +0,0 @@
1
- # Vendored from vLLM:
2
- #
3
- # https://github.com/vllm-project/vllm/blob/main/cmake/utils.cmake
4
- #
5
- # Attempt to find the python package that uses the same python executable as
6
- # `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`.
7
- #
8
- macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
9
- file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
10
- set(Python_EXECUTABLE ${EXECUTABLE})
11
- find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
12
- if (NOT Python_FOUND)
13
- message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
14
- endif()
15
- set(_VER "${Python_VERSION_MAJOR}.${Python_VERSION_MINOR}")
16
- set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN})
17
- if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST)
18
- message(FATAL_ERROR
19
- "Python version (${_VER}) is not one of the supported versions: "
20
- "${_SUPPORTED_VERSIONS_LIST}.")
21
- endif()
22
- message(STATUS "Found python matching: ${EXECUTABLE}.")
23
- endmacro()
24
-
25
- #
26
- # Run `EXPR` in python. The standard output of python is stored in `OUT` and
27
- # has trailing whitespace stripped. If an error is encountered when running
28
- # python, a fatal message `ERR_MSG` is issued.
29
- #
30
- function (run_python OUT EXPR ERR_MSG)
31
- execute_process(
32
- COMMAND
33
- "${Python_EXECUTABLE}" "-c" "${EXPR}"
34
- OUTPUT_VARIABLE PYTHON_OUT
35
- RESULT_VARIABLE PYTHON_ERROR_CODE
36
- ERROR_VARIABLE PYTHON_STDERR
37
- OUTPUT_STRIP_TRAILING_WHITESPACE)
38
-
39
- if(NOT PYTHON_ERROR_CODE EQUAL 0)
40
- message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}")
41
- endif()
42
- set(${OUT} ${PYTHON_OUT} PARENT_SCOPE)
43
- endfunction()
44
-
45
- # Run `EXPR` in python after importing `PKG`. Use the result of this to extend
46
- # `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported.
47
- macro (append_cmake_prefix_path PKG EXPR)
48
- run_python(_PREFIX_PATH
49
- "import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path")
50
- list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH})
51
- endmacro()
52
-
53
- #
54
- # Add a target named `hipify${NAME}` that runs the hipify preprocessor on a set
55
- # of CUDA source files. The names of the corresponding "hipified" sources are
56
- # stored in `OUT_SRCS`.
57
- #
58
- function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
59
- #
60
- # Split into C++ and non-C++ (i.e. CUDA) sources.
61
- #
62
- set(NODUP_SRCS ${ORIG_SRCS})
63
- list(REMOVE_DUPLICATES NODUP_SRCS)
64
- set(SRCS ${NODUP_SRCS})
65
- set(CXX_SRCS ${NODUP_SRCS})
66
- list(FILTER SRCS INCLUDE REGEX "\.cu$")
67
- list(FILTER CXX_SRCS EXCLUDE REGEX "\.cu$")
68
-
69
- #
70
- # Generate ROCm/HIP source file names from CUDA file names.
71
- # Since HIP files are generated code, they will appear in the build area
72
- # `CMAKE_CURRENT_BINARY_DIR` directory rather than the original csrc dir.
73
- #
74
- set(HIP_SRCS)
75
- foreach (SRC ${SRCS})
76
- get_source_file_property(include_dirs "${SRC}" INCLUDE_DIRECTORIES)
77
- string(REGEX REPLACE "\.cu$" "\.hip" SRC ${SRC})
78
- string(REGEX REPLACE "cuda" "hip" SRC ${SRC})
79
-
80
- if(include_dirs)
81
- # Copy over include directories from the original CUDA file.
82
- set_source_files_properties(
83
- ${SRC}
84
- PROPERTIES INCLUDE_DIRECTORIES "${include_dirs}")
85
- endif()
86
-
87
- list(APPEND HIP_SRCS "${CMAKE_CURRENT_BINARY_DIR}/${SRC}")
88
- endforeach()
89
-
90
- add_custom_target(
91
- hipify${NAME}
92
- COMMAND "${Python_EXECUTABLE}" ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR} -o ${CMAKE_CURRENT_BINARY_DIR} ${SRCS}
93
- DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS}
94
- BYPRODUCTS ${HIP_SRCS}
95
- COMMENT "Running hipify on ${NAME} extension source files.")
96
-
97
- # Swap out original extension sources with hipified sources.
98
- list(APPEND HIP_SRCS ${CXX_SRCS})
99
- set(${OUT_SRCS} ${HIP_SRCS} PARENT_SCOPE)
100
- endfunction()
101
-
102
- #
103
- # Get additional GPU compiler flags from torch.
104
- #
105
- function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
106
- if (${GPU_LANG} STREQUAL "CUDA")
107
- #
108
- # Get common NVCC flags from torch.
109
- #
110
- run_python(GPU_FLAGS
111
- "from torch.utils.cpp_extension import COMMON_NVCC_FLAGS; print(';'.join(COMMON_NVCC_FLAGS))"
112
- "Failed to determine torch nvcc compiler flags")
113
-
114
- if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
115
- list(APPEND GPU_FLAGS "-DENABLE_FP8")
116
- list(REMOVE_ITEM GPU_FLAGS
117
- "-D__CUDA_NO_HALF_OPERATORS__"
118
- "-D__CUDA_NO_HALF_CONVERSIONS__"
119
- "-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
120
- "-D__CUDA_NO_HALF2_OPERATORS__")
121
- endif()
122
-
123
- elseif(${GPU_LANG} STREQUAL "HIP")
124
- #
125
- # Get common HIP/HIPCC flags from torch.
126
- #
127
- run_python(GPU_FLAGS
128
- "import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))"
129
- "Failed to determine torch nvcc compiler flags")
130
-
131
- list(APPEND GPU_FLAGS
132
- "-DUSE_ROCM"
133
- "-DENABLE_FP8"
134
- "-U__HIP_NO_HALF_CONVERSIONS__"
135
- "-U__HIP_NO_HALF_OPERATORS__"
136
- "-fno-gpu-rdc")
137
-
138
- endif()
139
- set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE)
140
- endfunction()
141
-
142
- # Macro for converting a `gencode` version number to a cmake version number.
143
- macro(string_to_ver OUT_VER IN_STR)
144
- string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
145
- endmacro()
146
-
147
- #
148
- # Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
149
- # `CUDA_ARCH_FLAGS`.
150
- #
151
- # Example:
152
- # CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
153
- # clear_cuda_arches(CUDA_ARCH_FLAGS)
154
- # CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
155
- # CMAKE_CUDA_FLAGS="-Wall"
156
- #
157
- macro(clear_cuda_arches CUDA_ARCH_FLAGS)
158
- # Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
159
- string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
160
- ${CMAKE_CUDA_FLAGS})
161
-
162
- # Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
163
- # and passed back via the `CUDA_ARCHITECTURES` property.
164
- string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
165
- ${CMAKE_CUDA_FLAGS})
166
- endmacro()
167
-
168
- #
169
- # Extract unique CUDA architectures from a list of compute capabilities codes in
170
- # the form `<major><minor>[<letter>]`, convert them to the form sort
171
- # `<major>.<minor>`, dedupes them and then sorts them in ascending order and
172
- # stores them in `OUT_ARCHES`.
173
- #
174
- # Example:
175
- # CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a"
176
- # extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS)
177
- # OUT_ARCHES="7.5;...;9.0"
178
- function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS)
179
- set(_CUDA_ARCHES)
180
- foreach(_ARCH ${CUDA_ARCH_FLAGS})
181
- string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
182
- if (_COMPUTE)
183
- set(_COMPUTE ${CMAKE_MATCH_1})
184
- endif()
185
-
186
- string_to_ver(_COMPUTE_VER ${_COMPUTE})
187
- list(APPEND _CUDA_ARCHES ${_COMPUTE_VER})
188
- endforeach()
189
-
190
- list(REMOVE_DUPLICATES _CUDA_ARCHES)
191
- list(SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING)
192
- set(${OUT_ARCHES} ${_CUDA_ARCHES} PARENT_SCOPE)
193
- endfunction()
194
-
195
- #
196
- # For a specific file set the `-gencode` flag in compile options conditionally
197
- # for the CUDA language.
198
- #
199
- # Example:
200
- # set_gencode_flag_for_srcs(
201
- # SRCS "foo.cu"
202
- # ARCH "compute_75"
203
- # CODE "sm_75")
204
- # adds: "-gencode arch=compute_75,code=sm_75" to the compile options for
205
- # `foo.cu` (only for the CUDA language).
206
- #
207
- macro(set_gencode_flag_for_srcs)
208
- set(options)
209
- set(oneValueArgs ARCH CODE)
210
- set(multiValueArgs SRCS)
211
- cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
212
- "${multiValueArgs}" ${ARGN} )
213
- set(_FLAG -gencode arch=${arg_ARCH},code=${arg_CODE})
214
- set_property(
215
- SOURCE ${arg_SRCS}
216
- APPEND PROPERTY
217
- COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:${_FLAG}>"
218
- )
219
-
220
- message(DEBUG "Setting gencode flag for ${arg_SRCS}: ${_FLAG}")
221
- endmacro(set_gencode_flag_for_srcs)
222
-
223
- #
224
- # For a list of source files set the `-gencode` flags in the files specific
225
- # compile options (specifically for the CUDA language).
226
- #
227
- # arguments are:
228
- # SRCS: list of source files
229
- # CUDA_ARCHS: list of CUDA architectures in the form `<major>.<minor>[letter]`
230
- # BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built
231
- # for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS
232
- # that is larger than BUILD_PTX_FOR_ARCH.
233
- #
234
- macro(set_gencode_flags_for_srcs)
235
- set(options)
236
- set(oneValueArgs BUILD_PTX_FOR_ARCH)
237
- set(multiValueArgs SRCS CUDA_ARCHS)
238
- cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
239
- "${multiValueArgs}" ${ARGN} )
240
-
241
- foreach(_ARCH ${arg_CUDA_ARCHS})
242
- # handle +PTX suffix: generate both sm and ptx codes if requested
243
- string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
244
- if(NOT _HAS_PTX EQUAL -1)
245
- string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
246
- string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
247
- set_gencode_flag_for_srcs(
248
- SRCS ${arg_SRCS}
249
- ARCH "compute_${_STRIPPED_ARCH}"
250
- CODE "sm_${_STRIPPED_ARCH}")
251
- set_gencode_flag_for_srcs(
252
- SRCS ${arg_SRCS}
253
- ARCH "compute_${_STRIPPED_ARCH}"
254
- CODE "compute_${_STRIPPED_ARCH}")
255
- else()
256
- string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
257
- set_gencode_flag_for_srcs(
258
- SRCS ${arg_SRCS}
259
- ARCH "compute_${_STRIPPED_ARCH}"
260
- CODE "sm_${_STRIPPED_ARCH}")
261
- endif()
262
- endforeach()
263
-
264
- if (${arg_BUILD_PTX_FOR_ARCH})
265
- list(SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
266
- list(GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH)
267
- if (_HIGHEST_ARCH VERSION_GREATER_EQUAL ${arg_BUILD_PTX_FOR_ARCH})
268
- string(REPLACE "." "" _PTX_ARCH "${arg_BUILD_PTX_FOR_ARCH}")
269
- set_gencode_flag_for_srcs(
270
- SRCS ${arg_SRCS}
271
- ARCH "compute_${_PTX_ARCH}"
272
- CODE "compute_${_PTX_ARCH}")
273
- endif()
274
- endif()
275
- endmacro()
276
-
277
- #
278
- # For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
279
- # `<major>.<minor>[letter]` compute the "loose intersection" with the
280
- # `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
281
- # `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
282
- # is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
283
- # architecture in `SRC_CUDA_ARCHS`.
284
- # The loose intersection is defined as:
285
- # { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
286
- # where `<=` is the version comparison operator.
287
- # In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
288
- # in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
289
- # We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is
290
- # in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add
291
- # x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
292
- # The result is stored in `OUT_CUDA_ARCHS`.
293
- #
294
- # Example:
295
- # SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a"
296
- # TGT_CUDA_ARCHS="8.0;8.9;9.0"
297
- # cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
298
- # OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
299
- #
300
- # Example With PTX:
301
- # SRC_CUDA_ARCHS="8.0+PTX"
302
- # TGT_CUDA_ARCHS="9.0"
303
- # cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
304
- # OUT_CUDA_ARCHS="8.0+PTX"
305
- #
306
- function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
307
- set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
308
- set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})
309
-
310
- # handle +PTX suffix: separate base arch for matching, record PTX requests
311
- set(_PTX_ARCHS)
312
- foreach(_arch ${_SRC_CUDA_ARCHS})
313
- if(_arch MATCHES "\\+PTX$")
314
- string(REPLACE "+PTX" "" _base "${_arch}")
315
- list(APPEND _PTX_ARCHS "${_base}")
316
- list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
317
- list(APPEND _SRC_CUDA_ARCHS "${_base}")
318
- endif()
319
- endforeach()
320
- list(REMOVE_DUPLICATES _PTX_ARCHS)
321
- list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
322
-
323
- # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
324
- # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
325
- set(_CUDA_ARCHS)
326
- foreach(_arch ${_SRC_CUDA_ARCHS})
327
- if(_arch MATCHES "\\a$")
328
- list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
329
- string(REPLACE "a" "" _base "${_arch}")
330
- if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
331
- list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
332
- list(APPEND _CUDA_ARCHS "${_arch}")
333
- endif()
334
- endif()
335
- endforeach()
336
-
337
- list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
338
-
339
- # for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
340
- # is less or equal to ARCH (but has the same major version since SASS binary
341
- # compatibility is only forward compatible within the same major version).
342
- foreach(_ARCH ${_TGT_CUDA_ARCHS})
343
- set(_TMP_ARCH)
344
- # Extract the major version of the target arch
345
- string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
346
- foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
347
- # Extract the major version of the source arch
348
- string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
349
- # Check version-less-or-equal, and allow PTX arches to match across majors
350
- if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
351
- if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
352
- set(_TMP_ARCH "${_SRC_ARCH}")
353
- endif()
354
- else()
355
- # If we hit a version greater than the target, we can break
356
- break()
357
- endif()
358
- endforeach()
359
-
360
- # If we found a matching _TMP_ARCH, append it to _CUDA_ARCHS
361
- if (_TMP_ARCH)
362
- list(APPEND _CUDA_ARCHS "${_TMP_ARCH}")
363
- endif()
364
- endforeach()
365
-
366
- list(REMOVE_DUPLICATES _CUDA_ARCHS)
367
-
368
- # reapply +PTX suffix to architectures that requested PTX
369
- set(_FINAL_ARCHS)
370
- foreach(_arch ${_CUDA_ARCHS})
371
- if(_arch IN_LIST _PTX_ARCHS)
372
- list(APPEND _FINAL_ARCHS "${_arch}+PTX")
373
- else()
374
- list(APPEND _FINAL_ARCHS "${_arch}")
375
- endif()
376
- endforeach()
377
- set(_CUDA_ARCHS ${_FINAL_ARCHS})
378
-
379
- set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
380
- endfunction()
381
-
382
- #
383
- # For the given `SRC_ROCM_ARCHS` list of architecture versions in the form
384
- # `<name>` compute the "loose intersection" with the `TGT_ROCM_ARCHS` list.
385
- # The loose intersection is defined as:
386
- # { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
387
- # where `<=` is the version comparison operator.
388
- # In other words, for each version in `TGT_ROCM_ARCHS` find the highest version
389
- # in `SRC_ROCM_ARCHS` that is less or equal to the version in `TGT_ROCM_ARCHS`.
390
- # The result is stored in `OUT_ROCM_ARCHS`.
391
- #
392
- # Example:
393
- # SRC_ROCM_ARCHS="gfx900;gfx906;gfx908;gfx90a"
394
- # TGT_ROCM_ARCHS="gfx906;gfx908;gfx1030"
395
- # hip_archs_loose_intersection(OUT_ROCM_ARCHS SRC_ROCM_ARCHS TGT_ROCM_ARCHS)
396
- # OUT_ROCM_ARCHS="gfx906;gfx908"
397
- #
398
- function(hip_archs_loose_intersection OUT_ROCM_ARCHS SRC_ROCM_ARCHS TGT_ROCM_ARCHS)
399
- list(REMOVE_DUPLICATES SRC_ROCM_ARCHS)
400
-
401
- # ROCm architectures are typically in format gfxNNN or gfxNNNx where N is a digit
402
- # and x is a letter. We can sort them by string comparison which works for this format.
403
- list(SORT SRC_ROCM_ARCHS COMPARE STRING ORDER ASCENDING)
404
-
405
- set(_ROCM_ARCHS)
406
-
407
- # Find the intersection of supported architectures
408
- foreach(_SRC_ARCH ${SRC_ROCM_ARCHS})
409
- if(_SRC_ARCH IN_LIST TGT_ROCM_ARCHS)
410
- list(APPEND _ROCM_ARCHS ${_SRC_ARCH})
411
- endif()
412
- endforeach()
413
-
414
- list(REMOVE_DUPLICATES _ROCM_ARCHS)
415
- set(${OUT_ROCM_ARCHS} ${_ROCM_ARCHS} PARENT_SCOPE)
416
- endfunction()
417
-
418
- #
419
- # Override the GPU architectures detected by cmake/torch and filter them by
420
- # `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
421
- # `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set
422
- # the architectures on a per file basis.
423
- #
424
- # Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.
425
- #
426
- macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
427
- set(_GPU_SUPPORTED_ARCHES_LIST ${GPU_SUPPORTED_ARCHES} ${ARGN})
428
- message(STATUS "${GPU_LANG} supported arches: ${_GPU_SUPPORTED_ARCHES_LIST}")
429
-
430
- if (${GPU_LANG} STREQUAL "HIP")
431
- #
432
- # `GPU_ARCHES` controls the `--offload-arch` flags.
433
- #
434
- # If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list,
435
- # if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling
436
- # "rocm_agent_enumerator" in "enable_language(HIP)"
437
- # (in file Modules/CMakeDetermineHIPCompiler.cmake)
438
- #
439
- if(DEFINED ENV{PYTORCH_ROCM_ARCH})
440
- set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH})
441
- else()
442
- set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES})
443
- endif()
444
- #
445
- # Find the intersection of the supported + detected architectures to
446
- # set the module architecture flags.
447
- #
448
- set(${GPU_ARCHES})
449
- foreach (_ARCH ${HIP_ARCHITECTURES})
450
- if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
451
- list(APPEND ${GPU_ARCHES} ${_ARCH})
452
- endif()
453
- endforeach()
454
-
455
- if(NOT ${GPU_ARCHES})
456
- message(FATAL_ERROR
457
- "None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
458
- " supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
459
- endif()
460
- endif()
461
- endmacro()
462
-
463
- #
464
- # Define a target named `GPU_MOD_NAME` for a single extension. The
465
- # arguments are:
466
- #
467
- # DESTINATION <dest> - Module destination directory.
468
- # LANGUAGE <lang> - The GPU language for this module, e.g CUDA, HIP,
469
- # etc.
470
- # SOURCES <sources> - List of source files relative to CMakeLists.txt
471
- # directory.
472
- #
473
- # Optional arguments:
474
- #
475
- # ARCHITECTURES <arches> - A list of target GPU architectures in cmake
476
- # format.
477
- # Refer `CMAKE_CUDA_ARCHITECTURES` documentation
478
- # and `CMAKE_HIP_ARCHITECTURES` for more info.
479
- # ARCHITECTURES will use cmake's defaults if
480
- # not provided.
481
- # COMPILE_FLAGS <flags> - Extra compiler flags passed to NVCC/hip.
482
- # INCLUDE_DIRECTORIES <dirs> - Extra include directories.
483
- # LIBRARIES <libraries> - Extra link libraries.
484
- # WITH_SOABI - Generate library with python SOABI suffix name.
485
- # USE_SABI <version> - Use python stable api <version>
486
- #
487
- # Note: optimization level/debug info is set via cmake build type.
488
- #
489
- function (define_gpu_extension_target GPU_MOD_NAME)
490
- cmake_parse_arguments(PARSE_ARGV 1
491
- GPU
492
- "WITH_SOABI"
493
- "DESTINATION;LANGUAGE;USE_SABI"
494
- "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
495
-
496
- # Add hipify preprocessing step when building with HIP/ROCm.
497
- if (GPU_LANGUAGE STREQUAL "HIP")
498
- hipify_sources_target(GPU_SOURCES ${GPU_MOD_NAME} "${GPU_SOURCES}")
499
- endif()
500
-
501
- if (GPU_WITH_SOABI)
502
- set(GPU_WITH_SOABI WITH_SOABI)
503
- else()
504
- set(GPU_WITH_SOABI)
505
- endif()
506
-
507
- if (GPU_USE_SABI)
508
- Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
509
- else()
510
- Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
511
- endif()
512
-
513
- if (GPU_LANGUAGE STREQUAL "HIP")
514
- # Make this target dependent on the hipify preprocessor step.
515
- add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME})
516
- endif()
517
-
518
- if (GPU_ARCHITECTURES)
519
- set_target_properties(${GPU_MOD_NAME} PROPERTIES
520
- ${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}")
521
- endif()
522
-
523
- set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17)
524
-
525
- target_compile_options(${GPU_MOD_NAME} PRIVATE
526
- $<$<COMPILE_LANGUAGE:${GPU_LANGUAGE}>:${GPU_COMPILE_FLAGS}>)
527
-
528
- target_compile_definitions(${GPU_MOD_NAME} PRIVATE
529
- "-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}")
530
-
531
- target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
532
- ${GPU_INCLUDE_DIRECTORIES})
533
-
534
- target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES})
535
-
536
- # Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of
537
- # dependencies that are not necessary and may not be installed.
538
- if (GPU_LANGUAGE STREQUAL "CUDA")
539
- target_link_libraries(${GPU_MOD_NAME} PRIVATE CUDA::cudart)
540
- else()
541
- target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
542
- endif()
543
-
544
- install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME})
545
- endfunction()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
layer_norm/ln.h DELETED
@@ -1,281 +0,0 @@
1
- #pragma once
2
-
3
- #include <unordered_map>
4
- #include <cuda_fp16.h>
5
- #include <cuda_bf16.h>
6
-
7
- #ifdef OLD_GENERATOR_PATH
8
- #include <ATen/CUDAGeneratorImpl.h>
9
- #else
10
- #include <ATen/cuda/CUDAGeneratorImpl.h>
11
- #endif
12
-
13
- namespace layer_norm {
14
-
15
- ////////////////////////////////////////////////////////////////////////////////////////////////////
16
-
17
- template<typename Params>
18
- struct LaunchParams{
19
-
20
- size_t elts_per_thread;
21
- size_t workspace_bytes;
22
- size_t barrier_size;
23
-
24
- cudaDeviceProp * props;
25
-
26
- cudaStream_t stream;
27
-
28
- Params params;
29
-
30
- };
31
-
32
- ////////////////////////////////////////////////////////////////////////////////////////////////////
33
-
34
- struct ParamsBase {
35
- ParamsBase()
36
- : ctas_per_col(0)
37
- , rows(0)
38
- , cols(0)
39
- , x(nullptr)
40
- , mu(nullptr)
41
- , rs(nullptr)
42
- , gamma(nullptr)
43
- , gamma1(nullptr)
44
- , rowscale(nullptr)
45
- , colscale(nullptr)
46
- , dropout_keep_p(1.f)
47
- , dropout_scale(1.f)
48
- , is_rms_norm(false)
49
- , workspace(nullptr)
50
- , barrier(nullptr)
51
- {
52
- }
53
-
54
- // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
55
- int ctas_per_col;
56
-
57
- // Input is interpreted as matrix. We normalize across columns.
58
- int rows;
59
- int cols;
60
-
61
- // Common data pointers.
62
- void *x0;
63
- void *x1;
64
- void *residual;
65
- void *x;
66
- void *dmask;
67
- void *dmask1;
68
- void *mu;
69
- void *rs;
70
- void *gamma;
71
- void *gamma1;
72
- void *rowscale;
73
- void *colscale;
74
- void *x0_subset;
75
- void *z_subset;
76
-
77
- float inverse_cols;
78
-
79
- float dropout_keep_p;
80
- float dropout_scale;
81
- float rowscale_const;
82
-
83
- bool is_rms_norm;
84
-
85
- // Multi-CTA workspace in gmem.
86
- void *workspace;
87
-
88
- // Multi-CTA sync barriers in gmem.
89
- int *barrier;
90
-
91
- };
92
-
93
- ////////////////////////////////////////////////////////////////////////////////////////////////////
94
-
95
- struct FwdParams : public ParamsBase {
96
- FwdParams()
97
- : ParamsBase()
98
- , z(nullptr)
99
- , z1(nullptr)
100
- , beta(nullptr)
101
- , beta1(nullptr)
102
- , epsilon(0.f)
103
- {
104
- }
105
-
106
- // Output of LN FWD.
107
- void *z;
108
- void *z1;
109
- void *beta;
110
- void *beta1;
111
- float epsilon;
112
-
113
- // Random state.
114
- at::PhiloxCudaState philox_args;
115
- };
116
-
117
- ////////////////////////////////////////////////////////////////////////////////////////////////////
118
-
119
- struct BwdParams : public ParamsBase {
120
- BwdParams()
121
- : ParamsBase()
122
- , dz(nullptr)
123
- , dz1(nullptr)
124
- , dx(nullptr)
125
- , dbeta_part(nullptr)
126
- , dgamma_part(nullptr)
127
- , dbeta1_part(nullptr)
128
- , dgamma1_part(nullptr)
129
- , dcolscale_part(nullptr)
130
- , dx0(nullptr)
131
- , dx1(nullptr)
132
- , dresidual(nullptr)
133
- , dbeta(nullptr)
134
- , dgamma(nullptr)
135
- , dbeta1(nullptr)
136
- , dgamma1(nullptr)
137
- , dcolscale(nullptr)
138
- {
139
- }
140
-
141
- // Input: gradient wrt. LN FWD output.
142
- void *dz;
143
- void *dz1;
144
- // Input: gradient wrt residual.
145
- void *dx;
146
-
147
- // Workspace for Wgrad pre-reduction.
148
- void *dbeta_part;
149
- void *dgamma_part;
150
- void *dbeta1_part;
151
- void *dgamma1_part;
152
- void *dcolscale_part;
153
-
154
- // Output: Dgrad.
155
- void *dx0;
156
- void *dx1;
157
- void *dresidual;
158
- // Output: Wgrad.
159
- void *dbeta;
160
- void *dgamma;
161
- void *dbeta1;
162
- void *dgamma1;
163
- void *dcolscale;
164
-
165
- };
166
-
167
- ////////////////////////////////////////////////////////////////////////////////////////////////////
168
-
169
- using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
170
- using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;
171
- using FunctionKey = uint64_t;
172
- using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
173
- using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;
174
-
175
- extern FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS;
176
- extern BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS;
177
-
178
- ////////////////////////////////////////////////////////////////////////////////////////////////////
179
-
180
- using fp32 = float;
181
- using fp16 = half;
182
- using bf16 = nv_bfloat16;
183
-
184
- ////////////////////////////////////////////////////////////////////////////////////////////////////
185
-
186
- template<typename T>
187
- struct TypeId{};
188
-
189
- template<>
190
- struct TypeId<fp16>{
191
- constexpr static uint32_t Value = 0;
192
- };
193
-
194
- template<>
195
- struct TypeId<bf16>{
196
- constexpr static uint32_t Value = 1;
197
- };
198
-
199
- template<>
200
- struct TypeId<fp32>{
201
- constexpr static uint32_t Value = 2;
202
- };
203
-
204
- ////////////////////////////////////////////////////////////////////////////////////////////////////
205
-
206
- template<typename T, int S>
207
- struct Type2Key{
208
- constexpr static uint32_t Value = TypeId<T>::Value << S;
209
- };
210
-
211
- ////////////////////////////////////////////////////////////////////////////////////////////////////
212
-
213
- template<typename T>
214
- struct WeightType2Key : public Type2Key<T, 0>{};
215
-
216
- template<typename T>
217
- struct InputType2Key : public Type2Key<T, 2>{};
218
-
219
- template<typename T>
220
- struct ResidualType2Key : public Type2Key<T, 4>{};
221
-
222
- template<typename T>
223
- struct OutputType2Key : public Type2Key<T, 6>{};
224
-
225
- template<typename T>
226
- struct ComputeType2Key : public Type2Key<T, 8>{};
227
-
228
- ////////////////////////////////////////////////////////////////////////////////////////////////////
229
-
230
- template<typename W, typename I, typename R, typename O, typename C>
231
- struct Types2Key{
232
- constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | ResidualType2Key<R>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
233
- constexpr static inline uint64_t get(const uint64_t hidden_size){
234
- constexpr uint64_t type_key = Value;
235
- return (type_key << 32) | hidden_size;
236
- }
237
- };
238
-
239
- ////////////////////////////////////////////////////////////////////////////////////////////////////
240
-
241
- template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
242
- struct FwdRegistrar{
243
- FwdRegistrar(FwdFunction f){
244
- uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
245
- FWD_FUNCS.insert({ key, f });
246
- }
247
- };
248
-
249
- ////////////////////////////////////////////////////////////////////////////////////////////////////
250
-
251
- template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
252
- struct BwdRegistrar{
253
- BwdRegistrar(BwdFunction f){
254
- uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
255
- BWD_FUNCS.insert({ key, f });
256
- }
257
- };
258
-
259
- ////////////////////////////////////////////////////////////////////////////////////////////////////
260
-
261
- template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
262
- struct FwdParallelRegistrar{
263
- FwdParallelRegistrar(FwdFunction f){
264
- uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
265
- PARALLEL_FWD_FUNCS.insert({ key, f });
266
- }
267
- };
268
-
269
- ////////////////////////////////////////////////////////////////////////////////////////////////////
270
-
271
- template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
272
- struct BwdParallelRegistrar{
273
- BwdParallelRegistrar(BwdFunction f){
274
- uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
275
- PARALLEL_BWD_FUNCS.insert({ key, f });
276
- }
277
- };
278
-
279
- ////////////////////////////////////////////////////////////////////////////////////////////////////
280
-
281
- } // namespace layer_norm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
layer_norm/ln_api.cpp DELETED
@@ -1,828 +0,0 @@
1
- #include <torch/torch.h>
2
- #include "ATen/cuda/CUDAContext.h"
3
- #include <c10/cuda/CUDAGuard.h>
4
-
5
- #include "ln.h"
6
-
7
- /*
8
-
9
- Supported Type combinations:
10
-
11
- input residual compute weights output
12
- ============================================
13
- fp32 fp32 fp32 fp32 fp32
14
- fp16 fp32 fp32 fp32 fp16
15
- fp16 fp16 fp32 fp32 fp16
16
- bf16 fp32 fp32 fp32 bf16
17
- bf16 bf16 fp32 fp32 bf16
18
- fp16 fp16 fp32 fp16 fp16
19
- bf16 bf16 fp32 bf16 bf16
20
-
21
- Remarks:
22
- Output type = Input type
23
- Compute always in FP32
24
-
25
- */
26
-
27
- namespace layer_norm {
28
-
29
- // Create registries and provide runtime versions of config hash functions.
30
-
31
- FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS;
32
- BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS;
33
-
34
- ////////////////////////////////////////////////////////////////////////////////////////////////////
35
-
36
- uint32_t get_type_id(torch::Dtype dtype){
37
- if( dtype == torch::kFloat16 ) {
38
- return TypeId<fp16>::Value;
39
- } else if( dtype == torch::kBFloat16 ) {
40
- return TypeId<bf16>::Value;
41
- } else if( dtype == torch::kFloat32 ) {
42
- return TypeId<fp32>::Value;
43
- } else {
44
- TORCH_CHECK(false, "Type not supported: ", dtype);
45
- }
46
- }
47
-
48
- ////////////////////////////////////////////////////////////////////////////////////////////////////
49
-
50
- uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) {
51
- using namespace layer_norm;
52
- uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(rtype) << 4) | (get_type_id(otype) << 6) | (get_type_id(ctype) << 8);
53
- uint64_t launcher_key = (type_key << 32) | hidden_size;
54
- return launcher_key;
55
- }
56
-
57
- } // namespace layer_norm
58
-
59
- ////////////////////////////////////////////////////////////////////////////////////////////////////
60
-
61
- layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
62
- auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
63
- if( iter != layer_norm::FWD_FUNCS.end() ) {
64
- return iter->second;
65
- } else {
66
- TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
67
- }
68
- }
69
-
70
- ////////////////////////////////////////////////////////////////////////////////////////////////////
71
-
72
- layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
73
- auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
74
- if( iter != layer_norm::BWD_FUNCS.end() ) {
75
- return iter->second;
76
- } else {
77
- TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
78
- }
79
- }
80
-
81
- ////////////////////////////////////////////////////////////////////////////////////////////////////
82
-
83
- layer_norm::FwdFunction & get_parallel_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
84
- auto iter = layer_norm::PARALLEL_FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
85
- if( iter != layer_norm::PARALLEL_FWD_FUNCS.end() ) {
86
- return iter->second;
87
- } else {
88
- TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
89
- }
90
- }
91
-
92
- ////////////////////////////////////////////////////////////////////////////////////////////////////
93
-
94
- layer_norm::BwdFunction & get_parallel_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
95
- auto iter = layer_norm::PARALLEL_BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
96
- if( iter != layer_norm::PARALLEL_BWD_FUNCS.end() ) {
97
- return iter->second;
98
- } else {
99
- TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
100
- }
101
- }
102
-
103
- ////////////////////////////////////////////////////////////////////////////////////////////////////
104
-
105
- std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size
106
- c10::optional<const at::Tensor> &residual_, // Residual: BxSxhidden_size
107
- const at::Tensor &gamma, // hidden_size
108
- c10::optional<const at::Tensor> &beta_, // hidden_size
109
- c10::optional<const at::Tensor> &rowscale_, // BxS
110
- c10::optional<const at::Tensor> &colscale_, // hidden_size
111
- c10::optional<const at::Tensor> &x0_subset_, // BxS
112
- c10::optional<const at::Tensor> &z_subset_, // BxS
113
- const float dropout_p,
114
- const float epsilon,
115
- const float rowscale_const,
116
- const int64_t z_numrows,
117
- c10::optional<at::Generator> gen_,
118
- bool residual_in_fp32=false,
119
- bool is_rms_norm=false
120
- ) {
121
- auto itype = x0.scalar_type();
122
- auto rtype = residual_.has_value()
123
- ? residual_.value().scalar_type()
124
- : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
125
- auto wtype = gamma.scalar_type();
126
- auto otype = itype;
127
- auto ctype = torch::kFloat32;
128
- auto mtype = torch::kUInt8;
129
-
130
- TORCH_CHECK(x0.is_cuda());
131
- TORCH_CHECK(gamma.is_cuda());
132
-
133
- TORCH_CHECK(x0.is_contiguous());
134
- // c10::IntArrayRef does not own the storage, so we need to construct a vector.
135
- // Otherwise just constructing IntArrayRef({blah}) will cause uninitialized memory because
136
- // blah is then deallocated.
137
- std::vector<int64_t> sizes_vec {!x0_subset_.has_value() ? x0.size(0) : x0_subset_.value().size(0), x0.size(1)};
138
- auto sizes = c10::IntArrayRef(sizes_vec);
139
- TORCH_CHECK(x0.dim() == 2);
140
- TORCH_CHECK(sizes.size() == 2);
141
-
142
- const int rows = sizes[0];
143
- const int cols = sizes[1];
144
- auto hidden_size = gamma.numel();
145
- TORCH_CHECK(hidden_size == cols);
146
-
147
- if (beta_.has_value()) {
148
- auto beta = beta_.value();
149
- TORCH_CHECK(beta.dtype() == wtype);
150
- TORCH_CHECK(beta.is_cuda());
151
- TORCH_CHECK(beta.is_contiguous());
152
- TORCH_CHECK(beta.sizes() == gamma.sizes());
153
- }
154
-
155
- if (residual_.has_value()) {
156
- auto residual = residual_.value();
157
- TORCH_CHECK(residual.is_cuda());
158
- TORCH_CHECK(residual.is_contiguous());
159
- TORCH_CHECK(residual.sizes() == sizes);
160
- }
161
-
162
- if (rowscale_.has_value()) {
163
- auto rowscale = rowscale_.value();
164
- TORCH_CHECK(rowscale.is_cuda());
165
- TORCH_CHECK(rowscale.is_contiguous());
166
- TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});
167
- TORCH_CHECK(rowscale.dtype() == itype);
168
- }
169
-
170
- if (colscale_.has_value()) {
171
- auto colscale = colscale_.value();
172
- TORCH_CHECK(colscale.is_cuda());
173
- TORCH_CHECK(colscale.is_contiguous());
174
- TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});
175
- TORCH_CHECK(colscale.dtype() == wtype);
176
- }
177
-
178
- if (x0_subset_.has_value()) {
179
- auto x0_subset = x0_subset_.value();
180
- TORCH_CHECK(x0_subset.is_cuda());
181
- TORCH_CHECK(x0_subset.is_contiguous());
182
- TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});
183
- TORCH_CHECK(x0_subset.dtype() == torch::kInt32);
184
-
185
- TORCH_CHECK(z_subset_.has_value());
186
- auto z_subset = z_subset_.value();
187
- TORCH_CHECK(z_subset.is_cuda());
188
- TORCH_CHECK(z_subset.is_contiguous());
189
- TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows});
190
- TORCH_CHECK(z_subset.dtype() == torch::kInt32);
191
- }
192
-
193
- TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
194
- TORCH_CHECK(epsilon >= 0.f);
195
-
196
- // Otherwise the kernel will be launched from cuda:0 device
197
- // Cast to char to avoid compiler warning about narrowing
198
- at::cuda::CUDAGuard device_guard{(char)x0.get_device()};
199
-
200
- auto opts = x0.options();
201
-
202
- bool save_x = residual_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype);
203
- at::Tensor x;
204
- if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
205
- at::Tensor dmask;
206
- if (dropout_p > 0.f) { dmask = torch::empty(x0.sizes(), opts.dtype(mtype)); };
207
- auto z = torch::empty(z_subset_.has_value() ? c10::IntArrayRef{z_numrows, cols} : sizes, opts.dtype(otype));
208
-
209
- auto mu = torch::empty({ rows }, opts.dtype(ctype));
210
- auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
211
-
212
- layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
213
-
214
- launch_params.props = at::cuda::getCurrentDeviceProperties();
215
- launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
216
- TORCH_CHECK(dropout_p < 1.f);
217
- launch_params.params.dropout_keep_p = 1.f - dropout_p;
218
- launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr;
219
- launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
220
- launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
221
- launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
222
- launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr;
223
-
224
- auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
225
- gen_, at::cuda::detail::getDefaultCUDAGenerator());
226
-
227
- auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
228
- const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
229
- // Request the kernel launcher.
230
- auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
231
-
232
- // Set the kernel runtime parameters.
233
- layer_norm::FwdParams &params = launch_params.params;
234
- params.rows = rows;
235
- params.cols = cols;
236
- params.x0 = x0.data_ptr();
237
- params.x = save_x ? x.data_ptr() : nullptr;
238
- params.dmask = dropout_p > 0.f ? dmask.data_ptr() : nullptr;
239
- params.mu = mu.data_ptr();
240
- params.rs = rsigma.data_ptr();
241
- params.gamma = gamma.data_ptr();
242
- params.beta = beta_.has_value() ? beta_.value().data_ptr() : nullptr;
243
- params.z = z.data_ptr();
244
- params.epsilon = epsilon;
245
- params.dropout_scale = 1.f / (1.f - dropout_p);
246
- params.inverse_cols = 1.f / float(params.cols);
247
- params.rowscale_const = rowscale_const;
248
- params.is_rms_norm = is_rms_norm;
249
-
250
- // Query the kernel-specific launch parameters.
251
- launcher(launch_params, true);
252
-
253
- at::Tensor workspace, barrier;
254
-
255
- if (dropout_p > 0.f) {
256
- // number of times random will be generated per thread, to offset philox counter in thc random
257
- // state
258
- int64_t counter_offset = launch_params.elts_per_thread;
259
-
260
- // See Note [Acquire lock when using random generators]
261
- {
262
- std::lock_guard<std::mutex> lock(gen->mutex_);
263
- params.philox_args = gen->philox_cuda_state(counter_offset);
264
- }
265
- }
266
-
267
- if( launch_params.barrier_size > 0 ) {
268
- auto options = x0.options();
269
- barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
270
- workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
271
- params.workspace = workspace.data_ptr();
272
- params.barrier = barrier.data_ptr<int>();
273
- }
274
-
275
- // Launch the kernel.
276
- launcher(launch_params, false);
277
-
278
- return { z, x, dmask, mu, rsigma };
279
- }
280
-
281
- ////////////////////////////////////////////////////////////////////////////////////////////////////
282
-
283
- std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidden_size
284
- c10::optional<const at::Tensor> &dx_, // BxSxhidden_size
285
- const at::Tensor &x, // BxSxhidden_size
286
- c10::optional<const at::Tensor> &x0_, // BxSxhidden_size
287
- c10::optional<const at::Tensor> &dmask_, // BxSxhidden_size
288
- const at::Tensor &mu, // BxS, FP32!
289
- const at::Tensor &rsigma, // BxS, FP32!
290
- const at::Tensor &gamma, // hidden_size
291
- c10::optional<const at::Tensor> &rowscale_, // BxS
292
- c10::optional<const at::Tensor> &colscale_, // hidden_size
293
- c10::optional<const at::Tensor> &x0_subset_, // BxS
294
- c10::optional<const at::Tensor> &z_subset_, // BxS
295
- const float dropout_p,
296
- const float rowscale_const,
297
- const int64_t x0_numrows,
298
- const bool has_residual,
299
- bool is_rms_norm=false
300
- ) {
301
-
302
- auto itype = dz.scalar_type();
303
- auto rtype = x.scalar_type();
304
- auto wtype = gamma.scalar_type();
305
- auto otype = itype;
306
- auto ctype = torch::kFloat32;
307
- auto mtype = torch::kUInt8;
308
-
309
- if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); }
310
-
311
- TORCH_CHECK(dz.dtype() == otype);
312
- TORCH_CHECK(mu.dtype() == ctype);
313
- TORCH_CHECK(rsigma.dtype() == ctype);
314
-
315
- TORCH_CHECK(x.is_cuda());
316
- TORCH_CHECK(dz.is_cuda());
317
- TORCH_CHECK(mu.is_cuda());
318
- TORCH_CHECK(rsigma.is_cuda());
319
- TORCH_CHECK(gamma.is_cuda());
320
-
321
- TORCH_CHECK(x.is_contiguous());
322
- TORCH_CHECK(dz.is_contiguous());
323
-
324
- auto sizes = x.sizes();
325
- TORCH_CHECK(sizes.size() == 2);
326
- auto rows = sizes[0];
327
- auto cols = sizes[1];
328
- TORCH_CHECK(dz.dim() == 2);
329
- TORCH_CHECK(dz.size(1) == cols);
330
- auto hidden_size = gamma.numel();
331
- TORCH_CHECK(hidden_size == cols);
332
-
333
- // c10::IntArrayRef does not own the storage, so we need to construct a vector.
334
- // Otherwise just constructing IntArrayRef({blah}) will cause uninitialized memory because
335
- // blah is then deallocated.
336
- std::vector<int64_t> x0_sizes_vec {!x0_subset_.has_value() ? rows : x0_numrows, cols};
337
- auto x0_sizes = c10::IntArrayRef(x0_sizes_vec);
338
-
339
- if (dx_.has_value()) {
340
- auto dx = dx_.value();
341
- TORCH_CHECK(dx.dtype() == rtype);
342
- TORCH_CHECK(dx.is_cuda());
343
- TORCH_CHECK(dx.is_contiguous());
344
- TORCH_CHECK(dx.sizes() == sizes);
345
- }
346
-
347
- if (dmask_.has_value()) {
348
- auto dmask = dmask_.value();
349
- TORCH_CHECK(dmask.dtype() == mtype);
350
- TORCH_CHECK(dmask.is_cuda());
351
- TORCH_CHECK(dmask.is_contiguous());
352
- TORCH_CHECK(dmask.sizes() == x0_sizes);
353
- }
354
-
355
- if (rowscale_.has_value()) {
356
- auto rowscale = rowscale_.value();
357
- TORCH_CHECK(rowscale.is_cuda());
358
- TORCH_CHECK(rowscale.is_contiguous());
359
- TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});
360
- TORCH_CHECK(rowscale.dtype() == itype);
361
- }
362
-
363
- if (colscale_.has_value()) {
364
- auto colscale = colscale_.value();
365
- TORCH_CHECK(colscale.is_cuda());
366
- TORCH_CHECK(colscale.is_contiguous());
367
- TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});
368
- TORCH_CHECK(colscale.dtype() == wtype);
369
-
370
- TORCH_CHECK(x0_.has_value());
371
- auto x0 = x0_.value();
372
- TORCH_CHECK(x0.is_cuda());
373
- TORCH_CHECK(x0.is_contiguous());
374
- TORCH_CHECK(x0.sizes() == x0_sizes);
375
- TORCH_CHECK(x0.dtype() == itype);
376
- }
377
-
378
- if (x0_subset_.has_value()) {
379
- auto x0_subset = x0_subset_.value();
380
- TORCH_CHECK(x0_subset.is_cuda());
381
- TORCH_CHECK(x0_subset.is_contiguous());
382
- TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});
383
- TORCH_CHECK(x0_subset.dtype() == torch::kInt32);
384
-
385
- TORCH_CHECK(z_subset_.has_value());
386
- auto z_subset = z_subset_.value();
387
- TORCH_CHECK(z_subset.is_cuda());
388
- TORCH_CHECK(z_subset.is_contiguous());
389
- TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows});
390
- TORCH_CHECK(z_subset.dtype() == torch::kInt32);
391
- }
392
-
393
- TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
394
-
395
- TORCH_CHECK(mu.numel() == rows);
396
- TORCH_CHECK(mu.sizes() == rsigma.sizes());
397
-
398
- TORCH_CHECK(gamma.numel() == cols);
399
-
400
- // Otherwise the kernel will be launched from cuda:0 device
401
- // Cast to char to avoid compiler warning about narrowing
402
- at::cuda::CUDAGuard device_guard{(char)dz.get_device()};
403
-
404
- auto opts = x.options();
405
-
406
- auto dx0 = torch::empty(x0_sizes, opts.dtype(itype));
407
- at::Tensor dresidual;
408
- if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); }
409
- auto dgamma = torch::empty_like(gamma);
410
- auto dbeta = torch::empty_like(gamma);
411
- at::Tensor dcolscale;
412
- if (colscale_.has_value()) {
413
- dcolscale = torch::empty_like(colscale_.value());
414
- }
415
-
416
- layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
417
- launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
418
- launch_params.props = at::cuda::getCurrentDeviceProperties();
419
- TORCH_CHECK(dropout_p < 1.f);
420
- launch_params.params.dropout_keep_p = 1.f - dropout_p;
421
- launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr;
422
- launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
423
- launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
424
- launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
425
- launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr;
426
-
427
- auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
428
- const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
429
- auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
430
-
431
- launcher(launch_params, true);
432
-
433
- auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
434
- auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
435
- at::Tensor dcolscale_part;
436
- if (colscale_.has_value()) {
437
- dcolscale_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
438
- }
439
- at::Tensor workspace, barrier;
440
-
441
- layer_norm::BwdParams &params = launch_params.params;
442
- params.rows = rows;
443
- params.cols = cols;
444
- params.x = x.data_ptr();
445
- params.x0 = x0_.has_value() ? x0_.value().data_ptr() : nullptr;
446
- params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
447
- params.mu = mu.data_ptr();
448
- params.rs = rsigma.data_ptr();
449
- params.gamma = gamma.data_ptr();
450
- params.dz = dz.data_ptr();
451
- params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr;
452
- params.dx0 = dx0.data_ptr();
453
- params.dbeta = dbeta.data_ptr();
454
- params.dgamma = dgamma.data_ptr();
455
- params.dcolscale = colscale_.has_value() ? dcolscale.data_ptr() : nullptr;
456
- params.dbeta_part = dbeta_part.data_ptr();
457
- params.dgamma_part = dgamma_part.data_ptr();
458
- params.dcolscale_part = colscale_.has_value() ? dcolscale_part.data_ptr() : nullptr;
459
- params.dropout_scale = 1.f / (1.f - dropout_p);
460
- params.inverse_cols = 1.f / float(params.cols);
461
- params.rowscale_const = rowscale_const;
462
- params.is_rms_norm = is_rms_norm;
463
-
464
- if( launch_params.barrier_size > 0 ) {
465
- // TODO Any way to avoid this?
466
- barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
467
- workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
468
- params.workspace = workspace.data_ptr();
469
- params.barrier = barrier.data_ptr<int>();
470
- }
471
-
472
- launcher(launch_params, false);
473
-
474
- std::vector<at::Tensor> result = { dx0, dresidual, dgamma, dbeta, dgamma_part, dbeta_part };
475
- if (colscale_.has_value()) {
476
- result.push_back(dcolscale);
477
- result.push_back(dcolscale_part);
478
- }
479
- return result;
480
- }
481
-
482
- ////////////////////////////////////////////////////////////////////////////////////////////////////
483
-
484
- std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(
485
- const at::Tensor &x0, // Input: BxSxhidden_size
486
- c10::optional<const at::Tensor> &x1_, // Input: BxSxhidden_size
487
- c10::optional<const at::Tensor> &residual_, // Residual: BxSxhidden_size
488
- const at::Tensor &gamma0, // hidden_size
489
- c10::optional<const at::Tensor> &beta0_, // hidden_size
490
- c10::optional<const at::Tensor> &gamma1_, // hidden_size
491
- c10::optional<const at::Tensor> &beta1_, // hidden_size
492
- const float dropout_p,
493
- const float epsilon,
494
- c10::optional<at::Generator> gen_,
495
- bool residual_in_fp32=false,
496
- bool is_rms_norm=false
497
- ) {
498
- auto itype = x0.scalar_type();
499
- auto rtype = residual_.has_value()
500
- ? residual_.value().scalar_type()
501
- : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
502
- auto wtype = gamma0.scalar_type();
503
- auto otype = itype;
504
- auto ctype = torch::kFloat32;
505
- auto mtype = torch::kUInt8;
506
-
507
- TORCH_CHECK(x0.is_cuda());
508
- TORCH_CHECK(gamma0.is_cuda());
509
-
510
- TORCH_CHECK(x0.is_contiguous());
511
- const auto sizes = x0.sizes();
512
- TORCH_CHECK(x0.dim() == 2);
513
-
514
- const int rows = sizes[0];
515
- const int cols = sizes[1];
516
- auto hidden_size = gamma0.numel();
517
- TORCH_CHECK(hidden_size == cols);
518
-
519
- if (x1_.has_value()) {
520
- auto x1 = x1_.value();
521
- TORCH_CHECK(x1.is_cuda());
522
- TORCH_CHECK(x1.is_contiguous());
523
- TORCH_CHECK(x1.sizes() == sizes);
524
- }
525
-
526
- if (residual_.has_value()) {
527
- auto residual = residual_.value();
528
- TORCH_CHECK(residual.is_cuda());
529
- TORCH_CHECK(residual.is_contiguous());
530
- TORCH_CHECK(residual.sizes() == sizes);
531
- }
532
-
533
- if (beta0_.has_value()) {
534
- auto beta0 = beta0_.value();
535
- TORCH_CHECK(beta0.dtype() == wtype);
536
- TORCH_CHECK(beta0.is_cuda());
537
- TORCH_CHECK(beta0.is_contiguous());
538
- TORCH_CHECK(beta0.sizes() == gamma0.sizes());
539
- }
540
-
541
- if (gamma1_.has_value()) {
542
- auto gamma1 = gamma1_.value();
543
- TORCH_CHECK(gamma1.dtype() == wtype);
544
- TORCH_CHECK(gamma1.is_cuda());
545
- TORCH_CHECK(gamma1.is_contiguous());
546
- TORCH_CHECK(gamma1.sizes() == gamma0.sizes());
547
- }
548
-
549
- if (beta1_.has_value()) {
550
- auto beta1 = beta1_.value();
551
- TORCH_CHECK(beta1.dtype() == wtype);
552
- TORCH_CHECK(beta1.is_cuda());
553
- TORCH_CHECK(beta1.is_contiguous());
554
- TORCH_CHECK(beta1.sizes() == gamma0.sizes());
555
- }
556
-
557
- TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
558
- TORCH_CHECK(epsilon >= 0.f);
559
-
560
- // Otherwise the kernel will be launched from cuda:0 device
561
- // Cast to char to avoid compiler warning about narrowing
562
- at::cuda::CUDAGuard device_guard{(char)x0.get_device()};
563
-
564
- auto opts = x0.options();
565
-
566
- bool save_x = residual_.has_value() || x1_.has_value() || (dropout_p > 0.f) || (itype != rtype);
567
- at::Tensor x;
568
- if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
569
- at::Tensor dmask0, dmask1;
570
- if (dropout_p > 0.f) {
571
- dmask0 = torch::empty(x0.sizes(), opts.dtype(mtype));
572
- if (x1_.has_value()) { dmask1 = torch::empty(x0.sizes(), opts.dtype(mtype)); }
573
- };
574
- auto z0 = torch::empty(sizes, opts.dtype(otype));
575
- at::Tensor z1;
576
- if (gamma1_.has_value()) { z1 = torch::empty(sizes, opts.dtype(otype)); }
577
-
578
- auto mu = torch::empty({ rows }, opts.dtype(ctype));
579
- auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
580
-
581
- layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
582
-
583
- launch_params.props = at::cuda::getCurrentDeviceProperties();
584
- launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
585
- TORCH_CHECK(dropout_p < 1.f);
586
- launch_params.params.dropout_keep_p = 1.f - dropout_p;
587
- launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr;
588
-
589
- auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
590
- gen_, at::cuda::detail::getDefaultCUDAGenerator());
591
-
592
- auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
593
- const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
594
- // Request the kernel launcher.
595
- auto launcher = get_parallel_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
596
-
597
- // Set the kernel runtime parameters.
598
- layer_norm::FwdParams &params = launch_params.params;
599
- params.rows = rows;
600
- params.cols = cols;
601
- params.x0 = x0.data_ptr();
602
- params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
603
- params.x = save_x ? x.data_ptr() : nullptr;
604
- params.dmask = dropout_p > 0.f ? dmask0.data_ptr() : nullptr;
605
- params.dmask1 = (dropout_p > 0.f && x1_.has_value()) ? dmask1.data_ptr() : nullptr;
606
- params.mu = mu.data_ptr();
607
- params.rs = rsigma.data_ptr();
608
- params.gamma = gamma0.data_ptr();
609
- params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr;
610
- params.beta = beta0_.has_value() ? beta0_.value().data_ptr() : nullptr;
611
- params.beta1 = beta1_.has_value() ? beta1_.value().data_ptr() : nullptr;
612
- params.z = z0.data_ptr();
613
- params.z1 = gamma1_.has_value() ? z1.data_ptr() : nullptr;
614
- params.epsilon = epsilon;
615
- params.dropout_scale = 1.f / (1.f - dropout_p);
616
- params.inverse_cols = 1.f / float(params.cols);
617
- params.is_rms_norm = is_rms_norm;
618
-
619
- // Query the kernel-specific launch parameters.
620
- launcher(launch_params, true);
621
-
622
- at::Tensor workspace, barrier;
623
-
624
- if (dropout_p > 0.f) {
625
- // number of times random will be generated per thread, to offset philox counter in thc random
626
- // state
627
- int64_t counter_offset = 2 * launch_params.elts_per_thread;
628
-
629
- // See Note [Acquire lock when using random generators]
630
- {
631
- std::lock_guard<std::mutex> lock(gen->mutex_);
632
- params.philox_args = gen->philox_cuda_state(counter_offset);
633
- }
634
- }
635
-
636
- if( launch_params.barrier_size > 0 ) {
637
- auto options = x0.options();
638
- barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
639
- workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
640
- params.workspace = workspace.data_ptr();
641
- params.barrier = barrier.data_ptr<int>();
642
- }
643
-
644
- // Launch the kernel.
645
- launcher(launch_params, false);
646
-
647
- return { z0, z1, x, dmask0, dmask1, mu, rsigma };
648
- }
649
-
650
- ////////////////////////////////////////////////////////////////////////////////////////////////////
651
-
652
- std::vector<at::Tensor> dropout_add_ln_parallel_residual_bwd(
653
- const at::Tensor &dz0, // BxSxhidden_size
654
- c10::optional<const at::Tensor> &dz1_, // BxSxhidden_size
655
- c10::optional<const at::Tensor> &dx_, // BxSxhidden_size
656
- const at::Tensor &x, // BxSxhidden_size
657
- c10::optional<const at::Tensor> &dmask0_, // BxSxhidden_size
658
- c10::optional<const at::Tensor> &dmask1_, // BxSxhidden_size
659
- const at::Tensor &mu, // BxS, FP32!
660
- const at::Tensor &rsigma, // BxS, FP32!
661
- const at::Tensor &gamma0, // hidden_size
662
- c10::optional<const at::Tensor> &gamma1_, // hidden_size
663
- const float dropout_p,
664
- const bool has_x1,
665
- const bool has_residual,
666
- bool is_rms_norm=false
667
- ) {
668
-
669
- auto itype = dz0.scalar_type();
670
- auto rtype = x.scalar_type();
671
- auto wtype = gamma0.scalar_type();
672
- auto otype = itype;
673
- auto ctype = torch::kFloat32;
674
- auto mtype = torch::kUInt8;
675
-
676
- if (dropout_p > 0.f) { TORCH_CHECK(dmask0_.has_value()); }
677
-
678
- TORCH_CHECK(dz0.dtype() == otype);
679
- TORCH_CHECK(dz0.dtype() == otype);
680
- TORCH_CHECK(mu.dtype() == ctype);
681
- TORCH_CHECK(rsigma.dtype() == ctype);
682
-
683
- TORCH_CHECK(x.is_cuda());
684
- TORCH_CHECK(dz0.is_cuda());
685
- TORCH_CHECK(mu.is_cuda());
686
- TORCH_CHECK(rsigma.is_cuda());
687
- TORCH_CHECK(gamma0.is_cuda());
688
-
689
- TORCH_CHECK(x.is_contiguous());
690
- TORCH_CHECK(dz0.is_contiguous());
691
-
692
- auto sizes = x.sizes();
693
- TORCH_CHECK(sizes.size() == 2);
694
- auto rows = sizes[0];
695
- auto cols = sizes[1];
696
- TORCH_CHECK(dz0.dim() == 2);
697
- TORCH_CHECK(dz0.size(1) == cols);
698
- auto hidden_size = gamma0.numel();
699
- TORCH_CHECK(hidden_size == cols);
700
-
701
- if (dz1_.has_value()) {
702
- auto dz1 = dz1_.value();
703
- TORCH_CHECK(dz1.dtype() == otype);
704
- TORCH_CHECK(dz1.is_cuda());
705
- TORCH_CHECK(dz1.is_contiguous());
706
- TORCH_CHECK(dz1.sizes() == sizes);
707
-
708
- TORCH_CHECK(gamma1_.has_value());
709
- auto gamma1 = gamma1_.value();
710
- TORCH_CHECK(gamma1.dtype() == wtype);
711
- TORCH_CHECK(gamma1.is_cuda());
712
- TORCH_CHECK(gamma1.is_contiguous());
713
- TORCH_CHECK(gamma1.sizes() == gamma0.sizes());
714
- }
715
-
716
- if (dx_.has_value()) {
717
- auto dx = dx_.value();
718
- TORCH_CHECK(dx.dtype() == rtype);
719
- TORCH_CHECK(dx.is_cuda());
720
- TORCH_CHECK(dx.is_contiguous());
721
- TORCH_CHECK(dx.sizes() == sizes);
722
- }
723
-
724
- if (dmask0_.has_value()) {
725
- auto dmask0 = dmask0_.value();
726
- TORCH_CHECK(dmask0.dtype() == mtype);
727
- TORCH_CHECK(dmask0.is_cuda());
728
- TORCH_CHECK(dmask0.is_contiguous());
729
- TORCH_CHECK(dmask0.sizes() == sizes);
730
-
731
- if (has_x1) {
732
- TORCH_CHECK(dmask1_.has_value());
733
- auto dmask1 = dmask1_.value();
734
- TORCH_CHECK(dmask1.dtype() == mtype);
735
- TORCH_CHECK(dmask1.is_cuda());
736
- TORCH_CHECK(dmask1.is_contiguous());
737
- TORCH_CHECK(dmask1.sizes() == sizes);
738
- }
739
- }
740
-
741
- TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
742
-
743
- TORCH_CHECK(mu.numel() == rows);
744
- TORCH_CHECK(mu.sizes() == rsigma.sizes());
745
-
746
- // Otherwise the kernel will be launched from cuda:0 device
747
- // Cast to char to avoid compiler warning about narrowing
748
- at::cuda::CUDAGuard device_guard{(char)dz0.get_device()};
749
-
750
- auto opts = x.options();
751
-
752
- auto dx0 = torch::empty(sizes, opts.dtype(itype));
753
- at::Tensor dx1;
754
- if (has_x1) { dx1 = torch::empty(sizes, opts.dtype(itype)); }
755
- at::Tensor dresidual;
756
- if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); }
757
- auto dgamma0 = torch::empty_like(gamma0);
758
- auto dbeta0 = torch::empty_like(gamma0);
759
- at::Tensor dgamma1, dbeta1;
760
- if (gamma1_.has_value()) {
761
- dgamma1 = torch::empty_like(gamma0);
762
- dbeta1 = torch::empty_like(gamma0);
763
- }
764
-
765
- layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
766
- launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
767
- launch_params.props = at::cuda::getCurrentDeviceProperties();
768
- TORCH_CHECK(dropout_p < 1.f);
769
- launch_params.params.dropout_keep_p = 1.f - dropout_p;
770
- launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr;
771
-
772
- auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
773
- const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
774
- auto launcher = get_parallel_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
775
-
776
- launcher(launch_params, true);
777
-
778
- auto dgamma0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
779
- auto dbeta0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
780
- at::Tensor dgamma1_part, dbeta1_part;
781
- if (gamma1_.has_value()) {
782
- dgamma1_part = torch::zeros_like(dgamma0_part);
783
- dbeta1_part = torch::zeros_like(dbeta0_part);
784
- }
785
- at::Tensor workspace, barrier;
786
-
787
- layer_norm::BwdParams &params = launch_params.params;
788
- params.rows = rows;
789
- params.cols = cols;
790
- params.x = x.data_ptr();
791
- params.dmask = dropout_p > 0.f ? dmask0_.value().data_ptr() : nullptr;
792
- params.dmask1 = (dropout_p > 0.f && has_x1) ? dmask1_.value().data_ptr() : nullptr;
793
- params.mu = mu.data_ptr();
794
- params.rs = rsigma.data_ptr();
795
- params.gamma = gamma0.data_ptr();
796
- params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr;
797
- params.dz = dz0.data_ptr();
798
- params.dz1 = dz1_.has_value() ? dz1_.value().data_ptr() : nullptr;
799
- params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr;
800
- params.dx0 = dx0.data_ptr();
801
- params.dx1 = has_x1 ? dx1.data_ptr() : nullptr;
802
- params.dbeta = dbeta0.data_ptr();
803
- params.dgamma = dgamma0.data_ptr();
804
- params.dbeta1 = gamma1_.has_value() ? dbeta1.data_ptr() : nullptr;
805
- params.dgamma1 = gamma1_.has_value() ? dgamma1.data_ptr() : nullptr;
806
- params.dbeta_part = dbeta0_part.data_ptr();
807
- params.dgamma_part = dgamma0_part.data_ptr();
808
- params.dbeta1_part = gamma1_.has_value() ? dbeta1_part.data_ptr() : nullptr;
809
- params.dgamma1_part = gamma1_.has_value() ? dgamma1_part.data_ptr() : nullptr;
810
- params.dropout_scale = 1.f / (1.f - dropout_p);
811
- params.inverse_cols = 1.f / float(params.cols);
812
- params.is_rms_norm = is_rms_norm;
813
-
814
- if( launch_params.barrier_size > 0 ) {
815
- // TODO Any way to avoid this?
816
- barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
817
- workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
818
- params.workspace = workspace.data_ptr();
819
- params.barrier = barrier.data_ptr<int>();
820
- }
821
-
822
- launcher(launch_params, false);
823
-
824
- std::vector<at::Tensor> result = { dx0, dx1, dresidual, dgamma0, dbeta0, dgamma1, dbeta1, dgamma0_part, dbeta0_part, dgamma1_part, dbeta1_part };
825
- return result;
826
- }
827
-
828
- ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
layer_norm/ln_bwd_1024.cu DELETED
@@ -1,15 +0,0 @@
1
- #include "ln_bwd_kernels.cuh"
2
-
3
- // Create backward launch function and register. Macro signature:
4
- // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
-
6
- REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
7
- REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
8
- REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
9
- REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
10
- REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
11
- REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
12
- REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
13
- REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
14
- REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
15
- REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
layer_norm/ln_bwd_1280.cu DELETED
@@ -1,15 +0,0 @@
1
- #include "ln_bwd_kernels.cuh"
2
-
3
- // Create backward launch function and register. Macro signature:
4
- // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
-
6
- REGISTER_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
7
- REGISTER_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
8
- REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
9
- REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
10
- REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
11
- REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
12
- REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
13
- REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
14
- REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
15
- REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
layer_norm/ln_bwd_1536.cu DELETED
@@ -1,15 +0,0 @@
1
- #include "ln_bwd_kernels.cuh"
2
-
3
- // Create backward launch function and register. Macro signature:
4
- // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
-
6
- REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
7
- REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
8
- REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
9
- REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
10
- REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
11
- REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
12
- REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
13
- REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
14
- REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
15
- REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
layer_norm/ln_bwd_2048.cu DELETED
@@ -1,15 +0,0 @@
1
- #include "ln_bwd_kernels.cuh"
2
-
3
- // Create backward launch function and register. Macro signature:
4
- // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
-
6
- REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
7
- REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
8
- REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
9
- REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
10
- REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
11
- REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
12
- REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
13
- REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
14
- REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
15
- REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
layer_norm/ln_bwd_256.cu DELETED
@@ -1,15 +0,0 @@
1
- #include "ln_bwd_kernels.cuh"
2
-
3
- // Create backward launch function and register. Macro signature:
4
- // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
-
6
- REGISTER_BWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
7
- REGISTER_BWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
8
- REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
9
- REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
10
- REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
11
- REGISTER_BWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
12
- REGISTER_BWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
13
- REGISTER_BWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
14
- REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
15
- REGISTER_BWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
layer_norm/ln_bwd_2560.cu DELETED
@@ -1,15 +0,0 @@
1
- #include "ln_bwd_kernels.cuh"
2
-
3
- // Create backward launch function and register. Macro signature:
4
- // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
-
6
- REGISTER_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
7
- REGISTER_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
8
- REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
9
- REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
10
- REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
11
- REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
12
- REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
13
- REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
14
- REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
15
- REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
layer_norm/ln_bwd_3072.cu DELETED
@@ -1,15 +0,0 @@
1
- #include "ln_bwd_kernels.cuh"
2
-
3
- // Create backward launch function and register. Macro signature:
4
- // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
-
6
- REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
7
- REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
8
- REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
9
- REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
10
- REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
11
- REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
12
- REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
13
- REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
14
- REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
15
- REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
layer_norm/ln_bwd_4096.cu DELETED
@@ -1,15 +0,0 @@
1
- #include "ln_bwd_kernels.cuh"
2
-
3
- // Create backward launch function and register. Macro signature:
4
- // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
-
6
- REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
7
- REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
8
- REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
9
- REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
10
- REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
11
- REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
12
- REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
13
- REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
14
- REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
15
- REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
layer_norm/ln_bwd_512.cu DELETED
@@ -1,15 +0,0 @@
1
- #include "ln_bwd_kernels.cuh"
2
-
3
- // Create backward launch function and register. Macro signature:
4
- // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
-
6
- REGISTER_BWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
7
- REGISTER_BWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
8
- REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
9
- REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
10
- REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
11
- REGISTER_BWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
12
- REGISTER_BWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
13
- REGISTER_BWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
14
- REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
15
- REGISTER_BWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
layer_norm/ln_bwd_5120.cu DELETED
@@ -1,15 +0,0 @@
1
- #include "ln_bwd_kernels.cuh"
2
-
3
- // Create backward launch function and register. Macro signature:
4
- // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
-
6
- REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
7
- REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
8
- REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
9
- REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
10
- REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
11
- REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
12
- REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
13
- REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
14
- REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
15
- REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
layer_norm/ln_bwd_6144.cu DELETED
@@ -1,15 +0,0 @@
1
- #include "ln_bwd_kernels.cuh"
2
-
3
- // Create backward launch function and register. Macro signature:
4
- // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
-
6
- REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
7
- REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
8
- REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
9
- REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
10
- REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
11
- REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
12
- REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
13
- REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
14
- REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
15
- REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
layer_norm/ln_bwd_7168.cu DELETED
@@ -1,15 +0,0 @@
1
- #include "ln_bwd_kernels.cuh"
2
-
3
- // Create backward launch function and register. Macro signature:
4
- // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
-
6
- REGISTER_BWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
7
- REGISTER_BWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
8
- REGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4);
9
- REGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4);
10
- REGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4);
11
- REGISTER_BWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4);
12
- REGISTER_BWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4);
13
- REGISTER_BWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4);
14
- REGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4);
15
- REGISTER_BWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
layer_norm/ln_bwd_768.cu DELETED
@@ -1,15 +0,0 @@
1
- #include "ln_bwd_kernels.cuh"
2
-
3
- // Create backward launch function and register. Macro signature:
4
- // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
-
6
- REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
7
- REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
8
- REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
9
- REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
10
- REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
11
- REGISTER_BWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
12
- REGISTER_BWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
13
- REGISTER_BWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
14
- REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
15
- REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
layer_norm/ln_bwd_8192.cu DELETED
@@ -1,15 +0,0 @@
1
- #include "ln_bwd_kernels.cuh"
2
-
3
- // Create backward launch function and register. Macro signature:
4
- // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
5
-
6
- REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
7
- REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
8
- REGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
9
- REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
10
- REGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
11
- REGISTER_BWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
12
- REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
13
- REGISTER_BWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
14
- REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
15
- REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
layer_norm/ln_bwd_kernels.cuh DELETED
@@ -1,534 +0,0 @@
1
- #pragma once
2
-
3
- #include "ln.h"
4
- #include "ln_utils.cuh"
5
- #include "ln_kernel_traits.h"
6
- #include "static_switch.h"
7
-
8
- namespace layer_norm {
9
-
10
- template<typename Ktraits, bool Is_dropout, bool Has_colscale, bool Has_subset, bool Is_even_cols>
11
- __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
12
- void ln_bwd_kernel(layer_norm::BwdParams params) {
13
-
14
- enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
15
- enum { WARPS_M = Ktraits::WARPS_M };
16
- enum { WARPS_N = Ktraits::WARPS_N };
17
- enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
18
- enum { COLS = Ktraits::COLS };
19
- enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
20
- enum { LDGS = Ktraits::LDGS };
21
- enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
22
- enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
23
- enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
24
-
25
- using input_t = typename Ktraits::input_t;
26
- using compute_t = typename Ktraits::compute_t;
27
- using index_t = typename Ktraits::index_t;
28
- using mask_t = typename Ktraits::mask_t;
29
- using Ivec = typename Ktraits::Ivec;
30
- using Rvec = typename Ktraits::Rvec;
31
- using Ovec = typename Ktraits::Ovec;
32
- using Wvec = typename Ktraits::Wvec;
33
- using Cvec = typename Ktraits::Cvec;
34
- using Mvec = typename Ktraits::Mvec;
35
- using Reducer = typename Ktraits::Reducer;
36
- using reduce_t = typename Reducer::Type;
37
-
38
- extern __shared__ char smem_[];
39
-
40
- const bool has_residual = params.dresidual != nullptr;
41
- const bool prenorm = params.dx != nullptr;
42
-
43
- const index_t tidx = threadIdx.x;
44
- const index_t bidn = blockIdx.x % CTAS_PER_ROW;
45
- const index_t bidm = blockIdx.x / CTAS_PER_ROW;
46
- const index_t lane = tidx % THREADS_PER_WARP;
47
- const index_t warp = tidx / THREADS_PER_WARP;
48
- const index_t warp_m = warp / Ktraits::WARPS_N;
49
- const index_t warp_n = warp % Ktraits::WARPS_N;
50
- const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
51
-
52
- const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
53
- const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
54
-
55
- static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
56
-
57
- const input_t *rowscale = static_cast<input_t *>(params.rowscale);
58
- const index_t *x0_subset = static_cast<index_t *>(params.x0_subset);
59
- const index_t *z_subset = static_cast<index_t *>(params.z_subset);
60
-
61
- Cvec dzy_sum[LDGS];
62
- Cvec dz_sum[LDGS];
63
- Cvec dcolscale_sum[LDGS];
64
-
65
- memset(dzy_sum, 0, sizeof(dzy_sum));
66
- memset(dz_sum, 0, sizeof(dz_sum));
67
- if (Has_colscale) { memset(dcolscale_sum, 0, sizeof(dcolscale_sum)); }
68
-
69
- compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
70
- char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
71
-
72
- Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
73
-
74
- Sum<reduce_t> sum;
75
-
76
- const index_t num_valid_ldgs =
77
- ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG;
78
-
79
- Wvec gamma[LDGS];
80
- Wvec colscale[LDGS];
81
- index_t idx = c;
82
- #pragma unroll
83
- for( int it = 0; it < LDGS; it++ ) {
84
- if (Is_even_cols || (it < num_valid_ldgs)) {
85
- gamma[it].load_from(params.gamma, idx);
86
- if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
87
- idx += Ktraits::VEC_COLS_PER_LDG;
88
- }
89
- }
90
- // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
91
- // last blocks with syncthreads!
92
- // grid stride over rows
93
- #pragma unroll 1
94
- for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
95
- const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
96
- const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
97
- const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const;
98
- const int row_z = !Has_subset ? row + 1 : z_subset[row];
99
- const int row_x0 = !Has_subset ? row + 1 : x0_subset[row];
100
- const bool load_dz = !Has_subset || row_z > 0;
101
- const bool save_dx0 = !Has_subset || row_x0 > 0;
102
- Mvec dmask[LDGS];
103
- Rvec dx[LDGS];
104
- compute_t dy[LDGS * NUM_ELTS];
105
- compute_t y[LDGS * NUM_ELTS];
106
- compute_t mdy_local = 0.f;
107
- compute_t mdyy_local = 0.f;
108
- // If dz is not loaded, then dy should be 0 and we don't care about the value of y.
109
- if (load_dz) {
110
- index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
111
- index_t idx_z = !Has_subset ? idx_x : (load_dz ? (row_z - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
112
- index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
113
- #pragma unroll
114
- for( int it = 0; it < LDGS; it++ ) {
115
- if (Is_even_cols || (it < num_valid_ldgs)) {
116
- Rvec x;
117
- Ovec dz;
118
- dz.load_from(params.dz, !Has_subset ? idx_x : idx_z);
119
- if (prenorm) { dx[it].load_from(params.dx, idx_x); }
120
- x.load_from(params.x, idx_x);
121
- if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); }
122
- idx_x += Ktraits::VEC_COLS_PER_LDG;
123
- idx_z += Ktraits::VEC_COLS_PER_LDG;
124
- idx_x0 += Ktraits::VEC_COLS_PER_LDG;
125
- #pragma unroll
126
- for( int jt = 0; jt < NUM_ELTS; jt++ ) {
127
- compute_t x_tmp = x.data.elt[jt];
128
- compute_t y_tmp = rs_r * (x_tmp - (!params.is_rms_norm ? mu_r : 0.f));
129
- compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) * compute_t(dz.data.elt[jt]);
130
- compute_t dz_tmp = dz.data.elt[jt];
131
-
132
- mdy_local += dy_tmp;
133
- mdyy_local += dy_tmp * y_tmp;
134
-
135
- dy[it * NUM_ELTS + jt] = dy_tmp;
136
- y[it * NUM_ELTS + jt] = y_tmp;
137
-
138
- dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
139
- dz_sum[it].data.elt[jt] += dz_tmp;
140
- }
141
- }
142
- }
143
- } else {
144
- index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
145
- index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
146
- #pragma unroll
147
- for( int it = 0; it < LDGS; it++ ) {
148
- if (Is_even_cols || (it < num_valid_ldgs)) {
149
- if (prenorm) { dx[it].load_from(params.dx, idx_x); }
150
- if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); }
151
- idx_x += Ktraits::VEC_COLS_PER_LDG;
152
- idx_x0 += Ktraits::VEC_COLS_PER_LDG;
153
- }
154
- }
155
- }
156
-
157
- reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
158
- mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * params.inverse_cols;
159
- mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * params.inverse_cols;
160
-
161
- index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
162
- index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
163
- #pragma unroll
164
- for( int it = 0; it < LDGS; it++ ) {
165
- if (Is_even_cols || (it < num_valid_ldgs)) {
166
- Ivec dx0;
167
- Rvec dresidual;
168
- Ivec x0;
169
- if (Has_colscale && save_dx0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
170
- #pragma unroll
171
- for( int jt = 0; jt < NUM_ELTS; jt++ ) {
172
- compute_t dx_tmp_res;
173
- if (load_dz) {
174
- compute_t dy_tmp = dy[it * NUM_ELTS + jt];
175
- compute_t y_tmp = y[it * NUM_ELTS + jt];
176
- compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + (!params.is_rms_norm ? mdy_local : 0.f)));
177
- dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;
178
- } else {
179
- dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f;
180
- }
181
- if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; }
182
- if (save_dx0) {
183
- compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
184
- if (Is_dropout) {
185
- dx0_tmp_res *= params.dropout_scale;
186
- if (Has_colscale) {
187
- dcolscale_sum[it].data.elt[jt] += dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(x0.data.elt[jt]) : 0.f;
188
- dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(colscale[it].data.elt[jt]) : 0.f;
189
- } else {
190
- dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res : 0.f;
191
- }
192
- } else {
193
- if (Has_colscale) {
194
- dcolscale_sum[it].data.elt[jt] += dx0_tmp_res * compute_t(x0.data.elt[jt]);
195
- dx0.data.elt[jt] = dx0_tmp_res * compute_t(colscale[it].data.elt[jt]);
196
- } else {
197
- dx0.data.elt[jt] = dx0_tmp_res;
198
- }
199
- }
200
- }
201
- }
202
- if (has_residual) { dresidual.store_to(params.dresidual, idx_x); }
203
- if (save_dx0) { dx0.store_to(params.dx0, !Has_subset ? idx_x : idx_x0); }
204
- idx_x += Ktraits::VEC_COLS_PER_LDG;
205
- idx_x0 += Ktraits::VEC_COLS_PER_LDG;
206
- }
207
- }
208
-
209
- } // end: grid stride loop
210
-
211
- if( WARPS_M == 1 ) {
212
- idx = r * params.cols / Ktraits::ELTS_PER_LDG + c;
213
- #pragma unroll
214
- for( int it = 0; it < LDGS; it++ ) {
215
- if (Is_even_cols || (it < num_valid_ldgs)) {
216
- dz_sum[it].store_to(params.dbeta_part, idx);
217
- dzy_sum[it].store_to(params.dgamma_part, idx);
218
- if (Has_colscale) { dcolscale_sum[it].store_to(params.dcolscale_part, idx); }
219
- idx += Ktraits::VEC_COLS_PER_LDG;
220
- }
221
- }
222
- } else {
223
- static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA.");
224
- // Finalize reduction of part dgamma and dbeta for this CTA
225
- // by reducing over the rows held across the WARPS_M warps
226
-
227
- // Assumption: blockSize divides hidden size.
228
- enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
229
- static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
230
-
231
- idx = warp_m * Ktraits::VEC_COLS + tid_r;
232
- #pragma unroll
233
- for( int it = 0; it < LDGS; it++ ) {
234
- dz_sum[it].store_to(smem_wgrad, idx);
235
- idx += THREADS_PER_ROW;
236
- }
237
- __syncthreads();
238
- compute_t cta_dz_sum[NUM_RES];
239
- memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES);
240
- for( int it = 0; it < ROWS_PER_CTA; it++ ) {
241
- for( int jt = 0; jt < NUM_RES; jt++ ) {
242
- cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
243
- }
244
- }
245
- __syncthreads();
246
-
247
- idx = warp_m * Ktraits::VEC_COLS + tid_r;
248
- #pragma unroll
249
- for( int it = 0; it < LDGS; it++ ) {
250
- dzy_sum[it].store_to(smem_wgrad, idx);
251
- idx += THREADS_PER_ROW;
252
- }
253
- __syncthreads();
254
- compute_t cta_dzy_sum[NUM_RES];
255
- memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
256
- for( int it = 0; it < ROWS_PER_CTA; it++ ) {
257
- for( int jt = 0; jt < NUM_RES; jt++ ) {
258
- cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
259
- }
260
- }
261
-
262
- compute_t cta_dcolscale_sum[NUM_RES];
263
- if (Has_colscale) {
264
- __syncthreads();
265
- idx = warp_m * Ktraits::VEC_COLS + tid_r;
266
- #pragma unroll
267
- for( int it = 0; it < LDGS; it++ ) {
268
- dcolscale_sum[it].store_to(smem_wgrad, idx);
269
- idx += THREADS_PER_ROW;
270
- }
271
- __syncthreads();
272
- memset(cta_dcolscale_sum, 0, sizeof(compute_t) * NUM_RES);
273
- for( int it = 0; it < ROWS_PER_CTA; it++ ) {
274
- for( int jt = 0; jt < NUM_RES; jt++ ) {
275
- cta_dcolscale_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
276
- }
277
- }
278
- }
279
-
280
- const index_t num_valid_writes
281
- = (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA;
282
- compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * params.cols + tidx;
283
- compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * params.cols + tidx;
284
- compute_t *dcolscale_part = Has_colscale ? static_cast<compute_t *>(params.dcolscale_part) + bidm * params.cols + tidx : nullptr;
285
- for( int jt = 0; jt < NUM_RES; jt++ ) {
286
- if (Is_even_cols || (jt < num_valid_writes)) {
287
- *dgamma_part = cta_dzy_sum[jt];
288
- dgamma_part += Ktraits::THREADS_PER_CTA;
289
- *dbeta_part = cta_dz_sum[jt];
290
- dbeta_part += Ktraits::THREADS_PER_CTA;
291
- if (Has_colscale) {
292
- *dcolscale_part = cta_dcolscale_sum[jt];
293
- dcolscale_part += Ktraits::THREADS_PER_CTA;
294
- }
295
- }
296
- }
297
-
298
- }
299
- }
300
-
301
- template<typename Kernel_traits, bool Has_colscale, bool Is_even_cols>
302
- __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
303
- void ln_bwd_finalize_kernel(BwdParams params)
304
- {
305
-
306
- using compute_t = typename Kernel_traits::compute_t;
307
- using weight_t = typename Kernel_traits::weight_t;
308
- using index_t = typename Kernel_traits::index_t;
309
- using Reducer = typename Kernel_traits::Reducer;
310
- using reduce_t = typename Reducer::Type;
311
-
312
- Sum<reduce_t> sum;
313
- enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
314
- enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
315
-
316
- __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
317
-
318
- constexpr uint32_t bidm = 0;
319
-
320
- const uint32_t bidn = blockIdx.x;
321
- const uint32_t tidx = threadIdx.x;
322
- const uint32_t warp = tidx / THREADS_PER_WARP;
323
- const uint32_t lane = tidx % THREADS_PER_WARP;
324
-
325
- Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
326
-
327
- const uint32_t c = bidn * THREADS_PER_WARP + lane;
328
- const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
329
- constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
330
- for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
331
- // Each thread sums over NUM_ELT columns.
332
- Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local, dcolscale_local;
333
- memset(&dgamma_local, 0, sizeof(dgamma_local));
334
- memset(&dbeta_local, 0, sizeof(dbeta_local));
335
- if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
336
- if (Is_even_cols || col < params.cols) {
337
- for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
338
- index_t idx = row * params.cols + col;
339
-
340
- Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part, dcolscale_part;
341
- dbeta_part.load_from(params.dbeta_part, idx);
342
- dgamma_part.load_from(params.dgamma_part, idx);
343
- if (Has_colscale) { dcolscale_part.load_from(params.dcolscale_part, idx); }
344
- #pragma unroll
345
- for( int it = 0; it < NUM_ELT; it++ ) {
346
- dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
347
- dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
348
- if (Has_colscale) { dcolscale_local.data.elt[it] += dcolscale_part.data.elt[it]; }
349
- }
350
- }
351
- }
352
- void * smem_gamma = smem_;
353
- void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
354
- void * smem_colscale = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
355
-
356
- const int write_row = warp;
357
- const int write_col = lane ^ write_row;
358
- const int write_idx = write_row * THREADS_PER_WARP + write_col;
359
-
360
- dgamma_local.store_to(smem_gamma, write_idx);
361
- dbeta_local.store_to(smem_beta, write_idx);
362
- if (Has_colscale) { dcolscale_local.store_to(smem_colscale, write_idx); }
363
-
364
- __syncthreads();
365
-
366
- // It would be probably safe to reuse the first row of smem_beta and smem_gamma
367
- void * smem_gamma_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE];
368
- void * smem_beta_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
369
- void * smem_colscale_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT];
370
-
371
-
372
- // More than one iter iff ROWS_PER_CTA < 32.
373
- for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) {
374
- const int read_row = lane;
375
- const int read_col = w ^ read_row;
376
- const int read_idx = read_row * THREADS_PER_WARP + read_col;
377
-
378
- memset(&dbeta_local, 0, sizeof(dbeta_local));
379
- memset(&dgamma_local, 0, sizeof(dgamma_local));
380
- if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
381
-
382
- // Load beta and gamma transposed
383
- if(read_row < Kernel_traits::ROWS_PER_CTA){
384
- dbeta_local.load_from(smem_beta, read_idx);
385
- dgamma_local.load_from(smem_gamma, read_idx);
386
- if (Has_colscale) { dcolscale_local.load_from(smem_colscale, read_idx); }
387
- }
388
-
389
- // Call reducer on the loaded value(s) and convert.
390
- #pragma unroll
391
- for( int it = 0; it < NUM_ELT; it++ ) {
392
- compute_t b_i = dbeta_local.data.elt[it];
393
- compute_t g_i = dgamma_local.data.elt[it];
394
- b_i = reducer.allreduce(b_i, sum);
395
- g_i = reducer.allreduce(g_i, sum);
396
-
397
- dgamma_local.data.elt[it] = g_i;
398
- dbeta_local.data.elt[it] = b_i;
399
- if (Has_colscale) {
400
- compute_t cs_i = dcolscale_local.data.elt[it];
401
- cs_i = reducer.allreduce(cs_i, sum);
402
- dcolscale_local.data.elt[it] = cs_i;
403
- }
404
- }
405
-
406
- // Leader stores the result at the current column.
407
- if(lane == 0){
408
- dgamma_local.store_to(smem_gamma_out, w);
409
- dbeta_local.store_to(smem_beta_out, w);
410
- if (Has_colscale) { dcolscale_local.store_to(smem_colscale_out, w); }
411
- }
412
-
413
- }
414
-
415
- // All writes done.
416
- __syncthreads();
417
-
418
- // Pack and store: 2-wide stores with half the threads.
419
- if (Is_even_cols || col_out * 2 < params.cols) {
420
- if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
421
-
422
- using src_t = typename TypeToVec2<compute_t>::Type;
423
- using dst_t = typename TypeToVec2<weight_t>::Type;
424
- Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2, dcolscale_vec2;
425
- Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2, dcolscale_out2;
426
-
427
- dgamma_vec2.load_from(smem_gamma_out, lane);
428
- dbeta_vec2.load_from(smem_beta_out, lane);
429
- if (Has_colscale) { dcolscale_vec2.load_from(smem_colscale_out, lane); }
430
- #pragma unroll
431
- for( int it = 0; it < NUM_ELT; it++ ) {
432
- dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);
433
- dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);
434
- if (Has_colscale) { dcolscale_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dcolscale_vec2.data.elt[it]); }
435
- }
436
- dgamma_out2.store_to(params.dgamma, col_out);
437
- dbeta_out2.store_to(params.dbeta, col_out);
438
- if (Has_colscale) { dcolscale_out2.store_to(params.dcolscale, col_out); }
439
- }
440
- }
441
- }
442
- }
443
- } // namespace layer_norm
444
-
445
- using namespace layer_norm;
446
-
447
- template<
448
- typename weight_t,
449
- typename input_t,
450
- typename residual_t,
451
- typename output_t,
452
- typename compute_t,
453
- typename index_t,
454
- int HIDDEN_SIZE,
455
- int CTAS_PER_ROW,
456
- int WARPS_M,
457
- int WARPS_N,
458
- int BYTES_PER_LDG_MAIN,
459
- int BYTES_PER_LDG_FINAL
460
- >
461
- void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params){
462
-
463
- using Kernel_traits = Kernel_traits<weight_t,
464
- input_t,
465
- residual_t,
466
- output_t,
467
- compute_t,
468
- index_t,
469
- HIDDEN_SIZE,
470
- CTAS_PER_ROW,
471
- WARPS_M,
472
- WARPS_N,
473
- BYTES_PER_LDG_MAIN
474
- >;
475
- bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
476
- bool has_colscale = launch_params.params.colscale != nullptr;
477
- bool has_subset = launch_params.params.x0_subset != nullptr;
478
- bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
479
- BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
480
- BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
481
- BOOL_SWITCH(has_subset, HasSubsetConst, [&] {
482
- BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
483
- auto kernel = &ln_bwd_kernel<Kernel_traits, IsDropoutConst, HasColscaleConst, HasSubsetConst, IsEvenColsConst>;
484
- if( configure_params ) {
485
- int ctas_per_sm;
486
- CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
487
- &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
488
- launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
489
- launch_params.barrier_size = 0;
490
- launch_params.workspace_bytes = 0;
491
- if(Kernel_traits::CTAS_PER_ROW > 1) {
492
- launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
493
- launch_params.workspace_bytes = launch_params.params.ctas_per_col
494
- * Kernel_traits::WARPS_M
495
- * Kernel_traits::CTAS_PER_ROW
496
- * sizeof(typename Kernel_traits::reduce_t)
497
- * 2;
498
- }
499
- return;
500
- }
501
-
502
- if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
503
- CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
504
- }
505
- auto stream = launch_params.stream;
506
- auto ctas_per_col = launch_params.params.ctas_per_col;
507
-
508
- if( Kernel_traits::CTAS_PER_ROW == 1 ) {
509
- kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
510
- } else {
511
- dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
512
- dim3 block(Kernel_traits::THREADS_PER_CTA);
513
- void *params_ = (void *)&launch_params.params;
514
- cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
515
- }
516
-
517
- using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
518
- weight_t,
519
- input_t,
520
- residual_t,
521
- output_t,
522
- compute_t,
523
- index_t,
524
- HasColscaleConst,
525
- 32 * 32, // THREADS_PER_CTA
526
- BYTES_PER_LDG_FINAL>;
527
-
528
- auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, HasColscaleConst, IsEvenColsConst>;
529
- kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
530
- });
531
- });
532
- });
533
- });
534
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
layer_norm/ln_fwd_1024.cu DELETED
@@ -1,15 +0,0 @@
1
- #include "ln_fwd_kernels.cuh"
2
-
3
- // Create forward launch function and register. Macro signature:
4
- // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
5
-
6
- REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
7
- REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
8
- REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
9
- REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
10
- REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
11
- REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
12
- REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
13
- REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
14
- REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
15
- REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);