medmekk HF Staff commited on
Commit
f622ea1
·
1 Parent(s): 4e9c226

add 9.0 build

Browse files
Files changed (36) hide show
  1. CMakeLists.txt +213 -0
  2. build.toml +4 -1
  3. build/torch27-cxx11-cu118-x86_64-linux/layer_norm/__init__.py +26 -0
  4. build/torch27-cxx11-cu118-x86_64-linux/layer_norm/_layer_norm_4e9c226_dirty.abi3.so +3 -0
  5. build/torch27-cxx11-cu118-x86_64-linux/layer_norm/_ops.py +9 -0
  6. build/torch27-cxx11-cu118-x86_64-linux/layer_norm/layers.py +49 -0
  7. build/torch27-cxx11-cu126-x86_64-linux/layer_norm/__init__.py +26 -0
  8. build/torch27-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_4e9c226_dirty.abi3.so +3 -0
  9. build/torch27-cxx11-cu126-x86_64-linux/layer_norm/_ops.py +9 -0
  10. build/torch27-cxx11-cu126-x86_64-linux/layer_norm/layers.py +49 -0
  11. build/torch27-cxx11-cu128-x86_64-linux/layer_norm/__init__.py +26 -0
  12. build/torch27-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_4e9c226_dirty.abi3.so +3 -0
  13. build/torch27-cxx11-cu128-x86_64-linux/layer_norm/_ops.py +9 -0
  14. build/torch27-cxx11-cu128-x86_64-linux/layer_norm/layers.py +49 -0
  15. build/torch28-cxx11-cu126-x86_64-linux/layer_norm/__init__.py +26 -0
  16. build/torch28-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_4e9c226_dirty.abi3.so +3 -0
  17. build/torch28-cxx11-cu126-x86_64-linux/layer_norm/_ops.py +9 -0
  18. build/torch28-cxx11-cu126-x86_64-linux/layer_norm/layers.py +49 -0
  19. build/torch28-cxx11-cu128-x86_64-linux/layer_norm/__init__.py +26 -0
  20. build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_4e9c226_dirty.abi3.so +3 -0
  21. build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_ops.py +9 -0
  22. build/torch28-cxx11-cu128-x86_64-linux/layer_norm/layers.py +49 -0
  23. build/torch28-cxx11-cu129-x86_64-linux/layer_norm/__init__.py +26 -0
  24. build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_layer_norm_4e9c226_dirty.abi3.so +3 -0
  25. build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_ops.py +9 -0
  26. build/torch28-cxx11-cu129-x86_64-linux/layer_norm/layers.py +49 -0
  27. cmake/hipify.py +76 -0
  28. cmake/utils.cmake +545 -0
  29. flake.lock +168 -0
  30. pyproject.toml +10 -0
  31. setup.py +138 -0
  32. torch-ext/layer_norm/_layer_norm_711aa42_dirty.abi3.so +3 -0
  33. torch-ext/layer_norm/_ops.py +9 -0
  34. torch-ext/registration.h +30 -0
  35. torch-ext/torch_binding.cpp +146 -9
  36. torch-ext/torch_binding.h +66 -4
CMakeLists.txt ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cmake_minimum_required(VERSION 3.26)
2
+ project(layer_norm LANGUAGES CXX)
3
+
4
+ set(TARGET_DEVICE "cuda" CACHE STRING "Target device backend for kernel")
5
+
6
+ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
7
+
8
+ include(FetchContent)
9
+ file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
10
+ message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
11
+
12
+ set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
13
+
14
+ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101")
15
+
16
+ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
17
+
18
+ if(DEFINED Python_EXECUTABLE)
19
+ # Allow passing through the interpreter (e.g. from setup.py).
20
+ find_package(Python COMPONENTS Development Development.SABIModule Interpreter)
21
+ if (NOT Python_FOUND)
22
+ message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
23
+ endif()
24
+ else()
25
+ find_package(Python REQUIRED COMPONENTS Development Development.SABIModule Interpreter)
26
+ endif()
27
+
28
+ append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
29
+
30
+ find_package(Torch REQUIRED)
31
+
32
+ if (NOT TARGET_DEVICE STREQUAL "cuda" AND
33
+ NOT TARGET_DEVICE STREQUAL "rocm")
34
+ return()
35
+ endif()
36
+
37
+ if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
38
+ CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
39
+ set(CUDA_DEFAULT_KERNEL_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0+PTX")
40
+ else()
41
+ set(CUDA_DEFAULT_KERNEL_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0+PTX")
42
+ endif()
43
+
44
+ if (NOT HIP_FOUND AND CUDA_FOUND)
45
+ set(GPU_LANG "CUDA")
46
+
47
+
48
+
49
+ elseif(HIP_FOUND)
50
+ set(GPU_LANG "HIP")
51
+
52
+ # Importing torch recognizes and sets up some HIP/ROCm configuration but does
53
+ # not let cmake recognize .hip files. In order to get cmake to understand the
54
+ # .hip extension automatically, HIP must be enabled explicitly.
55
+ enable_language(HIP)
56
+ else()
57
+ message(FATAL_ERROR "Can't find CUDA or HIP installation.")
58
+ endif()
59
+
60
+ if(GPU_LANG STREQUAL "CUDA")
61
+ clear_cuda_arches(CUDA_ARCH_FLAGS)
62
+ extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}")
63
+ message(STATUS "CUDA target architectures: ${CUDA_ARCHS}")
64
+ # Filter the target architectures by the supported supported archs
65
+ # since for some files we will build for all CUDA_ARCHS.
66
+ cuda_archs_loose_intersection(CUDA_ARCHS "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}")
67
+ message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}")
68
+
69
+ if(NVCC_THREADS AND GPU_LANG STREQUAL "CUDA")
70
+ list(APPEND GPU_FLAGS "--threads=${NVCC_THREADS}")
71
+ endif()
72
+
73
+ add_compile_definitions(CUDA_KERNEL)
74
+ elseif(GPU_LANG STREQUAL "HIP")
75
+ set(ROCM_ARCHS "${HIP_SUPPORTED_ARCHS}")
76
+ # TODO: remove this once we can set specific archs per source file set.
77
+ override_gpu_arches(GPU_ARCHES
78
+ ${GPU_LANG}
79
+ "${${GPU_LANG}_SUPPORTED_ARCHS}")
80
+
81
+ add_compile_definitions(ROCM_KERNEL)
82
+ else()
83
+ override_gpu_arches(GPU_ARCHES
84
+ ${GPU_LANG}
85
+ "${${GPU_LANG}_SUPPORTED_ARCHS}")
86
+ endif()
87
+
88
+ get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG})
89
+ list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS})
90
+
91
+ set(TORCH_layer_norm_SRC
92
+ torch-ext/torch_binding.cpp torch-ext/torch_binding.h
93
+ )
94
+
95
+
96
+ list(APPEND SRC "${TORCH_layer_norm_SRC}")
97
+
98
+
99
+ set(layer_norm_SRC
100
+ "layer_norm/ln.h"
101
+ "layer_norm/ln_api.cpp"
102
+ "layer_norm/ln_bwd_1024.cu"
103
+ "layer_norm/ln_bwd_1280.cu"
104
+ "layer_norm/ln_bwd_1536.cu"
105
+ "layer_norm/ln_bwd_2048.cu"
106
+ "layer_norm/ln_bwd_256.cu"
107
+ "layer_norm/ln_bwd_2560.cu"
108
+ "layer_norm/ln_bwd_3072.cu"
109
+ "layer_norm/ln_bwd_4096.cu"
110
+ "layer_norm/ln_bwd_512.cu"
111
+ "layer_norm/ln_bwd_5120.cu"
112
+ "layer_norm/ln_bwd_6144.cu"
113
+ "layer_norm/ln_bwd_7168.cu"
114
+ "layer_norm/ln_bwd_768.cu"
115
+ "layer_norm/ln_bwd_8192.cu"
116
+ "layer_norm/ln_bwd_kernels.cuh"
117
+ "layer_norm/ln_fwd_1024.cu"
118
+ "layer_norm/ln_fwd_1280.cu"
119
+ "layer_norm/ln_fwd_1536.cu"
120
+ "layer_norm/ln_fwd_2048.cu"
121
+ "layer_norm/ln_fwd_256.cu"
122
+ "layer_norm/ln_fwd_2560.cu"
123
+ "layer_norm/ln_fwd_3072.cu"
124
+ "layer_norm/ln_fwd_4096.cu"
125
+ "layer_norm/ln_fwd_512.cu"
126
+ "layer_norm/ln_fwd_5120.cu"
127
+ "layer_norm/ln_fwd_6144.cu"
128
+ "layer_norm/ln_fwd_7168.cu"
129
+ "layer_norm/ln_fwd_768.cu"
130
+ "layer_norm/ln_fwd_8192.cu"
131
+ "layer_norm/ln_fwd_kernels.cuh"
132
+ "layer_norm/ln_kernel_traits.h"
133
+ "layer_norm/ln_parallel_bwd_1024.cu"
134
+ "layer_norm/ln_parallel_bwd_1280.cu"
135
+ "layer_norm/ln_parallel_bwd_1536.cu"
136
+ "layer_norm/ln_parallel_bwd_2048.cu"
137
+ "layer_norm/ln_parallel_bwd_256.cu"
138
+ "layer_norm/ln_parallel_bwd_2560.cu"
139
+ "layer_norm/ln_parallel_bwd_3072.cu"
140
+ "layer_norm/ln_parallel_bwd_4096.cu"
141
+ "layer_norm/ln_parallel_bwd_512.cu"
142
+ "layer_norm/ln_parallel_bwd_5120.cu"
143
+ "layer_norm/ln_parallel_bwd_6144.cu"
144
+ "layer_norm/ln_parallel_bwd_7168.cu"
145
+ "layer_norm/ln_parallel_bwd_768.cu"
146
+ "layer_norm/ln_parallel_bwd_8192.cu"
147
+ "layer_norm/ln_parallel_fwd_1024.cu"
148
+ "layer_norm/ln_parallel_fwd_1280.cu"
149
+ "layer_norm/ln_parallel_fwd_1536.cu"
150
+ "layer_norm/ln_parallel_fwd_2048.cu"
151
+ "layer_norm/ln_parallel_fwd_256.cu"
152
+ "layer_norm/ln_parallel_fwd_2560.cu"
153
+ "layer_norm/ln_parallel_fwd_3072.cu"
154
+ "layer_norm/ln_parallel_fwd_4096.cu"
155
+ "layer_norm/ln_parallel_fwd_512.cu"
156
+ "layer_norm/ln_parallel_fwd_5120.cu"
157
+ "layer_norm/ln_parallel_fwd_6144.cu"
158
+ "layer_norm/ln_parallel_fwd_7168.cu"
159
+ "layer_norm/ln_parallel_fwd_768.cu"
160
+ "layer_norm/ln_parallel_fwd_8192.cu"
161
+ "layer_norm/ln_parallel_residual_bwd_kernels.cuh"
162
+ "layer_norm/ln_parallel_residual_fwd_kernels.cuh"
163
+ "layer_norm/ln_utils.cuh"
164
+ "layer_norm/static_switch.h"
165
+ )
166
+
167
+ # TODO: check if CLion support this:
168
+ # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
169
+ set_source_files_properties(
170
+ ${layer_norm_SRC}
171
+ PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/.")
172
+
173
+ if(GPU_LANG STREQUAL "CUDA")
174
+ cuda_archs_loose_intersection(layer_norm_ARCHS "${CUDA_DEFAULT_KERNEL_ARCHS}" "${CUDA_ARCHS}")
175
+ message(STATUS "Capabilities for kernel layer_norm: ${layer_norm_ARCHS}")
176
+ set_gencode_flags_for_srcs(SRCS "${layer_norm_SRC}" CUDA_ARCHS "${layer_norm_ARCHS}")
177
+
178
+
179
+ foreach(_KERNEL_SRC ${layer_norm_SRC})
180
+ if(_KERNEL_SRC MATCHES ".*\\.cu$")
181
+ set_property(
182
+ SOURCE ${_KERNEL_SRC}
183
+ APPEND PROPERTY
184
+ COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-O3;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;-U__CUDA_NO_BFLOAT16_OPERATORS__;-U__CUDA_NO_BFLOAT16_CONVERSIONS__;-U__CUDA_NO_BFLOAT162_OPERATORS__;-U__CUDA_NO_BFLOAT162_CONVERSIONS__;--expt-relaxed-constexpr;--expt-extended-lambda;--use_fast_math>"
185
+ )
186
+ endif()
187
+ endforeach()
188
+
189
+ foreach(_KERNEL_SRC ${layer_norm_SRC})
190
+ set_property(
191
+ SOURCE ${_KERNEL_SRC}
192
+ APPEND PROPERTY
193
+ COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CXX>:-DFLASHATTENTION_DISABLE_PYBIND>"
194
+ )
195
+ endforeach()
196
+
197
+ list(APPEND SRC "${layer_norm_SRC}")
198
+ endif()
199
+
200
+
201
+ define_gpu_extension_target(
202
+ _layer_norm_711aa42_dirty
203
+ DESTINATION _layer_norm_711aa42_dirty
204
+ LANGUAGE ${GPU_LANG}
205
+ SOURCES ${SRC}
206
+ COMPILE_FLAGS ${GPU_FLAGS}
207
+ ARCHITECTURES ${GPU_ARCHES}
208
+ #INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
209
+ USE_SABI 3
210
+ WITH_SOABI)
211
+
212
+ target_link_options(_layer_norm_711aa42_dirty PRIVATE -static-libstdc++)
213
+
build.toml CHANGED
@@ -11,6 +11,9 @@ src = [
11
  [kernel.layer_norm]
12
  depends = ["torch"]
13
  backend = "cuda"
 
 
 
14
  include = ["."]
15
  src = [
16
  "layer_norm/ln.h",
@@ -79,7 +82,7 @@ src = [
79
  "layer_norm/ln_utils.cuh",
80
  "layer_norm/static_switch.h"
81
  ]
82
- cxx-flags = ["-DFLASHATTENTION_DISABLE_PYBIND"]
83
  cuda-flags = [
84
  "-O3",
85
  "-U__CUDA_NO_HALF_OPERATORS__",
 
11
  [kernel.layer_norm]
12
  depends = ["torch"]
13
  backend = "cuda"
14
+ cuda-capabilities = [
15
+ "9.0"
16
+ ]
17
  include = ["."]
18
  src = [
19
  "layer_norm/ln.h",
 
82
  "layer_norm/ln_utils.cuh",
83
  "layer_norm/static_switch.h"
84
  ]
85
+ cxx-flags = ["-DFLASHATTENTION_DISABLE_PYBIND", "-mcmodel=large"]
86
  cuda-flags = [
87
  "-O3",
88
  "-U__CUDA_NO_HALF_OPERATORS__",
build/torch27-cxx11-cu118-x86_64-linux/layer_norm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ._ops import ops
5
+
6
+ from . import layers
7
+
8
+ def dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm):
9
+ return ops.dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm)
10
+
11
+ def dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm):
12
+ return ops.dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm)
13
+
14
+ def dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm):
15
+ return ops.dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm)
16
+
17
+ def dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm):
18
+ return ops.dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm)
19
+
20
+ __all__ = [
21
+ "layers",
22
+ "dropout_add_ln_fwd",
23
+ "dropout_add_ln_bwd",
24
+ "dropout_add_ln_parallel_residual_fwd",
25
+ "dropout_add_ln_parallel_residual_bwd",
26
+ ]
build/torch27-cxx11-cu118-x86_64-linux/layer_norm/_layer_norm_4e9c226_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34e4a57b8d721c4dafb541a81e161435d25198632e3e4c8e2bc66c17eccc236f
3
+ size 248321384
build/torch27-cxx11-cu118-x86_64-linux/layer_norm/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _layer_norm_4e9c226_dirty
3
+ ops = torch.ops._layer_norm_4e9c226_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_layer_norm_4e9c226_dirty::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/layer_norm/layers.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ._ops import ops
5
+
6
+
7
+ class LayerNorm(nn.Module):
8
+ weight: torch.Tensor
9
+ variance_epsilon: float
10
+
11
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
12
+ return ops.dropout_add_ln_fwd(
13
+ hidden_states,
14
+ gamma = self.weight,
15
+ beta = None,
16
+ rowscale = None,
17
+ colscale = None,
18
+ x0_subset = None,
19
+ z_subset = None,
20
+ dropout_p = 0,
21
+ epsilon = self.variance_epsilon,
22
+ rowscale_const = 1.0,
23
+ z_numrows = hidden_states.shape[1],
24
+ gen = None,
25
+ residual_in_fp32 = False,
26
+ is_rms_norm = False,
27
+ )
28
+
29
+ class LlamaRMSNorm(nn.Module):
30
+ weight: torch.Tensor
31
+ variance_epsilon: float
32
+
33
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
34
+ return ops.dropout_add_ln_fwd(
35
+ hidden_states,
36
+ gamma = self.weight,
37
+ beta = None,
38
+ rowscale = None,
39
+ colscale = None,
40
+ x0_subset = None,
41
+ z_subset = None,
42
+ dropout_p = 0,
43
+ epsilon = self.variance_epsilon,
44
+ rowscale_const = 1.0,
45
+ z_numrows = hidden_states.shape[1],
46
+ gen = None,
47
+ residual_in_fp32 = False,
48
+ is_rms_norm = True,
49
+ )
build/torch27-cxx11-cu126-x86_64-linux/layer_norm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ._ops import ops
5
+
6
+ from . import layers
7
+
8
+ def dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm):
9
+ return ops.dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm)
10
+
11
+ def dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm):
12
+ return ops.dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm)
13
+
14
+ def dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm):
15
+ return ops.dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm)
16
+
17
+ def dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm):
18
+ return ops.dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm)
19
+
20
+ __all__ = [
21
+ "layers",
22
+ "dropout_add_ln_fwd",
23
+ "dropout_add_ln_bwd",
24
+ "dropout_add_ln_parallel_residual_fwd",
25
+ "dropout_add_ln_parallel_residual_bwd",
26
+ ]
build/torch27-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_4e9c226_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f541911e5471865e47faf1641da36bcee3b206aa4993949a3cac966c3b936d27
3
+ size 247115320
build/torch27-cxx11-cu126-x86_64-linux/layer_norm/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _layer_norm_4e9c226_dirty
3
+ ops = torch.ops._layer_norm_4e9c226_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_layer_norm_4e9c226_dirty::{op_name}"
build/torch27-cxx11-cu126-x86_64-linux/layer_norm/layers.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ._ops import ops
5
+
6
+
7
+ class LayerNorm(nn.Module):
8
+ weight: torch.Tensor
9
+ variance_epsilon: float
10
+
11
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
12
+ return ops.dropout_add_ln_fwd(
13
+ hidden_states,
14
+ gamma = self.weight,
15
+ beta = None,
16
+ rowscale = None,
17
+ colscale = None,
18
+ x0_subset = None,
19
+ z_subset = None,
20
+ dropout_p = 0,
21
+ epsilon = self.variance_epsilon,
22
+ rowscale_const = 1.0,
23
+ z_numrows = hidden_states.shape[1],
24
+ gen = None,
25
+ residual_in_fp32 = False,
26
+ is_rms_norm = False,
27
+ )
28
+
29
+ class LlamaRMSNorm(nn.Module):
30
+ weight: torch.Tensor
31
+ variance_epsilon: float
32
+
33
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
34
+ return ops.dropout_add_ln_fwd(
35
+ hidden_states,
36
+ gamma = self.weight,
37
+ beta = None,
38
+ rowscale = None,
39
+ colscale = None,
40
+ x0_subset = None,
41
+ z_subset = None,
42
+ dropout_p = 0,
43
+ epsilon = self.variance_epsilon,
44
+ rowscale_const = 1.0,
45
+ z_numrows = hidden_states.shape[1],
46
+ gen = None,
47
+ residual_in_fp32 = False,
48
+ is_rms_norm = True,
49
+ )
build/torch27-cxx11-cu128-x86_64-linux/layer_norm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ._ops import ops
5
+
6
+ from . import layers
7
+
8
+ def dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm):
9
+ return ops.dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm)
10
+
11
+ def dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm):
12
+ return ops.dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm)
13
+
14
+ def dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm):
15
+ return ops.dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm)
16
+
17
+ def dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm):
18
+ return ops.dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm)
19
+
20
+ __all__ = [
21
+ "layers",
22
+ "dropout_add_ln_fwd",
23
+ "dropout_add_ln_bwd",
24
+ "dropout_add_ln_parallel_residual_fwd",
25
+ "dropout_add_ln_parallel_residual_bwd",
26
+ ]
build/torch27-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_4e9c226_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7db683e74d55a1a71dc520a504521af3f08fb07724675d2097ce3d4ab3481e3d
3
+ size 246751936
build/torch27-cxx11-cu128-x86_64-linux/layer_norm/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _layer_norm_4e9c226_dirty
3
+ ops = torch.ops._layer_norm_4e9c226_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_layer_norm_4e9c226_dirty::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/layer_norm/layers.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ._ops import ops
5
+
6
+
7
+ class LayerNorm(nn.Module):
8
+ weight: torch.Tensor
9
+ variance_epsilon: float
10
+
11
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
12
+ return ops.dropout_add_ln_fwd(
13
+ hidden_states,
14
+ gamma = self.weight,
15
+ beta = None,
16
+ rowscale = None,
17
+ colscale = None,
18
+ x0_subset = None,
19
+ z_subset = None,
20
+ dropout_p = 0,
21
+ epsilon = self.variance_epsilon,
22
+ rowscale_const = 1.0,
23
+ z_numrows = hidden_states.shape[1],
24
+ gen = None,
25
+ residual_in_fp32 = False,
26
+ is_rms_norm = False,
27
+ )
28
+
29
+ class LlamaRMSNorm(nn.Module):
30
+ weight: torch.Tensor
31
+ variance_epsilon: float
32
+
33
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
34
+ return ops.dropout_add_ln_fwd(
35
+ hidden_states,
36
+ gamma = self.weight,
37
+ beta = None,
38
+ rowscale = None,
39
+ colscale = None,
40
+ x0_subset = None,
41
+ z_subset = None,
42
+ dropout_p = 0,
43
+ epsilon = self.variance_epsilon,
44
+ rowscale_const = 1.0,
45
+ z_numrows = hidden_states.shape[1],
46
+ gen = None,
47
+ residual_in_fp32 = False,
48
+ is_rms_norm = True,
49
+ )
build/torch28-cxx11-cu126-x86_64-linux/layer_norm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ._ops import ops
5
+
6
+ from . import layers
7
+
8
+ def dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm):
9
+ return ops.dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm)
10
+
11
+ def dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm):
12
+ return ops.dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm)
13
+
14
+ def dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm):
15
+ return ops.dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm)
16
+
17
+ def dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm):
18
+ return ops.dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm)
19
+
20
+ __all__ = [
21
+ "layers",
22
+ "dropout_add_ln_fwd",
23
+ "dropout_add_ln_bwd",
24
+ "dropout_add_ln_parallel_residual_fwd",
25
+ "dropout_add_ln_parallel_residual_bwd",
26
+ ]
build/torch28-cxx11-cu126-x86_64-linux/layer_norm/_layer_norm_4e9c226_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b28a4d7885c08614b479490306561990c4cf6e5958dedd5ce59c2ee10bd0f0a
3
+ size 247115408
build/torch28-cxx11-cu126-x86_64-linux/layer_norm/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _layer_norm_4e9c226_dirty
3
+ ops = torch.ops._layer_norm_4e9c226_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_layer_norm_4e9c226_dirty::{op_name}"
build/torch28-cxx11-cu126-x86_64-linux/layer_norm/layers.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ._ops import ops
5
+
6
+
7
+ class LayerNorm(nn.Module):
8
+ weight: torch.Tensor
9
+ variance_epsilon: float
10
+
11
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
12
+ return ops.dropout_add_ln_fwd(
13
+ hidden_states,
14
+ gamma = self.weight,
15
+ beta = None,
16
+ rowscale = None,
17
+ colscale = None,
18
+ x0_subset = None,
19
+ z_subset = None,
20
+ dropout_p = 0,
21
+ epsilon = self.variance_epsilon,
22
+ rowscale_const = 1.0,
23
+ z_numrows = hidden_states.shape[1],
24
+ gen = None,
25
+ residual_in_fp32 = False,
26
+ is_rms_norm = False,
27
+ )
28
+
29
+ class LlamaRMSNorm(nn.Module):
30
+ weight: torch.Tensor
31
+ variance_epsilon: float
32
+
33
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
34
+ return ops.dropout_add_ln_fwd(
35
+ hidden_states,
36
+ gamma = self.weight,
37
+ beta = None,
38
+ rowscale = None,
39
+ colscale = None,
40
+ x0_subset = None,
41
+ z_subset = None,
42
+ dropout_p = 0,
43
+ epsilon = self.variance_epsilon,
44
+ rowscale_const = 1.0,
45
+ z_numrows = hidden_states.shape[1],
46
+ gen = None,
47
+ residual_in_fp32 = False,
48
+ is_rms_norm = True,
49
+ )
build/torch28-cxx11-cu128-x86_64-linux/layer_norm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ._ops import ops
5
+
6
+ from . import layers
7
+
8
+ def dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm):
9
+ return ops.dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm)
10
+
11
+ def dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm):
12
+ return ops.dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm)
13
+
14
+ def dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm):
15
+ return ops.dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm)
16
+
17
+ def dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm):
18
+ return ops.dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm)
19
+
20
+ __all__ = [
21
+ "layers",
22
+ "dropout_add_ln_fwd",
23
+ "dropout_add_ln_bwd",
24
+ "dropout_add_ln_parallel_residual_fwd",
25
+ "dropout_add_ln_parallel_residual_bwd",
26
+ ]
build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_layer_norm_4e9c226_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69c897ea7e96a6988909ac3878f74baa2b598b0301a2ee3227f9f1c9804fb64d
3
+ size 246756512
build/torch28-cxx11-cu128-x86_64-linux/layer_norm/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _layer_norm_4e9c226_dirty
3
+ ops = torch.ops._layer_norm_4e9c226_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_layer_norm_4e9c226_dirty::{op_name}"
build/torch28-cxx11-cu128-x86_64-linux/layer_norm/layers.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ._ops import ops
5
+
6
+
7
+ class LayerNorm(nn.Module):
8
+ weight: torch.Tensor
9
+ variance_epsilon: float
10
+
11
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
12
+ return ops.dropout_add_ln_fwd(
13
+ hidden_states,
14
+ gamma = self.weight,
15
+ beta = None,
16
+ rowscale = None,
17
+ colscale = None,
18
+ x0_subset = None,
19
+ z_subset = None,
20
+ dropout_p = 0,
21
+ epsilon = self.variance_epsilon,
22
+ rowscale_const = 1.0,
23
+ z_numrows = hidden_states.shape[1],
24
+ gen = None,
25
+ residual_in_fp32 = False,
26
+ is_rms_norm = False,
27
+ )
28
+
29
+ class LlamaRMSNorm(nn.Module):
30
+ weight: torch.Tensor
31
+ variance_epsilon: float
32
+
33
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
34
+ return ops.dropout_add_ln_fwd(
35
+ hidden_states,
36
+ gamma = self.weight,
37
+ beta = None,
38
+ rowscale = None,
39
+ colscale = None,
40
+ x0_subset = None,
41
+ z_subset = None,
42
+ dropout_p = 0,
43
+ epsilon = self.variance_epsilon,
44
+ rowscale_const = 1.0,
45
+ z_numrows = hidden_states.shape[1],
46
+ gen = None,
47
+ residual_in_fp32 = False,
48
+ is_rms_norm = True,
49
+ )
build/torch28-cxx11-cu129-x86_64-linux/layer_norm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ._ops import ops
5
+
6
+ from . import layers
7
+
8
+ def dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm):
9
+ return ops.dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm)
10
+
11
+ def dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm):
12
+ return ops.dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm)
13
+
14
+ def dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm):
15
+ return ops.dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm)
16
+
17
+ def dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm):
18
+ return ops.dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm)
19
+
20
+ __all__ = [
21
+ "layers",
22
+ "dropout_add_ln_fwd",
23
+ "dropout_add_ln_bwd",
24
+ "dropout_add_ln_parallel_residual_fwd",
25
+ "dropout_add_ln_parallel_residual_bwd",
26
+ ]
build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_layer_norm_4e9c226_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:594fd2ab65b273a4fee370bab7e03cb79cbc9c320eb37364466940a60ef154fa
3
+ size 248443760
build/torch28-cxx11-cu129-x86_64-linux/layer_norm/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _layer_norm_4e9c226_dirty
3
+ ops = torch.ops._layer_norm_4e9c226_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_layer_norm_4e9c226_dirty::{op_name}"
build/torch28-cxx11-cu129-x86_64-linux/layer_norm/layers.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ._ops import ops
5
+
6
+
7
+ class LayerNorm(nn.Module):
8
+ weight: torch.Tensor
9
+ variance_epsilon: float
10
+
11
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
12
+ return ops.dropout_add_ln_fwd(
13
+ hidden_states,
14
+ gamma = self.weight,
15
+ beta = None,
16
+ rowscale = None,
17
+ colscale = None,
18
+ x0_subset = None,
19
+ z_subset = None,
20
+ dropout_p = 0,
21
+ epsilon = self.variance_epsilon,
22
+ rowscale_const = 1.0,
23
+ z_numrows = hidden_states.shape[1],
24
+ gen = None,
25
+ residual_in_fp32 = False,
26
+ is_rms_norm = False,
27
+ )
28
+
29
+ class LlamaRMSNorm(nn.Module):
30
+ weight: torch.Tensor
31
+ variance_epsilon: float
32
+
33
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
34
+ return ops.dropout_add_ln_fwd(
35
+ hidden_states,
36
+ gamma = self.weight,
37
+ beta = None,
38
+ rowscale = None,
39
+ colscale = None,
40
+ x0_subset = None,
41
+ z_subset = None,
42
+ dropout_p = 0,
43
+ epsilon = self.variance_epsilon,
44
+ rowscale_const = 1.0,
45
+ z_numrows = hidden_states.shape[1],
46
+ gen = None,
47
+ residual_in_fp32 = False,
48
+ is_rms_norm = True,
49
+ )
cmake/hipify.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # From vLLM: https://github.com/vllm-project/vllm/blob/main/cmake/hipify.py
5
+
6
+ #
7
+ # A command line tool for running pytorch's hipify preprocessor on CUDA
8
+ # source files.
9
+ #
10
+ # See https://github.com/ROCm/hipify_torch
11
+ # and <torch install dir>/utils/hipify/hipify_python.py
12
+ #
13
+
14
+ import argparse
15
+ import os
16
+ import shutil
17
+
18
+ from torch.utils.hipify.hipify_python import hipify
19
+
20
+ if __name__ == '__main__':
21
+ parser = argparse.ArgumentParser()
22
+
23
+ # Project directory where all the source + include files live.
24
+ parser.add_argument(
25
+ "-p",
26
+ "--project_dir",
27
+ help="The project directory.",
28
+ )
29
+
30
+ # Directory where hipified files are written.
31
+ parser.add_argument(
32
+ "-o",
33
+ "--output_dir",
34
+ help="The output directory.",
35
+ )
36
+
37
+ # Source files to convert.
38
+ parser.add_argument("sources",
39
+ help="Source files to hipify.",
40
+ nargs="*",
41
+ default=[])
42
+
43
+ args = parser.parse_args()
44
+
45
+ # Limit include scope to project_dir only
46
+ includes = [os.path.join(args.project_dir, '*')]
47
+
48
+ # Get absolute path for all source files.
49
+ extra_files = [os.path.abspath(s) for s in args.sources]
50
+
51
+ # Copy sources from project directory to output directory.
52
+ # The directory might already exist to hold object files so we ignore that.
53
+ shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True)
54
+
55
+ hipify_result = hipify(project_directory=args.project_dir,
56
+ output_directory=args.output_dir,
57
+ header_include_dirs=[],
58
+ includes=includes,
59
+ extra_files=extra_files,
60
+ show_detailed=True,
61
+ is_pytorch_extension=True,
62
+ hipify_extra_files_only=True)
63
+
64
+ hipified_sources = []
65
+ for source in args.sources:
66
+ s_abs = os.path.abspath(source)
67
+ hipified_s_abs = (hipify_result[s_abs].hipified_path if
68
+ (s_abs in hipify_result
69
+ and hipify_result[s_abs].hipified_path is not None)
70
+ else s_abs)
71
+ hipified_sources.append(hipified_s_abs)
72
+
73
+ assert (len(hipified_sources) == len(args.sources))
74
+
75
+ # Print hipified source files.
76
+ print("\n".join(hipified_sources))
cmake/utils.cmake ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Vendored from vLLM:
2
+ #
3
+ # https://github.com/vllm-project/vllm/blob/main/cmake/utils.cmake
4
+ #
5
+ # Attempt to find the python package that uses the same python executable as
6
+ # `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`.
7
+ #
8
+ macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
9
+ file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
10
+ set(Python_EXECUTABLE ${EXECUTABLE})
11
+ find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
12
+ if (NOT Python_FOUND)
13
+ message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
14
+ endif()
15
+ set(_VER "${Python_VERSION_MAJOR}.${Python_VERSION_MINOR}")
16
+ set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN})
17
+ if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST)
18
+ message(FATAL_ERROR
19
+ "Python version (${_VER}) is not one of the supported versions: "
20
+ "${_SUPPORTED_VERSIONS_LIST}.")
21
+ endif()
22
+ message(STATUS "Found python matching: ${EXECUTABLE}.")
23
+ endmacro()
24
+
25
+ #
26
+ # Run `EXPR` in python. The standard output of python is stored in `OUT` and
27
+ # has trailing whitespace stripped. If an error is encountered when running
28
+ # python, a fatal message `ERR_MSG` is issued.
29
+ #
30
+ function (run_python OUT EXPR ERR_MSG)
31
+ execute_process(
32
+ COMMAND
33
+ "${Python_EXECUTABLE}" "-c" "${EXPR}"
34
+ OUTPUT_VARIABLE PYTHON_OUT
35
+ RESULT_VARIABLE PYTHON_ERROR_CODE
36
+ ERROR_VARIABLE PYTHON_STDERR
37
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
38
+
39
+ if(NOT PYTHON_ERROR_CODE EQUAL 0)
40
+ message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}")
41
+ endif()
42
+ set(${OUT} ${PYTHON_OUT} PARENT_SCOPE)
43
+ endfunction()
44
+
45
+ # Run `EXPR` in python after importing `PKG`. Use the result of this to extend
46
+ # `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported.
47
+ macro (append_cmake_prefix_path PKG EXPR)
48
+ run_python(_PREFIX_PATH
49
+ "import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path")
50
+ list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH})
51
+ endmacro()
52
+
53
+ #
54
+ # Add a target named `hipify${NAME}` that runs the hipify preprocessor on a set
55
+ # of CUDA source files. The names of the corresponding "hipified" sources are
56
+ # stored in `OUT_SRCS`.
57
+ #
58
+ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
59
+ #
60
+ # Split into C++ and non-C++ (i.e. CUDA) sources.
61
+ #
62
+ set(NODUP_SRCS ${ORIG_SRCS})
63
+ list(REMOVE_DUPLICATES NODUP_SRCS)
64
+ set(SRCS ${NODUP_SRCS})
65
+ set(CXX_SRCS ${NODUP_SRCS})
66
+ list(FILTER SRCS INCLUDE REGEX "\.cu$")
67
+ list(FILTER CXX_SRCS EXCLUDE REGEX "\.cu$")
68
+
69
+ #
70
+ # Generate ROCm/HIP source file names from CUDA file names.
71
+ # Since HIP files are generated code, they will appear in the build area
72
+ # `CMAKE_CURRENT_BINARY_DIR` directory rather than the original csrc dir.
73
+ #
74
+ set(HIP_SRCS)
75
+ foreach (SRC ${SRCS})
76
+ get_source_file_property(include_dirs "${SRC}" INCLUDE_DIRECTORIES)
77
+ string(REGEX REPLACE "\.cu$" "\.hip" SRC ${SRC})
78
+ string(REGEX REPLACE "cuda" "hip" SRC ${SRC})
79
+
80
+ if(include_dirs)
81
+ # Copy over include directories from the original CUDA file.
82
+ set_source_files_properties(
83
+ ${SRC}
84
+ PROPERTIES INCLUDE_DIRECTORIES "${include_dirs}")
85
+ endif()
86
+
87
+ list(APPEND HIP_SRCS "${CMAKE_CURRENT_BINARY_DIR}/${SRC}")
88
+ endforeach()
89
+
90
+ add_custom_target(
91
+ hipify${NAME}
92
+ COMMAND "${Python_EXECUTABLE}" ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR} -o ${CMAKE_CURRENT_BINARY_DIR} ${SRCS}
93
+ DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS}
94
+ BYPRODUCTS ${HIP_SRCS}
95
+ COMMENT "Running hipify on ${NAME} extension source files.")
96
+
97
+ # Swap out original extension sources with hipified sources.
98
+ list(APPEND HIP_SRCS ${CXX_SRCS})
99
+ set(${OUT_SRCS} ${HIP_SRCS} PARENT_SCOPE)
100
+ endfunction()
101
+
102
+ #
103
+ # Get additional GPU compiler flags from torch.
104
+ #
105
+ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
106
+ if (${GPU_LANG} STREQUAL "CUDA")
107
+ #
108
+ # Get common NVCC flags from torch.
109
+ #
110
+ run_python(GPU_FLAGS
111
+ "from torch.utils.cpp_extension import COMMON_NVCC_FLAGS; print(';'.join(COMMON_NVCC_FLAGS))"
112
+ "Failed to determine torch nvcc compiler flags")
113
+
114
+ if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
115
+ list(APPEND GPU_FLAGS "-DENABLE_FP8")
116
+ list(REMOVE_ITEM GPU_FLAGS
117
+ "-D__CUDA_NO_HALF_OPERATORS__"
118
+ "-D__CUDA_NO_HALF_CONVERSIONS__"
119
+ "-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
120
+ "-D__CUDA_NO_HALF2_OPERATORS__")
121
+ endif()
122
+
123
+ elseif(${GPU_LANG} STREQUAL "HIP")
124
+ #
125
+ # Get common HIP/HIPCC flags from torch.
126
+ #
127
+ run_python(GPU_FLAGS
128
+ "import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))"
129
+ "Failed to determine torch nvcc compiler flags")
130
+
131
+ list(APPEND GPU_FLAGS
132
+ "-DUSE_ROCM"
133
+ "-DENABLE_FP8"
134
+ "-U__HIP_NO_HALF_CONVERSIONS__"
135
+ "-U__HIP_NO_HALF_OPERATORS__"
136
+ "-fno-gpu-rdc")
137
+
138
+ endif()
139
+ set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE)
140
+ endfunction()
141
+
142
+ # Macro for converting a `gencode` version number to a cmake version number.
143
+ macro(string_to_ver OUT_VER IN_STR)
144
+ string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
145
+ endmacro()
146
+
147
+ #
148
+ # Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
149
+ # `CUDA_ARCH_FLAGS`.
150
+ #
151
+ # Example:
152
+ # CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
153
+ # clear_cuda_arches(CUDA_ARCH_FLAGS)
154
+ # CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
155
+ # CMAKE_CUDA_FLAGS="-Wall"
156
+ #
157
+ macro(clear_cuda_arches CUDA_ARCH_FLAGS)
158
+ # Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
159
+ string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
160
+ ${CMAKE_CUDA_FLAGS})
161
+
162
+ # Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
163
+ # and passed back via the `CUDA_ARCHITECTURES` property.
164
+ string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
165
+ ${CMAKE_CUDA_FLAGS})
166
+ endmacro()
167
+
168
+ #
169
+ # Extract unique CUDA architectures from a list of compute capabilities codes in
170
+ # the form `<major><minor>[<letter>]`, convert them to the form sort
171
+ # `<major>.<minor>`, dedupes them and then sorts them in ascending order and
172
+ # stores them in `OUT_ARCHES`.
173
+ #
174
+ # Example:
175
+ # CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a"
176
+ # extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS)
177
+ # OUT_ARCHES="7.5;...;9.0"
178
+ function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS)
179
+ set(_CUDA_ARCHES)
180
+ foreach(_ARCH ${CUDA_ARCH_FLAGS})
181
+ string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
182
+ if (_COMPUTE)
183
+ set(_COMPUTE ${CMAKE_MATCH_1})
184
+ endif()
185
+
186
+ string_to_ver(_COMPUTE_VER ${_COMPUTE})
187
+ list(APPEND _CUDA_ARCHES ${_COMPUTE_VER})
188
+ endforeach()
189
+
190
+ list(REMOVE_DUPLICATES _CUDA_ARCHES)
191
+ list(SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING)
192
+ set(${OUT_ARCHES} ${_CUDA_ARCHES} PARENT_SCOPE)
193
+ endfunction()
194
+
195
+ #
196
+ # For a specific file set the `-gencode` flag in compile options conditionally
197
+ # for the CUDA language.
198
+ #
199
+ # Example:
200
+ # set_gencode_flag_for_srcs(
201
+ # SRCS "foo.cu"
202
+ # ARCH "compute_75"
203
+ # CODE "sm_75")
204
+ # adds: "-gencode arch=compute_75,code=sm_75" to the compile options for
205
+ # `foo.cu` (only for the CUDA language).
206
+ #
207
+ macro(set_gencode_flag_for_srcs)
208
+ set(options)
209
+ set(oneValueArgs ARCH CODE)
210
+ set(multiValueArgs SRCS)
211
+ cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
212
+ "${multiValueArgs}" ${ARGN} )
213
+ set(_FLAG -gencode arch=${arg_ARCH},code=${arg_CODE})
214
+ set_property(
215
+ SOURCE ${arg_SRCS}
216
+ APPEND PROPERTY
217
+ COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:${_FLAG}>"
218
+ )
219
+
220
+ message(DEBUG "Setting gencode flag for ${arg_SRCS}: ${_FLAG}")
221
+ endmacro(set_gencode_flag_for_srcs)
222
+
223
+ #
224
+ # For a list of source files set the `-gencode` flags in the files specific
225
+ # compile options (specifically for the CUDA language).
226
+ #
227
+ # arguments are:
228
+ # SRCS: list of source files
229
+ # CUDA_ARCHS: list of CUDA architectures in the form `<major>.<minor>[letter]`
230
+ # BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built
231
+ # for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS
232
+ # that is larger than BUILD_PTX_FOR_ARCH.
233
+ #
234
+ macro(set_gencode_flags_for_srcs)
235
+ set(options)
236
+ set(oneValueArgs BUILD_PTX_FOR_ARCH)
237
+ set(multiValueArgs SRCS CUDA_ARCHS)
238
+ cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
239
+ "${multiValueArgs}" ${ARGN} )
240
+
241
+ foreach(_ARCH ${arg_CUDA_ARCHS})
242
+ # handle +PTX suffix: generate both sm and ptx codes if requested
243
+ string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
244
+ if(NOT _HAS_PTX EQUAL -1)
245
+ string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
246
+ string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
247
+ set_gencode_flag_for_srcs(
248
+ SRCS ${arg_SRCS}
249
+ ARCH "compute_${_STRIPPED_ARCH}"
250
+ CODE "sm_${_STRIPPED_ARCH}")
251
+ set_gencode_flag_for_srcs(
252
+ SRCS ${arg_SRCS}
253
+ ARCH "compute_${_STRIPPED_ARCH}"
254
+ CODE "compute_${_STRIPPED_ARCH}")
255
+ else()
256
+ string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
257
+ set_gencode_flag_for_srcs(
258
+ SRCS ${arg_SRCS}
259
+ ARCH "compute_${_STRIPPED_ARCH}"
260
+ CODE "sm_${_STRIPPED_ARCH}")
261
+ endif()
262
+ endforeach()
263
+
264
+ if (${arg_BUILD_PTX_FOR_ARCH})
265
+ list(SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
266
+ list(GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH)
267
+ if (_HIGHEST_ARCH VERSION_GREATER_EQUAL ${arg_BUILD_PTX_FOR_ARCH})
268
+ string(REPLACE "." "" _PTX_ARCH "${arg_BUILD_PTX_FOR_ARCH}")
269
+ set_gencode_flag_for_srcs(
270
+ SRCS ${arg_SRCS}
271
+ ARCH "compute_${_PTX_ARCH}"
272
+ CODE "compute_${_PTX_ARCH}")
273
+ endif()
274
+ endif()
275
+ endmacro()
276
+
277
+ #
278
+ # For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
279
+ # `<major>.<minor>[letter]` compute the "loose intersection" with the
280
+ # `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
281
+ # `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
282
+ # is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
283
+ # architecture in `SRC_CUDA_ARCHS`.
284
+ # The loose intersection is defined as:
285
+ # { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
286
+ # where `<=` is the version comparison operator.
287
+ # In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
288
+ # in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
289
+ # We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is
290
+ # in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add
291
+ # x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
292
+ # The result is stored in `OUT_CUDA_ARCHS`.
293
+ #
294
+ # Example:
295
+ # SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a"
296
+ # TGT_CUDA_ARCHS="8.0;8.9;9.0"
297
+ # cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
298
+ # OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
299
+ #
300
+ # Example With PTX:
301
+ # SRC_CUDA_ARCHS="8.0+PTX"
302
+ # TGT_CUDA_ARCHS="9.0"
303
+ # cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
304
+ # OUT_CUDA_ARCHS="8.0+PTX"
305
+ #
306
+ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
307
+ set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
308
+ set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})
309
+
310
+ # handle +PTX suffix: separate base arch for matching, record PTX requests
311
+ set(_PTX_ARCHS)
312
+ foreach(_arch ${_SRC_CUDA_ARCHS})
313
+ if(_arch MATCHES "\\+PTX$")
314
+ string(REPLACE "+PTX" "" _base "${_arch}")
315
+ list(APPEND _PTX_ARCHS "${_base}")
316
+ list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
317
+ list(APPEND _SRC_CUDA_ARCHS "${_base}")
318
+ endif()
319
+ endforeach()
320
+ list(REMOVE_DUPLICATES _PTX_ARCHS)
321
+ list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
322
+
323
+ # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
324
+ # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
325
+ set(_CUDA_ARCHS)
326
+ foreach(_arch ${_SRC_CUDA_ARCHS})
327
+ if(_arch MATCHES "\\a$")
328
+ list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
329
+ string(REPLACE "a" "" _base "${_arch}")
330
+ if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
331
+ list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
332
+ list(APPEND _CUDA_ARCHS "${_arch}")
333
+ endif()
334
+ endif()
335
+ endforeach()
336
+
337
+ list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
338
+
339
+ # for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
340
+ # is less or equal to ARCH (but has the same major version since SASS binary
341
+ # compatibility is only forward compatible within the same major version).
342
+ foreach(_ARCH ${_TGT_CUDA_ARCHS})
343
+ set(_TMP_ARCH)
344
+ # Extract the major version of the target arch
345
+ string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
346
+ foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
347
+ # Extract the major version of the source arch
348
+ string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
349
+ # Check version-less-or-equal, and allow PTX arches to match across majors
350
+ if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
351
+ if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
352
+ set(_TMP_ARCH "${_SRC_ARCH}")
353
+ endif()
354
+ else()
355
+ # If we hit a version greater than the target, we can break
356
+ break()
357
+ endif()
358
+ endforeach()
359
+
360
+ # If we found a matching _TMP_ARCH, append it to _CUDA_ARCHS
361
+ if (_TMP_ARCH)
362
+ list(APPEND _CUDA_ARCHS "${_TMP_ARCH}")
363
+ endif()
364
+ endforeach()
365
+
366
+ list(REMOVE_DUPLICATES _CUDA_ARCHS)
367
+
368
+ # reapply +PTX suffix to architectures that requested PTX
369
+ set(_FINAL_ARCHS)
370
+ foreach(_arch ${_CUDA_ARCHS})
371
+ if(_arch IN_LIST _PTX_ARCHS)
372
+ list(APPEND _FINAL_ARCHS "${_arch}+PTX")
373
+ else()
374
+ list(APPEND _FINAL_ARCHS "${_arch}")
375
+ endif()
376
+ endforeach()
377
+ set(_CUDA_ARCHS ${_FINAL_ARCHS})
378
+
379
+ set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
380
+ endfunction()
381
+
382
+ #
383
+ # For the given `SRC_ROCM_ARCHS` list of architecture versions in the form
384
+ # `<name>` compute the "loose intersection" with the `TGT_ROCM_ARCHS` list.
385
+ # The loose intersection is defined as:
386
+ # { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
387
+ # where `<=` is the version comparison operator.
388
+ # In other words, for each version in `TGT_ROCM_ARCHS` find the highest version
389
+ # in `SRC_ROCM_ARCHS` that is less or equal to the version in `TGT_ROCM_ARCHS`.
390
+ # The result is stored in `OUT_ROCM_ARCHS`.
391
+ #
392
+ # Example:
393
+ # SRC_ROCM_ARCHS="gfx900;gfx906;gfx908;gfx90a"
394
+ # TGT_ROCM_ARCHS="gfx906;gfx908;gfx1030"
395
+ # hip_archs_loose_intersection(OUT_ROCM_ARCHS SRC_ROCM_ARCHS TGT_ROCM_ARCHS)
396
+ # OUT_ROCM_ARCHS="gfx906;gfx908"
397
+ #
398
+ function(hip_archs_loose_intersection OUT_ROCM_ARCHS SRC_ROCM_ARCHS TGT_ROCM_ARCHS)
399
+ list(REMOVE_DUPLICATES SRC_ROCM_ARCHS)
400
+
401
+ # ROCm architectures are typically in format gfxNNN or gfxNNNx where N is a digit
402
+ # and x is a letter. We can sort them by string comparison which works for this format.
403
+ list(SORT SRC_ROCM_ARCHS COMPARE STRING ORDER ASCENDING)
404
+
405
+ set(_ROCM_ARCHS)
406
+
407
+ # Find the intersection of supported architectures
408
+ foreach(_SRC_ARCH ${SRC_ROCM_ARCHS})
409
+ if(_SRC_ARCH IN_LIST TGT_ROCM_ARCHS)
410
+ list(APPEND _ROCM_ARCHS ${_SRC_ARCH})
411
+ endif()
412
+ endforeach()
413
+
414
+ list(REMOVE_DUPLICATES _ROCM_ARCHS)
415
+ set(${OUT_ROCM_ARCHS} ${_ROCM_ARCHS} PARENT_SCOPE)
416
+ endfunction()
417
+
418
+ #
419
+ # Override the GPU architectures detected by cmake/torch and filter them by
420
+ # `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
421
+ # `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set
422
+ # the architectures on a per file basis.
423
+ #
424
+ # Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.
425
+ #
426
+ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
427
+ set(_GPU_SUPPORTED_ARCHES_LIST ${GPU_SUPPORTED_ARCHES} ${ARGN})
428
+ message(STATUS "${GPU_LANG} supported arches: ${_GPU_SUPPORTED_ARCHES_LIST}")
429
+
430
+ if (${GPU_LANG} STREQUAL "HIP")
431
+ #
432
+ # `GPU_ARCHES` controls the `--offload-arch` flags.
433
+ #
434
+ # If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list,
435
+ # if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling
436
+ # "rocm_agent_enumerator" in "enable_language(HIP)"
437
+ # (in file Modules/CMakeDetermineHIPCompiler.cmake)
438
+ #
439
+ if(DEFINED ENV{PYTORCH_ROCM_ARCH})
440
+ set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH})
441
+ else()
442
+ set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES})
443
+ endif()
444
+ #
445
+ # Find the intersection of the supported + detected architectures to
446
+ # set the module architecture flags.
447
+ #
448
+ set(${GPU_ARCHES})
449
+ foreach (_ARCH ${HIP_ARCHITECTURES})
450
+ if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
451
+ list(APPEND ${GPU_ARCHES} ${_ARCH})
452
+ endif()
453
+ endforeach()
454
+
455
+ if(NOT ${GPU_ARCHES})
456
+ message(FATAL_ERROR
457
+ "None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
458
+ " supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
459
+ endif()
460
+ endif()
461
+ endmacro()
462
+
463
+ #
464
+ # Define a target named `GPU_MOD_NAME` for a single extension. The
465
+ # arguments are:
466
+ #
467
+ # DESTINATION <dest> - Module destination directory.
468
+ # LANGUAGE <lang> - The GPU language for this module, e.g CUDA, HIP,
469
+ # etc.
470
+ # SOURCES <sources> - List of source files relative to CMakeLists.txt
471
+ # directory.
472
+ #
473
+ # Optional arguments:
474
+ #
475
+ # ARCHITECTURES <arches> - A list of target GPU architectures in cmake
476
+ # format.
477
+ # Refer `CMAKE_CUDA_ARCHITECTURES` documentation
478
+ # and `CMAKE_HIP_ARCHITECTURES` for more info.
479
+ # ARCHITECTURES will use cmake's defaults if
480
+ # not provided.
481
+ # COMPILE_FLAGS <flags> - Extra compiler flags passed to NVCC/hip.
482
+ # INCLUDE_DIRECTORIES <dirs> - Extra include directories.
483
+ # LIBRARIES <libraries> - Extra link libraries.
484
+ # WITH_SOABI - Generate library with python SOABI suffix name.
485
+ # USE_SABI <version> - Use python stable api <version>
486
+ #
487
+ # Note: optimization level/debug info is set via cmake build type.
488
+ #
489
+ function (define_gpu_extension_target GPU_MOD_NAME)
490
+ cmake_parse_arguments(PARSE_ARGV 1
491
+ GPU
492
+ "WITH_SOABI"
493
+ "DESTINATION;LANGUAGE;USE_SABI"
494
+ "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
495
+
496
+ # Add hipify preprocessing step when building with HIP/ROCm.
497
+ if (GPU_LANGUAGE STREQUAL "HIP")
498
+ hipify_sources_target(GPU_SOURCES ${GPU_MOD_NAME} "${GPU_SOURCES}")
499
+ endif()
500
+
501
+ if (GPU_WITH_SOABI)
502
+ set(GPU_WITH_SOABI WITH_SOABI)
503
+ else()
504
+ set(GPU_WITH_SOABI)
505
+ endif()
506
+
507
+ if (GPU_USE_SABI)
508
+ Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
509
+ else()
510
+ Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
511
+ endif()
512
+
513
+ if (GPU_LANGUAGE STREQUAL "HIP")
514
+ # Make this target dependent on the hipify preprocessor step.
515
+ add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME})
516
+ endif()
517
+
518
+ if (GPU_ARCHITECTURES)
519
+ set_target_properties(${GPU_MOD_NAME} PROPERTIES
520
+ ${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}")
521
+ endif()
522
+
523
+ set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17)
524
+
525
+ target_compile_options(${GPU_MOD_NAME} PRIVATE
526
+ $<$<COMPILE_LANGUAGE:${GPU_LANGUAGE}>:${GPU_COMPILE_FLAGS}>)
527
+
528
+ target_compile_definitions(${GPU_MOD_NAME} PRIVATE
529
+ "-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}")
530
+
531
+ target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
532
+ ${GPU_INCLUDE_DIRECTORIES})
533
+
534
+ target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES})
535
+
536
+ # Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of
537
+ # dependencies that are not necessary and may not be installed.
538
+ if (GPU_LANGUAGE STREQUAL "CUDA")
539
+ target_link_libraries(${GPU_MOD_NAME} PRIVATE CUDA::cudart)
540
+ else()
541
+ target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
542
+ endif()
543
+
544
+ install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME})
545
+ endfunction()
flake.lock ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
+ "locked": {
20
+ "lastModified": 1747046372,
21
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
22
+ "owner": "edolstra",
23
+ "repo": "flake-compat",
24
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
25
+ "type": "github"
26
+ },
27
+ "original": {
28
+ "owner": "edolstra",
29
+ "repo": "flake-compat",
30
+ "type": "github"
31
+ }
32
+ },
33
+ "flake-utils": {
34
+ "inputs": {
35
+ "systems": "systems"
36
+ },
37
+ "locked": {
38
+ "lastModified": 1731533236,
39
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
+ "owner": "numtide",
41
+ "repo": "flake-utils",
42
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
+ "type": "github"
44
+ },
45
+ "original": {
46
+ "owner": "numtide",
47
+ "repo": "flake-utils",
48
+ "type": "github"
49
+ }
50
+ },
51
+ "flake-utils_2": {
52
+ "inputs": {
53
+ "systems": "systems_2"
54
+ },
55
+ "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
+ "type": "github"
62
+ },
63
+ "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
+ "type": "github"
67
+ }
68
+ },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
+ "locked": {
76
+ "lastModified": 1757675377,
77
+ "narHash": "sha256-JQKZOI1ZYO4faJnanuoTXziSmqzXe5rEFSGliWDWqWw=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "faf3354403a7381958d08e826c15fe30f6986a4f",
81
+ "type": "github"
82
+ },
83
+ "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
86
+ "type": "github"
87
+ }
88
+ },
89
+ "kernel-builder": {
90
+ "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
+ "nixpkgs": [
95
+ "kernel-builder",
96
+ "hf-nix",
97
+ "nixpkgs"
98
+ ]
99
+ },
100
+ "locked": {
101
+ "lastModified": 1758103102,
102
+ "narHash": "sha256-z9E9FxuxuxUztG5DbUcOvKBHvd27gBY9617t9x2QE6M=",
103
+ "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "94369928dc09ea7753c58495e3e406ac26f6c378",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "huggingface",
110
+ "repo": "kernel-builder",
111
+ "type": "github"
112
+ }
113
+ },
114
+ "nixpkgs": {
115
+ "locked": {
116
+ "lastModified": 1755963616,
117
+ "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=",
118
+ "owner": "nixos",
119
+ "repo": "nixpkgs",
120
+ "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4",
121
+ "type": "github"
122
+ },
123
+ "original": {
124
+ "owner": "nixos",
125
+ "ref": "nixos-unstable-small",
126
+ "repo": "nixpkgs",
127
+ "type": "github"
128
+ }
129
+ },
130
+ "root": {
131
+ "inputs": {
132
+ "kernel-builder": "kernel-builder"
133
+ }
134
+ },
135
+ "systems": {
136
+ "locked": {
137
+ "lastModified": 1681028828,
138
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
+ "owner": "nix-systems",
140
+ "repo": "default",
141
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
+ "type": "github"
143
+ },
144
+ "original": {
145
+ "owner": "nix-systems",
146
+ "repo": "default",
147
+ "type": "github"
148
+ }
149
+ },
150
+ "systems_2": {
151
+ "locked": {
152
+ "lastModified": 1681028828,
153
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
+ "owner": "nix-systems",
155
+ "repo": "default",
156
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
+ "type": "github"
158
+ },
159
+ "original": {
160
+ "owner": "nix-systems",
161
+ "repo": "default",
162
+ "type": "github"
163
+ }
164
+ }
165
+ },
166
+ "root": "root",
167
+ "version": 7
168
+ }
pyproject.toml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = [
3
+ "cmake>=3.26",
4
+ "ninja",
5
+ "packaging",
6
+ "setuptools>=61",
7
+ "torch",
8
+ "wheel",
9
+ ]
10
+ build-backend = "setuptools.build_meta"
setup.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from shutil import which, move
4
+ import subprocess
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ from setuptools import Extension, find_packages, setup
9
+ from setuptools.command.build_ext import build_ext
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def is_sccache_available() -> bool:
15
+ return which("sccache") is not None
16
+
17
+
18
+ def is_ccache_available() -> bool:
19
+ return which("ccache") is not None
20
+
21
+
22
+ def is_ninja_available() -> bool:
23
+ return which("ninja") is not None
24
+
25
+
26
+ class CMakeExtension(Extension):
27
+ def __init__(self, name: str, sourcedir: str = "") -> None:
28
+ super().__init__(name, sources=[], py_limited_api=True)
29
+ self.sourcedir = os.fspath(Path(sourcedir).resolve())
30
+
31
+
32
+ class CMakeBuild(build_ext):
33
+ def build_extension(self, ext: CMakeExtension) -> None:
34
+ ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name)
35
+ extdir = ext_fullpath.parent.resolve()
36
+
37
+ debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
38
+ cfg = "Debug" if debug else "Release"
39
+
40
+ cmake_generator = os.environ.get("CMAKE_GENERATOR", "")
41
+
42
+ # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON
43
+ # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code
44
+ # from Python.
45
+ cmake_args = [
46
+ f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}",
47
+ f"-DPython_EXECUTABLE={sys.executable}",
48
+ f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm
49
+ ]
50
+ build_args = []
51
+ if "CMAKE_ARGS" in os.environ:
52
+ cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
53
+
54
+ if not cmake_generator or cmake_generator == "Ninja":
55
+ try:
56
+ import ninja
57
+
58
+ ninja_executable_path = Path(ninja.BIN_DIR) / "ninja"
59
+ cmake_args += [
60
+ "-GNinja",
61
+ f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}",
62
+ ]
63
+ except ImportError:
64
+ pass
65
+
66
+ if is_sccache_available():
67
+ cmake_args += [
68
+ "-DCMAKE_C_COMPILER_LAUNCHER=sccache",
69
+ "-DCMAKE_CXX_COMPILER_LAUNCHER=sccache",
70
+ "-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache",
71
+ "-DCMAKE_HIP_COMPILER_LAUNCHER=sccache",
72
+ ]
73
+ elif is_ccache_available():
74
+ cmake_args += [
75
+ "-DCMAKE_C_COMPILER_LAUNCHER=ccache",
76
+ "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache",
77
+ "-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache",
78
+ "-DCMAKE_HIP_COMPILER_LAUNCHER=ccache",
79
+ ]
80
+
81
+ num_jobs = os.getenv("MAX_JOBS", None)
82
+ if num_jobs is not None:
83
+ num_jobs = int(num_jobs)
84
+ logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs)
85
+ else:
86
+ try:
87
+ # os.sched_getaffinity() isn't universally available, so fall
88
+ # back to os.cpu_count() if we get an error here.
89
+ num_jobs = len(os.sched_getaffinity(0))
90
+ except AttributeError:
91
+ num_jobs = os.cpu_count()
92
+
93
+ nvcc_threads = os.getenv("NVCC_THREADS", None)
94
+ if nvcc_threads is not None:
95
+ nvcc_threads = int(nvcc_threads)
96
+ logger.info(
97
+ "Using NVCC_THREADS=%d as the number of nvcc threads.", nvcc_threads
98
+ )
99
+ else:
100
+ nvcc_threads = 1
101
+ num_jobs = max(1, num_jobs // nvcc_threads)
102
+
103
+ build_args += [f"-j{num_jobs}"]
104
+ if sys.platform == "win32":
105
+ build_args += ["--config", cfg]
106
+
107
+ if nvcc_threads:
108
+ cmake_args += ["-DNVCC_THREADS={}".format(nvcc_threads)]
109
+
110
+ build_temp = Path(self.build_temp) / ext.name
111
+ if not build_temp.exists():
112
+ build_temp.mkdir(parents=True)
113
+
114
+ subprocess.run(
115
+ ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
116
+ )
117
+ subprocess.run(
118
+ ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
119
+ )
120
+ if sys.platform == "win32":
121
+ # Move the dylib one folder up for discovery.
122
+ for filename in os.listdir(extdir / cfg):
123
+ move(extdir / cfg / filename, extdir / filename)
124
+
125
+
126
+
127
+ setup(
128
+ name="layer_norm",
129
+ # The version is just a stub, it's not used by the final build artefact.
130
+ version="0.1.0",
131
+ ext_modules=[CMakeExtension("layer_norm._layer_norm_711aa42_dirty")],
132
+ cmdclass={"build_ext": CMakeBuild},
133
+ packages=find_packages(where="torch-ext", include=["layer_norm*"]),
134
+ package_dir={"": "torch-ext"},
135
+ zip_safe=False,
136
+ install_requires=["torch"],
137
+ python_requires=">=3.9",
138
+ )
torch-ext/layer_norm/_layer_norm_711aa42_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c824a0d2b400f4a89ccf293975ccfedc32733174dad4386a402149c440946674
3
+ size 247782208
torch-ext/layer_norm/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _layer_norm_711aa42_dirty
3
+ ops = torch.ops._layer_norm_711aa42_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_layer_norm_711aa42_dirty::{op_name}"
torch-ext/registration.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Registration macros from vLLM:
2
+ // https://github.com/vllm-project/vllm/blob/main/csrc/core/registration.h
3
+
4
+ #pragma once
5
+
6
+ #include <Python.h>
7
+
8
+ #define _CONCAT(A, B) A##B
9
+ #define CONCAT(A, B) _CONCAT(A, B)
10
+
11
+ #define _STRINGIFY(A) #A
12
+ #define STRINGIFY(A) _STRINGIFY(A)
13
+
14
+ // A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
15
+ // could be a macro instead of a literal token.
16
+ #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
17
+
18
+ // A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
19
+ // could be a macro instead of a literal token.
20
+ #define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
21
+ TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
22
+
23
+ // REGISTER_EXTENSION allows the shared library to be loaded and initialized
24
+ // via python's import statement.
25
+ #define REGISTER_EXTENSION(NAME) \
26
+ PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
27
+ static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
28
+ STRINGIFY(NAME), nullptr, 0, nullptr}; \
29
+ return PyModule_Create(&module); \
30
+ }
torch-ext/torch_binding.cpp CHANGED
@@ -3,15 +3,152 @@
3
  #include "registration.h"
4
  #include "torch_binding.h"
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
- ops.def("dropout_add_ln_fwd(Tensor input, Tensor gamma, Tensor beta, Tensor rowscale, Tensor colscale, Tensor x0_subset, Tensor z_subset, float dropout_p, float epsilon, float rowscale_const, int64_t z_numrows, Generator gen, bool residual_in_fp32, bool is_rms_norm) -> Tensor");
8
- ops.impl("dropout_add_ln_fwd", torch::kCUDA, &dropout_add_ln_fwd);
9
- ops.def("dropout_add_ln_bwd(Tensor dz, Tensor dx, Tensor x, Tensor mu, Tensor rsigma, Tensor gamma, Tensor rowscale, Tensor colscale, Tensor x0_subset, Tensor z_subset, float dropout_p, float rowscale_const, int64_t x0_numrows, bool has_residual, bool is_rms_norm) -> Tensor");
10
- ops.impl("dropout_add_ln_bwd", torch::kCUDA, &dropout_add_ln_bwd);
11
- ops.def("dropout_add_ln_parallel_residual_fwd(Tensor input, Tensor gamma0, Tensor beta0, Tensor gamma1, Tensor beta1, float dropout_p, float epsilon, Generator gen, bool residual_in_fp32, bool is_rms_norm) -> Tensor");
12
- ops.impl("dropout_add_ln_parallel_residual_fwd", torch::kCUDA, &dropout_add_ln_parallel_residual_fwd);
13
- ops.def("dropout_add_ln_parallel_residual_bwd(Tensor dz0, Tensor dz1, Tensor dx, Tensor x, Tensor mu, Tensor rsigma, Tensor gamma0, Tensor gamma1, float dropout_p, bool has_x1, bool has_residual, bool is_rms_norm) -> Tensor");
14
- ops.impl("dropout_add_ln_parallel_residual_bwd", torch::kCUDA, &dropout_add_ln_parallel_residual_bwd);
 
 
 
 
15
  }
16
 
17
- REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
3
  #include "registration.h"
4
  #include "torch_binding.h"
5
 
6
+ // Helper to turn Tensor? from schema (optional by value) into optional<const Tensor>& args
7
+ template <typename T>
8
+ static c10::optional<const at::Tensor> as_const_opt(const c10::optional<T>& v) {
9
+ if (v.has_value()) return c10::optional<const at::Tensor>(v.value());
10
+ return c10::optional<const at::Tensor>();
11
+ }
12
+
13
+ // Wrappers with dispatcher-friendly types (double scalars, optional Generator)
14
+ // Forward
15
+ static std::vector<at::Tensor> dropout_add_ln_fwd_wrap(
16
+ const at::Tensor& input,
17
+ const at::Tensor& gamma,
18
+ c10::optional<at::Tensor> beta,
19
+ c10::optional<at::Tensor> rowscale,
20
+ c10::optional<at::Tensor> colscale,
21
+ c10::optional<at::Tensor> x0_subset,
22
+ c10::optional<at::Tensor> z_subset,
23
+ double dropout_p,
24
+ double epsilon,
25
+ double rowscale_const,
26
+ int64_t z_numrows,
27
+ c10::optional<at::Generator> gen,
28
+ bool residual_in_fp32,
29
+ bool is_rms_norm) {
30
+
31
+ // residual is not exposed in this schema (None)
32
+ auto residual_c = c10::optional<const at::Tensor>();
33
+ auto beta_c = as_const_opt(beta);
34
+ auto rowscale_c = as_const_opt(rowscale);
35
+ auto colscale_c = as_const_opt(colscale);
36
+ auto x0_subset_c = as_const_opt(x0_subset);
37
+ auto z_subset_c = as_const_opt(z_subset);
38
+
39
+ return dropout_add_ln_fwd(
40
+ input, residual_c, gamma, beta_c, rowscale_c, colscale_c, x0_subset_c, z_subset_c,
41
+ static_cast<float>(dropout_p),
42
+ static_cast<float>(epsilon),
43
+ static_cast<float>(rowscale_const),
44
+ z_numrows, gen, residual_in_fp32, is_rms_norm);
45
+ }
46
+
47
+ // Backward
48
+ static std::vector<at::Tensor> dropout_add_ln_bwd_wrap(
49
+ const at::Tensor& dz,
50
+ c10::optional<at::Tensor> dx,
51
+ const at::Tensor& x,
52
+ c10::optional<at::Tensor> x0,
53
+ c10::optional<at::Tensor> dmask,
54
+ const at::Tensor& mu,
55
+ const at::Tensor& rsigma,
56
+ const at::Tensor& gamma,
57
+ c10::optional<at::Tensor> rowscale,
58
+ c10::optional<at::Tensor> colscale,
59
+ c10::optional<at::Tensor> x0_subset,
60
+ c10::optional<at::Tensor> z_subset,
61
+ double dropout_p,
62
+ double rowscale_const,
63
+ int64_t x0_numrows,
64
+ bool has_residual,
65
+ bool is_rms_norm) {
66
+
67
+ auto dx_c = as_const_opt(dx);
68
+ auto x0_c = as_const_opt(x0);
69
+ auto dmask_c = as_const_opt(dmask);
70
+ auto rowscale_c = as_const_opt(rowscale);
71
+ auto colscale_c = as_const_opt(colscale);
72
+ auto x0_subset_c = as_const_opt(x0_subset);
73
+ auto z_subset_c = as_const_opt(z_subset);
74
+
75
+ return dropout_add_ln_bwd(
76
+ dz, dx_c, x, x0_c, dmask_c, mu, rsigma, gamma,
77
+ rowscale_c, colscale_c, x0_subset_c, z_subset_c,
78
+ static_cast<float>(dropout_p),
79
+ static_cast<float>(rowscale_const),
80
+ x0_numrows, has_residual, is_rms_norm);
81
+ }
82
+
83
+ // Parallel forward
84
+ static std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd_wrap(
85
+ const at::Tensor& input,
86
+ c10::optional<at::Tensor> x1,
87
+ c10::optional<at::Tensor> residual,
88
+ const at::Tensor& gamma0,
89
+ c10::optional<at::Tensor> beta0,
90
+ c10::optional<at::Tensor> gamma1,
91
+ c10::optional<at::Tensor> beta1,
92
+ double dropout_p,
93
+ double epsilon,
94
+ c10::optional<at::Generator> gen,
95
+ bool residual_in_fp32,
96
+ bool is_rms_norm) {
97
+
98
+ auto x1_c = as_const_opt(x1);
99
+ auto residual_c = as_const_opt(residual);
100
+ auto beta0_c = as_const_opt(beta0);
101
+ auto gamma1_c = as_const_opt(gamma1);
102
+ auto beta1_c = as_const_opt(beta1);
103
+
104
+ return dropout_add_ln_parallel_residual_fwd(
105
+ input, x1_c, residual_c, gamma0, beta0_c, gamma1_c, beta1_c,
106
+ static_cast<float>(dropout_p),
107
+ static_cast<float>(epsilon),
108
+ gen, residual_in_fp32, is_rms_norm);
109
+ }
110
+
111
+ // Parallel backward
112
+ static std::vector<at::Tensor> dropout_add_ln_parallel_residual_bwd_wrap(
113
+ const at::Tensor& dz0,
114
+ c10::optional<at::Tensor> dz1,
115
+ c10::optional<at::Tensor> dx,
116
+ const at::Tensor& x,
117
+ c10::optional<at::Tensor> dmask0,
118
+ c10::optional<at::Tensor> dmask1,
119
+ const at::Tensor& mu,
120
+ const at::Tensor& rsigma,
121
+ const at::Tensor& gamma0,
122
+ c10::optional<at::Tensor> gamma1,
123
+ double dropout_p,
124
+ bool has_x1,
125
+ bool has_residual,
126
+ bool is_rms_norm) {
127
+
128
+ auto dz1_c = as_const_opt(dz1);
129
+ auto dx_c = as_const_opt(dx);
130
+ auto dmask0_c = as_const_opt(dmask0);
131
+ auto dmask1_c = as_const_opt(dmask1);
132
+ auto gamma1_c = as_const_opt(gamma1);
133
+
134
+ return dropout_add_ln_parallel_residual_bwd(
135
+ dz0, dz1_c, dx_c, x, dmask0_c, dmask1_c, mu, rsigma, gamma0, gamma1_c,
136
+ static_cast<float>(dropout_p), has_x1, has_residual, is_rms_norm);
137
+ }
138
+
139
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
140
+ // Return lists to match std::vector<at::Tensor> from implementations
141
+ ops.def("dropout_add_ln_fwd(Tensor input, Tensor gamma, Tensor? beta, Tensor? rowscale, Tensor? colscale, Tensor? x0_subset, Tensor? z_subset, float dropout_p, float epsilon, float rowscale_const, int z_numrows, Generator? gen, bool residual_in_fp32, bool is_rms_norm) -> Tensor[]");
142
+ ops.impl("dropout_add_ln_fwd", torch::kCUDA, &dropout_add_ln_fwd_wrap);
143
+
144
+ ops.def("dropout_add_ln_bwd(Tensor dz, Tensor? dx, Tensor x, Tensor? x0, Tensor? dmask, Tensor mu, Tensor rsigma, Tensor gamma, Tensor? rowscale, Tensor? colscale, Tensor? x0_subset, Tensor? z_subset, float dropout_p, float rowscale_const, int x0_numrows, bool has_residual, bool is_rms_norm) -> Tensor[]");
145
+ ops.impl("dropout_add_ln_bwd", torch::kCUDA, &dropout_add_ln_bwd_wrap);
146
+
147
+ ops.def("dropout_add_ln_parallel_residual_fwd(Tensor input, Tensor? x1, Tensor? residual, Tensor gamma0, Tensor? beta0, Tensor? gamma1, Tensor? beta1, float dropout_p, float epsilon, Generator? gen, bool residual_in_fp32, bool is_rms_norm) -> Tensor[]");
148
+ ops.impl("dropout_add_ln_parallel_residual_fwd", torch::kCUDA, &dropout_add_ln_parallel_residual_fwd_wrap);
149
+
150
+ ops.def("dropout_add_ln_parallel_residual_bwd(Tensor dz0, Tensor? dz1, Tensor? dx, Tensor x, Tensor? dmask0, Tensor? dmask1, Tensor mu, Tensor rsigma, Tensor gamma0, Tensor? gamma1, float dropout_p, bool has_x1, bool has_residual, bool is_rms_norm) -> Tensor[]");
151
+ ops.impl("dropout_add_ln_parallel_residual_bwd", torch::kCUDA, &dropout_add_ln_parallel_residual_bwd_wrap);
152
  }
153
 
154
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h CHANGED
@@ -2,7 +2,69 @@
2
 
3
  #include <torch/torch.h>
4
 
5
- torch::Tensor dropout_add_ln_fwd(torch::Tensor &input, torch::Tensor &gamma, torch::Tensor &beta, torch::Tensor &rowscale, torch::Tensor &colscale, torch::Tensor &x0_subset, torch::Tensor &z_subset, float dropout_p, float epsilon, float rowscale_const, int64_t z_numrows, torch::Generator &gen, bool residual_in_fp32, bool is_rms_norm);
6
- torch::Tensor dropout_add_ln_bwd(torch::Tensor &dz, torch::Tensor &dx, torch::Tensor &x, torch::Tensor &mu, torch::Tensor &rsigma, torch::Tensor &gamma, torch::Tensor &rowscale, torch::Tensor &colscale, torch::Tensor &x0_subset, torch::Tensor &z_subset, float dropout_p, float rowscale_const, int64_t x0_numrows, bool has_residual, bool is_rms_norm);
7
- torch::Tensor dropout_add_ln_parallel_residual_fwd(torch::Tensor &input, torch::Tensor &gamma0, torch::Tensor &beta0, torch::Tensor &gamma1, torch::Tensor &beta1, float dropout_p, float epsilon, torch::Generator &gen, bool residual_in_fp32, bool is_rms_norm);
8
- torch::Tensor dropout_add_ln_parallel_residual_bwd(torch::Tensor &dz0, torch::Tensor &dz1, torch::Tensor &dx, torch::Tensor &x, torch::Tensor &mu, torch::Tensor &rsigma, torch::Tensor &gamma0, torch::Tensor &gamma1, float dropout_p, bool has_x1, bool has_residual, bool is_rms_norm);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  #include <torch/torch.h>
4
 
5
+ // Declarations for implementations defined in layer_norm/ln_api.cpp
6
+ std::vector<at::Tensor> dropout_add_ln_fwd(
7
+ const at::Tensor &x0,
8
+ c10::optional<const at::Tensor> &residual,
9
+ const at::Tensor &gamma,
10
+ c10::optional<const at::Tensor> &beta,
11
+ c10::optional<const at::Tensor> &rowscale,
12
+ c10::optional<const at::Tensor> &colscale,
13
+ c10::optional<const at::Tensor> &x0_subset,
14
+ c10::optional<const at::Tensor> &z_subset,
15
+ const float dropout_p,
16
+ const float epsilon,
17
+ const float rowscale_const,
18
+ const int64_t z_numrows,
19
+ c10::optional<at::Generator> gen,
20
+ bool residual_in_fp32,
21
+ bool is_rms_norm);
22
+
23
+ std::vector<at::Tensor> dropout_add_ln_bwd(
24
+ const at::Tensor &dz,
25
+ c10::optional<const at::Tensor> &dx,
26
+ const at::Tensor &x,
27
+ c10::optional<const at::Tensor> &x0,
28
+ c10::optional<const at::Tensor> &dmask,
29
+ const at::Tensor &mu,
30
+ const at::Tensor &rsigma,
31
+ const at::Tensor &gamma,
32
+ c10::optional<const at::Tensor> &rowscale,
33
+ c10::optional<const at::Tensor> &colscale,
34
+ c10::optional<const at::Tensor> &x0_subset,
35
+ c10::optional<const at::Tensor> &z_subset,
36
+ const float dropout_p,
37
+ const float rowscale_const,
38
+ const int64_t x0_numrows,
39
+ const bool has_residual,
40
+ bool is_rms_norm);
41
+
42
+ std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(
43
+ const at::Tensor &x0,
44
+ c10::optional<const at::Tensor> &x1,
45
+ c10::optional<const at::Tensor> &residual,
46
+ const at::Tensor &gamma0,
47
+ c10::optional<const at::Tensor> &beta0,
48
+ c10::optional<const at::Tensor> &gamma1,
49
+ c10::optional<const at::Tensor> &beta1,
50
+ const float dropout_p,
51
+ const float epsilon,
52
+ c10::optional<at::Generator> gen,
53
+ bool residual_in_fp32,
54
+ bool is_rms_norm);
55
+
56
+ std::vector<at::Tensor> dropout_add_ln_parallel_residual_bwd(
57
+ const at::Tensor &dz0,
58
+ c10::optional<const at::Tensor> &dz1,
59
+ c10::optional<const at::Tensor> &dx,
60
+ const at::Tensor &x,
61
+ c10::optional<const at::Tensor> &dmask0,
62
+ c10::optional<const at::Tensor> &dmask1,
63
+ const at::Tensor &mu,
64
+ const at::Tensor &rsigma,
65
+ const at::Tensor &gamma0,
66
+ c10::optional<const at::Tensor> &gamma1,
67
+ const float dropout_p,
68
+ const bool has_x1,
69
+ const bool has_residual,
70
+ bool is_rms_norm);