Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +8 -0
- .venv/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2 +3 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/codecache.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/lowering.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/utils.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/aoti_hipify_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/codegen_device_driver.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_gemm_template.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_micro_gemm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_template.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_template_kernel.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cuda.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/debug_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_combo_kernel.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/implementation.cpp +87 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/common.py +2167 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp.py +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_gemm_template.py +1043 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py +850 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_template.py +128 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_template_kernel.py +384 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_utils.py +916 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cuda.py +432 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_epilogue_gen.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +114 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_env.py +46 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py +397 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_template.py +258 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py +361 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-311.pyc +0 -0
.gitattributes
CHANGED
|
@@ -129,3 +129,11 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 129 |
.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 130 |
.venv/lib/python3.11/site-packages/torch/nn/__pycache__/functional.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 131 |
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/scheduler.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 130 |
.venv/lib/python3.11/site-packages/torch/nn/__pycache__/functional.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 131 |
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/scheduler.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 132 |
+
.venv/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2 filter=lfs diff=lfs merge=lfs -text
|
| 133 |
+
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 134 |
+
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/codecache.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 135 |
+
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 136 |
+
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/symbolic_shapes.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 137 |
+
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 138 |
+
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/lowering.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 139 |
+
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:78df2f31f6db8142ec546a1e5a31cb066f7892d12d2f665b448f8069a08ef807
|
| 3 |
+
size 251616632
|
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/codecache.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c1ba20726a513f57e01fc1fbf9c3744defdeda5d64e6e3a00d7d3911f4f598d2
|
| 3 |
+
size 164293
|
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:495965e46b513011b3387880294e810069bb3299277002dd35d6e15e1a3d6508
|
| 3 |
+
size 118734
|
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/lowering.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c863587cdf0f8eef657d2fa0f0ebf9ddddc19a24d5670869719203bc7d877e48
|
| 3 |
+
size 337621
|
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cd3ce0e8ac0de613f90615aaf063bff822e142ca75c5993718647f82d9d0add5
|
| 3 |
+
size 109858
|
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/utils.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a247c49f4e32c680bff1ed7aa611b6eae0c91d995c632fbc3fa35605649b638b
|
| 3 |
+
size 109445
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/aoti_hipify_utils.cpython-311.pyc
ADDED
|
Binary file (1.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/codegen_device_driver.cpython-311.pyc
ADDED
|
Binary file (3.87 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_gemm_template.cpython-311.pyc
ADDED
|
Binary file (49.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_micro_gemm.cpython-311.pyc
ADDED
|
Binary file (32.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_template.cpython-311.pyc
ADDED
|
Binary file (7.74 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_template_kernel.cpython-311.pyc
ADDED
|
Binary file (26.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_utils.cpython-311.pyc
ADDED
|
Binary file (57.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cuda.cpython-311.pyc
ADDED
|
Binary file (22.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-311.pyc
ADDED
|
Binary file (6.63 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/debug_utils.cpython-311.pyc
ADDED
|
Binary file (7.01 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-311.pyc
ADDED
|
Binary file (44.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-311.pyc
ADDED
|
Binary file (21 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_combo_kernel.cpython-311.pyc
ADDED
|
Binary file (65.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-311.pyc
ADDED
|
Binary file (9.27 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-311.pyc
ADDED
|
Binary file (8.65 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/implementation.cpp
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// NOTE: Like interface.cpp, this file will be copied into AOTInductor
|
| 2 |
+
// generated output. This file is intended to keep implementation
|
| 3 |
+
// details separate from the implementation of the AOTI public
|
| 4 |
+
// interface. Note also that #includes should go into interface.cpp
|
| 5 |
+
// for simplicity of maintenance.
|
| 6 |
+
|
| 7 |
+
namespace torch {
|
| 8 |
+
namespace aot_inductor {
|
| 9 |
+
template <typename T>
|
| 10 |
+
void convert_output_to_handle(
|
| 11 |
+
const ArrayRefTensor<T>& output,
|
| 12 |
+
AtenTensorHandle& handle) {
|
| 13 |
+
handle = output.expensiveCopyToTensor();
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
template <typename... Ts, std::size_t... Is>
|
| 17 |
+
void convert_outputs_to_handles_helper(
|
| 18 |
+
const std::tuple<ArrayRefTensor<Ts>...>& outputs,
|
| 19 |
+
AtenTensorHandle* output_handles,
|
| 20 |
+
std::index_sequence<Is...>) {
|
| 21 |
+
(convert_output_to_handle(std::get<Is>(outputs), output_handles[Is]), ...);
|
| 22 |
+
}
|
| 23 |
+
template <typename... Ts>
|
| 24 |
+
void convert_outputs_to_handles(
|
| 25 |
+
const std::tuple<ArrayRefTensor<Ts>...>& outputs,
|
| 26 |
+
AtenTensorHandle* output_handles) {
|
| 27 |
+
convert_outputs_to_handles_helper(
|
| 28 |
+
outputs, output_handles, std::make_index_sequence<sizeof...(Ts)>());
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
template <typename T>
|
| 32 |
+
void convert_handle_to_arrayref_tensor(
|
| 33 |
+
AtenTensorHandle handle,
|
| 34 |
+
ArrayRefTensor<T>& input) {
|
| 35 |
+
void* data_ptr;
|
| 36 |
+
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle, &data_ptr));
|
| 37 |
+
int64_t dim;
|
| 38 |
+
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(handle, &dim));
|
| 39 |
+
int64_t numel;
|
| 40 |
+
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(handle, &numel));
|
| 41 |
+
int64_t* sizes;
|
| 42 |
+
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle, &sizes));
|
| 43 |
+
int64_t* strides;
|
| 44 |
+
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle, &strides));
|
| 45 |
+
int32_t dtype;
|
| 46 |
+
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(handle, &dtype));
|
| 47 |
+
int32_t device_type;
|
| 48 |
+
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(handle, &device_type));
|
| 49 |
+
int32_t device_index;
|
| 50 |
+
AOTI_TORCH_ERROR_CODE_CHECK(
|
| 51 |
+
aoti_torch_get_device_index(handle, &device_index));
|
| 52 |
+
|
| 53 |
+
input = ArrayRefTensor<T>(
|
| 54 |
+
MiniArrayRef<T>(reinterpret_cast<T*>(data_ptr), numel),
|
| 55 |
+
MiniArrayRef<const int64_t>(sizes, dim),
|
| 56 |
+
MiniArrayRef<const int64_t>(strides, dim),
|
| 57 |
+
device_type,
|
| 58 |
+
device_index);
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
template <typename... Ts, std::size_t... Is>
|
| 62 |
+
void convert_handles_to_inputs_helper(
|
| 63 |
+
AtenTensorHandle* input_handles,
|
| 64 |
+
std::tuple<ArrayRefTensor<Ts>...>& inputs,
|
| 65 |
+
std::index_sequence<Is...>) {
|
| 66 |
+
(convert_handle_to_arrayref_tensor(input_handles[Is], std::get<Is>(inputs)),
|
| 67 |
+
...);
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
template <typename... Ts>
|
| 71 |
+
void convert_handles_to_inputs(
|
| 72 |
+
AtenTensorHandle* input_handles,
|
| 73 |
+
std::tuple<ArrayRefTensor<Ts>...>& inputs) {
|
| 74 |
+
convert_handles_to_inputs_helper(
|
| 75 |
+
input_handles, inputs, std::make_index_sequence<sizeof...(Ts)>());
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
template <typename T>
|
| 79 |
+
void assert_numel(const ArrayRefTensor<T>& tensor, uint64_t numel) {
|
| 80 |
+
if (tensor.numel() != numel) {
|
| 81 |
+
std::stringstream err;
|
| 82 |
+
err << "incorrect numel for input tensor. expected " << numel << ", got " << tensor.numel();
|
| 83 |
+
throw std::runtime_error(err.str());
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
} // namespace aot_inductor
|
| 87 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/common.py
ADDED
|
@@ -0,0 +1,2167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import contextlib
|
| 3 |
+
import dataclasses
|
| 4 |
+
import functools
|
| 5 |
+
import itertools
|
| 6 |
+
import logging
|
| 7 |
+
import math
|
| 8 |
+
import operator
|
| 9 |
+
import re
|
| 10 |
+
from enum import auto, Enum
|
| 11 |
+
from itertools import chain
|
| 12 |
+
from typing import (
|
| 13 |
+
Any,
|
| 14 |
+
Callable,
|
| 15 |
+
ClassVar,
|
| 16 |
+
Dict,
|
| 17 |
+
List,
|
| 18 |
+
NamedTuple,
|
| 19 |
+
Optional,
|
| 20 |
+
Tuple,
|
| 21 |
+
Union,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
import sympy
|
| 25 |
+
from sympy.printing.printer import Printer
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.fx
|
| 29 |
+
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
| 30 |
+
from torch.utils import _pytree as pytree
|
| 31 |
+
from torch.utils._ordered_set import OrderedSet
|
| 32 |
+
from torch.utils._sympy.numbers import int_oo
|
| 33 |
+
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
| 34 |
+
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
|
| 35 |
+
|
| 36 |
+
from .. import config, metrics
|
| 37 |
+
from ..utils import (
|
| 38 |
+
DeferredLineBase,
|
| 39 |
+
generate_assert,
|
| 40 |
+
IndentedBuffer,
|
| 41 |
+
sympy_dot,
|
| 42 |
+
sympy_subs,
|
| 43 |
+
unique,
|
| 44 |
+
)
|
| 45 |
+
from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def data_type_logger(msg):
|
| 52 |
+
if schedule_log.isEnabledFor(logging.DEBUG):
|
| 53 |
+
schedule_log.debug("Data type propagation: %s", msg)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclasses.dataclass
|
| 57 |
+
class WorkspaceArg:
|
| 58 |
+
"""A temporary buffer used for a single kernel, then discarded.
|
| 59 |
+
|
| 60 |
+
Not registered as a traditional buffer since there are no users,
|
| 61 |
+
so it would be dead code eliminated.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
nbytes: sympy.Expr
|
| 65 |
+
zero_fill: bool
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@dataclasses.dataclass
|
| 69 |
+
class TensorArg:
|
| 70 |
+
name: str
|
| 71 |
+
buffer: str
|
| 72 |
+
dtype: torch.dtype
|
| 73 |
+
offset: sympy.Expr = sympy.Integer(0) # c++ only
|
| 74 |
+
alias_of: Optional[str] = None # halide only
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclasses.dataclass
|
| 78 |
+
class SizeArg:
|
| 79 |
+
name: str
|
| 80 |
+
expr: sympy.Expr
|
| 81 |
+
|
| 82 |
+
@property
|
| 83 |
+
def alias_of(self):
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@dataclasses.dataclass
|
| 88 |
+
class DeviceCodegen:
|
| 89 |
+
scheduling: Any
|
| 90 |
+
wrapper_codegen: type
|
| 91 |
+
cpp_wrapper_codegen: type = type(None)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg]
|
| 95 |
+
|
| 96 |
+
device_codegens: Dict[str, DeviceCodegen] = {}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class DeviceOpOverrides:
|
| 100 |
+
def import_get_raw_stream_as(self, name):
|
| 101 |
+
raise NotImplementedError
|
| 102 |
+
|
| 103 |
+
def set_device(self, device_idx):
|
| 104 |
+
raise NotImplementedError
|
| 105 |
+
|
| 106 |
+
def synchronize(self):
|
| 107 |
+
raise NotImplementedError
|
| 108 |
+
|
| 109 |
+
def device_guard(self, device_idx):
|
| 110 |
+
raise NotImplementedError
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
|
| 117 |
+
# For any new backend looking to integrate with Inductor, customization of these two main
|
| 118 |
+
# parts are necessary to generate its specific code.
|
| 119 |
+
#
|
| 120 |
+
# Kernel code generation is determined by different Scheduling. Consequently, a new
|
| 121 |
+
# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
|
| 122 |
+
# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
|
| 123 |
+
#
|
| 124 |
+
# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
|
| 125 |
+
# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
|
| 126 |
+
# and override specific member functions to create backend-specific Python wrapper code.
|
| 127 |
+
#
|
| 128 |
+
# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
|
| 129 |
+
# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
|
| 130 |
+
# provide flexibility to the backend. A backend can choose to implement these classes from scratch,
|
| 131 |
+
# or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
|
| 132 |
+
# register_backend_for_device, to equip a new backend at runtime.
|
| 133 |
+
#
|
| 134 |
+
# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
|
| 135 |
+
# This backend can be used as a reference:
|
| 136 |
+
# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
|
| 137 |
+
def register_backend_for_device(
|
| 138 |
+
device: str,
|
| 139 |
+
device_scheduling: Any,
|
| 140 |
+
device_wrapper_codegen: type,
|
| 141 |
+
device_cpp_wrapper_codegen: type = type(None),
|
| 142 |
+
):
|
| 143 |
+
device_codegens[device] = DeviceCodegen(
|
| 144 |
+
device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class BackendFeature(Enum):
|
| 149 |
+
FOREACH = auto()
|
| 150 |
+
BUCKETIZE = auto()
|
| 151 |
+
INPLACE_BUFFERS = auto()
|
| 152 |
+
MASKED_SCATTER_WITH_INDEX = auto()
|
| 153 |
+
SCAN = auto()
|
| 154 |
+
SORT = auto()
|
| 155 |
+
TUPLE_REDUCTION = auto()
|
| 156 |
+
PREFER_STORE_LOOP_ORDER = auto()
|
| 157 |
+
TRITON_TEMPLATES = auto()
|
| 158 |
+
REDUCE_TO_SINGLE_ELEMENT = auto()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def get_backend_features(device: Union[torch.device, str]):
|
| 162 |
+
init_backend_registration()
|
| 163 |
+
if isinstance(device, torch.device):
|
| 164 |
+
device_type = device.type
|
| 165 |
+
else:
|
| 166 |
+
assert isinstance(device, str)
|
| 167 |
+
device_type = device
|
| 168 |
+
device = torch.device(device_type)
|
| 169 |
+
scheduling = get_scheduling_for_device(device_type)
|
| 170 |
+
return scheduling(None).get_backend_features(device)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def has_backend_feature(device, feature):
|
| 174 |
+
"""See also V.graph.has_feature"""
|
| 175 |
+
assert isinstance(feature, BackendFeature)
|
| 176 |
+
return feature in get_backend_features(device)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def get_scheduling_for_device(device: str):
|
| 180 |
+
return device_codegens[device].scheduling if device in device_codegens else None
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False):
|
| 184 |
+
if device in device_codegens:
|
| 185 |
+
wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
|
| 186 |
+
return (
|
| 187 |
+
wrapper_codegen_obj.cpp_wrapper_codegen
|
| 188 |
+
if cpp_wrapper
|
| 189 |
+
else wrapper_codegen_obj.wrapper_codegen
|
| 190 |
+
)
|
| 191 |
+
else:
|
| 192 |
+
return None
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@functools.lru_cache(None)
|
| 196 |
+
def init_backend_registration():
|
| 197 |
+
from .cpp import CppScheduling
|
| 198 |
+
from .cpp_wrapper_cpu import CppWrapperCpu
|
| 199 |
+
from .cpp_wrapper_cuda import CppWrapperCuda
|
| 200 |
+
from .cuda_combined_scheduling import CUDACombinedScheduling
|
| 201 |
+
from .halide import HalideScheduling
|
| 202 |
+
from .triton import TritonScheduling
|
| 203 |
+
from .wrapper import WrapperCodeGen
|
| 204 |
+
|
| 205 |
+
if get_scheduling_for_device("cpu") is None:
|
| 206 |
+
cpu_backends = {"cpp": CppScheduling, "halide": HalideScheduling}
|
| 207 |
+
register_backend_for_device(
|
| 208 |
+
"cpu",
|
| 209 |
+
lambda *args, **kwargs: cpu_backends[config.cpu_backend](*args, **kwargs),
|
| 210 |
+
WrapperCodeGen,
|
| 211 |
+
CppWrapperCpu,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
if get_scheduling_for_device("cuda") is None:
|
| 215 |
+
# CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
|
| 216 |
+
cuda_backends = {"triton": CUDACombinedScheduling, "halide": HalideScheduling}
|
| 217 |
+
register_backend_for_device(
|
| 218 |
+
"cuda",
|
| 219 |
+
lambda *args, **kwargs: cuda_backends[config.cuda_backend](*args, **kwargs),
|
| 220 |
+
WrapperCodeGen,
|
| 221 |
+
CppWrapperCuda,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
if get_scheduling_for_device("xpu") is None:
|
| 225 |
+
register_backend_for_device("xpu", TritonScheduling, WrapperCodeGen)
|
| 226 |
+
|
| 227 |
+
private_backend = torch._C._get_privateuse1_backend_name()
|
| 228 |
+
if (
|
| 229 |
+
private_backend != "privateuseone"
|
| 230 |
+
and get_scheduling_for_device(private_backend) is None
|
| 231 |
+
):
|
| 232 |
+
from torch.utils.backend_registration import _get_custom_mod_func
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
device_scheduling = _get_custom_mod_func("Scheduling")
|
| 236 |
+
wrapper_codegen = _get_custom_mod_func("WrapperCodeGen")
|
| 237 |
+
cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodeGen")
|
| 238 |
+
if device_scheduling and wrapper_codegen and cpp_wrapper_codegen:
|
| 239 |
+
register_backend_for_device(
|
| 240 |
+
private_backend,
|
| 241 |
+
device_scheduling,
|
| 242 |
+
wrapper_codegen,
|
| 243 |
+
cpp_wrapper_codegen,
|
| 244 |
+
)
|
| 245 |
+
except RuntimeError:
|
| 246 |
+
pass
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
|
| 250 |
+
from ..ir import FlexibleLayout
|
| 251 |
+
|
| 252 |
+
# added contiguous index prevents reordering
|
| 253 |
+
return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides):
|
| 257 |
+
device_op_overrides_dict[device] = device_op_overrides
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def get_device_op_overrides(device: str):
|
| 261 |
+
assert isinstance(device, str)
|
| 262 |
+
|
| 263 |
+
if not device_op_overrides_dict.keys():
|
| 264 |
+
from .cuda import device_op_overrides # noqa: F401
|
| 265 |
+
from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401
|
| 266 |
+
|
| 267 |
+
if device in device_op_overrides_dict.keys():
|
| 268 |
+
return device_op_overrides_dict[device]
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
@functools.lru_cache(None)
|
| 272 |
+
def boolean_ops():
|
| 273 |
+
return (
|
| 274 |
+
"isinf",
|
| 275 |
+
"isnan",
|
| 276 |
+
"logical_not",
|
| 277 |
+
"signbit",
|
| 278 |
+
"le",
|
| 279 |
+
"lt",
|
| 280 |
+
"ge",
|
| 281 |
+
"gt",
|
| 282 |
+
"eq",
|
| 283 |
+
"ne",
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
DTYPE_TO_COMPUTATION_DTYPE = {
|
| 288 |
+
torch.bfloat16: torch.float,
|
| 289 |
+
torch.float16: torch.float,
|
| 290 |
+
**{
|
| 291 |
+
dtype: dtype
|
| 292 |
+
for dtype in [
|
| 293 |
+
torch.bool,
|
| 294 |
+
torch.float32,
|
| 295 |
+
torch.float64,
|
| 296 |
+
torch.int8,
|
| 297 |
+
torch.int16,
|
| 298 |
+
torch.int32,
|
| 299 |
+
torch.int64,
|
| 300 |
+
torch.uint8,
|
| 301 |
+
torch.uint16,
|
| 302 |
+
torch.uint32,
|
| 303 |
+
torch.uint64,
|
| 304 |
+
]
|
| 305 |
+
},
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def deduce_output_dtype_by_name(
|
| 310 |
+
op_name: str,
|
| 311 |
+
*args,
|
| 312 |
+
**kwargs,
|
| 313 |
+
) -> Optional[torch.dtype]:
|
| 314 |
+
"""
|
| 315 |
+
Given op name and a list of input dtypes, deduce the output dtype
|
| 316 |
+
"""
|
| 317 |
+
if op_name in boolean_ops():
|
| 318 |
+
return torch.bool
|
| 319 |
+
elif op_name in (
|
| 320 |
+
"to_dtype",
|
| 321 |
+
"index_expr",
|
| 322 |
+
):
|
| 323 |
+
return kwargs["dtype"] if "dtype" in kwargs else args[-1]
|
| 324 |
+
elif op_name in (
|
| 325 |
+
"rand",
|
| 326 |
+
"randn",
|
| 327 |
+
):
|
| 328 |
+
return torch.float
|
| 329 |
+
elif op_name in (
|
| 330 |
+
"get_index",
|
| 331 |
+
"randint64",
|
| 332 |
+
"load_seed",
|
| 333 |
+
):
|
| 334 |
+
return torch.int64
|
| 335 |
+
elif op_name == "reduction":
|
| 336 |
+
return kwargs["dtype"] if "dtype" in kwargs else args[1]
|
| 337 |
+
elif op_name == "constant":
|
| 338 |
+
dtype = kwargs["dtype"] if "dtype" in kwargs else args[-1]
|
| 339 |
+
return DTYPE_TO_COMPUTATION_DTYPE[dtype] # type: ignore[index]
|
| 340 |
+
elif op_name in (
|
| 341 |
+
"load",
|
| 342 |
+
"store",
|
| 343 |
+
"store_reduction",
|
| 344 |
+
):
|
| 345 |
+
buf_name = args[1]
|
| 346 |
+
return V.graph.get_dtype(buf_name) # type: ignore[arg-type]
|
| 347 |
+
elif op_name == "to_dtype_bitcast":
|
| 348 |
+
return kwargs["dtype"] if "dtype" in kwargs else args[-2]
|
| 349 |
+
return None
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class DataTypePropagation:
|
| 353 |
+
def __init__(self, body) -> None:
|
| 354 |
+
self.body = body
|
| 355 |
+
self.graphs: Dict[Union[Callable[..., Any], str], Any] = {
|
| 356 |
+
"root": body.root_block.graph
|
| 357 |
+
}
|
| 358 |
+
for k, v in body.subblocks.items():
|
| 359 |
+
self.graphs[k] = v.graph
|
| 360 |
+
|
| 361 |
+
def deduce_node_dtype_by_inputs(self, node: torch.fx.Node):
|
| 362 |
+
inputs = node.all_input_nodes
|
| 363 |
+
input_nodes = [
|
| 364 |
+
n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
|
| 365 |
+
]
|
| 366 |
+
if len(input_nodes) == 0:
|
| 367 |
+
return None
|
| 368 |
+
|
| 369 |
+
all_input_nodes_propagated = all(
|
| 370 |
+
OptimizationContext.key in n.meta
|
| 371 |
+
and n.meta[OptimizationContext.key].dtype is not None
|
| 372 |
+
for n in input_nodes
|
| 373 |
+
)
|
| 374 |
+
if not all_input_nodes_propagated:
|
| 375 |
+
return None
|
| 376 |
+
|
| 377 |
+
return functools.reduce(
|
| 378 |
+
torch.promote_types,
|
| 379 |
+
[n.meta[OptimizationContext.key].dtype for n in input_nodes],
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node):
|
| 383 |
+
sub_graph = self.graphs[node.target]
|
| 384 |
+
dtype = self.propagate_graph(sub_graph)
|
| 385 |
+
assert dtype
|
| 386 |
+
return dtype
|
| 387 |
+
|
| 388 |
+
def deduce_node_dtype(self, node: torch.fx.Node):
|
| 389 |
+
if node.op == "placeholder":
|
| 390 |
+
return None
|
| 391 |
+
|
| 392 |
+
if node.target == "output" and len(node.args) != 1:
|
| 393 |
+
# we can infer output node if it only have 1 arg
|
| 394 |
+
return None
|
| 395 |
+
|
| 396 |
+
if node.target == operator.getitem:
|
| 397 |
+
return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
|
| 398 |
+
|
| 399 |
+
assert isinstance(node.target, str)
|
| 400 |
+
|
| 401 |
+
if node.target.startswith("masked_subblock"):
|
| 402 |
+
return self.deduce_node_dtype_by_subgraph(node)
|
| 403 |
+
|
| 404 |
+
if (
|
| 405 |
+
output_dtype := deduce_output_dtype_by_name(
|
| 406 |
+
node.target,
|
| 407 |
+
*node.args,
|
| 408 |
+
**node.kwargs,
|
| 409 |
+
)
|
| 410 |
+
) is not None:
|
| 411 |
+
return output_dtype
|
| 412 |
+
|
| 413 |
+
return self.deduce_node_dtype_by_inputs(node)
|
| 414 |
+
|
| 415 |
+
def propagate_graph(self, graph: torch.fx.Graph):
|
| 416 |
+
assert graph.nodes
|
| 417 |
+
graph_dtype = None
|
| 418 |
+
# For masked_subblock, we use output's dtype to represent
|
| 419 |
+
# the dtype of this subgraph. For other cases, graph_dtype
|
| 420 |
+
# might be None
|
| 421 |
+
for node in graph.nodes:
|
| 422 |
+
if OptimizationContext.key in node.meta:
|
| 423 |
+
opt_ctx = node.meta[OptimizationContext.key]
|
| 424 |
+
else:
|
| 425 |
+
opt_ctx = OptimizationContext()
|
| 426 |
+
|
| 427 |
+
opt_ctx.dtype = self.deduce_node_dtype(node)
|
| 428 |
+
node.meta[OptimizationContext.key] = opt_ctx
|
| 429 |
+
if node.target == "output":
|
| 430 |
+
graph_dtype = opt_ctx.dtype
|
| 431 |
+
return graph_dtype
|
| 432 |
+
|
| 433 |
+
def propagate(self):
|
| 434 |
+
self.propagate_graph(self.graphs["root"])
|
| 435 |
+
|
| 436 |
+
@classmethod
|
| 437 |
+
def propagate_loopbody(cls, body):
|
| 438 |
+
return cls(body).propagate()
|
| 439 |
+
|
| 440 |
+
@classmethod
|
| 441 |
+
def propagate_scheduler_node(cls, node):
|
| 442 |
+
from ..loop_body import LoopBody
|
| 443 |
+
from ..scheduler import SchedulerNode
|
| 444 |
+
|
| 445 |
+
assert isinstance(node, SchedulerNode)
|
| 446 |
+
assert isinstance(node._body, LoopBody)
|
| 447 |
+
DataTypePropagation.propagate_loopbody(node._body)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
# This printer contains rules that are supposed to be generic for both C/C++ and
|
| 451 |
+
# Python
|
| 452 |
+
class ExprPrinter(Printer):
|
| 453 |
+
@staticmethod
|
| 454 |
+
def paren(string):
|
| 455 |
+
def all_in_parens(string):
|
| 456 |
+
if string[0] != "(" or len(string) < 2:
|
| 457 |
+
return False
|
| 458 |
+
count = 1
|
| 459 |
+
for i, char in enumerate(string[1:]):
|
| 460 |
+
if char == "(":
|
| 461 |
+
count += 1
|
| 462 |
+
elif char == ")":
|
| 463 |
+
count -= 1
|
| 464 |
+
if count == 0 and i != len(string) - 2:
|
| 465 |
+
return False
|
| 466 |
+
assert count == 0
|
| 467 |
+
return True
|
| 468 |
+
|
| 469 |
+
if (
|
| 470 |
+
isinstance(string, CSEVariable)
|
| 471 |
+
or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE)
|
| 472 |
+
or re.match(r"^\([^)]*\)$", string, re.IGNORECASE)
|
| 473 |
+
or string == ""
|
| 474 |
+
):
|
| 475 |
+
return string
|
| 476 |
+
# don't put extra parens for strings that are already wrapped in parens
|
| 477 |
+
if all_in_parens(string):
|
| 478 |
+
return string
|
| 479 |
+
return f"({string})"
|
| 480 |
+
|
| 481 |
+
def _print_Relational(self, expr):
|
| 482 |
+
return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
|
| 483 |
+
|
| 484 |
+
def _print_Mul(self, expr):
|
| 485 |
+
return "*".join(map(self.paren, map(self._print, expr.args)))
|
| 486 |
+
|
| 487 |
+
def _print_Add(self, expr):
|
| 488 |
+
return " + ".join(map(self.paren, map(self._print, expr.args)))
|
| 489 |
+
|
| 490 |
+
# NB: this is OK to put here, because Mod is only defined for positive
|
| 491 |
+
# numbers, and so across C/Python its behavior is consistent
|
| 492 |
+
def _print_Mod(self, expr):
|
| 493 |
+
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
| 494 |
+
|
| 495 |
+
def _print_FloatTrueDiv(self, expr):
|
| 496 |
+
lhs, rhs = expr.args
|
| 497 |
+
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
|
| 498 |
+
|
| 499 |
+
def _print_CleanDiv(self, expr):
|
| 500 |
+
return self._print_FloorDiv(expr)
|
| 501 |
+
|
| 502 |
+
def _print_Identity(self, expr):
|
| 503 |
+
return self._print(expr.args[0])
|
| 504 |
+
|
| 505 |
+
def _print_GreaterThan(self, expr):
|
| 506 |
+
# GreaterThan: >=
|
| 507 |
+
# StrictlyGreaterThan: >
|
| 508 |
+
# Go figure...
|
| 509 |
+
return " >= ".join(map(self.paren, map(self._print, expr.args)))
|
| 510 |
+
|
| 511 |
+
# NB: The C implementation is injected into codegen at
|
| 512 |
+
# torch/_inductor/codegen/wrapper.py
|
| 513 |
+
def _print_align(self, expr):
|
| 514 |
+
assert len(expr.args) == 1
|
| 515 |
+
return f"align({self._print(expr.args[0])})"
|
| 516 |
+
|
| 517 |
+
# This must be implemented because sympy will collect x * x into Pow(x, 2), without
|
| 518 |
+
# any explicit intervention. We print it just like x * x, notably, we
|
| 519 |
+
# never generate sympy.Pow with floats.
|
| 520 |
+
#
|
| 521 |
+
# NB: this pow by natural, you should never have used builtin sympy.pow
|
| 522 |
+
# for FloatPow, and a symbolic exponent should be PowByNatural. These
|
| 523 |
+
# means exp is guaranteed to be integer.
|
| 524 |
+
def _print_Pow(self, expr):
|
| 525 |
+
base, exp = expr.args
|
| 526 |
+
base = self._print(base)
|
| 527 |
+
assert exp == int(exp), exp
|
| 528 |
+
exp = int(exp)
|
| 529 |
+
assert exp >= 0
|
| 530 |
+
if exp > 0:
|
| 531 |
+
return "*".join([self.paren(base)] * exp)
|
| 532 |
+
else: # exp == 0
|
| 533 |
+
return "1"
|
| 534 |
+
|
| 535 |
+
# Explicit NotImplemented functions are to prevent default sympy printing
|
| 536 |
+
# behavior, which will just barf out ToFloat(...) to your IR. The error
|
| 537 |
+
# message is better here because it tells you which printer class it needs
|
| 538 |
+
# to go in.
|
| 539 |
+
|
| 540 |
+
def _print_ToFloat(self, expr):
|
| 541 |
+
raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
|
| 542 |
+
|
| 543 |
+
def _print_Infinity(self, expr):
|
| 544 |
+
raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
|
| 545 |
+
|
| 546 |
+
def _print_NegativeInfinity(self, expr):
|
| 547 |
+
raise NotImplementedError(
|
| 548 |
+
f"_print_NegativeInfinity not implemented for {type(self)}"
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
def _print_FloorDiv(self, expr):
|
| 552 |
+
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
|
| 553 |
+
|
| 554 |
+
def _print_PythonMod(self, expr):
|
| 555 |
+
raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
|
| 556 |
+
|
| 557 |
+
def _print_IntTrueDiv(self, expr):
|
| 558 |
+
raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
|
| 559 |
+
|
| 560 |
+
def _print_PowByNatural(self, expr):
|
| 561 |
+
raise NotImplementedError(
|
| 562 |
+
f"_print_PowByNatural not implemented for {type(self)}"
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
def _print_FloatPow(self, expr):
|
| 566 |
+
raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
|
| 567 |
+
|
| 568 |
+
def _print_TruncToInt(self, expr):
|
| 569 |
+
raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
|
| 570 |
+
|
| 571 |
+
def _print_RoundToInt(self, expr):
|
| 572 |
+
raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
|
| 573 |
+
|
| 574 |
+
def _print_RoundDecimal(self, expr):
|
| 575 |
+
raise NotImplementedError(
|
| 576 |
+
f"_print_RoundDecimal not implemented for {type(self)}"
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
# NB: Some float operations are INTENTIONALLY not implemented for
|
| 580 |
+
# printers. You can implement them as a quick unblock, but it is better
|
| 581 |
+
# to ask yourself why we haven't done this computation in the Tensor
|
| 582 |
+
# universe instead
|
| 583 |
+
|
| 584 |
+
def _print_TruncToFloat(self, expr):
|
| 585 |
+
raise NotImplementedError(
|
| 586 |
+
f"_print_TruncToFloat not implemented for {type(self)}"
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
def doprint(self, expr, *, simplify: bool = True):
|
| 590 |
+
# TODO: why are people passing strings to the printer here :think:
|
| 591 |
+
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
|
| 592 |
+
expr = V.graph.sizevars.simplify(expr)
|
| 593 |
+
return super().doprint(expr)
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
class PythonPrinter(ExprPrinter):
|
| 597 |
+
def _print_ToFloat(self, expr):
|
| 598 |
+
assert len(expr.args) == 1
|
| 599 |
+
return f"float({self._print(expr.args[0])})"
|
| 600 |
+
|
| 601 |
+
def _print_ModularIndexing(self, expr):
|
| 602 |
+
x, div, mod = expr.args
|
| 603 |
+
x = self.paren(self.doprint(x))
|
| 604 |
+
div = self.paren(self.doprint(div))
|
| 605 |
+
mod = self.paren(self.doprint(mod))
|
| 606 |
+
if div != "1":
|
| 607 |
+
x = f"({x} // {div})"
|
| 608 |
+
return f"{x} % {mod}"
|
| 609 |
+
|
| 610 |
+
def _print_Infinity(self, expr):
|
| 611 |
+
return "math.inf"
|
| 612 |
+
|
| 613 |
+
def _print_NegativeInfinity(self, expr):
|
| 614 |
+
return "-math.inf"
|
| 615 |
+
|
| 616 |
+
# WARNING: this is dangerous for Triton, which has C-style modulus
|
| 617 |
+
def _print_PythonMod(self, expr):
|
| 618 |
+
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
| 619 |
+
|
| 620 |
+
# WARNING: this is dangerous for Triton, which has C-style modulus
|
| 621 |
+
def _print_FloorDiv(self, expr):
|
| 622 |
+
x, div = expr.args
|
| 623 |
+
x = self.paren(self.doprint(x))
|
| 624 |
+
div = self.paren(self.doprint(div))
|
| 625 |
+
return f"({x} // {div})"
|
| 626 |
+
|
| 627 |
+
# WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
|
| 628 |
+
# does a special algorithm
|
| 629 |
+
def _print_IntTrueDiv(self, expr):
|
| 630 |
+
lhs, rhs = expr.args
|
| 631 |
+
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
|
| 632 |
+
|
| 633 |
+
def _helper_sqrt(self, expr):
|
| 634 |
+
return f"math.sqrt({self._print(expr)})"
|
| 635 |
+
|
| 636 |
+
def _print_OpaqueUnaryFn_sqrt(self, expr):
|
| 637 |
+
return self._helper_sqrt(expr.args[0])
|
| 638 |
+
|
| 639 |
+
def _print_FloatPow(self, expr):
|
| 640 |
+
base, exp = expr.args
|
| 641 |
+
return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
|
| 642 |
+
|
| 643 |
+
# TODO: Not sure this works with Triton, even when base/exp are integral
|
| 644 |
+
def _print_PowByNatural(self, expr):
|
| 645 |
+
base, exp = expr.args
|
| 646 |
+
return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
|
| 647 |
+
|
| 648 |
+
def _print_floor(self, expr):
|
| 649 |
+
assert len(expr.args) == 1
|
| 650 |
+
return f"math.floor({self._print(expr.args[0])})"
|
| 651 |
+
|
| 652 |
+
def _print_FloorToInt(self, expr):
|
| 653 |
+
assert len(expr.args) == 1
|
| 654 |
+
return f"math.floor({self._print(expr.args[0])})"
|
| 655 |
+
|
| 656 |
+
def _print_TruncToInt(self, expr):
|
| 657 |
+
assert len(expr.args) == 1
|
| 658 |
+
# This also could have been int(), they'll do the same thing for float
|
| 659 |
+
return f"math.trunc({self._print(expr.args[0])})"
|
| 660 |
+
|
| 661 |
+
def _print_ceiling(self, expr):
|
| 662 |
+
assert len(expr.args) == 1
|
| 663 |
+
return f"math.ceil({self._print(expr.args[0])})"
|
| 664 |
+
|
| 665 |
+
def _print_CeilToInt(self, expr):
|
| 666 |
+
assert len(expr.args) == 1
|
| 667 |
+
return f"math.ceil({self._print(expr.args[0])})"
|
| 668 |
+
|
| 669 |
+
def _print_Abs(self, expr):
|
| 670 |
+
assert len(expr.args) == 1
|
| 671 |
+
return f"abs({self._print(expr.args[0])})"
|
| 672 |
+
|
| 673 |
+
# NB: It's expected that we've made explicit any promotion in the sympy
|
| 674 |
+
# expression, so it doesn't matter that Python max/min doesn't perform
|
| 675 |
+
# promotion
|
| 676 |
+
def _print_Max(self, expr):
|
| 677 |
+
assert len(expr.args) >= 2
|
| 678 |
+
return f"max({', '.join(map(self._print, expr.args))})"
|
| 679 |
+
|
| 680 |
+
def _print_Min(self, expr):
|
| 681 |
+
assert len(expr.args) >= 2
|
| 682 |
+
return f"min({', '.join(map(self._print, expr.args))})"
|
| 683 |
+
|
| 684 |
+
def _print_OpaqueUnaryFn_cos(self, expr):
|
| 685 |
+
assert len(expr.args) == 1
|
| 686 |
+
return f"math.cos({self._print(expr.args[0])})"
|
| 687 |
+
|
| 688 |
+
def _print_OpaqueUnaryFn_cosh(self, expr):
|
| 689 |
+
assert len(expr.args) == 1
|
| 690 |
+
return f"math.cosh({self._print(expr.args[0])})"
|
| 691 |
+
|
| 692 |
+
def _print_OpaqueUnaryFn_acos(self, expr):
|
| 693 |
+
assert len(expr.args) == 1
|
| 694 |
+
return f"math.acos({self._print(expr.args[0])})"
|
| 695 |
+
|
| 696 |
+
def _print_OpaqueUnaryFn_sin(self, expr):
|
| 697 |
+
assert len(expr.args) == 1
|
| 698 |
+
return f"math.sin({self._print(expr.args[0])})"
|
| 699 |
+
|
| 700 |
+
def _print_OpaqueUnaryFn_sinh(self, expr):
|
| 701 |
+
assert len(expr.args) == 1
|
| 702 |
+
return f"math.sinh({self._print(expr.args[0])})"
|
| 703 |
+
|
| 704 |
+
def _print_OpaqueUnaryFn_asin(self, expr):
|
| 705 |
+
assert len(expr.args) == 1
|
| 706 |
+
return f"math.asin({self._print(expr.args[0])})"
|
| 707 |
+
|
| 708 |
+
def _print_OpaqueUnaryFn_tan(self, expr):
|
| 709 |
+
assert len(expr.args) == 1
|
| 710 |
+
return f"math.tan({self._print(expr.args[0])})"
|
| 711 |
+
|
| 712 |
+
def _print_OpaqueUnaryFn_tanh(self, expr):
|
| 713 |
+
assert len(expr.args) == 1
|
| 714 |
+
return f"math.tanh({self._print(expr.args[0])})"
|
| 715 |
+
|
| 716 |
+
def _print_OpaqueUnaryFn_atan(self, expr):
|
| 717 |
+
assert len(expr.args) == 1
|
| 718 |
+
return f"math.atan({self._print(expr.args[0])})"
|
| 719 |
+
|
| 720 |
+
def _print_RoundToInt(self, expr):
|
| 721 |
+
assert len(expr.args) == 1
|
| 722 |
+
return f"round({self._print(expr.args[0])})"
|
| 723 |
+
|
| 724 |
+
def _print_RoundDecimal(self, expr):
|
| 725 |
+
assert len(expr.args) == 2
|
| 726 |
+
number, ndigits = expr.args
|
| 727 |
+
assert isinstance(ndigits, sympy.Integer)
|
| 728 |
+
return f"round({self._print(number)}, {ndigits})"
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
class OpOverrides:
|
| 732 |
+
def __init__(self, parent):
|
| 733 |
+
super().__init__()
|
| 734 |
+
self._parent = parent
|
| 735 |
+
|
| 736 |
+
def __getattr__(self, item):
|
| 737 |
+
return getattr(self._parent, item)
|
| 738 |
+
|
| 739 |
+
@staticmethod
|
| 740 |
+
def identity(value):
|
| 741 |
+
# used to trigger cse
|
| 742 |
+
return value
|
| 743 |
+
|
| 744 |
+
@staticmethod
|
| 745 |
+
def constant(value, dtype):
|
| 746 |
+
return repr(value)
|
| 747 |
+
|
| 748 |
+
@staticmethod
|
| 749 |
+
def reciprocal(x):
|
| 750 |
+
return ops.truediv(ops.constant(1, torch.int32), x)
|
| 751 |
+
|
| 752 |
+
@staticmethod
|
| 753 |
+
def square(x):
|
| 754 |
+
return ops.mul(x, x)
|
| 755 |
+
|
| 756 |
+
@staticmethod
|
| 757 |
+
def erfc(x):
|
| 758 |
+
return ops.sub(ops.constant(1, torch.float32), ops.erf(x))
|
| 759 |
+
|
| 760 |
+
@staticmethod
|
| 761 |
+
def erfcx(x):
|
| 762 |
+
return ops.mul(ops.exp(ops.square(x)), ops.erfc(x))
|
| 763 |
+
|
| 764 |
+
@staticmethod
|
| 765 |
+
def expm1(x):
|
| 766 |
+
return ops.sub(ops.exp(x), ops.constant(1, torch.float32))
|
| 767 |
+
|
| 768 |
+
@staticmethod
|
| 769 |
+
def log10(x):
|
| 770 |
+
return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32))
|
| 771 |
+
|
| 772 |
+
@staticmethod
|
| 773 |
+
def log2(x):
|
| 774 |
+
return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32))
|
| 775 |
+
|
| 776 |
+
@staticmethod
|
| 777 |
+
def exp2(x):
|
| 778 |
+
return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32)))
|
| 779 |
+
|
| 780 |
+
@staticmethod
|
| 781 |
+
def log1p(x):
|
| 782 |
+
return ops.log(ops.add(x, ops.constant(1, torch.int32)))
|
| 783 |
+
|
| 784 |
+
@staticmethod
|
| 785 |
+
def sigmoid(x):
|
| 786 |
+
one = ops.constant(1, torch.int32)
|
| 787 |
+
return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x))))
|
| 788 |
+
|
| 789 |
+
@staticmethod
|
| 790 |
+
def libdevice_sigmoid(x):
|
| 791 |
+
one = ops.constant(1, torch.int32)
|
| 792 |
+
return ops.truediv(one, ops.add(one, ops.libdevice_exp(ops.neg(x))))
|
| 793 |
+
|
| 794 |
+
@staticmethod
|
| 795 |
+
def relu(x):
|
| 796 |
+
return ops.maximum(x, ops.constant(0, torch.int32))
|
| 797 |
+
|
| 798 |
+
@staticmethod
|
| 799 |
+
def libdevice_abs(x):
|
| 800 |
+
return ops.abs(x)
|
| 801 |
+
|
| 802 |
+
@staticmethod
|
| 803 |
+
def libdevice_sqrt(x):
|
| 804 |
+
return ops.sqrt(x)
|
| 805 |
+
|
| 806 |
+
@staticmethod
|
| 807 |
+
def libdevice_cos(x):
|
| 808 |
+
return ops.cos(x)
|
| 809 |
+
|
| 810 |
+
@staticmethod
|
| 811 |
+
def libdevice_sin(x):
|
| 812 |
+
return ops.sin(x)
|
| 813 |
+
|
| 814 |
+
@staticmethod
|
| 815 |
+
def libdevice_log(x):
|
| 816 |
+
return ops.log(x)
|
| 817 |
+
|
| 818 |
+
@staticmethod
|
| 819 |
+
def libdevice_exp(x):
|
| 820 |
+
return ops.exp(x)
|
| 821 |
+
|
| 822 |
+
@staticmethod
|
| 823 |
+
def bitwise_not(x):
|
| 824 |
+
return f"~{ExprPrinter.paren(x)}"
|
| 825 |
+
|
| 826 |
+
@staticmethod
|
| 827 |
+
def logical_not(a):
|
| 828 |
+
return f"{ExprPrinter.paren(a)} == 0"
|
| 829 |
+
|
| 830 |
+
@staticmethod
|
| 831 |
+
def bitwise_and(x, y):
|
| 832 |
+
return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
|
| 833 |
+
|
| 834 |
+
@staticmethod
|
| 835 |
+
def bitwise_or(x, y):
|
| 836 |
+
return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
|
| 837 |
+
|
| 838 |
+
@staticmethod
|
| 839 |
+
def bitwise_xor(x, y):
|
| 840 |
+
return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
|
| 841 |
+
|
| 842 |
+
@staticmethod
|
| 843 |
+
def bitwise_left_shift(x, y):
|
| 844 |
+
return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
|
| 845 |
+
|
| 846 |
+
@staticmethod
|
| 847 |
+
def bitwise_right_shift(x, y):
|
| 848 |
+
return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
|
| 849 |
+
|
| 850 |
+
@staticmethod
|
| 851 |
+
def remainder(a, b):
|
| 852 |
+
r = ops.mod(a, b)
|
| 853 |
+
cond = ops.and_(
|
| 854 |
+
ops.ne(r, ops.constant(0, torch.int32)),
|
| 855 |
+
ops.ne(ops.signbit(r), ops.signbit(b)),
|
| 856 |
+
)
|
| 857 |
+
return ops.where(cond, ops.add(r, b), r)
|
| 858 |
+
|
| 859 |
+
@staticmethod
|
| 860 |
+
def trunc_to_int(a, dtype):
|
| 861 |
+
return ops.to_dtype(ops.trunc(a), dtype)
|
| 862 |
+
|
| 863 |
+
@staticmethod
|
| 864 |
+
def floor_to_int(a, dtype):
|
| 865 |
+
return ops.to_dtype(ops.floor(a), dtype)
|
| 866 |
+
|
| 867 |
+
@staticmethod
|
| 868 |
+
def ceil_to_int(a, dtype):
|
| 869 |
+
return ops.to_dtype(ops.ceil(a), dtype)
|
| 870 |
+
|
| 871 |
+
@staticmethod
|
| 872 |
+
def round_to_int(a, dtype):
|
| 873 |
+
return ops.to_dtype(ops.round(a), dtype)
|
| 874 |
+
|
| 875 |
+
@staticmethod
|
| 876 |
+
def int_truediv(a, b):
|
| 877 |
+
# TODO: this is wrong
|
| 878 |
+
# TODO: an easy bandaid is to generate runtime asserts that it's
|
| 879 |
+
# <= 2**53, which is when this equation is correct
|
| 880 |
+
return ops.truediv(a, b)
|
| 881 |
+
|
| 882 |
+
@staticmethod
|
| 883 |
+
def load_seed(name, offset):
|
| 884 |
+
return ops.load(name, sympy.Integer(offset))
|
| 885 |
+
|
| 886 |
+
@classmethod
|
| 887 |
+
def _initialize_pointwise_overrides(cls, target):
|
| 888 |
+
assert target in {"triton", "cpp", "cppvec"}, target
|
| 889 |
+
|
| 890 |
+
for funcname, data in pointwise_overrides_data.items():
|
| 891 |
+
impl = getattr(data, target)
|
| 892 |
+
if impl is None:
|
| 893 |
+
continue
|
| 894 |
+
setattr(cls, funcname, staticmethod(impl))
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
@dataclasses.dataclass
|
| 898 |
+
class OverridesData:
|
| 899 |
+
name: str
|
| 900 |
+
cpp: Callable[..., str]
|
| 901 |
+
# None when not impl in libdevice/triton
|
| 902 |
+
triton: Optional[Callable[..., str]] = None
|
| 903 |
+
# None when not impl in aten/.../vec
|
| 904 |
+
cppvec: Optional[Callable[..., str]] = None
|
| 905 |
+
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
|
| 906 |
+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
# NB: if you add a new special function, don't forget to update
|
| 911 |
+
# torch._inductor.ops_handler too
|
| 912 |
+
pointwise_overrides_data: Dict[str, OverridesData] = dict(
|
| 913 |
+
airy_ai=OverridesData(
|
| 914 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 915 |
+
cpp=lambda x: f"airy_ai_forward({x})",
|
| 916 |
+
name="special_airy_ai",
|
| 917 |
+
),
|
| 918 |
+
bessel_j0=OverridesData(
|
| 919 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 920 |
+
cpp=lambda x: f"bessel_j0_forward({x})",
|
| 921 |
+
triton=lambda x: f"libdevice.j0({x})",
|
| 922 |
+
name="special_bessel_j0",
|
| 923 |
+
),
|
| 924 |
+
bessel_j1=OverridesData(
|
| 925 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 926 |
+
cpp=lambda x: f"bessel_j1_forward({x})",
|
| 927 |
+
triton=lambda x: f"libdevice.j1({x})",
|
| 928 |
+
name="special_bessel_j1",
|
| 929 |
+
),
|
| 930 |
+
bessel_y0=OverridesData(
|
| 931 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 932 |
+
cpp=lambda x: f"bessel_y0_forward({x})",
|
| 933 |
+
triton=lambda x: f"libdevice.y0({x})",
|
| 934 |
+
name="special_bessel_y0",
|
| 935 |
+
),
|
| 936 |
+
bessel_y1=OverridesData(
|
| 937 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 938 |
+
cpp=lambda x: f"bessel_y1_forward({x})",
|
| 939 |
+
triton=lambda x: f"libdevice.y1({x})",
|
| 940 |
+
name="special_bessel_y1",
|
| 941 |
+
),
|
| 942 |
+
digamma=OverridesData(
|
| 943 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 944 |
+
cpp=lambda x: f"calc_digamma({x})",
|
| 945 |
+
cppvec=lambda x: f"{x}.digamma()",
|
| 946 |
+
name="digamma",
|
| 947 |
+
),
|
| 948 |
+
# no cpp nor triton implementation for entr, it is defined as decomposition
|
| 949 |
+
# erf, erfc
|
| 950 |
+
erfcx=OverridesData(
|
| 951 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 952 |
+
cpp=lambda x: f"calc_erfcx({x})",
|
| 953 |
+
triton=lambda x: f"libdevice.erfcx({x})",
|
| 954 |
+
name="special_erfcx",
|
| 955 |
+
),
|
| 956 |
+
fma=OverridesData(
|
| 957 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 958 |
+
cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})",
|
| 959 |
+
cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})",
|
| 960 |
+
triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})",
|
| 961 |
+
name="fma",
|
| 962 |
+
),
|
| 963 |
+
# erfinv, exp2, expit, gammaln
|
| 964 |
+
igamma=OverridesData(
|
| 965 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 966 |
+
cpp=lambda x, y: f"calc_igamma({x}, {y})",
|
| 967 |
+
name="igamma",
|
| 968 |
+
),
|
| 969 |
+
igammac=OverridesData(
|
| 970 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 971 |
+
cpp=lambda x, y: f"calc_igammac({x}, {y})",
|
| 972 |
+
name="igammac",
|
| 973 |
+
),
|
| 974 |
+
gammainc=OverridesData(
|
| 975 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 976 |
+
cpp=lambda x, y: f"calc_igamma({x}, {y})",
|
| 977 |
+
name="special_gammainc",
|
| 978 |
+
),
|
| 979 |
+
gammaincc=OverridesData(
|
| 980 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 981 |
+
cpp=lambda x, y: f"calc_igammac({x}, {y})",
|
| 982 |
+
name="special_gammaincc",
|
| 983 |
+
),
|
| 984 |
+
i0=OverridesData(
|
| 985 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 986 |
+
cpp=lambda x: f"calc_i0({x})",
|
| 987 |
+
triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
|
| 988 |
+
cppvec=lambda x: f"{x}.i0()",
|
| 989 |
+
name="i0",
|
| 990 |
+
),
|
| 991 |
+
i0e=OverridesData(
|
| 992 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 993 |
+
cpp=lambda x: f"calc_i0e({x})",
|
| 994 |
+
cppvec=lambda x: f"{x}.i0e()",
|
| 995 |
+
name="special_i0e",
|
| 996 |
+
),
|
| 997 |
+
i1=OverridesData(
|
| 998 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 999 |
+
cpp=lambda x: f"calc_i1({x})",
|
| 1000 |
+
triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
|
| 1001 |
+
name="special_i1",
|
| 1002 |
+
),
|
| 1003 |
+
i1e=OverridesData(
|
| 1004 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1005 |
+
cpp=lambda x: f"calc_i1e({x})",
|
| 1006 |
+
name="special_i1e",
|
| 1007 |
+
),
|
| 1008 |
+
log_ndtr=OverridesData(
|
| 1009 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1010 |
+
cpp=lambda x: f"calc_log_ndtr({x})",
|
| 1011 |
+
name="special_log_ndtr",
|
| 1012 |
+
),
|
| 1013 |
+
# logit
|
| 1014 |
+
modified_bessel_i0=OverridesData(
|
| 1015 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1016 |
+
cpp=lambda x: f"modified_bessel_i0_forward({x})",
|
| 1017 |
+
triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
|
| 1018 |
+
name="special_modified_bessel_i0",
|
| 1019 |
+
),
|
| 1020 |
+
modified_bessel_i1=OverridesData(
|
| 1021 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1022 |
+
cpp=lambda x: f"modified_bessel_i1_forward({x})",
|
| 1023 |
+
triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
|
| 1024 |
+
name="special_modified_bessel_i1",
|
| 1025 |
+
),
|
| 1026 |
+
modified_bessel_k0=OverridesData(
|
| 1027 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1028 |
+
cpp=lambda x: f"modified_bessel_k0_forward({x})",
|
| 1029 |
+
name="special_modified_bessel_k0",
|
| 1030 |
+
),
|
| 1031 |
+
modified_bessel_k1=OverridesData(
|
| 1032 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1033 |
+
cpp=lambda x: f"modified_bessel_k1_forward({x})",
|
| 1034 |
+
name="special_modified_bessel_k1",
|
| 1035 |
+
),
|
| 1036 |
+
# multigamma
|
| 1037 |
+
ndtr=OverridesData(
|
| 1038 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1039 |
+
cpp=lambda x: f"calc_ndtr({x})",
|
| 1040 |
+
name="special_ndtr",
|
| 1041 |
+
),
|
| 1042 |
+
ndtri=OverridesData(
|
| 1043 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1044 |
+
cpp=lambda x: f"calc_ndtri({x})",
|
| 1045 |
+
name="special_ndtri",
|
| 1046 |
+
),
|
| 1047 |
+
polygamma=OverridesData(
|
| 1048 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1049 |
+
cpp=lambda x, y: f"calc_polygamma({y}, {x})",
|
| 1050 |
+
name="polygamma",
|
| 1051 |
+
),
|
| 1052 |
+
# psi - alias to digamma
|
| 1053 |
+
# round
|
| 1054 |
+
scaled_modified_bessel_k0=OverridesData(
|
| 1055 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1056 |
+
cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})",
|
| 1057 |
+
name="special_scaled_modified_bessel_k0",
|
| 1058 |
+
),
|
| 1059 |
+
scaled_modified_bessel_k1=OverridesData(
|
| 1060 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1061 |
+
cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})",
|
| 1062 |
+
name="special_scaled_modified_bessel_k1",
|
| 1063 |
+
),
|
| 1064 |
+
# sinc
|
| 1065 |
+
spherical_bessel_j0=OverridesData(
|
| 1066 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1067 |
+
cpp=lambda x: f"spherical_bessel_j0_forward({x})",
|
| 1068 |
+
name="special_spherical_bessel_j0",
|
| 1069 |
+
),
|
| 1070 |
+
zeta=OverridesData(
|
| 1071 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1072 |
+
cpp=lambda x, y: f"zeta({x}, {y})",
|
| 1073 |
+
name="special_zeta",
|
| 1074 |
+
),
|
| 1075 |
+
chebyshev_polynomial_t=OverridesData(
|
| 1076 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1077 |
+
cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})",
|
| 1078 |
+
name="special_chebyshev_polynomial_t",
|
| 1079 |
+
),
|
| 1080 |
+
chebyshev_polynomial_u=OverridesData(
|
| 1081 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1082 |
+
cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})",
|
| 1083 |
+
name="special_chebyshev_polynomial_u",
|
| 1084 |
+
),
|
| 1085 |
+
chebyshev_polynomial_v=OverridesData(
|
| 1086 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1087 |
+
cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})",
|
| 1088 |
+
name="special_chebyshev_polynomial_v",
|
| 1089 |
+
),
|
| 1090 |
+
chebyshev_polynomial_w=OverridesData(
|
| 1091 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1092 |
+
cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})",
|
| 1093 |
+
name="special_chebyshev_polynomial_w",
|
| 1094 |
+
),
|
| 1095 |
+
legendre_polynomial_p=OverridesData(
|
| 1096 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1097 |
+
cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})",
|
| 1098 |
+
name="special_legendre_polynomial_p",
|
| 1099 |
+
),
|
| 1100 |
+
shifted_chebyshev_polynomial_t=OverridesData(
|
| 1101 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1102 |
+
cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})",
|
| 1103 |
+
name="special_shifted_chebyshev_polynomial_t",
|
| 1104 |
+
),
|
| 1105 |
+
shifted_chebyshev_polynomial_u=OverridesData(
|
| 1106 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1107 |
+
cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})",
|
| 1108 |
+
name="special_shifted_chebyshev_polynomial_u",
|
| 1109 |
+
),
|
| 1110 |
+
shifted_chebyshev_polynomial_v=OverridesData(
|
| 1111 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1112 |
+
cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})",
|
| 1113 |
+
name="special_shifted_chebyshev_polynomial_v",
|
| 1114 |
+
),
|
| 1115 |
+
shifted_chebyshev_polynomial_w=OverridesData(
|
| 1116 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1117 |
+
cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})",
|
| 1118 |
+
name="special_shifted_chebyshev_polynomial_w",
|
| 1119 |
+
),
|
| 1120 |
+
hermite_polynomial_h=OverridesData(
|
| 1121 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1122 |
+
cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})",
|
| 1123 |
+
name="special_hermite_polynomial_h",
|
| 1124 |
+
),
|
| 1125 |
+
hermite_polynomial_he=OverridesData(
|
| 1126 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1127 |
+
cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})",
|
| 1128 |
+
name="special_hermite_polynomial_he",
|
| 1129 |
+
),
|
| 1130 |
+
laguerre_polynomial_l=OverridesData(
|
| 1131 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1132 |
+
cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})",
|
| 1133 |
+
name="special_laguerre_polynomial_l",
|
| 1134 |
+
),
|
| 1135 |
+
)
|
| 1136 |
+
|
| 1137 |
+
|
| 1138 |
+
# Use mypy to check protocol implemented correctly
|
| 1139 |
+
def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
|
| 1140 |
+
return h
|
| 1141 |
+
|
| 1142 |
+
|
| 1143 |
+
class DeferredLine(DeferredLineBase):
|
| 1144 |
+
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
|
| 1145 |
+
|
| 1146 |
+
def __init__(self, name, line):
|
| 1147 |
+
super().__init__(line)
|
| 1148 |
+
self.name = name
|
| 1149 |
+
assert not isinstance(line, DeferredLineBase)
|
| 1150 |
+
|
| 1151 |
+
def __call__(self):
|
| 1152 |
+
if all(
|
| 1153 |
+
self.name not in x
|
| 1154 |
+
for x in (
|
| 1155 |
+
V.graph.removed_buffers,
|
| 1156 |
+
V.kernel.removed_buffers,
|
| 1157 |
+
V.graph.inplaced_to_remove,
|
| 1158 |
+
V.kernel.inplaced_to_remove,
|
| 1159 |
+
)
|
| 1160 |
+
):
|
| 1161 |
+
return self.line
|
| 1162 |
+
return None
|
| 1163 |
+
|
| 1164 |
+
def _new_line(self, line):
|
| 1165 |
+
return DeferredLine(self.name, line)
|
| 1166 |
+
|
| 1167 |
+
|
| 1168 |
+
class BracesBuffer(IndentedBuffer):
|
| 1169 |
+
def indent(self, offset=1):
|
| 1170 |
+
@contextlib.contextmanager
|
| 1171 |
+
def ctx():
|
| 1172 |
+
for _ in range(offset):
|
| 1173 |
+
self.writeline("{")
|
| 1174 |
+
self._indent += 1
|
| 1175 |
+
for _ in range(-offset):
|
| 1176 |
+
self._indent -= 1
|
| 1177 |
+
self.writeline("}")
|
| 1178 |
+
yield
|
| 1179 |
+
for _ in range(-offset):
|
| 1180 |
+
self.writeline("{")
|
| 1181 |
+
self._indent += 1
|
| 1182 |
+
for _ in range(offset):
|
| 1183 |
+
self._indent -= 1
|
| 1184 |
+
self.writeline("}")
|
| 1185 |
+
|
| 1186 |
+
return ctx()
|
| 1187 |
+
|
| 1188 |
+
|
| 1189 |
+
class InplacedBuffer(NamedTuple):
|
| 1190 |
+
inner_name: str
|
| 1191 |
+
other_names: List[str]
|
| 1192 |
+
|
| 1193 |
+
|
| 1194 |
+
class KernelArgs:
|
| 1195 |
+
@staticmethod
|
| 1196 |
+
def _lookup(prefix, odict, name):
|
| 1197 |
+
assert isinstance(name, (str, sympy.Symbol))
|
| 1198 |
+
if name not in odict:
|
| 1199 |
+
odict[name] = f"{prefix}{len(odict)}"
|
| 1200 |
+
return odict[name]
|
| 1201 |
+
|
| 1202 |
+
def __init__(self, sizevars=None):
|
| 1203 |
+
self.input_buffers = {}
|
| 1204 |
+
self.output_buffers = {}
|
| 1205 |
+
self.inplace_buffers = {}
|
| 1206 |
+
self.sizevars = sizevars or {}
|
| 1207 |
+
self.workspace_arg = None
|
| 1208 |
+
|
| 1209 |
+
def __repr__(self):
|
| 1210 |
+
return "KernelArgs({})".format(
|
| 1211 |
+
", ".join(
|
| 1212 |
+
map(
|
| 1213 |
+
repr,
|
| 1214 |
+
[
|
| 1215 |
+
self.input_buffers,
|
| 1216 |
+
self.output_buffers,
|
| 1217 |
+
self.inplace_buffers,
|
| 1218 |
+
self.sizevars,
|
| 1219 |
+
],
|
| 1220 |
+
)
|
| 1221 |
+
)
|
| 1222 |
+
)
|
| 1223 |
+
|
| 1224 |
+
def _buffer_is_marked_removed(self, name):
|
| 1225 |
+
return isinstance(name, str) and name.startswith("REMOVED")
|
| 1226 |
+
|
| 1227 |
+
def input(self, name):
|
| 1228 |
+
if V.graph.scheduler:
|
| 1229 |
+
name = V.graph.scheduler.mutation_real_name.get(name, name)
|
| 1230 |
+
assert name not in V.graph.removed_buffers, name
|
| 1231 |
+
if name in self.output_buffers:
|
| 1232 |
+
return self.output_buffers[name]
|
| 1233 |
+
if name in self.inplace_buffers:
|
| 1234 |
+
return self.inplace_buffers[name].inner_name
|
| 1235 |
+
if name.startswith("seed"):
|
| 1236 |
+
return self._lookup("seed", self.input_buffers, name)
|
| 1237 |
+
return self._lookup("in_ptr", self.input_buffers, name)
|
| 1238 |
+
|
| 1239 |
+
def output(self, name):
|
| 1240 |
+
if V.graph.scheduler:
|
| 1241 |
+
name = V.graph.scheduler.mutation_real_name.get(name, name)
|
| 1242 |
+
assert name not in V.graph.removed_buffers, name
|
| 1243 |
+
if name in self.inplace_buffers:
|
| 1244 |
+
return self.inplace_buffers[name].inner_name
|
| 1245 |
+
return self._lookup("out_ptr", self.output_buffers, name)
|
| 1246 |
+
|
| 1247 |
+
def make_inplace(self, input_name, output_name):
|
| 1248 |
+
assert output_name not in self.inplace_buffers
|
| 1249 |
+
if input_name in self.inplace_buffers:
|
| 1250 |
+
buf = self.inplace_buffers[input_name]
|
| 1251 |
+
buf.other_names.append(output_name)
|
| 1252 |
+
self.inplace_buffers[output_name] = buf
|
| 1253 |
+
else:
|
| 1254 |
+
buf = InplacedBuffer(
|
| 1255 |
+
f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
|
| 1256 |
+
[input_name, output_name],
|
| 1257 |
+
)
|
| 1258 |
+
self.inplace_buffers[input_name] = buf
|
| 1259 |
+
self.inplace_buffers[output_name] = buf
|
| 1260 |
+
|
| 1261 |
+
def workspace(self, nbytes: sympy.Expr, zero_fill: bool):
|
| 1262 |
+
if self.workspace_arg is None:
|
| 1263 |
+
self.workspace_arg = WorkspaceArg(nbytes, zero_fill)
|
| 1264 |
+
return "ws_ptr", 0
|
| 1265 |
+
|
| 1266 |
+
offset = self.workspace_arg.nbytes
|
| 1267 |
+
zero_fill = zero_fill or self.workspace_arg.zero_fill
|
| 1268 |
+
self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill)
|
| 1269 |
+
return "ws_ptr", offset
|
| 1270 |
+
|
| 1271 |
+
def seed_offset(self, name, value):
|
| 1272 |
+
if value in self.sizevars:
|
| 1273 |
+
return self.sizevars[value]
|
| 1274 |
+
if name in self.sizevars.values():
|
| 1275 |
+
name = (
|
| 1276 |
+
f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
|
| 1277 |
+
)
|
| 1278 |
+
self.sizevars[value] = name
|
| 1279 |
+
return name
|
| 1280 |
+
|
| 1281 |
+
def size(self, name):
|
| 1282 |
+
if str(name) == "seed":
|
| 1283 |
+
self.sizevars["seed"] = "seed"
|
| 1284 |
+
return "seed"
|
| 1285 |
+
return self._lookup("ks", self.sizevars, name)
|
| 1286 |
+
|
| 1287 |
+
def call_names(self):
|
| 1288 |
+
return chain(
|
| 1289 |
+
self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
|
| 1290 |
+
)
|
| 1291 |
+
|
| 1292 |
+
def wrap_ptr_arg(self, buf, dtype):
|
| 1293 |
+
return buf
|
| 1294 |
+
|
| 1295 |
+
def wrap_size_arg(self, size):
|
| 1296 |
+
return str(size)
|
| 1297 |
+
|
| 1298 |
+
def cpp_argdefs(self):
|
| 1299 |
+
from .cpp_utils import DTYPE_TO_CPP, INDEX_TYPE
|
| 1300 |
+
|
| 1301 |
+
call_args = []
|
| 1302 |
+
arg_defs = []
|
| 1303 |
+
arg_types = []
|
| 1304 |
+
for inplaced in unique(self.inplace_buffers.values()):
|
| 1305 |
+
if self._buffer_is_marked_removed(inplaced):
|
| 1306 |
+
continue
|
| 1307 |
+
outer = inplaced.other_names[-1]
|
| 1308 |
+
inner = inplaced.inner_name
|
| 1309 |
+
dtype = V.graph.get_dtype(outer)
|
| 1310 |
+
cpp_dtype = DTYPE_TO_CPP[dtype]
|
| 1311 |
+
arg_defs.append(f"{cpp_dtype}* {inner}")
|
| 1312 |
+
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
| 1313 |
+
arg_types.append(f"{cpp_dtype}*")
|
| 1314 |
+
for outer, inner in self.input_buffers.items():
|
| 1315 |
+
if outer in self.inplace_buffers:
|
| 1316 |
+
continue
|
| 1317 |
+
dtype = V.graph.get_dtype(outer)
|
| 1318 |
+
cpp_dtype = DTYPE_TO_CPP[dtype]
|
| 1319 |
+
arg_defs.append(f"const {cpp_dtype}* {inner}")
|
| 1320 |
+
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
| 1321 |
+
arg_types.append(f"const {cpp_dtype}*")
|
| 1322 |
+
for outer, inner in self.output_buffers.items():
|
| 1323 |
+
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
|
| 1324 |
+
continue
|
| 1325 |
+
dtype = V.graph.get_dtype(outer)
|
| 1326 |
+
cpp_dtype = DTYPE_TO_CPP[dtype]
|
| 1327 |
+
arg_defs.append(f"{cpp_dtype}* {inner}")
|
| 1328 |
+
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
| 1329 |
+
arg_types.append(f"{cpp_dtype}*")
|
| 1330 |
+
for outer, inner in self.sizevars.items():
|
| 1331 |
+
arg_defs.append(f"const {INDEX_TYPE} {inner}")
|
| 1332 |
+
call_args.append(self.wrap_size_arg(outer))
|
| 1333 |
+
arg_types.append(f"const {INDEX_TYPE}")
|
| 1334 |
+
if V.graph.wrapper_code:
|
| 1335 |
+
V.graph.wrapper_code.ensure_size_computed(outer)
|
| 1336 |
+
assert self.workspace_arg is None, "Workspace not supported on CPU "
|
| 1337 |
+
return arg_defs, call_args, arg_types
|
| 1338 |
+
|
| 1339 |
+
def python_argdefs(self):
|
| 1340 |
+
arg_defs: List[str] = []
|
| 1341 |
+
call_args: List[str] = []
|
| 1342 |
+
arg_types: List[torch.dtype] = []
|
| 1343 |
+
precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = []
|
| 1344 |
+
for inplaced in unique(self.inplace_buffers.values()):
|
| 1345 |
+
if self._buffer_is_marked_removed(inplaced):
|
| 1346 |
+
continue
|
| 1347 |
+
arg_defs.append(inplaced.inner_name)
|
| 1348 |
+
call_args.append(inplaced.other_names[-1])
|
| 1349 |
+
arg_types.append(V.graph.get_dtype(inplaced.other_names[-1]))
|
| 1350 |
+
precompile_args.append(
|
| 1351 |
+
TensorArg(
|
| 1352 |
+
name=inplaced.inner_name,
|
| 1353 |
+
buffer=inplaced.other_names[-1],
|
| 1354 |
+
dtype=V.graph.get_dtype(inplaced.other_names[-1]),
|
| 1355 |
+
)
|
| 1356 |
+
)
|
| 1357 |
+
for outer, inner in chain(
|
| 1358 |
+
self.input_buffers.items(), self.output_buffers.items()
|
| 1359 |
+
):
|
| 1360 |
+
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
|
| 1361 |
+
continue
|
| 1362 |
+
arg_defs.append(inner)
|
| 1363 |
+
call_args.append(outer)
|
| 1364 |
+
arg_types.append(V.graph.get_dtype(outer))
|
| 1365 |
+
precompile_args.append(
|
| 1366 |
+
TensorArg(
|
| 1367 |
+
name=inner,
|
| 1368 |
+
buffer=outer,
|
| 1369 |
+
dtype=V.graph.get_dtype(outer),
|
| 1370 |
+
)
|
| 1371 |
+
)
|
| 1372 |
+
for outer, inner in self.sizevars.items():
|
| 1373 |
+
arg_defs.append(inner)
|
| 1374 |
+
call_args.append(outer)
|
| 1375 |
+
arg_types.append(type(outer)) # type: ignore[arg-type]
|
| 1376 |
+
precompile_args.append(SizeArg(inner, outer))
|
| 1377 |
+
if V.graph.wrapper_code:
|
| 1378 |
+
V.graph.wrapper_code.ensure_size_computed(outer)
|
| 1379 |
+
if self.workspace_arg is not None:
|
| 1380 |
+
arg_defs.append("ws_ptr")
|
| 1381 |
+
call_args.append("workspace")
|
| 1382 |
+
precompile_args.append(self.workspace_arg)
|
| 1383 |
+
return arg_defs, call_args, precompile_args, arg_types
|
| 1384 |
+
|
| 1385 |
+
def aliases(self):
|
| 1386 |
+
for inplaced in unique(self.inplace_buffers.values()):
|
| 1387 |
+
if self._buffer_is_marked_removed(inplaced):
|
| 1388 |
+
continue
|
| 1389 |
+
for other in inplaced.other_names:
|
| 1390 |
+
if (
|
| 1391 |
+
other in V.graph.inplaced_to_remove
|
| 1392 |
+
or other in V.kernel.inplaced_to_remove
|
| 1393 |
+
):
|
| 1394 |
+
continue
|
| 1395 |
+
if other in self.input_buffers:
|
| 1396 |
+
yield self.input_buffers[other], inplaced.inner_name
|
| 1397 |
+
if other in self.output_buffers:
|
| 1398 |
+
yield self.output_buffers[other], inplaced.inner_name
|
| 1399 |
+
|
| 1400 |
+
def is_removed(self, name):
|
| 1401 |
+
def _is_removed(name, buffers):
|
| 1402 |
+
return name not in buffers or self._buffer_is_marked_removed(buffers[name])
|
| 1403 |
+
|
| 1404 |
+
return _is_removed(name, self.output_buffers) and _is_removed(
|
| 1405 |
+
name, self.inplace_buffers
|
| 1406 |
+
)
|
| 1407 |
+
|
| 1408 |
+
# Includes inplace buffers, excludes removed buffers. Essentially,
|
| 1409 |
+
# after you do a call into this kernel, which buffers actually contain
|
| 1410 |
+
# updated data? Modeled off of python_argdefs.
|
| 1411 |
+
def live_output_buffers(self):
|
| 1412 |
+
live_outs = OrderedSet() # type: ignore[var-annotated]
|
| 1413 |
+
for inplaced in unique(self.inplace_buffers.values()):
|
| 1414 |
+
if self._buffer_is_marked_removed(inplaced):
|
| 1415 |
+
continue
|
| 1416 |
+
live_outs.add(inplaced.other_names[-1])
|
| 1417 |
+
for outer, inner in self.output_buffers.items():
|
| 1418 |
+
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
|
| 1419 |
+
continue
|
| 1420 |
+
live_outs.add(outer)
|
| 1421 |
+
return live_outs
|
| 1422 |
+
|
| 1423 |
+
|
| 1424 |
+
class CSEVariable:
|
| 1425 |
+
"""A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
|
| 1426 |
+
To do so, the backends can simply overload `Kernel.create_cse_var`
|
| 1427 |
+
The "CSEVariable.update_on_args" method gives you a hook for annotations
|
| 1428 |
+
See example of TritonCSEVariable in triton.py
|
| 1429 |
+
"""
|
| 1430 |
+
|
| 1431 |
+
def __init__(self, name, bounds: ValueRanges[Any]):
|
| 1432 |
+
assert isinstance(bounds, ValueRanges)
|
| 1433 |
+
self.name = name
|
| 1434 |
+
self.bounds = bounds
|
| 1435 |
+
self.use_count = 1 # track how many tims this expression is used
|
| 1436 |
+
|
| 1437 |
+
def __str__(self):
|
| 1438 |
+
return self.name
|
| 1439 |
+
|
| 1440 |
+
def __hash__(self) -> int:
|
| 1441 |
+
return hash(self.name)
|
| 1442 |
+
|
| 1443 |
+
def __eq__(self, other) -> bool:
|
| 1444 |
+
return type(other) == type(self) and other.name == self.name
|
| 1445 |
+
|
| 1446 |
+
def update_on_args(self, name, args, kwargs):
|
| 1447 |
+
pass
|
| 1448 |
+
|
| 1449 |
+
def __repr__(self):
|
| 1450 |
+
return f"{self.__class__.__name__}({self.name!r})"
|
| 1451 |
+
|
| 1452 |
+
|
| 1453 |
+
class CppWrapperKernelArgs(KernelArgs):
|
| 1454 |
+
def wrap_ptr_arg(self, buf, dtype):
|
| 1455 |
+
from .cpp_utils import DTYPE_TO_CPP
|
| 1456 |
+
|
| 1457 |
+
if config.abi_compatible:
|
| 1458 |
+
# In the abi_compatible model, we just return the buf here.
|
| 1459 |
+
# We will form correct call args later in wrapper.generate_kernel_all.
|
| 1460 |
+
return buf
|
| 1461 |
+
else:
|
| 1462 |
+
return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"
|
| 1463 |
+
|
| 1464 |
+
def wrap_size_arg(self, size):
|
| 1465 |
+
return f"{size}"
|
| 1466 |
+
|
| 1467 |
+
|
| 1468 |
+
class CSE:
|
| 1469 |
+
"""Common subexpression elimination"""
|
| 1470 |
+
|
| 1471 |
+
def __init__(
|
| 1472 |
+
self,
|
| 1473 |
+
prefix="",
|
| 1474 |
+
suffix="",
|
| 1475 |
+
name_prefix="tmp",
|
| 1476 |
+
iter_buffers=None,
|
| 1477 |
+
store_cache=None,
|
| 1478 |
+
reduction_cache=None,
|
| 1479 |
+
varname_map=None,
|
| 1480 |
+
):
|
| 1481 |
+
self.prefix = prefix
|
| 1482 |
+
self.suffix = suffix
|
| 1483 |
+
self.cache = {}
|
| 1484 |
+
self.name_prefix = name_prefix
|
| 1485 |
+
self.store_cache = store_cache or {}
|
| 1486 |
+
self.reduction_cache = reduction_cache or {}
|
| 1487 |
+
self.iter_buffer_ids = iter_buffers or itertools.count()
|
| 1488 |
+
self.invalidated_stores = OrderedSet() # type: ignore[var-annotated]
|
| 1489 |
+
self.varname_map = varname_map or {}
|
| 1490 |
+
|
| 1491 |
+
def invalidate(self, keep_vars: OrderedSet[str]):
|
| 1492 |
+
for name, tmp in list(self.store_cache.items()):
|
| 1493 |
+
if tmp not in keep_vars:
|
| 1494 |
+
del self.store_cache[name]
|
| 1495 |
+
self.invalidated_stores.add(name)
|
| 1496 |
+
self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}
|
| 1497 |
+
|
| 1498 |
+
def clone(self):
|
| 1499 |
+
# Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
|
| 1500 |
+
return CSE(
|
| 1501 |
+
prefix=self.prefix,
|
| 1502 |
+
suffix=self.suffix,
|
| 1503 |
+
name_prefix=self.name_prefix,
|
| 1504 |
+
iter_buffers=self.iter_buffer_ids,
|
| 1505 |
+
store_cache=self.store_cache,
|
| 1506 |
+
varname_map=self.varname_map,
|
| 1507 |
+
)
|
| 1508 |
+
|
| 1509 |
+
def generate(
|
| 1510 |
+
self,
|
| 1511 |
+
buffer: IndentedBuffer,
|
| 1512 |
+
expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
|
| 1513 |
+
*,
|
| 1514 |
+
bounds: ValueRanges[Any] = ValueRanges.unknown(),
|
| 1515 |
+
write=True,
|
| 1516 |
+
assignment=True,
|
| 1517 |
+
) -> CSEVariable:
|
| 1518 |
+
if isinstance(expr, OpsValue):
|
| 1519 |
+
expr = expr.value
|
| 1520 |
+
|
| 1521 |
+
assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr)
|
| 1522 |
+
assert write or assignment
|
| 1523 |
+
if isinstance(expr, CSEVariable):
|
| 1524 |
+
# If the expressions were always created with all the information, we could
|
| 1525 |
+
# assert expr.bounds == bounds, but sometimes the expression is created
|
| 1526 |
+
# with the loose ValueRanges.unknown(), so we need to tighten the bounds
|
| 1527 |
+
expr.bounds = expr.bounds.tighten(bounds)
|
| 1528 |
+
expr.use_count += 1
|
| 1529 |
+
return expr
|
| 1530 |
+
cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr
|
| 1531 |
+
var = self.cache.get(cache_key, None)
|
| 1532 |
+
if not var:
|
| 1533 |
+
var = self.newvar(bounds)
|
| 1534 |
+
self.cache[cache_key] = var
|
| 1535 |
+
if write:
|
| 1536 |
+
if V.kernel.current_node:
|
| 1537 |
+
V.kernel.current_node.codegen_originating_info(
|
| 1538 |
+
buffer, only_once=True
|
| 1539 |
+
)
|
| 1540 |
+
if isinstance(expr, IndentedBuffer):
|
| 1541 |
+
if assignment:
|
| 1542 |
+
buffer.writeline(f"{self.prefix}{var} =")
|
| 1543 |
+
buffer.splice(expr)
|
| 1544 |
+
buffer.writeline(self.suffix)
|
| 1545 |
+
else:
|
| 1546 |
+
if assignment:
|
| 1547 |
+
line = f"{self.prefix}{var} = {expr}{self.suffix}"
|
| 1548 |
+
else:
|
| 1549 |
+
line = f"{expr}{self.suffix}"
|
| 1550 |
+
buffer.writeline(line)
|
| 1551 |
+
else:
|
| 1552 |
+
var.bounds = var.bounds.tighten(bounds)
|
| 1553 |
+
var.use_count += 1
|
| 1554 |
+
|
| 1555 |
+
return var
|
| 1556 |
+
|
| 1557 |
+
def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable:
|
| 1558 |
+
var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
|
| 1559 |
+
var = V.kernel.create_cse_var(var_name, bounds)
|
| 1560 |
+
self.varname_map[var_name] = var
|
| 1561 |
+
return var
|
| 1562 |
+
|
| 1563 |
+
|
| 1564 |
+
class CodeGen:
|
| 1565 |
+
def __init__(self) -> None:
|
| 1566 |
+
super().__init__()
|
| 1567 |
+
self.exit_stack = contextlib.ExitStack()
|
| 1568 |
+
|
| 1569 |
+
def __enter__(self):
|
| 1570 |
+
self.exit_stack.__enter__()
|
| 1571 |
+
return self
|
| 1572 |
+
|
| 1573 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 1574 |
+
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
|
| 1575 |
+
|
| 1576 |
+
|
| 1577 |
+
class ScopedDict:
|
| 1578 |
+
def __init__(self, original_dict):
|
| 1579 |
+
self.original_dict = original_dict
|
| 1580 |
+
self.new_items = {}
|
| 1581 |
+
|
| 1582 |
+
def __getitem__(self, key):
|
| 1583 |
+
if key in self.new_items:
|
| 1584 |
+
return self.new_items[key]
|
| 1585 |
+
return self.original_dict[key]
|
| 1586 |
+
|
| 1587 |
+
def __setitem__(self, key, value):
|
| 1588 |
+
self.new_items[key] = value
|
| 1589 |
+
|
| 1590 |
+
def __contains__(self, key):
|
| 1591 |
+
return key in self.new_items or key in self.original_dict
|
| 1592 |
+
|
| 1593 |
+
def get(self, key, default=None):
|
| 1594 |
+
if key in self.new_items:
|
| 1595 |
+
return self.new_items[key]
|
| 1596 |
+
return self.original_dict.get(key, default)
|
| 1597 |
+
|
| 1598 |
+
|
| 1599 |
+
class Kernel(CodeGen):
|
| 1600 |
+
newvar_prefix = ""
|
| 1601 |
+
suffix = ""
|
| 1602 |
+
overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
|
| 1603 |
+
# TODO: these look dead, but with all the getattr it's hard to tell...
|
| 1604 |
+
load_format: None = None
|
| 1605 |
+
store_format: None = None
|
| 1606 |
+
|
| 1607 |
+
def __init__(self, args=None, increase_kernel_count=True):
|
| 1608 |
+
super().__init__()
|
| 1609 |
+
if increase_kernel_count:
|
| 1610 |
+
metrics.generated_kernel_count += 1
|
| 1611 |
+
self.args = args or KernelArgs()
|
| 1612 |
+
self.loads = IndentedBuffer()
|
| 1613 |
+
self.compute = IndentedBuffer()
|
| 1614 |
+
self.stores = IndentedBuffer()
|
| 1615 |
+
|
| 1616 |
+
self.num_load = 0
|
| 1617 |
+
self.num_reduction = 0
|
| 1618 |
+
|
| 1619 |
+
self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
|
| 1620 |
+
self.must_keep_buffers = OrderedSet() # type: ignore[var-annotated]
|
| 1621 |
+
self.store_buffer_names = OrderedSet() # type: ignore[var-annotated]
|
| 1622 |
+
self._load_mask = None
|
| 1623 |
+
self._load_other = None
|
| 1624 |
+
# OrderedSet in set_current_node
|
| 1625 |
+
self.current_node = None
|
| 1626 |
+
self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None
|
| 1627 |
+
|
| 1628 |
+
self.removed_buffers = OrderedSet() # type: ignore[var-annotated]
|
| 1629 |
+
self.inplaced_to_remove = OrderedSet() # type: ignore[var-annotated]
|
| 1630 |
+
|
| 1631 |
+
# key: the buffer to write
|
| 1632 |
+
# value: the buffer to read and whose memory can be reused for
|
| 1633 |
+
# the buffer specified by key
|
| 1634 |
+
self.inplace_update_buffers = {}
|
| 1635 |
+
# Set minimum number of elements processed per thread.
|
| 1636 |
+
self.min_elem_per_thread = 1
|
| 1637 |
+
self.kernel_name = None
|
| 1638 |
+
|
| 1639 |
+
@contextlib.contextmanager
|
| 1640 |
+
def set_current_node(self, node):
|
| 1641 |
+
prior = self.current_node
|
| 1642 |
+
self.current_node = node
|
| 1643 |
+
self.node_to_bounds = node._body.bounds().get_bounds()
|
| 1644 |
+
try:
|
| 1645 |
+
yield
|
| 1646 |
+
finally:
|
| 1647 |
+
self.current_node = prior
|
| 1648 |
+
|
| 1649 |
+
@contextlib.contextmanager
|
| 1650 |
+
def swap_buffers(self, lb, cb=None, sb=None):
|
| 1651 |
+
def scope_cse(cse):
|
| 1652 |
+
new_cse = cse.clone()
|
| 1653 |
+
new_cse.cache = ScopedDict(cse.cache)
|
| 1654 |
+
new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
|
| 1655 |
+
new_cse.store_cache = ScopedDict(cse.store_cache)
|
| 1656 |
+
return new_cse
|
| 1657 |
+
|
| 1658 |
+
if cb is None:
|
| 1659 |
+
cb = lb
|
| 1660 |
+
loads = self.loads
|
| 1661 |
+
compute = self.compute
|
| 1662 |
+
stores = self.stores
|
| 1663 |
+
cse = self.cse
|
| 1664 |
+
self.loads = lb
|
| 1665 |
+
self.compute = cb
|
| 1666 |
+
self.stores = sb
|
| 1667 |
+
self.cse = scope_cse(cse)
|
| 1668 |
+
try:
|
| 1669 |
+
yield
|
| 1670 |
+
finally:
|
| 1671 |
+
self.loads = loads
|
| 1672 |
+
self.compute = compute
|
| 1673 |
+
self.stores = stores
|
| 1674 |
+
self.cse = cse
|
| 1675 |
+
|
| 1676 |
+
def load(self, name: str, index: sympy.Expr) -> CSEVariable:
|
| 1677 |
+
raise NotImplementedError
|
| 1678 |
+
|
| 1679 |
+
def indirect_load(self, name: str, index: sympy.Expr):
|
| 1680 |
+
"""A load the depends on an index we have read"""
|
| 1681 |
+
prior = self.loads
|
| 1682 |
+
try:
|
| 1683 |
+
# put the load in the compute section as it might have deps
|
| 1684 |
+
self.loads = self.compute
|
| 1685 |
+
return self.load(name, index)
|
| 1686 |
+
finally:
|
| 1687 |
+
self.loads = prior
|
| 1688 |
+
|
| 1689 |
+
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
|
| 1690 |
+
raise NotImplementedError
|
| 1691 |
+
|
| 1692 |
+
def store(
|
| 1693 |
+
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
| 1694 |
+
) -> None:
|
| 1695 |
+
raise NotImplementedError
|
| 1696 |
+
|
| 1697 |
+
def reduction(
|
| 1698 |
+
self,
|
| 1699 |
+
dtype: torch.dtype,
|
| 1700 |
+
src_dtype: torch.dtype,
|
| 1701 |
+
reduction_type: ReductionType,
|
| 1702 |
+
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
|
| 1703 |
+
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
|
| 1704 |
+
raise NotImplementedError
|
| 1705 |
+
|
| 1706 |
+
def scan(
|
| 1707 |
+
self,
|
| 1708 |
+
dtypes: Tuple[torch.dtype, ...],
|
| 1709 |
+
combine_fn: Callable[
|
| 1710 |
+
[Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...]
|
| 1711 |
+
],
|
| 1712 |
+
values: Tuple[CSEVariable, ...],
|
| 1713 |
+
) -> Tuple[CSEVariable, ...]:
|
| 1714 |
+
raise NotImplementedError
|
| 1715 |
+
|
| 1716 |
+
def sort(
|
| 1717 |
+
self,
|
| 1718 |
+
dtypes: Tuple[torch.dtype, ...],
|
| 1719 |
+
values: Tuple[CSEVariable, ...],
|
| 1720 |
+
stable: bool,
|
| 1721 |
+
descending: bool,
|
| 1722 |
+
) -> Tuple[CSEVariable, ...]:
|
| 1723 |
+
raise NotImplementedError
|
| 1724 |
+
|
| 1725 |
+
def var_ranges(self):
|
| 1726 |
+
raise NotImplementedError
|
| 1727 |
+
|
| 1728 |
+
def bucketize(
|
| 1729 |
+
self,
|
| 1730 |
+
values: CSEVariable,
|
| 1731 |
+
offsets_name: str,
|
| 1732 |
+
offsets_size: sympy.Expr,
|
| 1733 |
+
indexing_dtype: torch.dtype,
|
| 1734 |
+
right: bool,
|
| 1735 |
+
) -> CSEVariable:
|
| 1736 |
+
"""
|
| 1737 |
+
See [Note: Inductor bucketize op]
|
| 1738 |
+
"""
|
| 1739 |
+
raise NotImplementedError
|
| 1740 |
+
|
| 1741 |
+
@property
|
| 1742 |
+
def assert_function(self) -> str:
|
| 1743 |
+
raise NotImplementedError
|
| 1744 |
+
|
| 1745 |
+
def indirect_assert(
|
| 1746 |
+
self,
|
| 1747 |
+
var: Union[CSEVariable, str],
|
| 1748 |
+
lower: Optional[str],
|
| 1749 |
+
upper: Optional[str],
|
| 1750 |
+
mask: Optional[Union[CSEVariable, str]] = None,
|
| 1751 |
+
) -> str:
|
| 1752 |
+
if isinstance(var, CSEVariable):
|
| 1753 |
+
var = str(var)
|
| 1754 |
+
assert isinstance(var, str)
|
| 1755 |
+
assert lower is None or isinstance(lower, str)
|
| 1756 |
+
assert upper is None or isinstance(upper, str)
|
| 1757 |
+
if lower and upper:
|
| 1758 |
+
# The conditions need to be in parens because of Python's operator precedence.
|
| 1759 |
+
# It'd be less error-prone to use and/or/not, which is suported by triton
|
| 1760 |
+
cond = f"({lower} <= {var}) & ({var} < {upper})"
|
| 1761 |
+
cond_print = f"{lower} <= {var} < {upper}"
|
| 1762 |
+
elif lower:
|
| 1763 |
+
cond = f"{lower} <= {var}"
|
| 1764 |
+
cond_print = cond
|
| 1765 |
+
else:
|
| 1766 |
+
assert upper
|
| 1767 |
+
cond = f"{var} < {upper}"
|
| 1768 |
+
cond_print = cond
|
| 1769 |
+
|
| 1770 |
+
if mask:
|
| 1771 |
+
cond = f"({cond}) | ~({mask})"
|
| 1772 |
+
|
| 1773 |
+
return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'
|
| 1774 |
+
|
| 1775 |
+
def check_bounds(
|
| 1776 |
+
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
| 1777 |
+
):
|
| 1778 |
+
raise NotImplementedError
|
| 1779 |
+
|
| 1780 |
+
def index_to_str(self, index: sympy.Expr) -> str:
|
| 1781 |
+
raise NotImplementedError
|
| 1782 |
+
|
| 1783 |
+
def __enter__(self):
|
| 1784 |
+
# TODO: hoist this to top level
|
| 1785 |
+
class CSEProxy:
|
| 1786 |
+
self.name = "CSEProxy"
|
| 1787 |
+
vr_analysis = ValueRangeAnalysis()
|
| 1788 |
+
|
| 1789 |
+
@staticmethod
|
| 1790 |
+
def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
|
| 1791 |
+
def inner(*args, **kwargs):
|
| 1792 |
+
bounds = CSEProxy._bound_variable(name, *args, **kwargs)
|
| 1793 |
+
|
| 1794 |
+
value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
|
| 1795 |
+
|
| 1796 |
+
def do_cse(v):
|
| 1797 |
+
csevar = V.kernel.cse.generate(
|
| 1798 |
+
V.kernel.compute, v, bounds=bounds
|
| 1799 |
+
)
|
| 1800 |
+
csevar.update_on_args(name, args, kwargs)
|
| 1801 |
+
return csevar
|
| 1802 |
+
|
| 1803 |
+
return pytree.tree_map(do_cse, value)
|
| 1804 |
+
|
| 1805 |
+
return inner
|
| 1806 |
+
|
| 1807 |
+
@staticmethod
|
| 1808 |
+
def _bound_variable(name, *args, **kwargs):
|
| 1809 |
+
"""
|
| 1810 |
+
If the variable comes from an FX node, we forward the bound we have already computed
|
| 1811 |
+
Else, if the variable when codegen'ing another op, we try to compute its bounds
|
| 1812 |
+
"""
|
| 1813 |
+
from ..select_algorithm import TritonTemplateKernel
|
| 1814 |
+
|
| 1815 |
+
if isinstance(V.kernel, TritonTemplateKernel):
|
| 1816 |
+
return ValueRanges.unknown()
|
| 1817 |
+
|
| 1818 |
+
fx_node = V.interpreter.current_node
|
| 1819 |
+
if fx_node.target == name and self.node_to_bounds is not None:
|
| 1820 |
+
assert isinstance(self.node_to_bounds, dict)
|
| 1821 |
+
return self.node_to_bounds.get(fx_node, ValueRanges.unknown())
|
| 1822 |
+
elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
|
| 1823 |
+
# These create lots of inner strings. We would need to compute the bounds at the ops
|
| 1824 |
+
# We will also likely not get much from computing VRs on these nodes
|
| 1825 |
+
if any(
|
| 1826 |
+
s in fx_node.target
|
| 1827 |
+
for s in ("set_indirect", "reduction", "scan")
|
| 1828 |
+
):
|
| 1829 |
+
return ValueRanges.unknown()
|
| 1830 |
+
|
| 1831 |
+
# We assume that the inputs come from `ops.` and are not strings. If you want to generate
|
| 1832 |
+
# intermediary strings, wrap them in CSE variables with properly initialised bounds.
|
| 1833 |
+
|
| 1834 |
+
# If there is no FX bound but we know how to compute one we do so
|
| 1835 |
+
assert not kwargs
|
| 1836 |
+
|
| 1837 |
+
def arg_to_bound(x):
|
| 1838 |
+
if isinstance(x, CSEVariable):
|
| 1839 |
+
return x.bounds
|
| 1840 |
+
elif isinstance(x, sympy.Expr):
|
| 1841 |
+
return bound_sympy(x)
|
| 1842 |
+
else:
|
| 1843 |
+
return x
|
| 1844 |
+
|
| 1845 |
+
arg_bounds = list(map(arg_to_bound, args))
|
| 1846 |
+
return getattr(CSEProxy.vr_analysis, name)(*arg_bounds)
|
| 1847 |
+
else:
|
| 1848 |
+
return ValueRanges.unknown()
|
| 1849 |
+
|
| 1850 |
+
@staticmethod
|
| 1851 |
+
def indirect_indexing(
|
| 1852 |
+
var: CSEVariable,
|
| 1853 |
+
size: Union[sympy.Expr, int],
|
| 1854 |
+
check: bool = True,
|
| 1855 |
+
wrap_neg=True,
|
| 1856 |
+
):
|
| 1857 |
+
if isinstance(size, int):
|
| 1858 |
+
size = sympy.Integer(size)
|
| 1859 |
+
assert isinstance(size, sympy.Expr), size
|
| 1860 |
+
# Skip CSE since this doesn't return an expression
|
| 1861 |
+
|
| 1862 |
+
if var.bounds.lower < 0: # type: ignore[operator]
|
| 1863 |
+
if wrap_neg:
|
| 1864 |
+
stm = ops.add(var, ops.index_expr(size, torch.long))
|
| 1865 |
+
# Mixed negative and non-negative
|
| 1866 |
+
if var.bounds.upper >= 0: # type: ignore[operator]
|
| 1867 |
+
lt = ops.lt(var, 0)
|
| 1868 |
+
stm = ops.where(lt, stm, var)
|
| 1869 |
+
else:
|
| 1870 |
+
stm = var
|
| 1871 |
+
|
| 1872 |
+
# Propagate bounds as we know how to compute them properly
|
| 1873 |
+
new_bounds = ValueRanges.unknown()
|
| 1874 |
+
if var.bounds != ValueRanges.unknown() and isinstance(
|
| 1875 |
+
size, sympy.Number
|
| 1876 |
+
):
|
| 1877 |
+
# Take the negative part of the bound and add size to it
|
| 1878 |
+
# Then take union of that and the positive part
|
| 1879 |
+
# This is a tighter bound than that of a generic ops.where, as we have info on the cond
|
| 1880 |
+
neg_bounds = var.bounds & ValueRanges(-int_oo, -1)
|
| 1881 |
+
new_bounds = ValueRanges(
|
| 1882 |
+
neg_bounds.lower + size, neg_bounds.upper + size
|
| 1883 |
+
)
|
| 1884 |
+
# We don't have a good way of representing the empty range
|
| 1885 |
+
if var.bounds.upper >= 0: # type: ignore[operator]
|
| 1886 |
+
pos = var.bounds & ValueRanges(0, int_oo)
|
| 1887 |
+
new_bounds = new_bounds | pos
|
| 1888 |
+
|
| 1889 |
+
var = self.cse.generate(self.compute, stm, bounds=new_bounds)
|
| 1890 |
+
|
| 1891 |
+
sympy_var = parent_handler.indirect_indexing(var, size, check)
|
| 1892 |
+
if generate_assert(check):
|
| 1893 |
+
assert_lower = not (var.bounds.lower >= 0)
|
| 1894 |
+
# value ranges cannot x < s when x and s are symbols
|
| 1895 |
+
assert_upper = not isinstance(size, sympy.Number) or not (
|
| 1896 |
+
var.bounds.upper < size
|
| 1897 |
+
)
|
| 1898 |
+
self.check_bounds(sympy_var, size, assert_lower, assert_upper)
|
| 1899 |
+
return sympy_var
|
| 1900 |
+
|
| 1901 |
+
@staticmethod
|
| 1902 |
+
def check_bounds(
|
| 1903 |
+
expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
| 1904 |
+
):
|
| 1905 |
+
return self.check_bounds(expr, size, lower, upper)
|
| 1906 |
+
|
| 1907 |
+
@staticmethod
|
| 1908 |
+
def load(name: str, index: sympy.Expr) -> CSEVariable:
|
| 1909 |
+
if name in self.cse.invalidated_stores:
|
| 1910 |
+
# A load from an invalidated store requires us to
|
| 1911 |
+
# keep the actual buffer around
|
| 1912 |
+
V.kernel.must_keep_buffers.add(name)
|
| 1913 |
+
if free_symbol_is_type(index, SymT.TMP):
|
| 1914 |
+
return self.indirect_load(name, index)
|
| 1915 |
+
store_cache = self.cse.store_cache
|
| 1916 |
+
if name in store_cache:
|
| 1917 |
+
return store_cache[name]
|
| 1918 |
+
out = self.load(name, index)
|
| 1919 |
+
# count load that is not in the store_cache, and also not in the
|
| 1920 |
+
# cse cache.
|
| 1921 |
+
if out.use_count == 1:
|
| 1922 |
+
self.num_load += 1
|
| 1923 |
+
return out
|
| 1924 |
+
|
| 1925 |
+
@staticmethod
|
| 1926 |
+
def _update_store_cache(name: str, value: CSEVariable):
|
| 1927 |
+
self.cse.store_cache[name] = value
|
| 1928 |
+
if self.current_node and name in V.graph.name_to_buffer:
|
| 1929 |
+
buf = self.current_node.get_output(name)
|
| 1930 |
+
for other_name in buf.get_mutations():
|
| 1931 |
+
self.cse.store_cache[other_name] = value
|
| 1932 |
+
|
| 1933 |
+
@staticmethod
|
| 1934 |
+
def store(
|
| 1935 |
+
name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
| 1936 |
+
) -> None:
|
| 1937 |
+
self.store_buffer_names.add(name)
|
| 1938 |
+
if mode is None:
|
| 1939 |
+
CSEProxy._update_store_cache(name, value)
|
| 1940 |
+
if name not in V.graph.removed_buffers:
|
| 1941 |
+
return self.store(name, index, value, mode=mode)
|
| 1942 |
+
else:
|
| 1943 |
+
return None # type: ignore[return-value]
|
| 1944 |
+
|
| 1945 |
+
@staticmethod
|
| 1946 |
+
def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
|
| 1947 |
+
self.store_buffer_names.add(name)
|
| 1948 |
+
CSEProxy._update_store_cache(name, value)
|
| 1949 |
+
|
| 1950 |
+
if name not in V.graph.removed_buffers:
|
| 1951 |
+
return self.store_reduction(name, index, value)
|
| 1952 |
+
|
| 1953 |
+
@staticmethod
|
| 1954 |
+
def reduction(
|
| 1955 |
+
dtype: torch.dtype,
|
| 1956 |
+
src_dtype: torch.dtype,
|
| 1957 |
+
reduction_type: ReductionType,
|
| 1958 |
+
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
|
| 1959 |
+
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
|
| 1960 |
+
self.num_reduction += 1
|
| 1961 |
+
return self.reduction(dtype, src_dtype, reduction_type, value)
|
| 1962 |
+
|
| 1963 |
+
@staticmethod
|
| 1964 |
+
def scan(
|
| 1965 |
+
dtypes: Tuple[torch.dtype, ...],
|
| 1966 |
+
combine_fn: Callable[
|
| 1967 |
+
[Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]],
|
| 1968 |
+
Tuple[CSEVariable, ...],
|
| 1969 |
+
],
|
| 1970 |
+
values: Tuple[CSEVariable, ...],
|
| 1971 |
+
) -> Tuple[CSEVariable, ...]:
|
| 1972 |
+
return self.scan(dtypes, combine_fn, values)
|
| 1973 |
+
|
| 1974 |
+
@staticmethod
|
| 1975 |
+
def sort(
|
| 1976 |
+
dtypes: Tuple[torch.dtype, ...],
|
| 1977 |
+
values: Tuple[CSEVariable, ...],
|
| 1978 |
+
stable: bool,
|
| 1979 |
+
descending: bool,
|
| 1980 |
+
) -> Tuple[CSEVariable, ...]:
|
| 1981 |
+
return self.sort(dtypes, values, stable, descending)
|
| 1982 |
+
|
| 1983 |
+
@staticmethod
|
| 1984 |
+
def bucketize(
|
| 1985 |
+
values: CSEVariable,
|
| 1986 |
+
offsets_name: str,
|
| 1987 |
+
offsets_size: sympy.Expr,
|
| 1988 |
+
indexing_dtype: torch.dtype,
|
| 1989 |
+
right: bool,
|
| 1990 |
+
) -> CSEVariable:
|
| 1991 |
+
"""
|
| 1992 |
+
[Note: Inductor bucketize op]
|
| 1993 |
+
|
| 1994 |
+
Given values (tensor) and offsets_name (reference to the name of a 1D
|
| 1995 |
+
tensor), calculate the bucket that each value belongs to.
|
| 1996 |
+
|
| 1997 |
+
e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
|
| 1998 |
+
return = [ 0, 1, 1, 1, 1, 3, 3, 4].
|
| 1999 |
+
|
| 2000 |
+
When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
|
| 2001 |
+
When right == True, bucket i refers to range [offsets[i], offsets[i+1]).
|
| 2002 |
+
|
| 2003 |
+
Offsets must be non-decreasing or the result is undefined.
|
| 2004 |
+
"""
|
| 2005 |
+
return self.bucketize(
|
| 2006 |
+
values, offsets_name, offsets_size, indexing_dtype, right
|
| 2007 |
+
)
|
| 2008 |
+
|
| 2009 |
+
# Use mypy to check protocol implemented correctly
|
| 2010 |
+
def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
|
| 2011 |
+
return h
|
| 2012 |
+
|
| 2013 |
+
super().__enter__()
|
| 2014 |
+
assert self.overrides
|
| 2015 |
+
parent_handler = self.overrides(V.get_ops_handler())
|
| 2016 |
+
self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
|
| 2017 |
+
self.exit_stack.enter_context(V.set_kernel_handler(self))
|
| 2018 |
+
return self
|
| 2019 |
+
|
| 2020 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 2021 |
+
"""
|
| 2022 |
+
Note that V.graph.scheduler can be None when codegening triton template
|
| 2023 |
+
kernels.
|
| 2024 |
+
"""
|
| 2025 |
+
if V.graph.scheduler:
|
| 2026 |
+
V.graph.scheduler.remove_kernel_local_buffers()
|
| 2027 |
+
super().__exit__(exc_type, exc_val, exc_tb)
|
| 2028 |
+
|
| 2029 |
+
def rename_indexing(self, index) -> sympy.Expr:
|
| 2030 |
+
# adds the necessary kernel args for index expressions
|
| 2031 |
+
# and renames variables in index expressions to kernel arg names
|
| 2032 |
+
if isinstance(index, (list, tuple)):
|
| 2033 |
+
return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
|
| 2034 |
+
index = V.graph.sizevars.simplify(index)
|
| 2035 |
+
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
|
| 2036 |
+
replacements = {
|
| 2037 |
+
x: self.args.size(x)
|
| 2038 |
+
for x in sorted_symbols
|
| 2039 |
+
if symbol_is_type(
|
| 2040 |
+
x,
|
| 2041 |
+
(
|
| 2042 |
+
SymT.UNBACKED_INT,
|
| 2043 |
+
SymT.SIZE,
|
| 2044 |
+
SymT.PRECOMPUTED_SIZE,
|
| 2045 |
+
),
|
| 2046 |
+
)
|
| 2047 |
+
}
|
| 2048 |
+
return sympy_subs(index, replacements)
|
| 2049 |
+
|
| 2050 |
+
def create_cse_var(self, *args, **kwargs):
|
| 2051 |
+
return CSEVariable(*args, **kwargs)
|
| 2052 |
+
|
| 2053 |
+
|
| 2054 |
+
@dataclasses.dataclass
|
| 2055 |
+
class OptimizationContext:
|
| 2056 |
+
key: ClassVar[str] = "opt_ctx"
|
| 2057 |
+
|
| 2058 |
+
dtype: Optional[torch.dtype] = None
|
| 2059 |
+
ops_name: str = ""
|
| 2060 |
+
|
| 2061 |
+
|
| 2062 |
+
@functools.lru_cache(None)
|
| 2063 |
+
def jinja2_env():
|
| 2064 |
+
try:
|
| 2065 |
+
import jinja2
|
| 2066 |
+
|
| 2067 |
+
return jinja2.Environment(
|
| 2068 |
+
undefined=jinja2.StrictUndefined,
|
| 2069 |
+
)
|
| 2070 |
+
except ImportError:
|
| 2071 |
+
return None
|
| 2072 |
+
|
| 2073 |
+
|
| 2074 |
+
class KernelTemplate:
|
| 2075 |
+
"""
|
| 2076 |
+
Base class for defining kernel templates.
|
| 2077 |
+
|
| 2078 |
+
Children classes: TritonTemplate, CUDATemplate
|
| 2079 |
+
"""
|
| 2080 |
+
|
| 2081 |
+
@staticmethod
|
| 2082 |
+
def indent_except_first(source: str, num_indents: int, indents_spacing=4):
|
| 2083 |
+
lines = source.splitlines(True)
|
| 2084 |
+
if len(lines) > 1:
|
| 2085 |
+
lines[1:] = [
|
| 2086 |
+
(" " * indents_spacing * num_indents) + line for line in lines[1:]
|
| 2087 |
+
]
|
| 2088 |
+
return "".join(lines)
|
| 2089 |
+
|
| 2090 |
+
@staticmethod
|
| 2091 |
+
def _template_from_string(source):
|
| 2092 |
+
env = jinja2_env()
|
| 2093 |
+
if env is not None:
|
| 2094 |
+
env.filters["indent_except_first"] = KernelTemplate.indent_except_first
|
| 2095 |
+
from jinja2 import TemplateSyntaxError
|
| 2096 |
+
|
| 2097 |
+
class DetailedTemplateSyntaxError(TemplateSyntaxError):
|
| 2098 |
+
def __init__(self, original_error):
|
| 2099 |
+
super().__init__(
|
| 2100 |
+
original_error.message,
|
| 2101 |
+
original_error.lineno,
|
| 2102 |
+
original_error.name,
|
| 2103 |
+
original_error.filename,
|
| 2104 |
+
)
|
| 2105 |
+
self.original_error = original_error
|
| 2106 |
+
|
| 2107 |
+
def __str__(self):
|
| 2108 |
+
error_info = f"Error in template at line {self.lineno}\n"
|
| 2109 |
+
error_info += f"Error message: {self.message}\n"
|
| 2110 |
+
if hasattr(self.original_error, "source"):
|
| 2111 |
+
lines = self.original_error.source.split("\n")
|
| 2112 |
+
error_info += "Context:\n"
|
| 2113 |
+
start = max(0, self.lineno - 2)
|
| 2114 |
+
end = min(len(lines), self.lineno + 2)
|
| 2115 |
+
for i in range(start, end):
|
| 2116 |
+
if i == self.lineno - 1:
|
| 2117 |
+
error_info += f"{i+1}: --> {lines[i]}\n"
|
| 2118 |
+
if hasattr(self.original_error, "column"):
|
| 2119 |
+
error_info += (
|
| 2120 |
+
" "
|
| 2121 |
+
+ " " * (self.original_error.column - 1)
|
| 2122 |
+
+ "^\n"
|
| 2123 |
+
)
|
| 2124 |
+
else:
|
| 2125 |
+
error_info += f"{i+1}: {lines[i]}\n"
|
| 2126 |
+
return error_info
|
| 2127 |
+
|
| 2128 |
+
try:
|
| 2129 |
+
return env.from_string(source)
|
| 2130 |
+
except TemplateSyntaxError as e:
|
| 2131 |
+
raise DetailedTemplateSyntaxError(e) from e
|
| 2132 |
+
|
| 2133 |
+
return None
|
| 2134 |
+
|
| 2135 |
+
@staticmethod
|
| 2136 |
+
def _fake_get_dtype(fake_out):
|
| 2137 |
+
_get_dtype_real = V.graph.get_dtype
|
| 2138 |
+
|
| 2139 |
+
def get_dtype(name):
|
| 2140 |
+
if name == fake_out.get_name():
|
| 2141 |
+
return fake_out.get_dtype()
|
| 2142 |
+
return _get_dtype_real(name)
|
| 2143 |
+
|
| 2144 |
+
return get_dtype
|
| 2145 |
+
|
| 2146 |
+
def __init__(self, name: str):
|
| 2147 |
+
self.name = name
|
| 2148 |
+
|
| 2149 |
+
def maybe_append_choice(self, choices, **kwargs):
|
| 2150 |
+
"""
|
| 2151 |
+
Maybe generates a new ChoiceCaller and appends it into existing choices.
|
| 2152 |
+
|
| 2153 |
+
choices: A list of ChoiceCallers.
|
| 2154 |
+
kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
|
| 2155 |
+
"""
|
| 2156 |
+
|
| 2157 |
+
try:
|
| 2158 |
+
choices.append(self.generate(**kwargs))
|
| 2159 |
+
except NotImplementedError as e:
|
| 2160 |
+
pass
|
| 2161 |
+
|
| 2162 |
+
def generate(self, **kwargs) -> "torch._inductor.ir.ChoiceCaller":
|
| 2163 |
+
"""
|
| 2164 |
+
Generates a ChoiceCaller instance from the given arguments.
|
| 2165 |
+
"""
|
| 2166 |
+
|
| 2167 |
+
raise NotImplementedError
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_gemm_template.py
ADDED
|
@@ -0,0 +1,1043 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import contextlib
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
from functools import lru_cache
|
| 6 |
+
from typing import Any, Callable, cast, List, Optional, Set, Union
|
| 7 |
+
from unittest.mock import patch
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.utils
|
| 11 |
+
|
| 12 |
+
from ..._dynamo.utils import counters
|
| 13 |
+
from .. import config, ir, lowering as L
|
| 14 |
+
from ..kernel.mm_common import mm_args
|
| 15 |
+
from ..select_algorithm import DataProcessorTemplateWrapper
|
| 16 |
+
from ..utils import cache_on_self, has_free_symbols, parallel_num_threads
|
| 17 |
+
from ..virtualized import ops, V
|
| 18 |
+
from .cpp import get_export_declaration
|
| 19 |
+
from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm, LayoutType
|
| 20 |
+
from .cpp_template import CppTemplate
|
| 21 |
+
from .cpp_template_kernel import CppTemplateKernel
|
| 22 |
+
from .cpp_utils import (
|
| 23 |
+
create_epilogue_with_attr,
|
| 24 |
+
DTYPE_TO_CPP,
|
| 25 |
+
GemmBlocking,
|
| 26 |
+
get_gemm_template_output_and_compute_dtype,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
log = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
GEMM_TEMPLATE = r"""
|
| 33 |
+
{{template.header().getvalue()}}
|
| 34 |
+
|
| 35 |
+
{{micro_gemm.codegen_define(kernel)}}
|
| 36 |
+
|
| 37 |
+
{%- if x_scale is not none %}
|
| 38 |
+
{%- set kernel_args = {"X": X, "W": W, "inp": inp, "x_scale": x_scale, "x_zp": x_zp, "w_scale": w_scale, "w_zp": w_zp,} %}
|
| 39 |
+
{%- else %}
|
| 40 |
+
{%- set kernel_args = {"X": X, "W": W, "inp": inp} %}
|
| 41 |
+
{%- endif %}
|
| 42 |
+
|
| 43 |
+
extern "C" {{export_declaration}}
|
| 44 |
+
{{kernel.def_kernel(inputs=kernel_args, outputs={"Y": Y}, aliases=aliases)}}
|
| 45 |
+
{
|
| 46 |
+
{{kernel.maybe_codegen_profile()}}
|
| 47 |
+
constexpr int64_t num_threads = {{num_threads}};
|
| 48 |
+
constexpr int64_t N = {{N}};
|
| 49 |
+
constexpr int64_t K = {{K}};
|
| 50 |
+
constexpr int64_t Mr = {{micro_gemm.register_blocking.block_m}};
|
| 51 |
+
constexpr int64_t Nr = {{micro_gemm.register_blocking.block_n}};
|
| 52 |
+
constexpr int64_t Kr = {{micro_gemm.register_blocking.block_k}};
|
| 53 |
+
constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
|
| 54 |
+
constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
|
| 55 |
+
|
| 56 |
+
{%- if is_dynamic_M %}
|
| 57 |
+
const int64_t M = {{kernel.size(GemmOut, 0)}};
|
| 58 |
+
const int64_t Mr_blocks = (M + Mr - 1) / Mr;
|
| 59 |
+
{%- if num_threads > 1 %}
|
| 60 |
+
int64_t Mt_blocks, Nt_blocks, Kt_blocks;
|
| 61 |
+
mm_get_thread_blocking(num_threads, {{config.cpp.gemm_max_k_slices}}, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks);
|
| 62 |
+
{%- else %}
|
| 63 |
+
const auto Mt_blocks = Mr_blocks;
|
| 64 |
+
const auto Nt_blocks = Nr_blocks;
|
| 65 |
+
const auto Kt_blocks = Kr_blocks;
|
| 66 |
+
{%- endif %}
|
| 67 |
+
int64_t Mc_blocks, Nc_blocks, Kc_blocks;
|
| 68 |
+
uint32_t L1_cache_size = {{L1_cache_size}};
|
| 69 |
+
uint32_t L2_cache_size = {{L2_cache_size}};
|
| 70 |
+
mm_get_cache_blocking<{{kernel.dtype(X)}}, {{kernel.dtype(W)}}>(
|
| 71 |
+
num_threads,
|
| 72 |
+
M,
|
| 73 |
+
N,
|
| 74 |
+
K,
|
| 75 |
+
Mr,
|
| 76 |
+
Nr,
|
| 77 |
+
Kr,
|
| 78 |
+
Mt_blocks,
|
| 79 |
+
Nt_blocks,
|
| 80 |
+
Kt_blocks,
|
| 81 |
+
Mc_blocks,
|
| 82 |
+
Nc_blocks,
|
| 83 |
+
Kc_blocks,
|
| 84 |
+
L1_cache_size,
|
| 85 |
+
L2_cache_size
|
| 86 |
+
);
|
| 87 |
+
const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
|
| 88 |
+
const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
|
| 89 |
+
const int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
|
| 90 |
+
{%- else %}
|
| 91 |
+
constexpr int64_t M = {{kernel.size(GemmOut, 0)}};
|
| 92 |
+
constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
|
| 93 |
+
constexpr int64_t Mt_blocks = {{template.thread_blocking().block_m}};
|
| 94 |
+
constexpr int64_t Nt_blocks = {{template.thread_blocking().block_n}};
|
| 95 |
+
constexpr int64_t Kt_blocks = {{template.thread_blocking().block_k}};
|
| 96 |
+
constexpr int64_t Mc_blocks = {{template.cache_blocking().block_m}};
|
| 97 |
+
constexpr int64_t Nc_blocks = {{template.cache_blocking().block_n}};
|
| 98 |
+
constexpr int64_t Kc_blocks = {{template.cache_blocking().block_k}};
|
| 99 |
+
constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
|
| 100 |
+
constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
|
| 101 |
+
constexpr int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
|
| 102 |
+
{%- endif %}
|
| 103 |
+
|
| 104 |
+
// make sure all partitions are assigned
|
| 105 |
+
{{kernel.assert_function}}(
|
| 106 |
+
Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= Mr_blocks * Nr_blocks * Kr_blocks,
|
| 107 |
+
"Not all partitions are assigned."
|
| 108 |
+
);
|
| 109 |
+
|
| 110 |
+
{%- if maybe_k_slicing %}
|
| 111 |
+
std::unique_ptr<std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[]> local_buf_ptrs;
|
| 112 |
+
if (num_k_slices > 1) {
|
| 113 |
+
local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_k_slices]);
|
| 114 |
+
}
|
| 115 |
+
{%- endif %}
|
| 116 |
+
|
| 117 |
+
{%- if num_threads > 1 %}
|
| 118 |
+
#pragma omp parallel num_threads({{num_threads}})
|
| 119 |
+
{
|
| 120 |
+
const int tid = omp_get_thread_num();
|
| 121 |
+
int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end;
|
| 122 |
+
mm_get_thread_blocks(
|
| 123 |
+
tid, Mr_blocks, Nr_blocks, Kr_blocks, Mt_blocks, Nt_blocks, Kt_blocks,
|
| 124 |
+
m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end);
|
| 125 |
+
{%- if maybe_k_slicing %}
|
| 126 |
+
const int64_t k_group_id = tid / num_k_slices;
|
| 127 |
+
const int64_t k_slice_id = tid % num_k_slices;
|
| 128 |
+
{%- endif %}
|
| 129 |
+
{%- else %}
|
| 130 |
+
{
|
| 131 |
+
const int tid = 0;
|
| 132 |
+
const int64_t m_block_start = 0;
|
| 133 |
+
const int64_t m_block_end = Mr_blocks;
|
| 134 |
+
const int64_t n_block_start = 0;
|
| 135 |
+
const int64_t n_block_end = Nr_blocks;
|
| 136 |
+
const int64_t k_block_start = 0;
|
| 137 |
+
const int64_t k_block_end = Kr_blocks;
|
| 138 |
+
{%- endif %}
|
| 139 |
+
{{ micro_gemm.codegen_init(kernel) }}
|
| 140 |
+
{%- if use_local_acc %}
|
| 141 |
+
{%- set acc_buf_name = "local_acc_buf" %}
|
| 142 |
+
{{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }}
|
| 143 |
+
{%- endif %}
|
| 144 |
+
for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
|
| 145 |
+
const int64_t m_start = mc * Mr;
|
| 146 |
+
const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
|
| 147 |
+
const int64_t m_size = m_end - m_start;
|
| 148 |
+
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
|
| 149 |
+
const int64_t n_start = nc * Nr;
|
| 150 |
+
const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
|
| 151 |
+
const int64_t n_size = n_end - n_start;
|
| 152 |
+
// NB: assume we pad N, nc_block_end won't exceed padded N here.
|
| 153 |
+
const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
|
| 154 |
+
{%- if use_local_acc %}
|
| 155 |
+
{%- set acc = kernel.local_buffers[acc_buf_name] %}
|
| 156 |
+
{{ kernel.reinit_buffer_if_null(acc_buf_name) }}
|
| 157 |
+
{%- else %}
|
| 158 |
+
{%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_end")]) %}
|
| 159 |
+
{%- endif %}
|
| 160 |
+
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
|
| 161 |
+
int64_t k_start = kc * Kr;
|
| 162 |
+
int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
|
| 163 |
+
{%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %}
|
| 164 |
+
for (int64_t nci = nc; nci < nc_block_end; nci++) {
|
| 165 |
+
{%- set acc_slice = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %}
|
| 166 |
+
{%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %}
|
| 167 |
+
{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
|
| 168 |
+
if (kc == k_block_start) {
|
| 169 |
+
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=False)|indent(28, false) }}
|
| 170 |
+
} else {
|
| 171 |
+
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=True)|indent(28, false) }}
|
| 172 |
+
}
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
+
{%- if maybe_k_slicing %}
|
| 176 |
+
if (num_k_slices > 1) {
|
| 177 |
+
const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
|
| 178 |
+
local_buf_ptrs[mxn_cache_block_id * num_k_slices + k_slice_id].reset({{ kernel.release_buffer(acc_buf_name) }});
|
| 179 |
+
} else
|
| 180 |
+
{%- endif %}
|
| 181 |
+
{
|
| 182 |
+
{%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_end")]) %}
|
| 183 |
+
{%- set tile_acc = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("0", "n_end - n_start")]) %}
|
| 184 |
+
{{ kernel.store_output(
|
| 185 |
+
tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
|
| 186 |
+
)|indent(20, false)
|
| 187 |
+
}}
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
{%- if maybe_k_slicing %}
|
| 192 |
+
if (num_k_slices > 1) {
|
| 193 |
+
#pragma omp barrier
|
| 194 |
+
for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
|
| 195 |
+
// We slice M-dim and each thread in the k-slicing group works on a slice
|
| 196 |
+
const int64_t m_start_unsliced = mc * Mr;
|
| 197 |
+
const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
|
| 198 |
+
const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced;
|
| 199 |
+
const int64_t m_slice_size = (m_size_unsliced + num_k_slices - 1) / num_k_slices;
|
| 200 |
+
const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced);
|
| 201 |
+
const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced);
|
| 202 |
+
const int64_t m_size = m_end - m_start;
|
| 203 |
+
const int64_t m_offset = m_start - m_start_unsliced;
|
| 204 |
+
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
|
| 205 |
+
const int64_t n_start = nc * Nr;
|
| 206 |
+
const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
|
| 207 |
+
const int64_t n_size = n_end - n_start;
|
| 208 |
+
const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
|
| 209 |
+
auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_k_slices].get();
|
| 210 |
+
for (int64_t other_slice = 1; other_slice < num_k_slices; other_slice++) {
|
| 211 |
+
auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_k_slices + other_slice].get();
|
| 212 |
+
for (int64_t m = m_offset; m < m_offset + m_size; m++) {
|
| 213 |
+
#pragma omp simd
|
| 214 |
+
for (int64_t n = 0; n < n_size; n++) {
|
| 215 |
+
{{acc_buf_name}}[m*Nr + n] += other_acc[m*Nr + n];
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
{%- set tile_acc_m_slice = kernel.slice_nd(tile_acc, [("m_offset", "m_offset + m_end - m_start"), ()]) %}
|
| 220 |
+
{{ kernel.store_output(
|
| 221 |
+
tile_Y, tile_acc_m_slice, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
|
| 222 |
+
)|indent(20, false)
|
| 223 |
+
}}
|
| 224 |
+
}
|
| 225 |
+
}
|
| 226 |
+
}
|
| 227 |
+
{%- endif %}
|
| 228 |
+
{{ micro_gemm.codegen_finalize(kernel) }}
|
| 229 |
+
}
|
| 230 |
+
}
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def get_padded_n(n, block_n):
|
| 235 |
+
return (n + block_n - 1) // block_n * block_n
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class CppPackedGemmTemplate(CppTemplate):
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
input_nodes,
|
| 242 |
+
layout: ir.Layout,
|
| 243 |
+
num_threads: int,
|
| 244 |
+
register_blocking: GemmBlocking,
|
| 245 |
+
beta=1,
|
| 246 |
+
alpha=1,
|
| 247 |
+
has_bias=False,
|
| 248 |
+
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
| 249 |
+
) -> None:
|
| 250 |
+
assert layout.dtype in [torch.float, torch.bfloat16, torch.half, torch.uint8]
|
| 251 |
+
super().__init__(
|
| 252 |
+
"packed_gemm",
|
| 253 |
+
input_nodes,
|
| 254 |
+
layout,
|
| 255 |
+
num_threads,
|
| 256 |
+
epilogue_creator=epilogue_creator,
|
| 257 |
+
)
|
| 258 |
+
self.beta = beta
|
| 259 |
+
self.alpha = alpha
|
| 260 |
+
self.has_bias = has_bias
|
| 261 |
+
self.register_blocking = register_blocking
|
| 262 |
+
m, n = layout.size
|
| 263 |
+
_, k = input_nodes[0].get_size()
|
| 264 |
+
self.m, self.n, self.k = m, n, k
|
| 265 |
+
self.padded_n = get_padded_n(n, self.register_blocking.block_n)
|
| 266 |
+
self.is_dynamic_M = has_free_symbols((m,))
|
| 267 |
+
|
| 268 |
+
@cache_on_self
|
| 269 |
+
def thread_blocking(self) -> GemmBlocking:
|
| 270 |
+
"""
|
| 271 |
+
NOTE [Thread blocking in Cpp GEMM]
|
| 272 |
+
We use simple heuristics to decide the thread blocking:
|
| 273 |
+
1. Make sure all threads are occupied as much as possible.
|
| 274 |
+
2. For (m, n) blocks, favor more square-sized thread blocks for better data reuse.
|
| 275 |
+
3. If (m, n) blocks cannot occupy all the threads, we consider k-slicing.
|
| 276 |
+
TODO(jgong5): allow tuning various blocking options
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
@lru_cache(maxsize=100)
|
| 280 |
+
def get_factors(number):
|
| 281 |
+
factors = []
|
| 282 |
+
for i in range(int(number**0.5), 0, -1):
|
| 283 |
+
if number % i == 0:
|
| 284 |
+
factors.append(number // i)
|
| 285 |
+
factors.append(i)
|
| 286 |
+
return factors
|
| 287 |
+
|
| 288 |
+
def get_blocking(m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks):
|
| 289 |
+
thread_block_k = math.ceil(k_blocks / k_factor)
|
| 290 |
+
thread_block_n = math.ceil(n_blocks / n_factor)
|
| 291 |
+
thread_block_m = math.ceil(m_blocks / m_factor)
|
| 292 |
+
return GemmBlocking(thread_block_m, thread_block_n, thread_block_k)
|
| 293 |
+
|
| 294 |
+
assert (
|
| 295 |
+
not self.is_dynamic_M
|
| 296 |
+
), "Unable to determine thread blocking for dynamic M."
|
| 297 |
+
register_blocking = self.register_blocking
|
| 298 |
+
m_blocks = math.ceil(self.m / register_blocking.block_m)
|
| 299 |
+
n_blocks = math.ceil(self.n / register_blocking.block_n)
|
| 300 |
+
k_blocks = math.ceil(self.k / register_blocking.block_k)
|
| 301 |
+
factors = get_factors(self.num_threads)
|
| 302 |
+
assert len(factors) > 0
|
| 303 |
+
|
| 304 |
+
if config.cpp.gemm_thread_factors is not None:
|
| 305 |
+
factors = [int(i) for i in config.cpp.gemm_thread_factors.split(",")]
|
| 306 |
+
assert len(factors) == 3
|
| 307 |
+
assert math.prod(factors) == self.num_threads
|
| 308 |
+
return get_blocking(
|
| 309 |
+
factors[0], factors[1], factors[2], m_blocks, n_blocks, k_blocks
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# we favor square-sized thread blocks for good data reuse
|
| 313 |
+
def get_better_blocking(blocking, best_blocking):
|
| 314 |
+
if best_blocking is None:
|
| 315 |
+
best_blocking = blocking
|
| 316 |
+
else:
|
| 317 |
+
block_m_size = blocking.block_m * register_blocking.block_m
|
| 318 |
+
block_n_size = blocking.block_n * register_blocking.block_n
|
| 319 |
+
best_block_m_size = best_blocking.block_m * register_blocking.block_m
|
| 320 |
+
best_block_n_size = best_blocking.block_n * register_blocking.block_n
|
| 321 |
+
if blocking.block_k > best_blocking.block_k:
|
| 322 |
+
best_blocking = blocking
|
| 323 |
+
elif (
|
| 324 |
+
blocking.block_k == best_blocking.block_k
|
| 325 |
+
and block_m_size + block_n_size
|
| 326 |
+
< best_block_m_size + best_block_n_size
|
| 327 |
+
):
|
| 328 |
+
best_blocking = blocking
|
| 329 |
+
return best_blocking
|
| 330 |
+
|
| 331 |
+
best_blocking = None
|
| 332 |
+
# check if we can have a thread-blocking to occupy all threads without k-slicing
|
| 333 |
+
for n_factor in factors:
|
| 334 |
+
m_factor = self.num_threads // n_factor
|
| 335 |
+
if n_blocks >= n_factor and m_blocks >= m_factor:
|
| 336 |
+
blocking = get_blocking(
|
| 337 |
+
m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks
|
| 338 |
+
)
|
| 339 |
+
best_blocking = get_better_blocking(blocking, best_blocking)
|
| 340 |
+
|
| 341 |
+
if best_blocking is None:
|
| 342 |
+
for k_factor in factors:
|
| 343 |
+
if k_blocks >= k_factor and (
|
| 344 |
+
config.cpp.gemm_max_k_slices == 0
|
| 345 |
+
or k_factor <= config.cpp.gemm_max_k_slices
|
| 346 |
+
):
|
| 347 |
+
n_factors = get_factors(self.num_threads // k_factor)
|
| 348 |
+
for n_factor in n_factors:
|
| 349 |
+
m_factor = (self.num_threads // k_factor) // n_factor
|
| 350 |
+
if n_blocks >= n_factor and m_blocks >= m_factor:
|
| 351 |
+
blocking = get_blocking(
|
| 352 |
+
m_factor,
|
| 353 |
+
n_factor,
|
| 354 |
+
k_factor,
|
| 355 |
+
m_blocks,
|
| 356 |
+
n_blocks,
|
| 357 |
+
k_blocks,
|
| 358 |
+
)
|
| 359 |
+
best_blocking = get_better_blocking(blocking, best_blocking)
|
| 360 |
+
|
| 361 |
+
if best_blocking is None:
|
| 362 |
+
for n_factor in factors:
|
| 363 |
+
m_factor = self.num_threads // n_factor
|
| 364 |
+
if n_blocks >= n_factor or m_blocks >= m_factor:
|
| 365 |
+
blocking = get_blocking(
|
| 366 |
+
m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks
|
| 367 |
+
)
|
| 368 |
+
best_blocking = get_better_blocking(blocking, best_blocking)
|
| 369 |
+
|
| 370 |
+
assert best_blocking is not None
|
| 371 |
+
return best_blocking
|
| 372 |
+
|
| 373 |
+
@cache_on_self
|
| 374 |
+
def cache_blocking(self) -> GemmBlocking:
|
| 375 |
+
def get_cache_blocking(register_blocking, thread_blocking):
|
| 376 |
+
Mr = register_blocking.block_m
|
| 377 |
+
Nr = register_blocking.block_n
|
| 378 |
+
Kr = register_blocking.block_k
|
| 379 |
+
|
| 380 |
+
Mt_blocks = thread_blocking.block_m
|
| 381 |
+
Nt_blocks = thread_blocking.block_n
|
| 382 |
+
Kt_blocks = thread_blocking.block_k
|
| 383 |
+
|
| 384 |
+
if config.cpp.gemm_cache_blocking is not None:
|
| 385 |
+
blockings = [int(i) for i in config.cpp.gemm_cache_blocking.split(",")]
|
| 386 |
+
assert len(blockings) == 3
|
| 387 |
+
Mc_blocks, Nc_blocks, Kc_blocks = blockings
|
| 388 |
+
return (
|
| 389 |
+
min(Mc_blocks, Mt_blocks),
|
| 390 |
+
min(Nc_blocks, Nt_blocks),
|
| 391 |
+
min(Kc_blocks, Kt_blocks),
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
# The ratios below are empirically determined to decide
|
| 395 |
+
# the effective sizes of L1 and L2.
|
| 396 |
+
# TODO: tune the factor here
|
| 397 |
+
L1_limit_factor = 0.8
|
| 398 |
+
L2_limit_factor = 0.5
|
| 399 |
+
|
| 400 |
+
L1_cache_size = (
|
| 401 |
+
torch._C._cpu._L1d_cache_size()
|
| 402 |
+
) # per core cache size in Bytes
|
| 403 |
+
assert (
|
| 404 |
+
L1_cache_size > 0
|
| 405 |
+
), f"Expect L1_cache_size > 0 but got {L1_cache_size}"
|
| 406 |
+
L1 = L1_cache_size * L1_limit_factor
|
| 407 |
+
|
| 408 |
+
L2_cache_size = (
|
| 409 |
+
torch._C._cpu._L2_cache_size()
|
| 410 |
+
) # per core cache size in Bytes
|
| 411 |
+
assert (
|
| 412 |
+
L2_cache_size > 0
|
| 413 |
+
), f"Expect L2_cache_size > 0 but got {L2_cache_size}"
|
| 414 |
+
L2 = L2_cache_size * L2_limit_factor
|
| 415 |
+
|
| 416 |
+
def get_num_byte(dtype):
|
| 417 |
+
return torch.tensor([], dtype=dtype).element_size()
|
| 418 |
+
|
| 419 |
+
num_byte_A = get_num_byte(self.input_nodes[0].get_dtype())
|
| 420 |
+
num_byte_B = get_num_byte(self.input_nodes[1].get_dtype())
|
| 421 |
+
|
| 422 |
+
# NOTE [CPP GEMM Cache Blocking Algorithm]
|
| 423 |
+
# Our overall strategy is to
|
| 424 |
+
# 1) Make cache blocks of B L1-reside and reused by multiple rows of A, i.e. Mc.
|
| 425 |
+
# Here, B is Kc x Nr where Nr is a single register block. We use L1 size to
|
| 426 |
+
# decide Kc. We want to make Mc large enough to better reuse B.
|
| 427 |
+
# 2) Make cache blocks of A L2-reside, which would limit Mc. We want to reuse A
|
| 428 |
+
# along N, where we have two sub-strategies (see notes below) to decide Mc and Nc.
|
| 429 |
+
|
| 430 |
+
# Step 1: Decide Kc assuming B block is L1-reside.
|
| 431 |
+
size_cache_B = Kr * Kt_blocks * Nr * num_byte_B
|
| 432 |
+
Kc_blocks = Kt_blocks
|
| 433 |
+
if size_cache_B > L1:
|
| 434 |
+
Kc_blocks = math.floor(L1 / (Kr * Nr * num_byte_B))
|
| 435 |
+
|
| 436 |
+
# Step 2: Decide Mc assuming A block is L2-reside.
|
| 437 |
+
min_Mc_ratio = 2 # TODO(jgong5): something to tune?
|
| 438 |
+
min_Mc_blocks = math.ceil(min_Mc_ratio * Mr / Nr)
|
| 439 |
+
assert min_Mc_blocks >= 1
|
| 440 |
+
Kt_bytes = Kt_blocks * Kr * num_byte_A
|
| 441 |
+
if min_Mc_blocks * Mr * Kt_bytes < L2:
|
| 442 |
+
# Strategy 1: A (Mc x Kt) resides in L2 and reused by all Nt
|
| 443 |
+
# when Nc_blocks is kept 1. Mc should be large enough (>= min_Mc_blocks)
|
| 444 |
+
# to reuse B (Kc x Nr) in L1. This makes C (Mc x Nr) small enough to reside
|
| 445 |
+
# in L1.
|
| 446 |
+
Mc_blocks = min(Mt_blocks, math.floor(L2 / (Mr * Kt_bytes)))
|
| 447 |
+
Nc_blocks = 1
|
| 448 |
+
else:
|
| 449 |
+
# Strategy 2: Kt is too large to hold A (Mc x Kt) in L2, we reuse
|
| 450 |
+
# A (Mc x Kc) in L2 by B (Kc x Nc). C (Mc x Nc) resides in L2.
|
| 451 |
+
Mc_blocks = Mt_blocks
|
| 452 |
+
Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks)
|
| 453 |
+
Nc_bytes = Nc_blocks * Nr * 4 # assume C or acc is float32/int32
|
| 454 |
+
Kc_bytes = Kc_blocks * Kr * num_byte_A
|
| 455 |
+
if Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2:
|
| 456 |
+
# The following is the solution for 4*Mc*Nc + Mc*Kc_bytes = L2,
|
| 457 |
+
# assuming Mc == Nc for good data reuse.
|
| 458 |
+
M_max = (math.sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8
|
| 459 |
+
if M_max < Mc_blocks * Mr:
|
| 460 |
+
Mc_blocks = math.floor(M_max / Mr)
|
| 461 |
+
Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks)
|
| 462 |
+
|
| 463 |
+
return Mc_blocks, Nc_blocks, Kc_blocks
|
| 464 |
+
|
| 465 |
+
assert (
|
| 466 |
+
not self.is_dynamic_M
|
| 467 |
+
), "Unable to determine cache blocking for dynamic M."
|
| 468 |
+
register_blocking = self.register_blocking
|
| 469 |
+
thread_blocking = self.thread_blocking()
|
| 470 |
+
|
| 471 |
+
return GemmBlocking(*get_cache_blocking(register_blocking, thread_blocking))
|
| 472 |
+
|
| 473 |
+
def log_blockings(self):
|
| 474 |
+
log.debug(f"Register blocking: {self.register_blocking}") # noqa: G004
|
| 475 |
+
if self.is_dynamic_M:
|
| 476 |
+
# thread and cache blockings are determined at runtime for dynamic shapes
|
| 477 |
+
return
|
| 478 |
+
log.debug(f"Cache blocking: {self.cache_blocking()}") # noqa: G004
|
| 479 |
+
thread_blocking = self.thread_blocking()
|
| 480 |
+
log.debug(f"Thread blocking: {thread_blocking}") # noqa: G004
|
| 481 |
+
|
| 482 |
+
def get_occupancy():
|
| 483 |
+
m_blocks = math.ceil(self.m / self.register_blocking.block_m)
|
| 484 |
+
n_blocks = math.ceil(self.n / self.register_blocking.block_n)
|
| 485 |
+
k_blocks = math.ceil(self.k / self.register_blocking.block_k)
|
| 486 |
+
m = math.ceil(m_blocks / thread_blocking.block_m)
|
| 487 |
+
n = math.ceil(n_blocks / thread_blocking.block_n)
|
| 488 |
+
k = math.ceil(k_blocks / thread_blocking.block_k)
|
| 489 |
+
return (m, n, k)
|
| 490 |
+
|
| 491 |
+
log.debug(
|
| 492 |
+
f"Number of threads: {self.num_threads}, occupancy: {get_occupancy()}" # noqa: G004
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
def maybe_k_slicing(self):
|
| 496 |
+
if self.num_threads == 1:
|
| 497 |
+
return False
|
| 498 |
+
if self.is_dynamic_M:
|
| 499 |
+
# TODO(jgong5): perhaps use size hint to decide?
|
| 500 |
+
return True
|
| 501 |
+
register_blocking = self.register_blocking
|
| 502 |
+
k_blocks = math.ceil(self.k / register_blocking.block_k)
|
| 503 |
+
thread_blocking = self.thread_blocking()
|
| 504 |
+
return k_blocks > thread_blocking.block_k
|
| 505 |
+
|
| 506 |
+
@staticmethod
|
| 507 |
+
def add_choices(
|
| 508 |
+
choices,
|
| 509 |
+
layout,
|
| 510 |
+
input_nodes,
|
| 511 |
+
beta=1,
|
| 512 |
+
alpha=1,
|
| 513 |
+
has_bias=False,
|
| 514 |
+
trans_w=False,
|
| 515 |
+
input_indices=None,
|
| 516 |
+
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
| 517 |
+
):
|
| 518 |
+
if input_indices is None:
|
| 519 |
+
input_indices = list(range(len(input_nodes)))
|
| 520 |
+
|
| 521 |
+
def reorder_and_filter(inputs, layout_or_out):
|
| 522 |
+
if has_bias:
|
| 523 |
+
assert len(input_indices) >= 3
|
| 524 |
+
# Assume the input order is [inp, x, w] and we reorder it to [x, w, inp]
|
| 525 |
+
inp_idx = input_indices[0]
|
| 526 |
+
x_idx = input_indices[1]
|
| 527 |
+
w_idx = input_indices[2]
|
| 528 |
+
return [
|
| 529 |
+
inputs[x_idx],
|
| 530 |
+
inputs[w_idx],
|
| 531 |
+
inputs[inp_idx],
|
| 532 |
+
*[inputs[idx] for idx in input_indices[3:]],
|
| 533 |
+
], layout_or_out
|
| 534 |
+
else:
|
| 535 |
+
assert len(input_indices) >= 2
|
| 536 |
+
return [inputs[idx] for idx in input_indices], layout_or_out
|
| 537 |
+
|
| 538 |
+
def maybe_to_dense(inputs, layout_or_out):
|
| 539 |
+
new_inputs = list(inputs)
|
| 540 |
+
if isinstance(inputs[1], torch.Tensor):
|
| 541 |
+
W = inputs[1]
|
| 542 |
+
new_inputs[1] = W.to_dense() if W.is_mkldnn else W
|
| 543 |
+
return new_inputs, layout_or_out
|
| 544 |
+
|
| 545 |
+
def normalize_shapes(inputs, layout_or_out):
|
| 546 |
+
if not trans_w:
|
| 547 |
+
return inputs, layout_or_out
|
| 548 |
+
new_inputs = list(inputs)
|
| 549 |
+
X = inputs[0]
|
| 550 |
+
W = inputs[1]
|
| 551 |
+
B = inputs[2] if has_bias else None
|
| 552 |
+
if isinstance(W, ir.IRNode):
|
| 553 |
+
if trans_w:
|
| 554 |
+
if not isinstance(W, ir.TensorBox):
|
| 555 |
+
W = ir.TensorBox(W)
|
| 556 |
+
W = L.permute(W, [1, 0])
|
| 557 |
+
else:
|
| 558 |
+
if trans_w:
|
| 559 |
+
assert isinstance(W, torch.Tensor)
|
| 560 |
+
W = W.transpose(0, 1)
|
| 561 |
+
if B is not None:
|
| 562 |
+
if isinstance(B, ir.IRNode):
|
| 563 |
+
if not isinstance(B, ir.TensorBox):
|
| 564 |
+
B = ir.TensorBox(B)
|
| 565 |
+
B = L.expand(B, (X.get_size()[0], B.get_size()[-1]))
|
| 566 |
+
else:
|
| 567 |
+
assert isinstance(B, torch.Tensor)
|
| 568 |
+
B = B.expand(X.shape[0], B.shape[-1])
|
| 569 |
+
new_inputs[1] = W
|
| 570 |
+
if B is not None:
|
| 571 |
+
new_inputs[2] = B
|
| 572 |
+
return new_inputs, layout_or_out
|
| 573 |
+
|
| 574 |
+
# TODO(jgong5): decide proper number of threads per problem size
|
| 575 |
+
num_threads = parallel_num_threads()
|
| 576 |
+
new_inputs, _ = normalize_shapes(
|
| 577 |
+
*maybe_to_dense(*reorder_and_filter(input_nodes, layout))
|
| 578 |
+
)
|
| 579 |
+
m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1])
|
| 580 |
+
output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
|
| 581 |
+
new_inputs[0].get_dtype()
|
| 582 |
+
)
|
| 583 |
+
micro_gemm = create_micro_gemm(
|
| 584 |
+
"micro_gemm",
|
| 585 |
+
m,
|
| 586 |
+
n,
|
| 587 |
+
k,
|
| 588 |
+
input_dtype=new_inputs[0].get_dtype(),
|
| 589 |
+
input2_dtype=new_inputs[1].get_dtype(),
|
| 590 |
+
output_dtype=output_dtype,
|
| 591 |
+
compute_dtype=compute_dtype,
|
| 592 |
+
alpha=alpha,
|
| 593 |
+
num_threads=num_threads,
|
| 594 |
+
)
|
| 595 |
+
assert micro_gemm is not None
|
| 596 |
+
_, block_n, _ = micro_gemm.register_blocking
|
| 597 |
+
padded_n = get_padded_n(n, block_n)
|
| 598 |
+
|
| 599 |
+
def pack_weight(inputs, layout_or_out):
|
| 600 |
+
W = inputs[1]
|
| 601 |
+
new_inputs = list(inputs)
|
| 602 |
+
blocked_w: Union[ir.IRNode, torch.Tensor] = W
|
| 603 |
+
if isinstance(W, ir.IRNode):
|
| 604 |
+
new_size = [padded_n // block_n, k, block_n]
|
| 605 |
+
blocked_w = ir.Buffer(
|
| 606 |
+
W.get_name(), # Borrow the registered buffer name
|
| 607 |
+
ir.FixedLayout(
|
| 608 |
+
W.get_device(),
|
| 609 |
+
W.get_dtype(),
|
| 610 |
+
new_size,
|
| 611 |
+
ir.FlexibleLayout.contiguous_strides(new_size),
|
| 612 |
+
0,
|
| 613 |
+
),
|
| 614 |
+
)
|
| 615 |
+
else:
|
| 616 |
+
blocked_w = (
|
| 617 |
+
torch.nn.functional.pad(W, (0, padded_n - n))
|
| 618 |
+
.reshape(k, padded_n // block_n, block_n)
|
| 619 |
+
.transpose(0, 1)
|
| 620 |
+
.contiguous()
|
| 621 |
+
)
|
| 622 |
+
if micro_gemm.get_b_layout() != LayoutType.NORMAL:
|
| 623 |
+
layout_str = (
|
| 624 |
+
"VNNI4"
|
| 625 |
+
if micro_gemm.get_b_layout() == LayoutType.VNNI4
|
| 626 |
+
else "VNNI2"
|
| 627 |
+
)
|
| 628 |
+
assert micro_gemm.get_b_layout() in [
|
| 629 |
+
LayoutType.VNNI2,
|
| 630 |
+
LayoutType.VNNI4,
|
| 631 |
+
], f"We only support {layout_str} for now"
|
| 632 |
+
vnni_size = (
|
| 633 |
+
4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2
|
| 634 |
+
)
|
| 635 |
+
assert (
|
| 636 |
+
k % vnni_size == 0
|
| 637 |
+
), f"k should be divisible by vnni_size for {layout_str} layout"
|
| 638 |
+
blocked_w = (
|
| 639 |
+
blocked_w.view(
|
| 640 |
+
padded_n // block_n, k // vnni_size, vnni_size, block_n
|
| 641 |
+
)
|
| 642 |
+
.transpose(-1, -2)
|
| 643 |
+
.contiguous()
|
| 644 |
+
.view(padded_n // block_n, k, block_n)
|
| 645 |
+
)
|
| 646 |
+
# normalize stride to be "contiguous_strides" per size
|
| 647 |
+
# this avoids the problems in L.view during template codegen
|
| 648 |
+
new_stride = [1]
|
| 649 |
+
for sz in reversed(blocked_w.shape[1:]):
|
| 650 |
+
new_stride.insert(0, new_stride[0] * sz)
|
| 651 |
+
blocked_w = blocked_w.as_strided(blocked_w.shape, new_stride)
|
| 652 |
+
new_inputs[1] = blocked_w
|
| 653 |
+
|
| 654 |
+
def _is_int8_gemm(inputs):
|
| 655 |
+
return (
|
| 656 |
+
isinstance(inputs[0], ir.IRNode)
|
| 657 |
+
and inputs[0].get_dtype() == torch.uint8
|
| 658 |
+
) or (
|
| 659 |
+
isinstance(inputs[0], torch.Tensor)
|
| 660 |
+
and inputs[0].dtype == torch.uint8
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
if _is_int8_gemm(new_inputs):
|
| 664 |
+
BCompensate = None
|
| 665 |
+
if isinstance(W, ir.IRNode):
|
| 666 |
+
BCompensate = V.graph.add_tensor_constant(
|
| 667 |
+
V.graph.constants[W.get_name() + "_BMatrixCompens"],
|
| 668 |
+
W.get_name() + "_BMatrixCompens",
|
| 669 |
+
)
|
| 670 |
+
else:
|
| 671 |
+
BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0) # type: ignore[assignment]
|
| 672 |
+
new_inputs.append(BCompensate)
|
| 673 |
+
return new_inputs, layout_or_out
|
| 674 |
+
|
| 675 |
+
def preprocessor(inputs, layout):
|
| 676 |
+
return pack_weight(
|
| 677 |
+
*normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout)))
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
def postprocessor(output):
|
| 681 |
+
if isinstance(output, ir.TensorBox):
|
| 682 |
+
# prepack the weight as input to the template buffer
|
| 683 |
+
template_buffer = ir.InputsKernel.unwrap_storage_for_input(output)
|
| 684 |
+
assert isinstance(template_buffer, ir.CppTemplateBuffer)
|
| 685 |
+
new_input_nodes, _ = reorder_and_filter(input_nodes, layout)
|
| 686 |
+
|
| 687 |
+
W_node = new_input_nodes[1]
|
| 688 |
+
assert W_node.get_name() in V.graph.constants
|
| 689 |
+
W = V.graph.constants[W_node.get_name()]
|
| 690 |
+
new_input_nodes[1] = W
|
| 691 |
+
new_input_nodes, _ = pack_weight(
|
| 692 |
+
*normalize_shapes(*maybe_to_dense(new_input_nodes, layout))
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
# By using the new packed weight for the GEMM template, we can prune the
|
| 696 |
+
# old weight if it has no other users. This saves memory but makes the FX graph
|
| 697 |
+
# non-retraceable. To support retracing, we can add a repack node to the
|
| 698 |
+
# FX graph. For example:
|
| 699 |
+
# mkldnn._linear_pointwise <- repack_linear_wgt <- packed_wgt_for_template
|
| 700 |
+
W_tensor_users = 0
|
| 701 |
+
for node in reversed(V.graph.graph.nodes):
|
| 702 |
+
# Case may happen when the wgt tensor is used by more than 1 get_attr node
|
| 703 |
+
# https://github.com/pytorch/pytorch/issues/134998
|
| 704 |
+
if node.op == "get_attr" and hasattr(
|
| 705 |
+
V.graph.module, node.name
|
| 706 |
+
): # wgt might already be deleted
|
| 707 |
+
comp_tensor = getattr(V.graph.module, node.name)
|
| 708 |
+
if (
|
| 709 |
+
W.is_mkldnn == comp_tensor.is_mkldnn
|
| 710 |
+
and W.dtype == comp_tensor.dtype
|
| 711 |
+
and W.device == comp_tensor.device
|
| 712 |
+
and (
|
| 713 |
+
(
|
| 714 |
+
not W.is_mkldnn
|
| 715 |
+
and (
|
| 716 |
+
W.untyped_storage().data_ptr()
|
| 717 |
+
== comp_tensor.untyped_storage().data_ptr()
|
| 718 |
+
)
|
| 719 |
+
)
|
| 720 |
+
or (
|
| 721 |
+
W.is_mkldnn
|
| 722 |
+
and (
|
| 723 |
+
torch.ops.mkldnn.data_ptr(W)
|
| 724 |
+
== torch.ops.mkldnn.data_ptr(comp_tensor)
|
| 725 |
+
)
|
| 726 |
+
)
|
| 727 |
+
)
|
| 728 |
+
):
|
| 729 |
+
W_tensor_users += 1
|
| 730 |
+
|
| 731 |
+
for node in reversed(V.graph.graph.nodes):
|
| 732 |
+
# The wgt tensor has been used by only 1 get_attr node
|
| 733 |
+
# The get_attr node has only 1 user fx node
|
| 734 |
+
if (
|
| 735 |
+
node.name == W_node.get_name()
|
| 736 |
+
and len(node.users) == 1
|
| 737 |
+
and W_tensor_users == 1
|
| 738 |
+
):
|
| 739 |
+
del V.graph.constants[node.name]
|
| 740 |
+
delattr(V.graph.module, node.name)
|
| 741 |
+
delattr(V.graph.graph.owning_module, node.name)
|
| 742 |
+
|
| 743 |
+
W_packed = new_input_nodes[1]
|
| 744 |
+
W_packed_constant = V.graph.add_tensor_constant(W_packed)
|
| 745 |
+
template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input(
|
| 746 |
+
W_packed_constant
|
| 747 |
+
)
|
| 748 |
+
return output
|
| 749 |
+
|
| 750 |
+
template = DataProcessorTemplateWrapper(
|
| 751 |
+
CppPackedGemmTemplate,
|
| 752 |
+
preprocessor,
|
| 753 |
+
postprocessor,
|
| 754 |
+
input_nodes=input_nodes,
|
| 755 |
+
layout=layout,
|
| 756 |
+
num_threads=num_threads,
|
| 757 |
+
register_blocking=micro_gemm.register_blocking,
|
| 758 |
+
beta=beta,
|
| 759 |
+
alpha=alpha,
|
| 760 |
+
has_bias=has_bias,
|
| 761 |
+
epilogue_creator=epilogue_creator,
|
| 762 |
+
)
|
| 763 |
+
template.maybe_append_choice(choices)
|
| 764 |
+
return template
|
| 765 |
+
|
| 766 |
+
def render( # type: ignore[override,return]
|
| 767 |
+
self,
|
| 768 |
+
kernel: CppTemplateKernel,
|
| 769 |
+
template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
|
| 770 |
+
flag_template_buffer_has_other_users: Optional[bool] = None,
|
| 771 |
+
epilogue_nodes: Optional[List[ir.IRNode]] = None,
|
| 772 |
+
**kwargs,
|
| 773 |
+
) -> str:
|
| 774 |
+
assert len(self.input_nodes) >= 2
|
| 775 |
+
|
| 776 |
+
int8_gemm = self.input_nodes[0].get_dtype() == torch.uint8
|
| 777 |
+
x_scale = None
|
| 778 |
+
x_zp = None
|
| 779 |
+
w_scale = None
|
| 780 |
+
w_zp = None
|
| 781 |
+
if int8_gemm:
|
| 782 |
+
X, W = self.input_nodes[0], self.input_nodes[1]
|
| 783 |
+
bias_idx = 2 if self.has_bias else 1
|
| 784 |
+
inp = self.input_nodes[bias_idx] if self.has_bias else None
|
| 785 |
+
x_scale = self.input_nodes[bias_idx + 1]
|
| 786 |
+
x_zp = self.input_nodes[bias_idx + 2]
|
| 787 |
+
w_scale = self.input_nodes[bias_idx + 3]
|
| 788 |
+
w_zp = self.input_nodes[bias_idx + 4]
|
| 789 |
+
Y = self.output_node
|
| 790 |
+
else:
|
| 791 |
+
X, W = self.input_nodes[0], self.input_nodes[1]
|
| 792 |
+
Y = self.output_node
|
| 793 |
+
inp = self.input_nodes[2] if self.has_bias else None
|
| 794 |
+
|
| 795 |
+
template_buffer_has_other_users = None
|
| 796 |
+
|
| 797 |
+
if template_buffer_node is not None:
|
| 798 |
+
# Use the updated prepacked weight buffer
|
| 799 |
+
W = template_buffer_node.inputs[1]
|
| 800 |
+
Y = template_buffer_node
|
| 801 |
+
|
| 802 |
+
assert flag_template_buffer_has_other_users is not None
|
| 803 |
+
template_buffer_has_other_users = flag_template_buffer_has_other_users
|
| 804 |
+
|
| 805 |
+
template_buffer = Y
|
| 806 |
+
gemm_output_buffer = template_buffer
|
| 807 |
+
|
| 808 |
+
epilogues: List[ir.IRNode] = []
|
| 809 |
+
reindexers: List[Optional[Callable[[List[Any]], List[Any]]]] = []
|
| 810 |
+
epilogue_creators: List[Callable[[ir.Buffer], ir.Pointwise]] = []
|
| 811 |
+
fake_buffers: List[ir.Buffer] = []
|
| 812 |
+
Y_aliases: Set[str] = set()
|
| 813 |
+
|
| 814 |
+
use_local_acc = (
|
| 815 |
+
self.layout.dtype != torch.float
|
| 816 |
+
or template_buffer_has_other_users
|
| 817 |
+
or int8_gemm
|
| 818 |
+
or self.padded_n != self.n
|
| 819 |
+
or self.maybe_k_slicing()
|
| 820 |
+
)
|
| 821 |
+
|
| 822 |
+
# TODO(jgong5): for int8 gemm, bias-add is handled outside of gemm template,
|
| 823 |
+
# but we'd better move it here to align with fp.
|
| 824 |
+
if inp is not None and self.beta != 0 and not int8_gemm:
|
| 825 |
+
# add an epilogue for bias add
|
| 826 |
+
def _bias_add_epilogue(buf):
|
| 827 |
+
return create_epilogue_with_attr(
|
| 828 |
+
buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
epilogue_creators.append(_bias_add_epilogue)
|
| 832 |
+
|
| 833 |
+
if self.epilogue_creator is not None:
|
| 834 |
+
epilogue_creators.append(self.epilogue_creator)
|
| 835 |
+
|
| 836 |
+
# When the GEMM output buffer is localized but it has users other than the epilogue nodes,
|
| 837 |
+
# we need to copy the value in the GEMM output local buffer to a global buffer.
|
| 838 |
+
def need_copy_from_local_to_global_buffer_epilogue(
|
| 839 |
+
use_local_acc, template_buffer_has_other_users, epilogue_creators
|
| 840 |
+
):
|
| 841 |
+
# The GEMM output buffer is a global buffer, thus copy is not needed.
|
| 842 |
+
if not use_local_acc:
|
| 843 |
+
return False
|
| 844 |
+
|
| 845 |
+
# The possible value of template_buffer_has_other_users is (None, False, True)
|
| 846 |
+
# It is None when generating the gemm template during autotune and it will have value during scheduler codegen.
|
| 847 |
+
# extra copy_from_local_to_global_buffer_epilogue is not needed in either of the below two cases:
|
| 848 |
+
# 1. template_buffer_has_other_users is None (i.e. when doing the codegen during autotune)
|
| 849 |
+
# 2. template_buffer_has_other_users is False, which means it's safe to keep the value in the
|
| 850 |
+
# GEMM output buffer in local buffer only (no users outside of the epilogues will use its value).
|
| 851 |
+
if not template_buffer_has_other_users:
|
| 852 |
+
return False
|
| 853 |
+
|
| 854 |
+
# When bias is not None or self.epilogue_creator is not None,
|
| 855 |
+
# there will be epilogue_creators after the GEMM.
|
| 856 |
+
# The GEMM output buffer is localized while
|
| 857 |
+
# the output buffer of the epilogue_creators is a global buffer.
|
| 858 |
+
if epilogue_creators:
|
| 859 |
+
return False
|
| 860 |
+
|
| 861 |
+
return True
|
| 862 |
+
|
| 863 |
+
if need_copy_from_local_to_global_buffer_epilogue(
|
| 864 |
+
use_local_acc, template_buffer_has_other_users, epilogue_creators
|
| 865 |
+
):
|
| 866 |
+
|
| 867 |
+
def copy_from_local_to_global_buffer_epilogue(input_buffer: ir.Buffer):
|
| 868 |
+
dtype = self.layout.dtype
|
| 869 |
+
input_loader = input_buffer.make_loader()
|
| 870 |
+
|
| 871 |
+
def copy_inner(index):
|
| 872 |
+
input = input_loader(index)
|
| 873 |
+
result = ops.to_dtype(input, dtype)
|
| 874 |
+
return result
|
| 875 |
+
|
| 876 |
+
return ir.Pointwise(
|
| 877 |
+
device=input_buffer.get_device(),
|
| 878 |
+
dtype=self.layout.dtype,
|
| 879 |
+
inner_fn=copy_inner,
|
| 880 |
+
ranges=input_buffer.get_size(),
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
epilogue_creators.append(copy_from_local_to_global_buffer_epilogue)
|
| 884 |
+
|
| 885 |
+
# NOTE [How CPP GEMM template epilogues are organized]
|
| 886 |
+
# gemm_output_buffer
|
| 887 |
+
# --> zero or more in-template epilogues (created by `epilogue_creators`) -->
|
| 888 |
+
# template_buffer
|
| 889 |
+
# --> zero or more out-of-template epilogues (`epilogue_nodes`) -->
|
| 890 |
+
# Y
|
| 891 |
+
if epilogue_creators:
|
| 892 |
+
gemm_output_name = "buf_GemmOut"
|
| 893 |
+
gemm_output_buffer = ir.Buffer(gemm_output_name, template_buffer.layout)
|
| 894 |
+
current_input_buffer = gemm_output_buffer
|
| 895 |
+
for i, creator in enumerate(epilogue_creators):
|
| 896 |
+
if i == len(epilogue_creators) - 1:
|
| 897 |
+
buffer_name = template_buffer.get_name()
|
| 898 |
+
else:
|
| 899 |
+
buffer_name = f"buf_GemmOut_epilogue_{i}"
|
| 900 |
+
epilogues.append(
|
| 901 |
+
ir.ComputedBuffer(
|
| 902 |
+
name=buffer_name,
|
| 903 |
+
layout=template_buffer.layout,
|
| 904 |
+
data=creator(current_input_buffer),
|
| 905 |
+
)
|
| 906 |
+
)
|
| 907 |
+
fake_buffers.append(current_input_buffer)
|
| 908 |
+
Y_aliases.add(current_input_buffer.get_name())
|
| 909 |
+
reindexers.append(None)
|
| 910 |
+
if i < len(epilogue_creators) - 1:
|
| 911 |
+
current_input_buffer = ir.Buffer(
|
| 912 |
+
buffer_name, template_buffer.layout
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y
|
| 916 |
+
|
| 917 |
+
if epilogue_nodes:
|
| 918 |
+
epilogues.extend(epilogue_nodes)
|
| 919 |
+
assert Y.get_numel() == epilogues[-1].get_numel()
|
| 920 |
+
Y = cast(ir.Buffer, epilogues[-1])
|
| 921 |
+
|
| 922 |
+
if not template_buffer_has_other_users:
|
| 923 |
+
Y_aliases.add(template_buffer.get_name())
|
| 924 |
+
|
| 925 |
+
if (
|
| 926 |
+
Y.get_size() == template_buffer.get_size()
|
| 927 |
+
and Y.get_stride() == template_buffer.get_stride()
|
| 928 |
+
):
|
| 929 |
+
reindexers.extend([None] * len(epilogue_nodes))
|
| 930 |
+
Y_2d = Y
|
| 931 |
+
else:
|
| 932 |
+
|
| 933 |
+
def get_reindexer(epilogue_node):
|
| 934 |
+
# From template_buffer to epilogue_node_ordered (ordered by stride decreasingly, in dense format), for example:
|
| 935 |
+
# template_buffer:
|
| 936 |
+
# size (324, 512), stride (512, 1)
|
| 937 |
+
# epilogue_node_ordered (ordered by stride decreasingly, in dense format):
|
| 938 |
+
# size (1, 18, 18, 512), stride (165888, 9216, 512, 1)
|
| 939 |
+
stride_order = list(
|
| 940 |
+
ir.get_stride_order(
|
| 941 |
+
V.graph.sizevars.size_hints(epilogue_node.get_stride())
|
| 942 |
+
)
|
| 943 |
+
)
|
| 944 |
+
fill_order = ir.stride_order2fill_order(stride_order)
|
| 945 |
+
reversed_fill_order = list(reversed(fill_order))
|
| 946 |
+
size_with_stride_ordered_decreasingly = [
|
| 947 |
+
epilogue_node.get_size()[i] for i in reversed_fill_order
|
| 948 |
+
]
|
| 949 |
+
reshape_reindex = ir.View.dynamic_reshape_indexer(
|
| 950 |
+
size_with_stride_ordered_decreasingly,
|
| 951 |
+
template_buffer.get_size(),
|
| 952 |
+
)
|
| 953 |
+
|
| 954 |
+
# From epilogue_node_ordered (ordered by stride decreasingly, in dense format) to epilogue_node, for example:
|
| 955 |
+
# epilogue_node_ordered (ordered by stride decreasingly, in dense format):
|
| 956 |
+
# size (1, 18, 18, 512), stride (165888, 9216, 512, 1)
|
| 957 |
+
# epilogue_node:
|
| 958 |
+
# size (1, 18, 18, 512), stride (165888, 1, 9216, 512)
|
| 959 |
+
from_stride_ordered_decreasingly_to_epilogue_node_order = [
|
| 960 |
+
(len(stride_order) - 1) - stride_order[i]
|
| 961 |
+
for i in range(len(stride_order))
|
| 962 |
+
]
|
| 963 |
+
stride_reindex = ir.same_reorder(
|
| 964 |
+
from_stride_ordered_decreasingly_to_epilogue_node_order
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
reindexer = ir.fuse_reindexing(stride_reindex, reshape_reindex)
|
| 968 |
+
return reindexer
|
| 969 |
+
|
| 970 |
+
reindexers.extend([get_reindexer(epilogue_node) for epilogue_node in epilogue_nodes]) # type: ignore[list-item]
|
| 971 |
+
if isinstance(Y, ir.BaseView):
|
| 972 |
+
storage = ir.StorageBox(Y.unwrap_view())
|
| 973 |
+
else:
|
| 974 |
+
assert isinstance(Y, ir.Buffer)
|
| 975 |
+
storage = ir.StorageBox(Y)
|
| 976 |
+
Y_2d = ir.ReinterpretView(storage, template_buffer.get_layout())
|
| 977 |
+
|
| 978 |
+
output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
|
| 979 |
+
X.get_dtype()
|
| 980 |
+
)
|
| 981 |
+
micro_gemm = create_micro_gemm(
|
| 982 |
+
f"{kernel.kernel_name}_micro_gemm",
|
| 983 |
+
self.m,
|
| 984 |
+
self.n,
|
| 985 |
+
self.k,
|
| 986 |
+
input_dtype=X.get_dtype(),
|
| 987 |
+
input2_dtype=W.get_dtype(),
|
| 988 |
+
output_dtype=output_dtype,
|
| 989 |
+
compute_dtype=compute_dtype,
|
| 990 |
+
alpha=self.alpha,
|
| 991 |
+
num_threads=self.num_threads,
|
| 992 |
+
)
|
| 993 |
+
assert micro_gemm is not None
|
| 994 |
+
assert self.register_blocking == micro_gemm.register_blocking
|
| 995 |
+
self.log_blockings()
|
| 996 |
+
if isinstance(micro_gemm, CppMicroGemmAMX):
|
| 997 |
+
counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1
|
| 998 |
+
|
| 999 |
+
L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes
|
| 1000 |
+
assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}"
|
| 1001 |
+
|
| 1002 |
+
L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes
|
| 1003 |
+
assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}"
|
| 1004 |
+
|
| 1005 |
+
options = dict(
|
| 1006 |
+
X=X,
|
| 1007 |
+
W=W,
|
| 1008 |
+
inp=inp,
|
| 1009 |
+
Y=Y,
|
| 1010 |
+
N=self.n,
|
| 1011 |
+
K=self.k,
|
| 1012 |
+
PADDED_N=self.padded_n,
|
| 1013 |
+
GemmOut=gemm_output_buffer,
|
| 1014 |
+
aliases={alias: Y.get_name() for alias in Y_aliases},
|
| 1015 |
+
beta=self.beta,
|
| 1016 |
+
alpha=self.alpha,
|
| 1017 |
+
num_threads=self.num_threads,
|
| 1018 |
+
micro_gemm=micro_gemm,
|
| 1019 |
+
is_dynamic_M=self.is_dynamic_M,
|
| 1020 |
+
template=self,
|
| 1021 |
+
kernel=kernel,
|
| 1022 |
+
export_declaration=get_export_declaration(),
|
| 1023 |
+
epilogue_nodes=epilogues,
|
| 1024 |
+
reindexers=reindexers,
|
| 1025 |
+
Y_2d=Y_2d,
|
| 1026 |
+
use_local_acc=use_local_acc,
|
| 1027 |
+
maybe_k_slicing=self.maybe_k_slicing(),
|
| 1028 |
+
x_scale=x_scale,
|
| 1029 |
+
x_zp=x_zp,
|
| 1030 |
+
w_scale=w_scale,
|
| 1031 |
+
w_zp=w_zp,
|
| 1032 |
+
acc_buf_dtype=torch.int32 if int8_gemm else torch.float,
|
| 1033 |
+
DTYPE_TO_CPP=DTYPE_TO_CPP,
|
| 1034 |
+
L1_cache_size=L1_cache_size,
|
| 1035 |
+
L2_cache_size=L2_cache_size,
|
| 1036 |
+
config=config,
|
| 1037 |
+
)
|
| 1038 |
+
with contextlib.ExitStack() as stack:
|
| 1039 |
+
for buf in fake_buffers:
|
| 1040 |
+
stack.enter_context(
|
| 1041 |
+
patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
|
| 1042 |
+
)
|
| 1043 |
+
return self._template_from_string(GEMM_TEMPLATE).render(**options)
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py
ADDED
|
@@ -0,0 +1,850 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import dataclasses
|
| 3 |
+
import sys
|
| 4 |
+
from enum import Enum
|
| 5 |
+
from typing import Callable, Dict, List, Optional, Type
|
| 6 |
+
|
| 7 |
+
import sympy
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from .. import ir
|
| 12 |
+
from ..cpu_vec_isa import pick_vec_isa, VecAMX, VecAVX2, VecAVX512, VecISA
|
| 13 |
+
from ..utils import IndentedBuffer, parallel_num_threads
|
| 14 |
+
from ..virtualized import V
|
| 15 |
+
from .common import KernelTemplate
|
| 16 |
+
from .cpp_template_kernel import CppTemplateKernel
|
| 17 |
+
from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class LayoutType(Enum):
|
| 21 |
+
NORMAL = 0
|
| 22 |
+
VNNI2 = 1
|
| 23 |
+
VNNI4 = 2
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
_IS_WINDOWS = sys.platform == "win32"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_restrict_keyword() -> str:
|
| 30 |
+
if _IS_WINDOWS:
|
| 31 |
+
# https://learn.microsoft.com/en-us/cpp/cpp/extension-restrict?view=msvc-170
|
| 32 |
+
return "__restrict"
|
| 33 |
+
else:
|
| 34 |
+
return "__restrict__"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class CppMicroGemm:
|
| 38 |
+
"""
|
| 39 |
+
A class that codegens a kernel that computes small-sized matrix multiplication.
|
| 40 |
+
|
| 41 |
+
A micro GEMM kernel is responsible for register blocking, instruction selection,
|
| 42 |
+
and other CPU architecture-specific optimizations.
|
| 43 |
+
|
| 44 |
+
The subclasses need to override `codegen_define` to define the kernel function
|
| 45 |
+
that is called by the code generated by `codegen_call`.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
# TODO(jgong5): support constant shapes and lds as template args.
|
| 49 |
+
DECLARE_KERNEL = r"""
|
| 50 |
+
template <bool accum>
|
| 51 |
+
inline void {{kernel_name}}(
|
| 52 |
+
{%- if kernel_extra_args_declare %}
|
| 53 |
+
{{kernel_extra_args_declare}}
|
| 54 |
+
{%- endif %}
|
| 55 |
+
const {{input_t}}* {{restrict_keyword}} A,
|
| 56 |
+
const {{input2_t}}* {{restrict_keyword}} B,
|
| 57 |
+
{{output_t}}* {{restrict_keyword}} C,
|
| 58 |
+
int64_t M,
|
| 59 |
+
int64_t N,
|
| 60 |
+
int64_t K,
|
| 61 |
+
int64_t lda,
|
| 62 |
+
int64_t ldb,
|
| 63 |
+
int64_t ldc
|
| 64 |
+
)
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
name,
|
| 70 |
+
input_dtype,
|
| 71 |
+
input2_dtype,
|
| 72 |
+
output_dtype,
|
| 73 |
+
compute_dtype,
|
| 74 |
+
register_blocking,
|
| 75 |
+
alpha=1,
|
| 76 |
+
) -> None:
|
| 77 |
+
self.name = name
|
| 78 |
+
self.input_dtype = input_dtype
|
| 79 |
+
assert input2_dtype is not None
|
| 80 |
+
self.input2_dtype = input2_dtype
|
| 81 |
+
self.output_dtype = output_dtype
|
| 82 |
+
self.compute_dtype = compute_dtype
|
| 83 |
+
self.register_blocking = register_blocking
|
| 84 |
+
self.alpha = alpha
|
| 85 |
+
|
| 86 |
+
def get_common_options(self):
|
| 87 |
+
if self.input_dtype == torch.uint8:
|
| 88 |
+
assert self.compute_dtype == torch.int32
|
| 89 |
+
assert self.output_dtype == torch.int32
|
| 90 |
+
assert self.input2_dtype == torch.int8
|
| 91 |
+
return {
|
| 92 |
+
"torch": torch,
|
| 93 |
+
"kernel_name": self.name,
|
| 94 |
+
"input_dtype": self.input_dtype,
|
| 95 |
+
"input2_dtype": self.input2_dtype,
|
| 96 |
+
"output_dtype": self.output_dtype,
|
| 97 |
+
"compute_dtype": self.compute_dtype,
|
| 98 |
+
"input_t": DTYPE_TO_CPP[self.input_dtype],
|
| 99 |
+
"input2_t": DTYPE_TO_CPP[self.input2_dtype],
|
| 100 |
+
"output_t": DTYPE_TO_CPP[self.output_dtype],
|
| 101 |
+
"compute_t": DTYPE_TO_CPP[self.compute_dtype],
|
| 102 |
+
"alpha": self.alpha,
|
| 103 |
+
"kernel_extra_args_declare": self.get_kernel_extra_args_declare(),
|
| 104 |
+
"int8_gemm": self.input_dtype == torch.uint8,
|
| 105 |
+
"vnni_size": 4 if self.input_dtype == torch.uint8 else 2,
|
| 106 |
+
"restrict_keyword": get_restrict_keyword(),
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
def get_kernel_declaration(self):
|
| 110 |
+
options = self.get_common_options()
|
| 111 |
+
return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options)
|
| 112 |
+
|
| 113 |
+
def get_kernel_extra_args_declare(self) -> str:
|
| 114 |
+
return ""
|
| 115 |
+
|
| 116 |
+
def get_kernel_extra_args(self) -> str:
|
| 117 |
+
return ""
|
| 118 |
+
|
| 119 |
+
def codegen_define(self, kernel: CppTemplateKernel) -> str:
|
| 120 |
+
raise NotImplementedError
|
| 121 |
+
|
| 122 |
+
def codegen_call(
|
| 123 |
+
self,
|
| 124 |
+
kernel: CppTemplateKernel,
|
| 125 |
+
A: ir.Buffer,
|
| 126 |
+
B: ir.Buffer,
|
| 127 |
+
C: ir.Buffer,
|
| 128 |
+
accum: bool,
|
| 129 |
+
) -> str:
|
| 130 |
+
"""
|
| 131 |
+
Generate the code for calling the templated kernel that computes
|
| 132 |
+
`C += alpha * A @ B` if `accum` is True, or `C = alpha * A @ B` otherwise.
|
| 133 |
+
"""
|
| 134 |
+
A_ptr = f"&({kernel.index(A, [0, 0])})"
|
| 135 |
+
B_ptr = f"&({kernel.index(B, [0, 0])})"
|
| 136 |
+
C_ptr = f"&({kernel.index(C, [0, 0])})"
|
| 137 |
+
M = kernel.size(C, 0)
|
| 138 |
+
N = kernel.size(C, 1)
|
| 139 |
+
K = kernel.size(A, 1)
|
| 140 |
+
lda = kernel.stride(A, 0)
|
| 141 |
+
ldb = kernel.stride(B, 0)
|
| 142 |
+
ldc = kernel.stride(C, 0)
|
| 143 |
+
res = IndentedBuffer()
|
| 144 |
+
res.writeline(f"{self.name}<{value_to_cpp(accum, 'bool')}>(")
|
| 145 |
+
with res.indent():
|
| 146 |
+
extra_args = self.get_kernel_extra_args()
|
| 147 |
+
if extra_args:
|
| 148 |
+
res.writeline(extra_args)
|
| 149 |
+
res.writeline(f"{A_ptr},")
|
| 150 |
+
res.writeline(f"{B_ptr},")
|
| 151 |
+
res.writeline(f"{C_ptr},")
|
| 152 |
+
res.writeline(f"{M},")
|
| 153 |
+
res.writeline(f"{N},")
|
| 154 |
+
res.writeline(f"{K},")
|
| 155 |
+
res.writeline(f"{lda},")
|
| 156 |
+
res.writeline(f"{ldb},")
|
| 157 |
+
res.writeline(f"{ldc}")
|
| 158 |
+
res.writeline(");")
|
| 159 |
+
return res.getvalue()
|
| 160 |
+
|
| 161 |
+
def codegen_init(
|
| 162 |
+
self,
|
| 163 |
+
kernel: CppTemplateKernel,
|
| 164 |
+
) -> str:
|
| 165 |
+
return ""
|
| 166 |
+
|
| 167 |
+
def codegen_finalize(
|
| 168 |
+
self,
|
| 169 |
+
kernel: CppTemplateKernel,
|
| 170 |
+
) -> str:
|
| 171 |
+
return ""
|
| 172 |
+
|
| 173 |
+
def get_b_layout(self) -> LayoutType:
|
| 174 |
+
return LayoutType.NORMAL
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@dataclasses.dataclass
|
| 178 |
+
class CppMicroGemmConfig:
|
| 179 |
+
input_dtype: torch.dtype
|
| 180 |
+
input2_dtype: torch.dtype
|
| 181 |
+
output_dtype: torch.dtype
|
| 182 |
+
compute_dtype: torch.dtype
|
| 183 |
+
vec_isa_cls: Type[VecISA]
|
| 184 |
+
register_blocking: GemmBlocking
|
| 185 |
+
extra_check: Optional[Callable[..., bool]] = None
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
micro_gemm_configs: Dict[Type[CppMicroGemm], List[CppMicroGemmConfig]] = {}
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def register_micro_gemm(*configs):
|
| 192 |
+
def inner(cls):
|
| 193 |
+
assert (
|
| 194 |
+
cls not in micro_gemm_configs
|
| 195 |
+
), f"Duplicate micro_gemm registration for {cls}"
|
| 196 |
+
assert len(configs) > 0, f"No micro_gemm configs provided for {cls}"
|
| 197 |
+
micro_gemm_configs[cls] = list(configs)
|
| 198 |
+
return cls
|
| 199 |
+
|
| 200 |
+
return inner
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def generate_gemm_config(
|
| 204 |
+
vec_isa_cls,
|
| 205 |
+
register_blockings,
|
| 206 |
+
input_dtype=torch.float,
|
| 207 |
+
input2_dtype=None,
|
| 208 |
+
output_dtype=None,
|
| 209 |
+
compute_dtype=None,
|
| 210 |
+
extra_check=None,
|
| 211 |
+
):
|
| 212 |
+
if output_dtype is None:
|
| 213 |
+
output_dtype = input_dtype
|
| 214 |
+
if compute_dtype is None:
|
| 215 |
+
compute_dtype = output_dtype
|
| 216 |
+
if input2_dtype is None:
|
| 217 |
+
input2_dtype = input_dtype
|
| 218 |
+
return [
|
| 219 |
+
CppMicroGemmConfig(
|
| 220 |
+
input_dtype,
|
| 221 |
+
input2_dtype,
|
| 222 |
+
output_dtype,
|
| 223 |
+
compute_dtype,
|
| 224 |
+
vec_isa_cls,
|
| 225 |
+
GemmBlocking(*blocking),
|
| 226 |
+
extra_check,
|
| 227 |
+
)
|
| 228 |
+
for blocking in register_blockings
|
| 229 |
+
]
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class CppMicroGemmRef(CppMicroGemm):
|
| 233 |
+
"""
|
| 234 |
+
A reference implementation of the CppMicroGemm class with naive C++ code.
|
| 235 |
+
It is used for correctness debugging.
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
TEMPLATE_ENTRY = r"""
|
| 239 |
+
{{declare_kernel}} {
|
| 240 |
+
for (int64_t m = 0; m < M; ++m) {
|
| 241 |
+
for (int64_t n = 0; n < N; ++n) {
|
| 242 |
+
{{compute_t}} result = accum ? C[m * ldc + n] : 0;
|
| 243 |
+
for (int64_t k = 0; k < K; ++k) {
|
| 244 |
+
result += ({{compute_t}})A[m * lda + k] * ({{compute_t}})B[k * ldb + n] * {{alpha}};
|
| 245 |
+
}
|
| 246 |
+
C[m * ldc + n] = result;
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
}
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
def __init__(
|
| 253 |
+
self, name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha
|
| 254 |
+
) -> None:
|
| 255 |
+
super().__init__(
|
| 256 |
+
name,
|
| 257 |
+
input_dtype,
|
| 258 |
+
input2_dtype,
|
| 259 |
+
output_dtype,
|
| 260 |
+
compute_dtype,
|
| 261 |
+
GemmBlocking(1, 1, 1),
|
| 262 |
+
alpha,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
def codegen_define(self, kernel: CppTemplateKernel) -> str:
|
| 266 |
+
options = {
|
| 267 |
+
"declare_kernel": self.get_kernel_declaration(),
|
| 268 |
+
**self.get_common_options(),
|
| 269 |
+
}
|
| 270 |
+
return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
@register_micro_gemm(
|
| 274 |
+
*generate_gemm_config(
|
| 275 |
+
VecAVX512,
|
| 276 |
+
[(8, 48, 1), (8, 32, 1), (16, 16, 1)],
|
| 277 |
+
input_dtype=torch.float,
|
| 278 |
+
),
|
| 279 |
+
*generate_gemm_config(
|
| 280 |
+
VecAVX512,
|
| 281 |
+
[(8, 48, 1), (8, 32, 1), (16, 16, 1)],
|
| 282 |
+
input_dtype=torch.bfloat16,
|
| 283 |
+
output_dtype=torch.float,
|
| 284 |
+
),
|
| 285 |
+
*generate_gemm_config(
|
| 286 |
+
VecAVX512,
|
| 287 |
+
[(8, 48, 1), (8, 32, 1), (16, 16, 1)],
|
| 288 |
+
input_dtype=torch.half,
|
| 289 |
+
output_dtype=torch.float,
|
| 290 |
+
),
|
| 291 |
+
*generate_gemm_config(
|
| 292 |
+
VecAVX512,
|
| 293 |
+
[(8, 48, 1), (8, 32, 1), (16, 16, 1)],
|
| 294 |
+
input_dtype=torch.bfloat16,
|
| 295 |
+
input2_dtype=torch.int8,
|
| 296 |
+
output_dtype=torch.float,
|
| 297 |
+
compute_dtype=torch.float,
|
| 298 |
+
),
|
| 299 |
+
*generate_gemm_config(
|
| 300 |
+
VecAVX2,
|
| 301 |
+
[(4, 24, 1), (4, 16, 1), (8, 8, 1)],
|
| 302 |
+
input_dtype=torch.float,
|
| 303 |
+
),
|
| 304 |
+
*generate_gemm_config(
|
| 305 |
+
VecAVX2,
|
| 306 |
+
[(4, 24, 1), (4, 16, 1), (8, 8, 1)],
|
| 307 |
+
input_dtype=torch.bfloat16,
|
| 308 |
+
output_dtype=torch.float,
|
| 309 |
+
),
|
| 310 |
+
*generate_gemm_config(
|
| 311 |
+
VecAVX2,
|
| 312 |
+
[(4, 24, 1), (4, 16, 1), (8, 8, 1)],
|
| 313 |
+
input_dtype=torch.half,
|
| 314 |
+
output_dtype=torch.float,
|
| 315 |
+
),
|
| 316 |
+
*generate_gemm_config(
|
| 317 |
+
VecAVX2,
|
| 318 |
+
[(4, 24, 1), (4, 16, 1), (8, 8, 1)],
|
| 319 |
+
input_dtype=torch.bfloat16,
|
| 320 |
+
input2_dtype=torch.int8,
|
| 321 |
+
output_dtype=torch.float,
|
| 322 |
+
compute_dtype=torch.float,
|
| 323 |
+
),
|
| 324 |
+
)
|
| 325 |
+
class CppMicroGemmFP32Vec(CppMicroGemm):
|
| 326 |
+
"""
|
| 327 |
+
This class generates the code for micro gemm using fp32 vec instructions for compute.
|
| 328 |
+
It supports input types of torch.float, torch.bfloat16, and torch.half with fp32 output.
|
| 329 |
+
The output of the microkernel is in FP32, but it would be converted to BF16/FP16 in the template,
|
| 330 |
+
if the desired output is BF16/FP16.
|
| 331 |
+
"""
|
| 332 |
+
|
| 333 |
+
TEMPLATE_ENTRY = r"""
|
| 334 |
+
{{declare_kernel}} {
|
| 335 |
+
TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
|
| 336 |
+
TORCH_CHECK(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}");
|
| 337 |
+
// TODO(jgong5): loop unroll for M and N
|
| 338 |
+
for (int64_t m = 0; m < M; m += {{block_m}}) {
|
| 339 |
+
int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
|
| 340 |
+
for (int64_t n = 0; n < N; n += {{block_n}}) {
|
| 341 |
+
if (block_m == {{block_m}}) {
|
| 342 |
+
{{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>(
|
| 343 |
+
A + m * lda,
|
| 344 |
+
B + n,
|
| 345 |
+
C + m * ldc + n,
|
| 346 |
+
K,
|
| 347 |
+
lda,
|
| 348 |
+
ldb,
|
| 349 |
+
ldc
|
| 350 |
+
);
|
| 351 |
+
} else {
|
| 352 |
+
switch (block_m) {
|
| 353 |
+
{%- for b in range(block_m - 1, 0, -1) %}
|
| 354 |
+
case {{b}}:
|
| 355 |
+
{{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>(
|
| 356 |
+
A + m * lda,
|
| 357 |
+
B + n,
|
| 358 |
+
C + m * ldc + n,
|
| 359 |
+
K,
|
| 360 |
+
lda,
|
| 361 |
+
ldb,
|
| 362 |
+
ldc
|
| 363 |
+
);
|
| 364 |
+
break;
|
| 365 |
+
{%- endfor %}
|
| 366 |
+
default:
|
| 367 |
+
{{kernel.assert_function}}(false, "Unsupported block_m: ", block_m);
|
| 368 |
+
}
|
| 369 |
+
}
|
| 370 |
+
}
|
| 371 |
+
}
|
| 372 |
+
}
|
| 373 |
+
"""
|
| 374 |
+
|
| 375 |
+
TEMPLATE_KERNEL = r"""
|
| 376 |
+
template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum>
|
| 377 |
+
inline void {{kernel_name}}_kernel(
|
| 378 |
+
const {{input_t}}* {{restrict_keyword}} A,
|
| 379 |
+
const {{input2_t}}* {{restrict_keyword}} B,
|
| 380 |
+
{{output_t}}* {{restrict_keyword}} C,
|
| 381 |
+
int64_t K,
|
| 382 |
+
int64_t lda,
|
| 383 |
+
int64_t ldb,
|
| 384 |
+
int64_t ldc
|
| 385 |
+
) {
|
| 386 |
+
using Vectorized = at::vec::Vectorized<{{compute_t}}>;
|
| 387 |
+
using VectorizedIn = at::vec::Vectorized<{{input_t}}>;
|
| 388 |
+
constexpr auto VLEN = Vectorized::size();
|
| 389 |
+
constexpr auto ROWS = BLOCK_M;
|
| 390 |
+
constexpr auto COLS = BLOCK_N / VLEN;
|
| 391 |
+
|
| 392 |
+
Vectorized va;
|
| 393 |
+
at::vec::VectorizedN<{{compute_t}}, COLS> vb;
|
| 394 |
+
at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc;
|
| 395 |
+
|
| 396 |
+
auto loadc = [&](auto i) {
|
| 397 |
+
if constexpr (accum) {
|
| 398 |
+
constexpr int row = i / COLS;
|
| 399 |
+
constexpr int col = i % COLS;
|
| 400 |
+
vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN);
|
| 401 |
+
} else {
|
| 402 |
+
vc[i] = Vectorized(0.0f);
|
| 403 |
+
}
|
| 404 |
+
};
|
| 405 |
+
c10::ForcedUnroll<ROWS * COLS>{}(loadc);
|
| 406 |
+
|
| 407 |
+
auto compute = [&, COLS](auto i, int k) {
|
| 408 |
+
constexpr int row = i / COLS;
|
| 409 |
+
constexpr int col = i % COLS;
|
| 410 |
+
|
| 411 |
+
if constexpr (col == 0) {
|
| 412 |
+
{%- if alpha != 1 %}
|
| 413 |
+
va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]) * {{alpha}});
|
| 414 |
+
{%- else %}
|
| 415 |
+
va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]));
|
| 416 |
+
{%- endif %}
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
if constexpr (row == 0) {
|
| 420 |
+
{%- if input2_dtype in [torch.bfloat16, torch.float16] %}
|
| 421 |
+
auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN);
|
| 422 |
+
vb[col] = at::vec::convert<{{compute_t}}>(b);
|
| 423 |
+
{%- elif input2_dtype == torch.int8 %}
|
| 424 |
+
// Convert VLEN int8 elements to int32, and then fp32
|
| 425 |
+
auto b32 = at::vec::convert_to_int32<int8_t>(B + k * ldb + col * VLEN);
|
| 426 |
+
vb[col] = at::vec::convert<float>(b32);
|
| 427 |
+
{%- else %}
|
| 428 |
+
vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN);
|
| 429 |
+
{%- endif %}
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
constexpr int idx = row * COLS + col;
|
| 433 |
+
vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]);
|
| 434 |
+
};
|
| 435 |
+
|
| 436 |
+
for (int k = 0; k < K; ++k) {
|
| 437 |
+
c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
// store to C
|
| 441 |
+
auto storec = [&](auto i) {
|
| 442 |
+
constexpr int row = i / COLS;
|
| 443 |
+
constexpr int col = i % COLS;
|
| 444 |
+
vc[i].store(C + row * ldc + col * VLEN);
|
| 445 |
+
};
|
| 446 |
+
c10::ForcedUnroll<ROWS * COLS>{}(storec);
|
| 447 |
+
}
|
| 448 |
+
"""
|
| 449 |
+
|
| 450 |
+
def codegen_define(self, kernel: CppTemplateKernel) -> str:
|
| 451 |
+
options = {
|
| 452 |
+
"declare_kernel": self.get_kernel_declaration(),
|
| 453 |
+
"kernel": kernel,
|
| 454 |
+
"block_m": self.register_blocking.block_m,
|
| 455 |
+
"block_n": self.register_blocking.block_n,
|
| 456 |
+
"block_k": self.register_blocking.block_k,
|
| 457 |
+
"restrict_keyword": get_restrict_keyword(),
|
| 458 |
+
**self.get_common_options(),
|
| 459 |
+
}
|
| 460 |
+
result = KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
|
| 461 |
+
options
|
| 462 |
+
)
|
| 463 |
+
result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
|
| 464 |
+
options
|
| 465 |
+
)
|
| 466 |
+
return result
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
# extra check for CppMicroGemmAMX
|
| 470 |
+
def check_amx_extra(config, m, n, k, alpha, num_threads):
|
| 471 |
+
vnni_size = 4 if config.input_dtype == torch.uint8 else 2
|
| 472 |
+
return k % vnni_size == 0 and alpha == 1
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
@register_micro_gemm(
|
| 476 |
+
*generate_gemm_config(
|
| 477 |
+
VecAMX,
|
| 478 |
+
[(32, 32, 32), (48, 16, 32), (16, 48, 32)],
|
| 479 |
+
input_dtype=torch.bfloat16,
|
| 480 |
+
input2_dtype=torch.int8,
|
| 481 |
+
output_dtype=torch.float,
|
| 482 |
+
compute_dtype=torch.float,
|
| 483 |
+
extra_check=check_amx_extra,
|
| 484 |
+
),
|
| 485 |
+
*generate_gemm_config(
|
| 486 |
+
VecAMX,
|
| 487 |
+
[(32, 32, 32), (48, 16, 32), (16, 48, 32)],
|
| 488 |
+
input_dtype=torch.bfloat16,
|
| 489 |
+
output_dtype=torch.float,
|
| 490 |
+
extra_check=check_amx_extra,
|
| 491 |
+
),
|
| 492 |
+
*generate_gemm_config(
|
| 493 |
+
VecAMX,
|
| 494 |
+
[(32, 32, 64), (48, 16, 64)],
|
| 495 |
+
input_dtype=torch.uint8,
|
| 496 |
+
input2_dtype=torch.int8,
|
| 497 |
+
output_dtype=torch.int32,
|
| 498 |
+
compute_dtype=torch.int32,
|
| 499 |
+
extra_check=check_amx_extra,
|
| 500 |
+
),
|
| 501 |
+
)
|
| 502 |
+
class CppMicroGemmAMX(CppMicroGemm):
|
| 503 |
+
"""
|
| 504 |
+
This class generates the code for micro gemm using Advanced Matrix eXtention (AMX)
|
| 505 |
+
instructions available in 4th generation Intel Xeon for compute.
|
| 506 |
+
It supports input types of torch.bfloat16 with fp32 output.
|
| 507 |
+
TODO(jgong5): support int8 data type.
|
| 508 |
+
"""
|
| 509 |
+
|
| 510 |
+
TEMPLATE_ENTRY = r"""
|
| 511 |
+
{{declare_kernel}} {
|
| 512 |
+
TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
|
| 513 |
+
TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2");
|
| 514 |
+
// TODO(jgong5): loop unroll for M and N
|
| 515 |
+
for (int64_t m = 0; m < M; m += {{block_m}}) {
|
| 516 |
+
int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
|
| 517 |
+
int64_t m_tail = m;
|
| 518 |
+
for (int64_t n = 0; n < N; n += {{block_n}}) {
|
| 519 |
+
{%- for num_rows in range(block_m, 0, -16) %}
|
| 520 |
+
{%- if num_rows != block_m %}
|
| 521 |
+
else
|
| 522 |
+
{%- endif %}
|
| 523 |
+
if (block_m >= {{num_rows}}) {
|
| 524 |
+
{{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}<accum>(
|
| 525 |
+
amx_state,
|
| 526 |
+
A + m * lda,
|
| 527 |
+
B + n,
|
| 528 |
+
C + m * ldc + n,
|
| 529 |
+
K,
|
| 530 |
+
lda,
|
| 531 |
+
ldb,
|
| 532 |
+
ldc,
|
| 533 |
+
16
|
| 534 |
+
);
|
| 535 |
+
block_m -= {{num_rows}};
|
| 536 |
+
m_tail += {{num_rows}};
|
| 537 |
+
}
|
| 538 |
+
{%- endfor %}
|
| 539 |
+
if (block_m > 0) {
|
| 540 |
+
{{kernel_name}}_amx_kernel_16_{{num_columns}}<accum>(
|
| 541 |
+
amx_state,
|
| 542 |
+
A + m_tail * lda,
|
| 543 |
+
B + n,
|
| 544 |
+
C + m_tail * ldc + n,
|
| 545 |
+
K,
|
| 546 |
+
lda,
|
| 547 |
+
ldb,
|
| 548 |
+
ldc,
|
| 549 |
+
block_m
|
| 550 |
+
);
|
| 551 |
+
}
|
| 552 |
+
}
|
| 553 |
+
}
|
| 554 |
+
}
|
| 555 |
+
"""
|
| 556 |
+
|
| 557 |
+
TEMPLATE_KERNEL = r"""
|
| 558 |
+
template <bool accum>
|
| 559 |
+
inline void {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}(
|
| 560 |
+
AMXState& amx_state,
|
| 561 |
+
const {{input_t}}* {{restrict_keyword}} A,
|
| 562 |
+
const {{input2_t}}* {{restrict_keyword}} B,
|
| 563 |
+
{{output_t}}* {{restrict_keyword}} C,
|
| 564 |
+
int64_t K,
|
| 565 |
+
int64_t lda,
|
| 566 |
+
int64_t ldb,
|
| 567 |
+
int64_t ldc,
|
| 568 |
+
uint8_t tilecfg_rows
|
| 569 |
+
) {
|
| 570 |
+
// TODO(jgong5): add prefetch hint for A, B, C
|
| 571 |
+
auto loadconfig = [](const amx_tilecfg& cfg) {
|
| 572 |
+
_tile_loadconfig(&cfg);
|
| 573 |
+
};
|
| 574 |
+
const auto last_k_offset = K / {{block_k}} * {{block_k}};
|
| 575 |
+
const auto tail_k_size = K - last_k_offset;
|
| 576 |
+
if C10_LIKELY (last_k_offset > 0) {
|
| 577 |
+
amx_state.configure(tilecfg_rows, 64, {{num_rows}} / 16, {{num_columns}}, loadconfig);
|
| 578 |
+
} else {
|
| 579 |
+
amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig);
|
| 580 |
+
}
|
| 581 |
+
auto load_c = [&]() {
|
| 582 |
+
{%- for tile_row in range(num_rows // 16) %}
|
| 583 |
+
{%- for tile_col in range(num_columns) %}
|
| 584 |
+
{%- set tile_idx = tile_row * num_columns + tile_col %}
|
| 585 |
+
_tile_loadd({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}}));
|
| 586 |
+
{%- endfor %}
|
| 587 |
+
{%- endfor %}
|
| 588 |
+
};
|
| 589 |
+
auto zero_c = [&]() {
|
| 590 |
+
{%- for tile_row in range(num_rows // 16) %}
|
| 591 |
+
{%- for tile_col in range(num_columns) %}
|
| 592 |
+
{%- set tile_idx = tile_row * num_columns + tile_col %}
|
| 593 |
+
_tile_zero({{tile_idx}});
|
| 594 |
+
{%- endfor %}
|
| 595 |
+
{%- endfor %}
|
| 596 |
+
};
|
| 597 |
+
|
| 598 |
+
if constexpr (accum) {
|
| 599 |
+
load_c();
|
| 600 |
+
} else {
|
| 601 |
+
zero_c();
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
{%- if input_dtype == torch.bfloat16 and input2_dtype == torch.int8 %}
|
| 605 |
+
// create a buffer for tiles of B.
|
| 606 |
+
alignas(64) {{input_t}} bf16_weights_buf[512];
|
| 607 |
+
|
| 608 |
+
int num_b_rows = (last_k_offset > 0) ? 16 : (tail_k_size * sizeof({{input_t}})) / 4;
|
| 609 |
+
int b_tile_ptr_stride = ldb * {{vnni_size}};
|
| 610 |
+
|
| 611 |
+
auto load_B_row = [&]({{input2_t}}* src, {{input_t}}* dst) {
|
| 612 |
+
{{kernel.unroll_pragma(2)}}
|
| 613 |
+
for (int i = 0; i < 2; i++) {
|
| 614 |
+
// int8 -> int32 -> fp32 -> bf16
|
| 615 |
+
auto b32 = at::vec::convert_to_int32<int8_t>(src + i * 16);
|
| 616 |
+
auto b_bf16 = at::vec::convert<{{input_t}}>(b32);
|
| 617 |
+
b_bf16.store(dst + i * 16);
|
| 618 |
+
}
|
| 619 |
+
};
|
| 620 |
+
|
| 621 |
+
auto load_B_in_buf = [&]({{input2_t}}* B_ptr) {
|
| 622 |
+
{{kernel.unroll_pragma(8)}}
|
| 623 |
+
for (int i = 0; i < num_b_rows; i++) {
|
| 624 |
+
load_B_row(
|
| 625 |
+
B_ptr + i * b_tile_ptr_stride,
|
| 626 |
+
bf16_weights_buf + i * 32
|
| 627 |
+
);
|
| 628 |
+
}
|
| 629 |
+
};
|
| 630 |
+
{%- endif %}
|
| 631 |
+
|
| 632 |
+
auto compute = [&](int k) {
|
| 633 |
+
{%- set tile_offset_a = num_rows // 16 * num_columns %}
|
| 634 |
+
{%- set tile_offset_b = tile_offset_a + num_rows // 16 %}
|
| 635 |
+
{%- for tile_row in range(num_rows // 16) %}
|
| 636 |
+
{%- for tile_col in range(num_columns) %}
|
| 637 |
+
{%- set tile_idx_a = tile_offset_a + tile_row %}
|
| 638 |
+
{%- set tile_idx_b = tile_offset_b + tile_col %}
|
| 639 |
+
{%- set tile_idx_c = tile_row * num_columns + tile_col %}
|
| 640 |
+
{%- if tile_col == 0 %}
|
| 641 |
+
_tile_stream_loadd({{tile_idx_a}}, A + {{tile_row * 16}} * lda + k, lda * sizeof({{input_t}}));
|
| 642 |
+
{%- endif %}
|
| 643 |
+
{%- if tile_row == 0 %}
|
| 644 |
+
{%- if input_dtype == torch.bfloat16 and input2_dtype == torch.int8 %}
|
| 645 |
+
load_B_in_buf(const_cast<{{input2_t}}*>(B) + k * ldb + {{tile_col * 16 * vnni_size}});
|
| 646 |
+
_tile_loadd({{tile_idx_b}}, bf16_weights_buf, 64);
|
| 647 |
+
{%- else %}
|
| 648 |
+
_tile_loadd({{tile_idx_b}}, B + k * ldb + {{tile_col * 16 * vnni_size}}, ldb * {{vnni_size}} * sizeof({{input_t}}));
|
| 649 |
+
{%- endif %}
|
| 650 |
+
{%- endif %}
|
| 651 |
+
{%- if int8_gemm %}
|
| 652 |
+
_tile_dpbusd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
|
| 653 |
+
{%- else %}
|
| 654 |
+
_tile_dpbf16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
|
| 655 |
+
{%- endif %}
|
| 656 |
+
{%- endfor %}
|
| 657 |
+
{%- endfor %}
|
| 658 |
+
};
|
| 659 |
+
|
| 660 |
+
{{kernel.unroll_pragma(4)}}
|
| 661 |
+
for (int k = 0; k < last_k_offset; k += {{block_k}}) {
|
| 662 |
+
compute(k);
|
| 663 |
+
}
|
| 664 |
+
|
| 665 |
+
auto store_c = [&]() {
|
| 666 |
+
// store to C
|
| 667 |
+
{%- for tile_row in range(num_rows // 16) %}
|
| 668 |
+
{%- for tile_col in range(num_columns) %}
|
| 669 |
+
{%- set tile_idx = tile_row * num_columns + tile_col %}
|
| 670 |
+
_tile_stored({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}}));
|
| 671 |
+
{%- endfor %}
|
| 672 |
+
{%- endfor %}
|
| 673 |
+
};
|
| 674 |
+
|
| 675 |
+
// TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
|
| 676 |
+
if C10_UNLIKELY (tail_k_size > 0) {
|
| 677 |
+
if C10_LIKELY (last_k_offset > 0) {
|
| 678 |
+
store_c();
|
| 679 |
+
amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig);
|
| 680 |
+
load_c();
|
| 681 |
+
}
|
| 682 |
+
compute(last_k_offset);
|
| 683 |
+
}
|
| 684 |
+
|
| 685 |
+
store_c();
|
| 686 |
+
}
|
| 687 |
+
"""
|
| 688 |
+
|
| 689 |
+
def codegen_define(self, kernel: CppTemplateKernel) -> str:
|
| 690 |
+
block_m, block_n, block_k = self.register_blocking
|
| 691 |
+
assert block_m % 16 == 0, "Only support block_m % 16 == 0 for AMX"
|
| 692 |
+
assert block_n % 16 == 0, "Only support block_n % 16 == 0 for AMX"
|
| 693 |
+
if self.input_dtype == torch.uint8:
|
| 694 |
+
assert block_k == 64, "Only support block_k = 64 for AMX INT8"
|
| 695 |
+
else:
|
| 696 |
+
assert block_k == 32, "Only support block_k = 32 for AMX Bfloat16/Float16"
|
| 697 |
+
num_columns = block_n // 16
|
| 698 |
+
options = {
|
| 699 |
+
"declare_kernel": self.get_kernel_declaration(),
|
| 700 |
+
"kernel": kernel,
|
| 701 |
+
"block_m": block_m,
|
| 702 |
+
"block_n": block_n,
|
| 703 |
+
"block_k": block_k,
|
| 704 |
+
"num_columns": num_columns,
|
| 705 |
+
"restrict_keyword": get_restrict_keyword(),
|
| 706 |
+
**self.get_common_options(),
|
| 707 |
+
}
|
| 708 |
+
result = ""
|
| 709 |
+
for num_rows in range(block_m, 0, -16):
|
| 710 |
+
amx_kernel_options = {**options, "num_rows": num_rows}
|
| 711 |
+
result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
|
| 712 |
+
amx_kernel_options
|
| 713 |
+
)
|
| 714 |
+
result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
|
| 715 |
+
options
|
| 716 |
+
)
|
| 717 |
+
return result
|
| 718 |
+
|
| 719 |
+
def codegen_init(
|
| 720 |
+
self,
|
| 721 |
+
kernel: CppTemplateKernel,
|
| 722 |
+
) -> str:
|
| 723 |
+
return "AMXState amx_state;"
|
| 724 |
+
|
| 725 |
+
def codegen_finalize(
|
| 726 |
+
self,
|
| 727 |
+
kernel: CppTemplateKernel,
|
| 728 |
+
) -> str:
|
| 729 |
+
return "amx_state.release([]() { _tile_release(); });"
|
| 730 |
+
|
| 731 |
+
def get_kernel_extra_args_declare(self) -> str:
|
| 732 |
+
return "AMXState& amx_state,"
|
| 733 |
+
|
| 734 |
+
def get_kernel_extra_args(self) -> str:
|
| 735 |
+
return "amx_state,"
|
| 736 |
+
|
| 737 |
+
def get_b_layout(self):
|
| 738 |
+
if self.input_dtype == torch.uint8:
|
| 739 |
+
return LayoutType.VNNI4
|
| 740 |
+
else:
|
| 741 |
+
return LayoutType.VNNI2
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
def create_micro_gemm(
|
| 745 |
+
name,
|
| 746 |
+
m,
|
| 747 |
+
n,
|
| 748 |
+
k,
|
| 749 |
+
input_dtype,
|
| 750 |
+
input2_dtype,
|
| 751 |
+
output_dtype=None,
|
| 752 |
+
compute_dtype=None,
|
| 753 |
+
alpha=1,
|
| 754 |
+
num_threads=-1,
|
| 755 |
+
use_ref=True,
|
| 756 |
+
) -> Optional[CppMicroGemm]:
|
| 757 |
+
def create_from_config(cls, config: CppMicroGemmConfig):
|
| 758 |
+
return cls(
|
| 759 |
+
name,
|
| 760 |
+
config.input_dtype,
|
| 761 |
+
config.input2_dtype,
|
| 762 |
+
config.output_dtype,
|
| 763 |
+
config.compute_dtype,
|
| 764 |
+
config.register_blocking,
|
| 765 |
+
alpha,
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
assert isinstance(n, int) or n.is_number, n
|
| 769 |
+
assert isinstance(k, int) or k.is_number, k
|
| 770 |
+
m = V.graph.sizevars.size_hint(m, fallback=1) if isinstance(m, sympy.Expr) else m
|
| 771 |
+
assert isinstance(m, int), m
|
| 772 |
+
if output_dtype is None:
|
| 773 |
+
output_dtype = input_dtype
|
| 774 |
+
if compute_dtype is None:
|
| 775 |
+
compute_dtype = output_dtype
|
| 776 |
+
if num_threads < 0:
|
| 777 |
+
num_threads = parallel_num_threads()
|
| 778 |
+
vec_isa = pick_vec_isa()
|
| 779 |
+
matched_configs = []
|
| 780 |
+
for cls, configs in micro_gemm_configs.items():
|
| 781 |
+
for config in configs:
|
| 782 |
+
if not issubclass(vec_isa.__class__, config.vec_isa_cls):
|
| 783 |
+
continue
|
| 784 |
+
if (
|
| 785 |
+
config.input_dtype == input_dtype
|
| 786 |
+
and config.compute_dtype == compute_dtype
|
| 787 |
+
and config.input2_dtype == input2_dtype
|
| 788 |
+
and config.output_dtype == output_dtype
|
| 789 |
+
# The output_dtype here is the output dtype of the micro-kernel.
|
| 790 |
+
# In some cases, the actual output dtype of the op for which the micro-kernel
|
| 791 |
+
# is being created would be same as that of the activation, but the micro-kernels
|
| 792 |
+
# compute output in Float/int32, which is converted in the GEMM template. This is
|
| 793 |
+
# subject to change in the future.
|
| 794 |
+
):
|
| 795 |
+
if config.extra_check is not None and not config.extra_check(
|
| 796 |
+
config, m, n, k, alpha, num_threads
|
| 797 |
+
):
|
| 798 |
+
continue
|
| 799 |
+
block_m, block_n, block_k = config.register_blocking
|
| 800 |
+
if (
|
| 801 |
+
config.vec_isa_cls == VecAMX
|
| 802 |
+
and m < block_m
|
| 803 |
+
and input_dtype == torch.bfloat16
|
| 804 |
+
and input2_dtype == torch.int8
|
| 805 |
+
):
|
| 806 |
+
# For int8 WoQ GEMM, AMX micro-kernel may not perform well if m < block_m
|
| 807 |
+
continue
|
| 808 |
+
# Criteria on the ranking of configurations
|
| 809 |
+
# 1. ISA: AMX > VEC
|
| 810 |
+
# 2. Dividable by block sizes (block_m, block_n, block_k)
|
| 811 |
+
# 3. Number of mxn blocks is large enough to occupy all the threads
|
| 812 |
+
# 4. Register blocks are larger
|
| 813 |
+
isa_score = 0
|
| 814 |
+
if config.vec_isa_cls == VecAMX:
|
| 815 |
+
isa_score += 1
|
| 816 |
+
dividable_score = 0
|
| 817 |
+
if m % block_m == 0:
|
| 818 |
+
dividable_score += 1
|
| 819 |
+
if n % block_n == 0:
|
| 820 |
+
dividable_score += 1
|
| 821 |
+
if k % block_k == 0:
|
| 822 |
+
dividable_score += 1
|
| 823 |
+
occupancy_score = 0
|
| 824 |
+
n_blocks = (n + block_n - 1) // block_n
|
| 825 |
+
total_mxn_blocks = n_blocks * ((m + block_m - 1) // block_m)
|
| 826 |
+
if n_blocks >= num_threads:
|
| 827 |
+
occupancy_score += 1
|
| 828 |
+
if total_mxn_blocks >= num_threads:
|
| 829 |
+
occupancy_score += 1
|
| 830 |
+
register_bytes = (
|
| 831 |
+
block_m * block_n * config.compute_dtype.itemsize
|
| 832 |
+
+ (block_m * block_k + block_k * block_n)
|
| 833 |
+
* config.input_dtype.itemsize
|
| 834 |
+
)
|
| 835 |
+
matched_configs.append(
|
| 836 |
+
(
|
| 837 |
+
(isa_score, dividable_score, occupancy_score, register_bytes),
|
| 838 |
+
cls,
|
| 839 |
+
config,
|
| 840 |
+
)
|
| 841 |
+
)
|
| 842 |
+
if len(matched_configs) == 0:
|
| 843 |
+
if use_ref:
|
| 844 |
+
return CppMicroGemmRef(
|
| 845 |
+
name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha
|
| 846 |
+
)
|
| 847 |
+
else:
|
| 848 |
+
return None
|
| 849 |
+
# TODO(jgong5): allow autotuning on choices of configs
|
| 850 |
+
return create_from_config(*max(matched_configs, key=lambda x: x[0])[1:])
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_template.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import ctypes
|
| 3 |
+
import functools
|
| 4 |
+
import itertools
|
| 5 |
+
import logging
|
| 6 |
+
import sys
|
| 7 |
+
from typing import Callable, List, Optional
|
| 8 |
+
from unittest.mock import patch
|
| 9 |
+
|
| 10 |
+
import sympy
|
| 11 |
+
|
| 12 |
+
from .. import codecache, config, ir
|
| 13 |
+
from ..autotune_process import CppBenchmarkRequest, TensorMeta
|
| 14 |
+
from ..utils import IndentedBuffer, Placeholder, unique
|
| 15 |
+
from ..virtualized import V
|
| 16 |
+
from .common import KernelTemplate
|
| 17 |
+
from .cpp_template_kernel import CppTemplateCaller, CppTemplateKernel
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
log = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CppTemplate(KernelTemplate):
|
| 24 |
+
index_counter = itertools.count()
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
name: str,
|
| 29 |
+
input_nodes,
|
| 30 |
+
layout: ir.Layout,
|
| 31 |
+
num_threads: int,
|
| 32 |
+
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
| 33 |
+
) -> None:
|
| 34 |
+
super().__init__(name)
|
| 35 |
+
self.input_nodes = input_nodes
|
| 36 |
+
self.output_node: ir.Buffer = ir.Buffer("buf_out", layout)
|
| 37 |
+
self.layout = layout
|
| 38 |
+
self.num_threads = num_threads
|
| 39 |
+
self.epilogue_creator = epilogue_creator
|
| 40 |
+
|
| 41 |
+
def generate(self, **kwargs):
|
| 42 |
+
kernel_name = f"cpp_{self.name}"
|
| 43 |
+
with patch.object(
|
| 44 |
+
V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
|
| 45 |
+
), patch.object(ir.FlexibleLayout, "allow_indexing", True), CppTemplateKernel(
|
| 46 |
+
kernel_name=kernel_name, num_threads=self.num_threads
|
| 47 |
+
) as kernel:
|
| 48 |
+
code = kernel.render(self, **kwargs)
|
| 49 |
+
_, call_args, _, _ = kernel.args.python_argdefs()
|
| 50 |
+
log.debug("Generated Code:\n%s", code)
|
| 51 |
+
log.debug(
|
| 52 |
+
"Args: cpp_argdefs: %s, python_argdefs: %s",
|
| 53 |
+
kernel.args.cpp_argdefs(),
|
| 54 |
+
kernel.args.python_argdefs(),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
expected_args = list(
|
| 58 |
+
unique(input_node.get_name() for input_node in self.input_nodes)
|
| 59 |
+
)
|
| 60 |
+
expected_args.extend([self.output_node.get_name()])
|
| 61 |
+
assert list(call_args)[: len(expected_args)] == expected_args, (
|
| 62 |
+
call_args,
|
| 63 |
+
expected_args,
|
| 64 |
+
)
|
| 65 |
+
extra_args = V.graph.sizevars.size_hints(
|
| 66 |
+
map(sympy.expand, call_args[len(expected_args) :])
|
| 67 |
+
)
|
| 68 |
+
# Cast the size hint from int to ctypes.c_ulonglong explicitly
|
| 69 |
+
# since in cpp kernel, we bind it to C long
|
| 70 |
+
extra_args = tuple(ctypes.c_ulonglong(x) for x in extra_args)
|
| 71 |
+
|
| 72 |
+
kernel_hash_name = f"cpp_{self.name}_{next(self.index_counter)}"
|
| 73 |
+
|
| 74 |
+
# Create the BenchmarkRequest for CPP
|
| 75 |
+
bmreq = CppBenchmarkRequest(
|
| 76 |
+
kernel_name=kernel_name,
|
| 77 |
+
input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
|
| 78 |
+
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
|
| 79 |
+
extra_args=extra_args,
|
| 80 |
+
source_code=code,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def make_kernel_render(
|
| 84 |
+
template_node: ir.CppTemplateBuffer,
|
| 85 |
+
flag_template_buffer_has_other_users: bool,
|
| 86 |
+
epilogue_nodes: Optional[List[ir.IRNode]] = None,
|
| 87 |
+
):
|
| 88 |
+
kernel = CppTemplateKernel(
|
| 89 |
+
kernel_name=str(Placeholder.KERNEL_NAME), num_threads=self.num_threads
|
| 90 |
+
)
|
| 91 |
+
render = functools.partial(
|
| 92 |
+
kernel.render,
|
| 93 |
+
self,
|
| 94 |
+
template_buffer_node=template_node,
|
| 95 |
+
flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
|
| 96 |
+
epilogue_nodes=epilogue_nodes,
|
| 97 |
+
**kwargs,
|
| 98 |
+
)
|
| 99 |
+
return kernel, render
|
| 100 |
+
|
| 101 |
+
return CppTemplateCaller(
|
| 102 |
+
kernel_hash_name,
|
| 103 |
+
self.name,
|
| 104 |
+
self.input_nodes,
|
| 105 |
+
self.output_node.get_layout(),
|
| 106 |
+
make_kernel_render,
|
| 107 |
+
bmreq,
|
| 108 |
+
self,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def header(self) -> IndentedBuffer:
|
| 112 |
+
res = IndentedBuffer()
|
| 113 |
+
res.writeline(codecache.cpp_prefix())
|
| 114 |
+
res.splice(
|
| 115 |
+
"""
|
| 116 |
+
#include "c10/util/Unroll.h"
|
| 117 |
+
"""
|
| 118 |
+
)
|
| 119 |
+
enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
|
| 120 |
+
"linux",
|
| 121 |
+
"win32",
|
| 122 |
+
]
|
| 123 |
+
if enable_kernel_profile:
|
| 124 |
+
res.writelines(["#include <ATen/record_function.h>"])
|
| 125 |
+
return res
|
| 126 |
+
|
| 127 |
+
def render(self, **kwargs) -> str:
|
| 128 |
+
raise NotImplementedError
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_template_kernel.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import itertools
|
| 3 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import sympy
|
| 6 |
+
from sympy.parsing.sympy_parser import parse_expr
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils._sympy.symbol import SymT
|
| 10 |
+
|
| 11 |
+
from .. import config, cpp_builder, ir, lowering as L
|
| 12 |
+
from ..autotune_process import CppBenchmarkRequest
|
| 13 |
+
from ..loop_body import LoopBody
|
| 14 |
+
from ..select_algorithm import PartialRender
|
| 15 |
+
from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix
|
| 16 |
+
from ..virtualized import V
|
| 17 |
+
from .common import CppWrapperKernelArgs
|
| 18 |
+
from .cpp import CppKernel, CppKernelProxy, KernelGroup
|
| 19 |
+
from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferContext
|
| 20 |
+
from .cpp_wrapper_cpu import CppWrapperCpu
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def parse_expr_with_index_symbols(expr):
|
| 24 |
+
if isinstance(expr, sympy.Expr):
|
| 25 |
+
return expr
|
| 26 |
+
elif isinstance(expr, (list, tuple)):
|
| 27 |
+
return [parse_expr_with_index_symbols(e) for e in expr]
|
| 28 |
+
else:
|
| 29 |
+
expr = parse_expr(str(expr))
|
| 30 |
+
int_symbols = {sym: sympy_index_symbol(sym.name) for sym in expr.free_symbols}
|
| 31 |
+
return expr.subs(int_symbols)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def wrap_with_tensorbox(node) -> ir.TensorBox:
|
| 35 |
+
return (
|
| 36 |
+
ir.TensorBox.create(node) if isinstance(node, ir.Buffer) else ir.TensorBox(node)
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class CppTemplateKernel(CppKernel):
|
| 41 |
+
def __init__(self, kernel_name, num_threads):
|
| 42 |
+
super().__init__(None, num_threads)
|
| 43 |
+
self.kernel_name = kernel_name
|
| 44 |
+
self.render_hooks = {}
|
| 45 |
+
self.local_buffers = {}
|
| 46 |
+
if isinstance(V.graph.wrapper_code, CppWrapperCpu):
|
| 47 |
+
self.args = CppWrapperKernelArgs()
|
| 48 |
+
|
| 49 |
+
def render(self, template, **kwargs):
|
| 50 |
+
return PartialRender(
|
| 51 |
+
template.render(kernel=self, **kwargs), self.render_hooks
|
| 52 |
+
).finalize_all()
|
| 53 |
+
|
| 54 |
+
def def_kernel(
|
| 55 |
+
self,
|
| 56 |
+
inputs: Dict[str, ir.Buffer],
|
| 57 |
+
outputs: Dict[str, ir.Buffer],
|
| 58 |
+
aliases: Optional[Dict[str, str]] = None,
|
| 59 |
+
) -> str:
|
| 60 |
+
for name, inp in inputs.items():
|
| 61 |
+
if inp is not None:
|
| 62 |
+
self.args.input_buffers[inp.get_name()] = name
|
| 63 |
+
for name, out in outputs.items():
|
| 64 |
+
self.args.output_buffers[out.get_name()] = name
|
| 65 |
+
if aliases is not None:
|
| 66 |
+
for alias, orig in aliases.items():
|
| 67 |
+
if orig in self.args.input_buffers:
|
| 68 |
+
self.args.input_buffers[alias] = self.args.input_buffers[orig]
|
| 69 |
+
if orig in self.args.output_buffers:
|
| 70 |
+
self.args.output_buffers[alias] = self.args.output_buffers[orig]
|
| 71 |
+
|
| 72 |
+
unique_sizevars = {
|
| 73 |
+
s
|
| 74 |
+
for input in inputs.values()
|
| 75 |
+
if input is not None
|
| 76 |
+
for sym in itertools.chain(input.get_size(), input.get_stride())
|
| 77 |
+
if isinstance(sym, sympy.Expr)
|
| 78 |
+
for s in sym.free_symbols
|
| 79 |
+
}
|
| 80 |
+
unique_sizevars |= {
|
| 81 |
+
s
|
| 82 |
+
for output in outputs.values()
|
| 83 |
+
for sym in itertools.chain(output.get_size(), output.get_stride())
|
| 84 |
+
if isinstance(sym, sympy.Expr)
|
| 85 |
+
for s in sym.free_symbols
|
| 86 |
+
}
|
| 87 |
+
sizevars = sorted(unique_sizevars, key=str)
|
| 88 |
+
for sizevar in sizevars:
|
| 89 |
+
self.args.sizevars[sizevar] = f"k{sizevar}"
|
| 90 |
+
|
| 91 |
+
def hook():
|
| 92 |
+
# remove all aliases before generate function definition
|
| 93 |
+
if aliases is not None:
|
| 94 |
+
for alias in aliases:
|
| 95 |
+
if alias in self.args.input_buffers:
|
| 96 |
+
self.args.input_buffers[alias] = "REMOVED"
|
| 97 |
+
if alias in self.args.output_buffers:
|
| 98 |
+
self.args.output_buffers[alias] = "REMOVED"
|
| 99 |
+
cpp_argdefs, _, _ = self.args.cpp_argdefs()
|
| 100 |
+
return f"void {self.kernel_name}({', '.join(cpp_argdefs)})"
|
| 101 |
+
|
| 102 |
+
placeholder = "<DEF_KERNEL>"
|
| 103 |
+
assert placeholder not in self.render_hooks
|
| 104 |
+
self.render_hooks[placeholder] = hook
|
| 105 |
+
return placeholder
|
| 106 |
+
|
| 107 |
+
def call_kernel(self, name: str, node: ir.CppTemplateBuffer):
|
| 108 |
+
wrapper = V.graph.wrapper_code
|
| 109 |
+
_, call_args, arg_types = self.args.cpp_argdefs()
|
| 110 |
+
wrapper.generate_kernel_call(name, call_args, cuda=False, arg_types=arg_types)
|
| 111 |
+
|
| 112 |
+
def dtype(self, node: ir.Buffer) -> str:
|
| 113 |
+
return DTYPE_TO_CPP[node.get_dtype()]
|
| 114 |
+
|
| 115 |
+
def acc_dtype(self, node: ir.Buffer) -> str:
|
| 116 |
+
if node.get_dtype() in [torch.float32, torch.bfloat16, torch.half]:
|
| 117 |
+
return "float"
|
| 118 |
+
else:
|
| 119 |
+
raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}")
|
| 120 |
+
|
| 121 |
+
def size(self, node: ir.Buffer, dim: int) -> str:
|
| 122 |
+
return cexpr_index(self.rename_indexing(node.get_size()[dim]))
|
| 123 |
+
|
| 124 |
+
def stride(self, node: ir.Buffer, dim: int) -> str:
|
| 125 |
+
return cexpr_index(self.rename_indexing(node.get_stride()[dim]))
|
| 126 |
+
|
| 127 |
+
def index(self, node: ir.Buffer, indices: List[Any]) -> str:
|
| 128 |
+
indexer = node.layout.as_fixed().make_indexer()
|
| 129 |
+
index = indexer(parse_expr_with_index_symbols(indices))
|
| 130 |
+
index = self.rename_indexing(index)
|
| 131 |
+
outer_name = node.get_name()
|
| 132 |
+
inner_name = (
|
| 133 |
+
outer_name
|
| 134 |
+
if outer_name in self.local_buffers
|
| 135 |
+
else self.args.input(node.get_name())
|
| 136 |
+
)
|
| 137 |
+
return f"{inner_name}[{cexpr_index(index)}]"
|
| 138 |
+
|
| 139 |
+
def slice_nd(self, node, ranges: List[Tuple[Any, Any]]) -> ir.ReinterpretView:
|
| 140 |
+
"""
|
| 141 |
+
Slice the given node with a list of ranges (start and end) corresponding to its dims.
|
| 142 |
+
The dim is not sliced if the corresponding range is empty.
|
| 143 |
+
"""
|
| 144 |
+
assert len(ranges) == len(node.get_size()), f"{ranges=}, {node=}"
|
| 145 |
+
sliced = wrap_with_tensorbox(node)
|
| 146 |
+
for dim, _range in enumerate(ranges):
|
| 147 |
+
if len(_range) == 0:
|
| 148 |
+
continue
|
| 149 |
+
assert len(_range) == 2
|
| 150 |
+
start, end = parse_expr_with_index_symbols(_range)
|
| 151 |
+
sliced = L.slice_(sliced, dim, start, end, clamp=False)
|
| 152 |
+
assert isinstance(sliced.data, ir.ReinterpretView), sliced.data
|
| 153 |
+
return sliced.data
|
| 154 |
+
|
| 155 |
+
def view(self, node, sizes: List[Any]) -> ir.View:
|
| 156 |
+
node = wrap_with_tensorbox(node)
|
| 157 |
+
sizes = parse_expr_with_index_symbols(sizes)
|
| 158 |
+
return L.view(node, sizes).data
|
| 159 |
+
|
| 160 |
+
def permute(self, node, dims):
|
| 161 |
+
node = wrap_with_tensorbox(node)
|
| 162 |
+
permuted = L.permute(node, dims).data
|
| 163 |
+
assert isinstance(permuted, ir.ReinterpretView)
|
| 164 |
+
return permuted
|
| 165 |
+
|
| 166 |
+
def maybe_codegen_profile(self) -> str:
|
| 167 |
+
if config.cpp.enable_kernel_profile:
|
| 168 |
+
graph_id = V.graph.graph_id
|
| 169 |
+
prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else ""
|
| 170 |
+
return f'RECORD_FUNCTION("{prefix}{self.kernel_name}", c10::ArrayRef<c10::IValue>({{}}));'
|
| 171 |
+
else:
|
| 172 |
+
return ""
|
| 173 |
+
|
| 174 |
+
def unroll_pragma(self, unroll):
|
| 175 |
+
if cpp_builder.is_gcc():
|
| 176 |
+
return f"#pragma GCC unroll {unroll}"
|
| 177 |
+
else:
|
| 178 |
+
return f"#pragma unroll {unroll}"
|
| 179 |
+
|
| 180 |
+
def define_buffer(self, name, sizes: List[Any], dtype=torch.float) -> str:
|
| 181 |
+
"""Define kernel local buffer"""
|
| 182 |
+
sizes = parse_expr_with_index_symbols(sizes)
|
| 183 |
+
buf = ir.Buffer(name, ir.FixedLayout(torch.device("cpu"), dtype, sizes))
|
| 184 |
+
self.local_buffers[name] = buf
|
| 185 |
+
ctype = f"{DTYPE_TO_CPP[dtype]}"
|
| 186 |
+
numel = f"{cexpr_index(buf.get_numel())}"
|
| 187 |
+
return f"auto _{name} = std::make_unique<{ctype}[]>({numel}); auto {name} = _{name}.get();"
|
| 188 |
+
|
| 189 |
+
def reinit_buffer_if_null(self, name):
|
| 190 |
+
"""Reinit the previously defined local buffer if it is null"""
|
| 191 |
+
assert name in self.local_buffers
|
| 192 |
+
buf = self.local_buffers[name]
|
| 193 |
+
ctype = f"{DTYPE_TO_CPP[buf.layout.dtype]}"
|
| 194 |
+
numel = f"{cexpr_index(buf.get_numel())}"
|
| 195 |
+
return f"if (_{name} == nullptr) {{ _{name} = std::make_unique<{ctype}[]>({numel}); {name} = _{name}.get(); }}"
|
| 196 |
+
|
| 197 |
+
def release_buffer(self, name):
|
| 198 |
+
"""Codegen the code to release the ownership of a local buffer to others"""
|
| 199 |
+
assert name in self.local_buffers
|
| 200 |
+
return f"_{name}.release()"
|
| 201 |
+
|
| 202 |
+
def store_pointwise_nodes(
|
| 203 |
+
self,
|
| 204 |
+
dst: ir.Buffer,
|
| 205 |
+
nodes: List[ir.IRNode],
|
| 206 |
+
offsets: Optional[List[sympy.Expr]] = None,
|
| 207 |
+
reindexers: Optional[List[Optional[Callable[[List[Any]], List[Any]]]]] = None,
|
| 208 |
+
) -> str:
|
| 209 |
+
var_sizes = (tuple(dst.get_size()), ())
|
| 210 |
+
var_ranges = {
|
| 211 |
+
sympy_index_symbol_with_prefix(SymT.INDEX, i): sz
|
| 212 |
+
for i, sz in enumerate(var_sizes[0])
|
| 213 |
+
}
|
| 214 |
+
if not offsets:
|
| 215 |
+
offsets = [sympy.Integer(0)] * len(var_sizes[0])
|
| 216 |
+
if not reindexers:
|
| 217 |
+
reindexers = [None] * len(nodes)
|
| 218 |
+
assert len(offsets) == len(var_sizes[0])
|
| 219 |
+
output_index = dst.get_layout().make_indexer()(var_ranges.keys())
|
| 220 |
+
kernel_group = KernelGroup()
|
| 221 |
+
kernel_group.args = self.args
|
| 222 |
+
cpp_kernel_proxy = CppKernelProxy(kernel_group)
|
| 223 |
+
bodies = []
|
| 224 |
+
var_sizes_list = []
|
| 225 |
+
for i, node in enumerate(nodes):
|
| 226 |
+
output_name = node.get_name() if i < len(nodes) - 1 else dst.get_name()
|
| 227 |
+
node = node.data if isinstance(node, ir.ComputedBuffer) else node
|
| 228 |
+
assert isinstance(node, ir.Pointwise), node
|
| 229 |
+
|
| 230 |
+
def fn(*args):
|
| 231 |
+
assert len(args) == 2
|
| 232 |
+
assert len(args[0]) == len(var_sizes[0])
|
| 233 |
+
assert len(args[1]) == 0
|
| 234 |
+
new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type]
|
| 235 |
+
if reindexers[i] is not None:
|
| 236 |
+
new_args = reindexers[i](new_args) # type: ignore[misc]
|
| 237 |
+
V.ops.store(
|
| 238 |
+
output_name,
|
| 239 |
+
output_index,
|
| 240 |
+
node.make_loader()(new_args).value,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
body = LoopBody(
|
| 244 |
+
fn,
|
| 245 |
+
(list(var_ranges.keys()), ()),
|
| 246 |
+
var_ranges,
|
| 247 |
+
list(var_ranges.keys()),
|
| 248 |
+
tuple(),
|
| 249 |
+
)
|
| 250 |
+
bodies.append(body)
|
| 251 |
+
var_sizes_list.append(var_sizes)
|
| 252 |
+
|
| 253 |
+
cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list)
|
| 254 |
+
kernel_group.finalize_kernel(cpp_kernel_proxy, [])
|
| 255 |
+
return kernel_group.loops_code.getvalue()
|
| 256 |
+
|
| 257 |
+
def store_output(
|
| 258 |
+
self,
|
| 259 |
+
dst: ir.Buffer,
|
| 260 |
+
src: ir.Buffer,
|
| 261 |
+
orig_src: Optional[ir.Buffer] = None,
|
| 262 |
+
epilogue_nodes: Optional[List[ir.IRNode]] = None,
|
| 263 |
+
offsets: Optional[List[Any]] = None,
|
| 264 |
+
reindexers: Optional[List[Optional[Callable[[List[Any]], List[Any]]]]] = None,
|
| 265 |
+
):
|
| 266 |
+
"""
|
| 267 |
+
Store the `src` buffer to the `dst` buffer. The size of `src` and `dst` should match.
|
| 268 |
+
If `epilogue_nodes` is provided, the `src` buffer is firstly computed with the epilogues
|
| 269 |
+
before stored to `dst`. The `epilogues_nodes` are all pointwise.
|
| 270 |
+
|
| 271 |
+
Notes:
|
| 272 |
+
1. `src` and `dst` buffer could be the same buffer in which case we are doing in-place compute
|
| 273 |
+
and stores. In case `epilogue_nodes` are not provided, we do nothing.
|
| 274 |
+
2. The `epilogue_nodes`, if exist, have computations on `src` before storing to `dst` but since
|
| 275 |
+
they come form the original Inductor IR, they might need to be adjusted before working with
|
| 276 |
+
`src` and `dst` as outlined below:
|
| 277 |
+
a) `src` or `dst` buffer could be a sub-slice of the ranges the `epilogue_nodes`work on.
|
| 278 |
+
In this case, the `offsets` could be provided to adjust the indices passed to
|
| 279 |
+
`epilogue_nodes` during codegen and the data ranges are also configured according to
|
| 280 |
+
the sizes of `src` and `dst`.
|
| 281 |
+
b) `dst` might be indexed in a different way as the `epilogue_nodes`, hence a `reindexer` is
|
| 282 |
+
needed on the indices to `epilogue_nodes` to match the indexing of `dst`.
|
| 283 |
+
c) If `src` is local, we need to add a local buffer for it and localize the `orig_src` buffer
|
| 284 |
+
in `epilogue_nodes` with `src`.
|
| 285 |
+
"""
|
| 286 |
+
assert dst.get_size() == src.get_size(), f"{dst=}, {src=}"
|
| 287 |
+
if offsets:
|
| 288 |
+
offsets = parse_expr_with_index_symbols(offsets)
|
| 289 |
+
if epilogue_nodes:
|
| 290 |
+
with LocalBufferContext(self.args) as scope:
|
| 291 |
+
assert orig_src is not None
|
| 292 |
+
if orig_src.get_name() != src.get_name():
|
| 293 |
+
scope.add_local_buffer(
|
| 294 |
+
src,
|
| 295 |
+
[
|
| 296 |
+
orig_src,
|
| 297 |
+
],
|
| 298 |
+
)
|
| 299 |
+
epilogue_nodes = scope.localize_nodes(epilogue_nodes)
|
| 300 |
+
return self.store_pointwise_nodes(
|
| 301 |
+
dst, epilogue_nodes, offsets, reindexers # type: ignore[arg-type]
|
| 302 |
+
)
|
| 303 |
+
else:
|
| 304 |
+
if dst.get_name() != src.get_name():
|
| 305 |
+
# src is local
|
| 306 |
+
copy = L.copy(dst, src).data.data
|
| 307 |
+
with LocalBufferContext(self.args) as scope:
|
| 308 |
+
scope.add_local_buffer(src)
|
| 309 |
+
return self.store_pointwise_nodes(dst, [copy])
|
| 310 |
+
else:
|
| 311 |
+
assert dst.layout == src.layout, f"{dst=}, {src=}"
|
| 312 |
+
return ""
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class CppTemplateCaller(ir.ChoiceCaller):
|
| 316 |
+
"""
|
| 317 |
+
CppTemplateCaller
|
| 318 |
+
|
| 319 |
+
This class represents a caller for CPP template kernels. It is a subclass of ir.ChoiceCaller.
|
| 320 |
+
Attributes:
|
| 321 |
+
name (str): The name of the caller.
|
| 322 |
+
category (str): The category of the caller.
|
| 323 |
+
bmreq (CppBenchmarkRequest): The benchmark request for the caller.
|
| 324 |
+
template_buffer (ir.CppTemplateBuffer): The template buffer for the caller.
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
def __init__(
|
| 328 |
+
self,
|
| 329 |
+
name: str,
|
| 330 |
+
category: str,
|
| 331 |
+
input_nodes: List[ir.Buffer],
|
| 332 |
+
layout: ir.Layout,
|
| 333 |
+
make_kernel_render: Callable[
|
| 334 |
+
[
|
| 335 |
+
ir.CppTemplateBuffer,
|
| 336 |
+
bool,
|
| 337 |
+
Optional[List[ir.IRNode]],
|
| 338 |
+
],
|
| 339 |
+
str,
|
| 340 |
+
],
|
| 341 |
+
bmreq: CppBenchmarkRequest,
|
| 342 |
+
template: "CppTemplate", # type: ignore[name-defined] # noqa: F821
|
| 343 |
+
info_kwargs: Optional[
|
| 344 |
+
Dict[str, Union[ir.PrimitiveInfoType, List[ir.PrimitiveInfoType]]]
|
| 345 |
+
] = None,
|
| 346 |
+
):
|
| 347 |
+
super().__init__(name, input_nodes, layout)
|
| 348 |
+
self.category = category
|
| 349 |
+
self.make_kernel_render = make_kernel_render
|
| 350 |
+
self.bmreq = bmreq
|
| 351 |
+
self.template = template
|
| 352 |
+
self.info_kwargs = info_kwargs
|
| 353 |
+
|
| 354 |
+
def precompile(self) -> None:
|
| 355 |
+
assert self.bmreq is not None
|
| 356 |
+
self.bmreq.precompile()
|
| 357 |
+
|
| 358 |
+
def benchmark(self, *args, out) -> float:
|
| 359 |
+
assert self.bmreq is not None
|
| 360 |
+
return self.bmreq.benchmark(*args, output_tensor=out)
|
| 361 |
+
|
| 362 |
+
def hash_key(self) -> str:
|
| 363 |
+
return "-".join(
|
| 364 |
+
[
|
| 365 |
+
self.category,
|
| 366 |
+
self.bmreq.hash_key,
|
| 367 |
+
]
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
def info_dict(
|
| 371 |
+
self,
|
| 372 |
+
) -> Dict[str, Union[ir.PrimitiveInfoType, List[ir.PrimitiveInfoType]]]:
|
| 373 |
+
return {"backend": "CPP", "op_type": "unknown"}
|
| 374 |
+
|
| 375 |
+
def output_node(self) -> ir.TensorBox:
|
| 376 |
+
return ir.TensorBox.create(
|
| 377 |
+
ir.CppTemplateBuffer(
|
| 378 |
+
layout=self.layout,
|
| 379 |
+
inputs=self.input_nodes,
|
| 380 |
+
make_kernel_render=self.make_kernel_render,
|
| 381 |
+
template=self.template,
|
| 382 |
+
choice=self,
|
| 383 |
+
)
|
| 384 |
+
)
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_utils.py
ADDED
|
@@ -0,0 +1,916 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import contextlib
|
| 3 |
+
import copy
|
| 4 |
+
import functools
|
| 5 |
+
import math
|
| 6 |
+
import sys
|
| 7 |
+
from collections import namedtuple
|
| 8 |
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
| 9 |
+
from unittest.mock import patch
|
| 10 |
+
|
| 11 |
+
import sympy
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch._prims_common import is_integer_dtype
|
| 15 |
+
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
| 16 |
+
from torch.utils._sympy.value_ranges import ValueRanges
|
| 17 |
+
|
| 18 |
+
from .. import ir
|
| 19 |
+
from ..loop_body import LoopBody
|
| 20 |
+
from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
|
| 21 |
+
from ..virtualized import ops, OpsValue, V
|
| 22 |
+
from .common import (
|
| 23 |
+
CSEVariable,
|
| 24 |
+
deduce_output_dtype_by_name,
|
| 25 |
+
ExprPrinter,
|
| 26 |
+
Kernel,
|
| 27 |
+
KernelArgs,
|
| 28 |
+
OptimizationContext,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
DTYPE_TO_CPP = {
|
| 33 |
+
torch.float32: "float",
|
| 34 |
+
torch.float64: "double",
|
| 35 |
+
torch.float16: "half",
|
| 36 |
+
torch.int64: "int64_t",
|
| 37 |
+
torch.int32: "int32_t",
|
| 38 |
+
torch.int16: "int16_t",
|
| 39 |
+
torch.int8: "int8_t",
|
| 40 |
+
torch.uint64: "uint64_t",
|
| 41 |
+
torch.uint32: "uint32_t",
|
| 42 |
+
torch.uint16: "uint16_t",
|
| 43 |
+
torch.uint8: "uint8_t",
|
| 44 |
+
torch.bool: "bool",
|
| 45 |
+
torch.bfloat16: "bfloat16",
|
| 46 |
+
torch.complex64: "c10::complex<float>",
|
| 47 |
+
torch.float8_e4m3fn: "float8_e4m3fn",
|
| 48 |
+
torch.float8_e5m2: "float8_e5m2",
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
DTYPE_TO_ATEN = {
|
| 52 |
+
torch.float32: "at::kFloat",
|
| 53 |
+
torch.float64: "at::kDouble",
|
| 54 |
+
torch.float16: "at::kHalf",
|
| 55 |
+
torch.int64: "at::kLong",
|
| 56 |
+
torch.int32: "at::kInt",
|
| 57 |
+
torch.int16: "at::kShort",
|
| 58 |
+
torch.int8: "at::kChar",
|
| 59 |
+
torch.uint64: "at::kUInt64",
|
| 60 |
+
torch.uint32: "at::kUInt32",
|
| 61 |
+
torch.uint16: "at::kUInt16",
|
| 62 |
+
torch.uint8: "at::kByte",
|
| 63 |
+
torch.uint32: "at::kUInt32",
|
| 64 |
+
torch.uint64: "at::kUInt64",
|
| 65 |
+
torch.bool: "at::kBool",
|
| 66 |
+
torch.bfloat16: "at::kBFloat16",
|
| 67 |
+
torch.complex32: "at::kComplexHalf",
|
| 68 |
+
torch.complex64: "at::kComplexFloat",
|
| 69 |
+
torch.complex128: "at::kComplexDouble",
|
| 70 |
+
torch.float8_e4m3fn: "at::kFloat8_e4m3fn",
|
| 71 |
+
torch.float8_e5m2: "at::kFloat8_e5m2",
|
| 72 |
+
torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz",
|
| 73 |
+
torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz",
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
DEVICE_TO_ATEN = {
|
| 77 |
+
"cpu": "at::kCPU",
|
| 78 |
+
"cuda": "at::kCUDA",
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
LAYOUT_TO_ATEN = {
|
| 82 |
+
torch.strided: "at::kStrided",
|
| 83 |
+
torch._mkldnn: "at::kMkldnn", # type: ignore[attr-defined]
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
_IS_WINDOWS = sys.platform == "win32"
|
| 87 |
+
|
| 88 |
+
INDEX_TYPE = "int64_t"
|
| 89 |
+
|
| 90 |
+
GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"])
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_promote_dtype(args):
|
| 94 |
+
return (
|
| 95 |
+
functools.reduce(
|
| 96 |
+
torch.promote_types, # type: ignore[arg-type]
|
| 97 |
+
[n.dtype for n in args if isinstance(n, CppCSEVariable)],
|
| 98 |
+
)
|
| 99 |
+
if all(n.dtype is not None for n in args if isinstance(n, CppCSEVariable))
|
| 100 |
+
else None # not enough info to calculate the promote dtype
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def promote_args(new_args):
|
| 105 |
+
def promote_arg(arg, promote_type):
|
| 106 |
+
if (
|
| 107 |
+
isinstance(arg, CppCSEVariable)
|
| 108 |
+
and arg.dtype
|
| 109 |
+
and promote_type
|
| 110 |
+
and arg.dtype != promote_type
|
| 111 |
+
):
|
| 112 |
+
arg = ops.to_dtype(arg, promote_type)
|
| 113 |
+
arg = arg.value if isinstance(arg, OpsValue) else arg
|
| 114 |
+
arg.dtype = promote_type
|
| 115 |
+
return arg
|
| 116 |
+
|
| 117 |
+
promote_type = get_promote_dtype(new_args)
|
| 118 |
+
promote_fn = functools.partial(
|
| 119 |
+
promote_arg,
|
| 120 |
+
promote_type=promote_type,
|
| 121 |
+
)
|
| 122 |
+
if (
|
| 123 |
+
all(
|
| 124 |
+
new_arg.dtype is not None
|
| 125 |
+
for new_arg in new_args
|
| 126 |
+
if isinstance(new_arg, CppCSEVariable)
|
| 127 |
+
)
|
| 128 |
+
and promote_type
|
| 129 |
+
):
|
| 130 |
+
new_args = list(map(promote_fn, new_args))
|
| 131 |
+
return new_args
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_opt_ctx(node: torch.fx.Node) -> OptimizationContext:
|
| 135 |
+
return node.meta.get(OptimizationContext.key, None)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def get_current_node_opt_ctx() -> OptimizationContext:
|
| 139 |
+
assert V.interpreter.current_node
|
| 140 |
+
return get_opt_ctx(V.interpreter.current_node)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def deduce_dtype_for_cpp_cse_variable(name, *args, **kwargs):
|
| 144 |
+
if (
|
| 145 |
+
output_dtype := deduce_output_dtype_by_name(
|
| 146 |
+
name,
|
| 147 |
+
*args,
|
| 148 |
+
**kwargs,
|
| 149 |
+
)
|
| 150 |
+
) is not None:
|
| 151 |
+
return output_dtype
|
| 152 |
+
elif name == "masked":
|
| 153 |
+
# <TODO> Leslie: perhaps we can also deduce the masked dtype by
|
| 154 |
+
# inputs' CppCseVariable like other. Let's check it if any
|
| 155 |
+
# unexpected failures.
|
| 156 |
+
assert (
|
| 157 |
+
hasattr(V.interpreter, "current_node")
|
| 158 |
+
and V.interpreter.current_node.target.startswith("masked_subblock")
|
| 159 |
+
and get_current_node_opt_ctx() is not None
|
| 160 |
+
)
|
| 161 |
+
return get_current_node_opt_ctx().dtype
|
| 162 |
+
else:
|
| 163 |
+
# deduce output dtype by inputs' dtype
|
| 164 |
+
assert all(
|
| 165 |
+
arg.dtype is not None for arg in args if isinstance(arg, CppCSEVariable)
|
| 166 |
+
)
|
| 167 |
+
return functools.reduce(
|
| 168 |
+
torch.promote_types, # type: ignore[arg-type]
|
| 169 |
+
[arg.dtype for arg in args if isinstance(arg, CppCSEVariable)],
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class CppCSEVariable(CSEVariable):
|
| 174 |
+
def __init__(self, name, bounds: ValueRanges[Any]) -> None:
|
| 175 |
+
super().__init__(name, bounds)
|
| 176 |
+
self.is_vec = False
|
| 177 |
+
self.dtype: Optional[torch.dtype] = None
|
| 178 |
+
self.dependent_itervars: Set[sympy.Symbol] = set()
|
| 179 |
+
|
| 180 |
+
def __repr__(self) -> str:
|
| 181 |
+
return (
|
| 182 |
+
f"CppCSEVariable(name: {self.name}, bounds: {self.bounds}, is_vec: {self.is_vec}, dtype: {self.dtype}, "
|
| 183 |
+
f"dependent_itervars: {self.dependent_itervars})"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def update_on_args(self, name, args, kwargs):
|
| 187 |
+
if name == "load":
|
| 188 |
+
# args[2] is index
|
| 189 |
+
self._set_dependent_itervars(args[2])
|
| 190 |
+
else:
|
| 191 |
+
# propagate relevant itervars and is_vec from args
|
| 192 |
+
self.dependent_itervars.update(
|
| 193 |
+
*[
|
| 194 |
+
arg.dependent_itervars
|
| 195 |
+
for arg in args
|
| 196 |
+
if isinstance(arg, CppCSEVariable)
|
| 197 |
+
]
|
| 198 |
+
)
|
| 199 |
+
if name == "index_expr":
|
| 200 |
+
self._set_dependent_itervars(args[0])
|
| 201 |
+
if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)):
|
| 202 |
+
self.is_vec = True
|
| 203 |
+
# NOTE [Deduce dtype of CppCSEVariable at runtime]
|
| 204 |
+
self.dtype = deduce_dtype_for_cpp_cse_variable(name, *args, **kwargs)
|
| 205 |
+
assert self.dtype is not None
|
| 206 |
+
|
| 207 |
+
def _set_dependent_itervars(self, index: sympy.Expr):
|
| 208 |
+
"""
|
| 209 |
+
Set the relevant itervars for this variable based on the `index` expression.
|
| 210 |
+
This includes the itervars directly used in the `index` as well as relevant itervars
|
| 211 |
+
of other cse variables used in the `index`.
|
| 212 |
+
"""
|
| 213 |
+
for s in index.free_symbols:
|
| 214 |
+
if s in V.kernel.itervars:
|
| 215 |
+
self.dependent_itervars.add(s) # type: ignore[arg-type]
|
| 216 |
+
elif s.name in V.kernel.cse.varname_map: # type: ignore[attr-defined]
|
| 217 |
+
self.dependent_itervars.update(
|
| 218 |
+
V.kernel.cse.varname_map[s.name].dependent_itervars # type: ignore[attr-defined]
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
def depends_on(self, itervar: sympy.Symbol):
|
| 222 |
+
return itervar in self.dependent_itervars
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class CppPrinter(ExprPrinter):
|
| 226 |
+
def _print_Integer(self, expr):
|
| 227 |
+
return (
|
| 228 |
+
f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L"
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
def _print_Where(self, expr):
|
| 232 |
+
c = self.paren(self.doprint(expr.args[0]))
|
| 233 |
+
p = self.paren(self.doprint(expr.args[1]))
|
| 234 |
+
q = self.paren(self.doprint(expr.args[2]))
|
| 235 |
+
return f"{c} ? {p} : {q}"
|
| 236 |
+
|
| 237 |
+
def _print_ModularIndexing(self, expr):
|
| 238 |
+
x, div, mod = expr.args
|
| 239 |
+
x = self.paren(self.doprint(x))
|
| 240 |
+
if div != 1:
|
| 241 |
+
div = self.paren(self.doprint(div))
|
| 242 |
+
if expr.is_integer:
|
| 243 |
+
x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
|
| 244 |
+
else:
|
| 245 |
+
x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
|
| 246 |
+
mod = self.paren(self.doprint(mod))
|
| 247 |
+
return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})"
|
| 248 |
+
|
| 249 |
+
def _print_FloorDiv(self, expr):
|
| 250 |
+
x, div = expr.args
|
| 251 |
+
x = self.paren(self.doprint(x))
|
| 252 |
+
div = self.paren(self.doprint(div))
|
| 253 |
+
if expr.is_integer:
|
| 254 |
+
return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
|
| 255 |
+
return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
|
| 256 |
+
|
| 257 |
+
def _print_floor(self, expr):
|
| 258 |
+
assert len(expr.args) == 1
|
| 259 |
+
r = f"std::floor({self._print(expr.args[0])})"
|
| 260 |
+
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
| 261 |
+
|
| 262 |
+
def _print_FloorToInt(self, expr):
|
| 263 |
+
assert len(expr.args) == 1
|
| 264 |
+
r = f"std::floor({self._print(expr.args[0])})"
|
| 265 |
+
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
| 266 |
+
|
| 267 |
+
def _print_TruncToInt(self, expr):
|
| 268 |
+
assert len(expr.args) == 1
|
| 269 |
+
r = f"std::trunc({self._print(expr.args[0])})"
|
| 270 |
+
return f"static_cast<{INDEX_TYPE}>({r})"
|
| 271 |
+
|
| 272 |
+
def _print_TruncToFloat(self, expr):
|
| 273 |
+
assert len(expr.args) == 1
|
| 274 |
+
return f"std::trunc({self._print(expr.args[0])})"
|
| 275 |
+
|
| 276 |
+
def _print_ToFloat(self, expr):
|
| 277 |
+
assert len(expr.args) == 1
|
| 278 |
+
return f"static_cast<double>({self._print(expr.args[0])})"
|
| 279 |
+
|
| 280 |
+
# TODO: This is wrong if one of the inputs is negative. This is hard to
|
| 281 |
+
# tickle though, as the inputs are typically positive (and if we can prove
|
| 282 |
+
# they are positive, we will have used Mod instead, for which this codegen
|
| 283 |
+
# is right).
|
| 284 |
+
def _print_PythonMod(self, expr):
|
| 285 |
+
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
| 286 |
+
|
| 287 |
+
def _print_CMod(self, expr):
|
| 288 |
+
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
| 289 |
+
|
| 290 |
+
def _print_IntTrueDiv(self, expr):
|
| 291 |
+
lhs, rhs = expr.args
|
| 292 |
+
# TODO: This is only accurate up to 2**53
|
| 293 |
+
return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
|
| 294 |
+
|
| 295 |
+
# TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
|
| 296 |
+
# use std::pow, that operates on floats
|
| 297 |
+
def _print_PowByNatural(self, expr):
|
| 298 |
+
raise NotImplementedError(
|
| 299 |
+
f"_print_PowByNatural not implemented for {type(self)}"
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
def _print_FloatTrueDiv(self, expr):
|
| 303 |
+
lhs, rhs = expr.args
|
| 304 |
+
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
|
| 305 |
+
|
| 306 |
+
def _print_FloatPow(self, expr):
|
| 307 |
+
base, exp = expr.args
|
| 308 |
+
return f"std::pow({self._print(base)}, {self._print(exp)})"
|
| 309 |
+
|
| 310 |
+
def _print_Pow(self, expr):
|
| 311 |
+
# Uses float constants to perform FP div
|
| 312 |
+
base, exp = expr.args
|
| 313 |
+
base = self._print(base)
|
| 314 |
+
|
| 315 |
+
if exp == 0.5 or exp == -0.5:
|
| 316 |
+
return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
|
| 317 |
+
if exp.is_integer:
|
| 318 |
+
exp = int(exp)
|
| 319 |
+
if exp > 0:
|
| 320 |
+
r = "*".join([self.paren(base)] * exp)
|
| 321 |
+
elif exp < 0:
|
| 322 |
+
r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp)))
|
| 323 |
+
else: # exp == 0
|
| 324 |
+
r = "1.0"
|
| 325 |
+
|
| 326 |
+
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
| 327 |
+
else:
|
| 328 |
+
# TODO: float vs double
|
| 329 |
+
return f"std::pow({base}, {float(exp)})"
|
| 330 |
+
|
| 331 |
+
def _print_Rational(self, expr):
|
| 332 |
+
# Uses float constants to perform FP div
|
| 333 |
+
if expr.q == 1:
|
| 334 |
+
r = f"{expr.p}"
|
| 335 |
+
else:
|
| 336 |
+
r = f"{expr.p}.0/{expr.q}.0"
|
| 337 |
+
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
| 338 |
+
|
| 339 |
+
def _print_ceiling(self, expr):
|
| 340 |
+
assert len(expr.args) == 1
|
| 341 |
+
r = f"std::ceil({self._print(expr.args[0])})"
|
| 342 |
+
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
| 343 |
+
|
| 344 |
+
def _print_CeilToInt(self, expr):
|
| 345 |
+
assert len(expr.args) == 1
|
| 346 |
+
r = f"std::ceil({self._print(expr.args[0])})"
|
| 347 |
+
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
| 348 |
+
|
| 349 |
+
def _print_Min(self, expr):
|
| 350 |
+
args = [self._print(a) for a in expr.args]
|
| 351 |
+
if len(args) == 2:
|
| 352 |
+
return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
|
| 353 |
+
else:
|
| 354 |
+
# Initializer list overload
|
| 355 |
+
il = "{" + ", ".join(args) + "}"
|
| 356 |
+
return f"std::min({il})"
|
| 357 |
+
|
| 358 |
+
def _print_Max(self, expr):
|
| 359 |
+
args = [self._print(a) for a in expr.args]
|
| 360 |
+
if len(args) == 2:
|
| 361 |
+
return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
|
| 362 |
+
else:
|
| 363 |
+
# Initializer list overload
|
| 364 |
+
il = "{" + ", ".join(args) + "}"
|
| 365 |
+
return f"std::max({il})"
|
| 366 |
+
|
| 367 |
+
def _print_Abs(self, expr):
|
| 368 |
+
assert len(expr.args) == 1
|
| 369 |
+
return f"std::abs({self._print(expr.args[0])})"
|
| 370 |
+
|
| 371 |
+
def _print_OpaqueUnaryFn_cos(self, expr):
|
| 372 |
+
assert len(expr.args) == 1
|
| 373 |
+
return f"std::cos({self._print(expr.args[0])})"
|
| 374 |
+
|
| 375 |
+
def _print_OpaqueUnaryFn_cosh(self, expr):
|
| 376 |
+
assert len(expr.args) == 1
|
| 377 |
+
return f"std::cosh({self._print(expr.args[0])})"
|
| 378 |
+
|
| 379 |
+
def _print_OpaqueUnaryFn_acos(self, expr):
|
| 380 |
+
assert len(expr.args) == 1
|
| 381 |
+
return f"std::acos({self._print(expr.args[0])})"
|
| 382 |
+
|
| 383 |
+
def _print_OpaqueUnaryFn_sin(self, expr):
|
| 384 |
+
assert len(expr.args) == 1
|
| 385 |
+
return f"std::sin({self._print(expr.args[0])})"
|
| 386 |
+
|
| 387 |
+
def _print_OpaqueUnaryFn_sinh(self, expr):
|
| 388 |
+
assert len(expr.args) == 1
|
| 389 |
+
return f"std::sinh({self._print(expr.args[0])})"
|
| 390 |
+
|
| 391 |
+
def _print_OpaqueUnaryFn_asin(self, expr):
|
| 392 |
+
assert len(expr.args) == 1
|
| 393 |
+
return f"std::asin({self._print(expr.args[0])})"
|
| 394 |
+
|
| 395 |
+
def _print_OpaqueUnaryFn_tan(self, expr):
|
| 396 |
+
assert len(expr.args) == 1
|
| 397 |
+
return f"std::tan({self._print(expr.args[0])})"
|
| 398 |
+
|
| 399 |
+
def _print_OpaqueUnaryFn_tanh(self, expr):
|
| 400 |
+
assert len(expr.args) == 1
|
| 401 |
+
return f"std::tanh({self._print(expr.args[0])})"
|
| 402 |
+
|
| 403 |
+
def _print_OpaqueUnaryFn_atan(self, expr):
|
| 404 |
+
assert len(expr.args) == 1
|
| 405 |
+
return f"std::atan({self._print(expr.args[0])})"
|
| 406 |
+
|
| 407 |
+
def _print_OpaqueUnaryFn_sqrt(self, expr):
|
| 408 |
+
return f"std::sqrt({self._print(expr.args[0])})"
|
| 409 |
+
|
| 410 |
+
def _print_RoundToInt(self, expr):
|
| 411 |
+
assert len(expr.args) == 1
|
| 412 |
+
# TODO: dispatch to llrint depending on index type
|
| 413 |
+
return f"std::lrint({self._print(expr.args[0])})"
|
| 414 |
+
|
| 415 |
+
def _print_RoundDecimal(self, expr):
|
| 416 |
+
assert len(expr.args) == 2
|
| 417 |
+
number, ndigits = expr.args
|
| 418 |
+
if number.is_integer:
|
| 419 |
+
# ndigits < 0 should have been filtered by the sympy function
|
| 420 |
+
assert ndigits < 0
|
| 421 |
+
raise ValueError(
|
| 422 |
+
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
|
| 423 |
+
)
|
| 424 |
+
return f"static_cast<double>(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})"
|
| 425 |
+
|
| 426 |
+
def _print_BooleanTrue(self, expr):
|
| 427 |
+
return "true"
|
| 428 |
+
|
| 429 |
+
def _print_BooleanFalse(self, expr):
|
| 430 |
+
return "false"
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
# A function to print, useful for printing sympy symbols.
|
| 434 |
+
cexpr = CppPrinter().doprint
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def cexpr_index(index):
|
| 438 |
+
return f"static_cast<{INDEX_TYPE}>({cexpr(index)})"
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def value_to_cpp(value, cpp_type):
|
| 442 |
+
if value == float("-inf"):
|
| 443 |
+
return f"-std::numeric_limits<{cpp_type}>::infinity()"
|
| 444 |
+
elif value == float("inf"):
|
| 445 |
+
return f"std::numeric_limits<{cpp_type}>::infinity()"
|
| 446 |
+
elif isinstance(value, bool):
|
| 447 |
+
return f"static_cast<{cpp_type}>({str(value).lower()})"
|
| 448 |
+
elif math.isnan(value):
|
| 449 |
+
return f"std::numeric_limits<{cpp_type}>::quiet_NaN()"
|
| 450 |
+
else:
|
| 451 |
+
return f"static_cast<{cpp_type}>({repr(value)})"
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def rewrite_index_for_function(
|
| 455 |
+
localize_buffer_handler: "LocalizeBufferHandler",
|
| 456 |
+
index: sympy.Expr,
|
| 457 |
+
global_buf_name: str,
|
| 458 |
+
):
|
| 459 |
+
# Local buffer at the inner dimensions
|
| 460 |
+
snode = V.graph.scheduler.name_to_buf[global_buf_name].defining_op
|
| 461 |
+
local_buf = localize_buffer_handler.global_to_local[global_buf_name]
|
| 462 |
+
scheduler_nodes = snode.get_nodes()
|
| 463 |
+
_, (group, reduction_group) = max(
|
| 464 |
+
scheduler_nodes, key=lambda x: int(x.is_reduction())
|
| 465 |
+
).group
|
| 466 |
+
call_ranges = tuple(group) + tuple(reduction_group)
|
| 467 |
+
indices_to_keep = [
|
| 468 |
+
f"x{len(call_ranges) - (idx + 1)}"
|
| 469 |
+
for idx in range(len(local_buf.get_layout().size))
|
| 470 |
+
]
|
| 471 |
+
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) # type: ignore[attr-defined]
|
| 472 |
+
replacements = {}
|
| 473 |
+
for x in sorted_symbols:
|
| 474 |
+
if x.name.startswith("x") and x.name not in indices_to_keep: # type: ignore[attr-defined]
|
| 475 |
+
# Only keep index used by local buffer
|
| 476 |
+
replacements[x] = sympy.core.numbers.Zero()
|
| 477 |
+
index = sympy_subs(index, replacements) # type: ignore[arg-type]
|
| 478 |
+
return index
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def rewrite_index_for_nodes(
|
| 482 |
+
localize_buffer_handler: "LocalizeBufferHandler",
|
| 483 |
+
index: sympy.Expr,
|
| 484 |
+
global_buf_name: str,
|
| 485 |
+
):
|
| 486 |
+
used_vars = {s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)}
|
| 487 |
+
index_vars = []
|
| 488 |
+
local_buf = localize_buffer_handler.global_to_local[global_buf_name]
|
| 489 |
+
for i in range(len(local_buf.get_size())):
|
| 490 |
+
var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
|
| 491 |
+
index_vars.append(var if var in used_vars else 0)
|
| 492 |
+
index = local_buf.layout.make_indexer()(index_vars)
|
| 493 |
+
return index
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
|
| 497 |
+
def __init__(
|
| 498 |
+
self,
|
| 499 |
+
inner,
|
| 500 |
+
global_to_local: Dict[str, ir.Buffer],
|
| 501 |
+
rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr],
|
| 502 |
+
) -> None:
|
| 503 |
+
super().__init__(inner)
|
| 504 |
+
self.global_to_local = global_to_local
|
| 505 |
+
self.rewrite_index = rewrite_index
|
| 506 |
+
|
| 507 |
+
def localize(self, name: str, index: sympy.Expr):
|
| 508 |
+
if self.global_to_local and name in self.global_to_local:
|
| 509 |
+
assert self.rewrite_index is not None
|
| 510 |
+
index = self.rewrite_index(self, index, name)
|
| 511 |
+
name = self.global_to_local[name].get_name()
|
| 512 |
+
return name, index
|
| 513 |
+
|
| 514 |
+
def load(self, name: str, index: sympy.Expr):
|
| 515 |
+
return self._inner.load(*self.localize(name, index))
|
| 516 |
+
|
| 517 |
+
def store(self, name, index, value, mode=None):
|
| 518 |
+
local_buffer_name, local_buffer_index = self.localize(name, index)
|
| 519 |
+
res = self._inner.store(local_buffer_name, local_buffer_index, value, mode)
|
| 520 |
+
if (
|
| 521 |
+
self.global_to_local
|
| 522 |
+
and name in self.global_to_local
|
| 523 |
+
and isinstance(V.kernel, Kernel)
|
| 524 |
+
):
|
| 525 |
+
# Remove name of local buffer from Kernel.store_buffer_names
|
| 526 |
+
# local_buffer_name is added to Kernel.store_buffer_names in Kernel.CSEProxy.store.
|
| 527 |
+
V.kernel.store_buffer_names.discard(local_buffer_name)
|
| 528 |
+
return res
|
| 529 |
+
|
| 530 |
+
def store_reduction(self, name, index, value):
|
| 531 |
+
return self._inner.store_reduction(*self.localize(name, index), value)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
class LocalBufferContext:
|
| 535 |
+
"""
|
| 536 |
+
This class creates a context that helps to generate code involving Inductor IR with
|
| 537 |
+
function local buffers. These buffers are constructed during the codegen process and
|
| 538 |
+
are used to store intermediate results such as local accumulators. We do not want to
|
| 539 |
+
add them to `V.graph` since they are not global and we do not want to add them as
|
| 540 |
+
function arguments either. So we patch the codegen processes under this scope to support
|
| 541 |
+
these buffers without exposure to the outside world.
|
| 542 |
+
"""
|
| 543 |
+
|
| 544 |
+
def __init__(self, kernel_args: KernelArgs) -> None:
|
| 545 |
+
self.kernel_args = kernel_args
|
| 546 |
+
self.exit_stack = contextlib.ExitStack()
|
| 547 |
+
# map local buffer name to local buffer
|
| 548 |
+
self.local_buffers: Dict[str, ir.Buffer] = {}
|
| 549 |
+
# map global buffer name to global buffer
|
| 550 |
+
self.global_buffers: Dict[str, ir.Buffer] = {}
|
| 551 |
+
# map global buffer name to local buffer
|
| 552 |
+
self.global_to_local: Dict[str, ir.Buffer] = {}
|
| 553 |
+
|
| 554 |
+
def __enter__(self):
|
| 555 |
+
self.exit_stack.__enter__()
|
| 556 |
+
original_get_dtype = V.graph.get_dtype
|
| 557 |
+
|
| 558 |
+
def get_dtype(name):
|
| 559 |
+
if name in self.local_buffers:
|
| 560 |
+
return self.local_buffers[name].get_dtype()
|
| 561 |
+
return original_get_dtype(name)
|
| 562 |
+
|
| 563 |
+
self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype))
|
| 564 |
+
|
| 565 |
+
original_input = self.kernel_args.input
|
| 566 |
+
|
| 567 |
+
def input(name):
|
| 568 |
+
if name in self.local_buffers:
|
| 569 |
+
return name
|
| 570 |
+
return original_input(name)
|
| 571 |
+
|
| 572 |
+
self.exit_stack.enter_context(patch.object(self.kernel_args, "input", input))
|
| 573 |
+
|
| 574 |
+
original_output = self.kernel_args.output
|
| 575 |
+
|
| 576 |
+
def output(name):
|
| 577 |
+
if name in self.local_buffers:
|
| 578 |
+
return name
|
| 579 |
+
return original_output(name)
|
| 580 |
+
|
| 581 |
+
self.exit_stack.enter_context(patch.object(self.kernel_args, "output", output))
|
| 582 |
+
|
| 583 |
+
# Set current LocalBufferContext into V
|
| 584 |
+
self.exit_stack.enter_context(V.set_local_buffer_context(self))
|
| 585 |
+
|
| 586 |
+
return self
|
| 587 |
+
|
| 588 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 589 |
+
self.local_buffers.clear()
|
| 590 |
+
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
|
| 591 |
+
|
| 592 |
+
def add_local_buffer(
|
| 593 |
+
self, local_buffer: ir.Buffer, global_buffers: Optional[List[ir.Buffer]] = None
|
| 594 |
+
):
|
| 595 |
+
assert local_buffer.get_name() not in self.local_buffers
|
| 596 |
+
self.local_buffers[local_buffer.get_name()] = local_buffer
|
| 597 |
+
if global_buffers:
|
| 598 |
+
for global_buffer in global_buffers:
|
| 599 |
+
global_buffer_name = global_buffer.get_name()
|
| 600 |
+
assert (
|
| 601 |
+
global_buffer_name not in self.global_buffers
|
| 602 |
+
and global_buffer_name not in self.global_to_local
|
| 603 |
+
)
|
| 604 |
+
self.global_buffers[global_buffer_name] = global_buffer
|
| 605 |
+
self.global_to_local[global_buffer_name] = local_buffer
|
| 606 |
+
V.graph.removed_buffers.add(global_buffer_name)
|
| 607 |
+
|
| 608 |
+
def localize_function(
|
| 609 |
+
self,
|
| 610 |
+
fn: Callable[..., Any],
|
| 611 |
+
rewrite_index: Callable[
|
| 612 |
+
["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr
|
| 613 |
+
] = rewrite_index_for_function,
|
| 614 |
+
):
|
| 615 |
+
def inner(*args, **kwargs):
|
| 616 |
+
with V.set_ops_handler(
|
| 617 |
+
LocalizeBufferHandler(
|
| 618 |
+
V.get_ops_handler(),
|
| 619 |
+
global_to_local=self.global_to_local,
|
| 620 |
+
rewrite_index=rewrite_index,
|
| 621 |
+
)
|
| 622 |
+
):
|
| 623 |
+
return fn(*args, **kwargs)
|
| 624 |
+
|
| 625 |
+
return inner
|
| 626 |
+
|
| 627 |
+
def localize_nodes(
|
| 628 |
+
self,
|
| 629 |
+
nodes: List[ir.IRNode],
|
| 630 |
+
rewrite_index: Callable[
|
| 631 |
+
["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr
|
| 632 |
+
] = rewrite_index_for_nodes,
|
| 633 |
+
) -> List[ir.IRNode]:
|
| 634 |
+
"""
|
| 635 |
+
Given `local_buf` and `global_buf` registered in current `LocalBufferContext`
|
| 636 |
+
though the method of `add_local_buffer`, localizes the `global_buf` to `local_buf`
|
| 637 |
+
for the given `nodes` and returns a new list of IR nodes that work on `local_buf`
|
| 638 |
+
instead of `global_buf`, i.e., all the loads and stores are redirected to
|
| 639 |
+
`local_buf`. This helps the fused loops to work on smaller-sized local buffers
|
| 640 |
+
for better data locality.
|
| 641 |
+
|
| 642 |
+
The the data access of `local_buf` is assumed to be contiguous with the
|
| 643 |
+
same order as the `global_buf`.
|
| 644 |
+
"""
|
| 645 |
+
assert len(nodes) > 0
|
| 646 |
+
|
| 647 |
+
def wrap_inner_fn_for_node(node: ir.IRNode):
|
| 648 |
+
loops = node.data if isinstance(node, ir.ComputedBuffer) else node
|
| 649 |
+
assert isinstance(loops, ir.Loops)
|
| 650 |
+
new_loops = copy.copy(loops)
|
| 651 |
+
if isinstance(node, ir.ComputedBuffer):
|
| 652 |
+
new_node = ir.ComputedBuffer(
|
| 653 |
+
node.get_name(), node.get_layout(), new_loops
|
| 654 |
+
)
|
| 655 |
+
else:
|
| 656 |
+
new_node = new_loops # type: ignore[assignment]
|
| 657 |
+
|
| 658 |
+
new_loops.inner_fn = self.localize_function(
|
| 659 |
+
new_loops.inner_fn,
|
| 660 |
+
rewrite_index,
|
| 661 |
+
)
|
| 662 |
+
return new_node
|
| 663 |
+
|
| 664 |
+
return [wrap_inner_fn_for_node(node) for node in nodes]
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def unify_mask_base_type(
|
| 668 |
+
buffer: IndentedBuffer,
|
| 669 |
+
vars: Tuple[CSEVariable, ...],
|
| 670 |
+
dtype=torch.float,
|
| 671 |
+
):
|
| 672 |
+
"""
|
| 673 |
+
Given list of cse variables,
|
| 674 |
+
Cast each to new mask base dtype and return casted cse variable.
|
| 675 |
+
"""
|
| 676 |
+
new_vars = (
|
| 677 |
+
V.kernel.cse.generate(
|
| 678 |
+
buffer,
|
| 679 |
+
f"{V.kernel._get_mask_cast(var, dtype)}",
|
| 680 |
+
)
|
| 681 |
+
for var in vars
|
| 682 |
+
)
|
| 683 |
+
return new_vars
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
def codegen_rand(offset, code, rand_function, dst_dtype=torch.float32):
|
| 687 |
+
assert is_integer_dtype(offset.dtype)
|
| 688 |
+
code.writeline("[&]()")
|
| 689 |
+
with code.indent():
|
| 690 |
+
code.writeline(
|
| 691 |
+
f"{DTYPE_TO_CPP[offset.dtype]} offset[{V.kernel.tiling_factor}];"
|
| 692 |
+
)
|
| 693 |
+
code.writeline(f"{DTYPE_TO_CPP[dst_dtype]} result[{V.kernel.tiling_factor}];")
|
| 694 |
+
code.writeline(f"{offset}.store(offset);")
|
| 695 |
+
code.writeline(
|
| 696 |
+
f"for( {DTYPE_TO_CPP[offset.dtype]} offset_idx = 0; offset_idx < {V.kernel.tiling_factor}; offset_idx++ )"
|
| 697 |
+
)
|
| 698 |
+
with code.indent():
|
| 699 |
+
code.writeline(rand_function)
|
| 700 |
+
num_vectors = V.kernel._get_num_vectors(dtype=dst_dtype)
|
| 701 |
+
if num_vectors == 1:
|
| 702 |
+
code.writeline(
|
| 703 |
+
f"return at::vec::Vectorized<{DTYPE_TO_CPP[dst_dtype]}>::loadu(result);"
|
| 704 |
+
)
|
| 705 |
+
else:
|
| 706 |
+
code.writeline(
|
| 707 |
+
f"return at::vec::VectorizedN<{DTYPE_TO_CPP[dst_dtype]}, {num_vectors}>::loadu(result);"
|
| 708 |
+
)
|
| 709 |
+
code.writeline("()")
|
| 710 |
+
return code
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
def get_gemm_template_output_and_compute_dtype(input_dtype):
|
| 714 |
+
if input_dtype == torch.uint8:
|
| 715 |
+
return (torch.int32, torch.int32)
|
| 716 |
+
else:
|
| 717 |
+
return (torch.float32, torch.float32)
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
def create_epilogue_with_attr(input_buffer, attr, **kwargs):
|
| 721 |
+
input_loader = input_buffer.make_loader()
|
| 722 |
+
dtype = input_buffer.get_dtype()
|
| 723 |
+
if attr == "relu":
|
| 724 |
+
|
| 725 |
+
def inner_fn(index):
|
| 726 |
+
input = input_loader(index)
|
| 727 |
+
zero = ops.constant(0, dtype)
|
| 728 |
+
return ops.maximum(input, zero)
|
| 729 |
+
|
| 730 |
+
elif attr == "gelu":
|
| 731 |
+
assert "algorithm" in kwargs
|
| 732 |
+
if kwargs["algorithm"] == "none":
|
| 733 |
+
|
| 734 |
+
def inner_fn(index):
|
| 735 |
+
input = input_loader(index)
|
| 736 |
+
if dtype != torch.float:
|
| 737 |
+
input = ops.to_dtype(input, torch.float)
|
| 738 |
+
half = ops.constant(0.5, torch.float)
|
| 739 |
+
one = ops.constant(1.0, torch.float)
|
| 740 |
+
const = ops.constant(0.7071067811865476, torch.float)
|
| 741 |
+
result = input * half * (ops.erf(input * const) + one)
|
| 742 |
+
if dtype != torch.float:
|
| 743 |
+
result = ops.to_dtype(result, dtype)
|
| 744 |
+
return result
|
| 745 |
+
|
| 746 |
+
else:
|
| 747 |
+
assert kwargs["algorithm"] == "tanh"
|
| 748 |
+
|
| 749 |
+
def inner_fn(index):
|
| 750 |
+
input = input_loader(index)
|
| 751 |
+
if dtype != torch.float:
|
| 752 |
+
input = ops.to_dtype(input, torch.float)
|
| 753 |
+
half = ops.constant(0.5, torch.float)
|
| 754 |
+
one = ops.constant(1.0, torch.float)
|
| 755 |
+
const1 = ops.constant(0.7978845608028654, torch.float)
|
| 756 |
+
const2 = ops.constant(0.044715, torch.float)
|
| 757 |
+
result = (
|
| 758 |
+
half
|
| 759 |
+
* input
|
| 760 |
+
* (
|
| 761 |
+
one
|
| 762 |
+
+ ops.tanh(const1 * (input + const2 * input * input * input))
|
| 763 |
+
)
|
| 764 |
+
)
|
| 765 |
+
if dtype != torch.float:
|
| 766 |
+
result = ops.to_dtype(result, dtype)
|
| 767 |
+
return result
|
| 768 |
+
|
| 769 |
+
elif attr == "swish":
|
| 770 |
+
|
| 771 |
+
def inner_fn(index):
|
| 772 |
+
input = input_loader(index)
|
| 773 |
+
result = input * ops.sigmoid(input)
|
| 774 |
+
return result
|
| 775 |
+
|
| 776 |
+
elif attr == "sigmoid":
|
| 777 |
+
|
| 778 |
+
def inner_fn(index):
|
| 779 |
+
return ops.sigmoid(input_loader(index))
|
| 780 |
+
|
| 781 |
+
elif attr == "tanh":
|
| 782 |
+
|
| 783 |
+
def inner_fn(index):
|
| 784 |
+
return ops.tanh(input_loader(index))
|
| 785 |
+
|
| 786 |
+
elif attr == "hardswish" or attr == "hardsigmoid":
|
| 787 |
+
|
| 788 |
+
def hardsigmoid_float(input):
|
| 789 |
+
zero = ops.constant(0, torch.float)
|
| 790 |
+
six = ops.constant(6, torch.float)
|
| 791 |
+
three = ops.constant(3, torch.float)
|
| 792 |
+
one_over_six = ops.constant(0.16666666666666666, torch.float)
|
| 793 |
+
max = ops.maximum(input + three, zero)
|
| 794 |
+
min = ops.minimum(max, six)
|
| 795 |
+
return min * one_over_six
|
| 796 |
+
|
| 797 |
+
def inner_fn(index):
|
| 798 |
+
input = input_loader(index)
|
| 799 |
+
if dtype != torch.float:
|
| 800 |
+
input = ops.to_dtype(input, torch.float)
|
| 801 |
+
result = hardsigmoid_float(input)
|
| 802 |
+
if attr == "hardswish":
|
| 803 |
+
result = input * result
|
| 804 |
+
if dtype != torch.float:
|
| 805 |
+
result = ops.to_dtype(result, dtype)
|
| 806 |
+
return result
|
| 807 |
+
|
| 808 |
+
elif attr == "leaky_relu":
|
| 809 |
+
assert "scalars" in kwargs
|
| 810 |
+
assert len(kwargs["scalars"]) == 1
|
| 811 |
+
negative_slope = kwargs["scalars"][0]
|
| 812 |
+
|
| 813 |
+
def inner_fn(index):
|
| 814 |
+
input = input_loader(index)
|
| 815 |
+
if dtype != torch.float:
|
| 816 |
+
input = ops.to_dtype(input, torch.float)
|
| 817 |
+
zero = ops.constant(0, torch.float)
|
| 818 |
+
result = ops.where(
|
| 819 |
+
input > zero, input, input * ops.constant(negative_slope, torch.float)
|
| 820 |
+
)
|
| 821 |
+
if dtype != torch.float:
|
| 822 |
+
result = ops.to_dtype(result, dtype)
|
| 823 |
+
return result
|
| 824 |
+
|
| 825 |
+
elif attr == "hardtanh":
|
| 826 |
+
assert "scalars" in kwargs
|
| 827 |
+
assert len(kwargs["scalars"]) == 2
|
| 828 |
+
min_value = kwargs["scalars"][0]
|
| 829 |
+
max_value = kwargs["scalars"][1]
|
| 830 |
+
|
| 831 |
+
def inner_fn(index):
|
| 832 |
+
input = input_loader(index)
|
| 833 |
+
if dtype != torch.float:
|
| 834 |
+
input = ops.to_dtype(input, torch.float)
|
| 835 |
+
result = ops.minimum(
|
| 836 |
+
ops.maximum(input, ops.constant(min_value, torch.float)),
|
| 837 |
+
ops.constant(max_value, torch.float),
|
| 838 |
+
)
|
| 839 |
+
if dtype != torch.float:
|
| 840 |
+
result = ops.to_dtype(result, dtype)
|
| 841 |
+
return result
|
| 842 |
+
|
| 843 |
+
elif attr in ["add", "sub", "mul"]:
|
| 844 |
+
assert "other" in kwargs
|
| 845 |
+
other = kwargs["other"]
|
| 846 |
+
num_input_dims = len(input_buffer.get_size())
|
| 847 |
+
num_other_dims = len(other.get_size())
|
| 848 |
+
dims_diff = num_input_dims - num_other_dims
|
| 849 |
+
other_loader = other.make_loader()
|
| 850 |
+
|
| 851 |
+
def inner_fn(index):
|
| 852 |
+
op = getattr(ops, attr)
|
| 853 |
+
if dims_diff != 0:
|
| 854 |
+
return op(input_loader(index), other_loader(index[dims_diff:]))
|
| 855 |
+
else:
|
| 856 |
+
return op(input_loader(index), other_loader(index))
|
| 857 |
+
|
| 858 |
+
elif attr == "bias_add":
|
| 859 |
+
assert "other" in kwargs
|
| 860 |
+
assert "beta" in kwargs
|
| 861 |
+
assert "dtype" in kwargs
|
| 862 |
+
beta = kwargs["beta"]
|
| 863 |
+
other = kwargs["other"]
|
| 864 |
+
dtype = kwargs["dtype"]
|
| 865 |
+
bias_loader = other.make_loader()
|
| 866 |
+
|
| 867 |
+
def inner_fn(index):
|
| 868 |
+
bias = bias_loader(index)
|
| 869 |
+
input = input_loader(index)
|
| 870 |
+
if beta != 1:
|
| 871 |
+
result = ops.constant(beta, torch.float) * bias + input
|
| 872 |
+
else:
|
| 873 |
+
result = bias + input
|
| 874 |
+
return result
|
| 875 |
+
|
| 876 |
+
else:
|
| 877 |
+
raise ValueError(f"Unsupported epilogue attribute: {attr}")
|
| 878 |
+
return ir.Pointwise(
|
| 879 |
+
device=input_buffer.get_device(),
|
| 880 |
+
dtype=dtype,
|
| 881 |
+
inner_fn=inner_fn,
|
| 882 |
+
ranges=input_buffer.get_size(),
|
| 883 |
+
)
|
| 884 |
+
|
| 885 |
+
|
| 886 |
+
def _get_loop_body(fn_list):
|
| 887 |
+
if all(isinstance(fn, LoopBody) for fn in fn_list):
|
| 888 |
+
loop_bodies = fn_list
|
| 889 |
+
else:
|
| 890 |
+
if hasattr(fn_list[0], "original_fn"):
|
| 891 |
+
# For the case of local buffer, we wrap the fn with localize_function
|
| 892 |
+
assert all(hasattr(fn, "original_fn") for fn in fn_list)
|
| 893 |
+
assert all(
|
| 894 |
+
isinstance(fn.original_fn.args[0]._body, LoopBody) for fn in fn_list
|
| 895 |
+
)
|
| 896 |
+
loop_bodies = [fn.original_fn.args[0]._body for fn in fn_list]
|
| 897 |
+
else:
|
| 898 |
+
assert all(isinstance(fn, functools.partial) for fn in fn_list)
|
| 899 |
+
assert all(isinstance(fn.args[0]._body, LoopBody) for fn in fn_list)
|
| 900 |
+
loop_bodies = [fn.args[0]._body for fn in fn_list]
|
| 901 |
+
assert loop_bodies is not None
|
| 902 |
+
return loop_bodies
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
def _get_dtype_from_loopbodies(loop_bodies):
|
| 906 |
+
dtypes = set()
|
| 907 |
+
for loop_body in loop_bodies:
|
| 908 |
+
graphs = [loop_body.root_block.graph] + [
|
| 909 |
+
body.graph for body in list(loop_body.subblocks.values())
|
| 910 |
+
]
|
| 911 |
+
for graph in graphs:
|
| 912 |
+
for node in graph.nodes:
|
| 913 |
+
if node.op != "call_method":
|
| 914 |
+
continue
|
| 915 |
+
dtypes.add(node.meta[OptimizationContext.key].dtype)
|
| 916 |
+
return dtypes
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cuda.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
import os
|
| 4 |
+
from itertools import chain, count
|
| 5 |
+
from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, Union
|
| 6 |
+
|
| 7 |
+
import sympy
|
| 8 |
+
|
| 9 |
+
from torch import dtype as torch_dtype
|
| 10 |
+
from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
|
| 11 |
+
from torch._inductor.runtime.triton_heuristics import grid as default_grid
|
| 12 |
+
|
| 13 |
+
from .. import config
|
| 14 |
+
from ..codecache import CudaKernelParamCache
|
| 15 |
+
from ..utils import DeferredLineBase
|
| 16 |
+
from ..virtualized import V
|
| 17 |
+
from .aoti_hipify_utils import maybe_hipify_code_wrapper
|
| 18 |
+
from .codegen_device_driver import cuda_kernel_driver, cuda_kernel_header
|
| 19 |
+
from .cpp_utils import cexpr, DTYPE_TO_CPP
|
| 20 |
+
from .cpp_wrapper_cpu import CppWrapperCpu
|
| 21 |
+
from .wrapper import SymbolicCallArg
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if TYPE_CHECKING:
|
| 25 |
+
from ..graph import GraphLowering
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DeferredCudaKernelLine(DeferredLineBase):
|
| 29 |
+
"""
|
| 30 |
+
When using cpp wrapper, CUDA kernel load and launch needs to wait for Triton kernels
|
| 31 |
+
to be tuned and stored as cubin files, so use a deferred line to backfill those information
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
kernel_name: str,
|
| 37 |
+
line_template: str,
|
| 38 |
+
keys: Tuple[str, ...],
|
| 39 |
+
):
|
| 40 |
+
super().__init__(line_template)
|
| 41 |
+
assert not isinstance(line_template, DeferredLineBase)
|
| 42 |
+
self.kernel_name = kernel_name
|
| 43 |
+
self.line_template = line_template
|
| 44 |
+
self.keys = keys
|
| 45 |
+
|
| 46 |
+
def __call__(self):
|
| 47 |
+
params = CudaKernelParamCache.get(self.kernel_name)
|
| 48 |
+
assert (
|
| 49 |
+
params is not None
|
| 50 |
+
), f"{self.kernel_name} not found in CudaKernelParamCache"
|
| 51 |
+
for key in self.keys:
|
| 52 |
+
assert (
|
| 53 |
+
key in params
|
| 54 |
+
), f"{key} not found in CudaKernelParamCache[{self.kernel_name}]"
|
| 55 |
+
if key == get_cpp_wrapper_cubin_path_name():
|
| 56 |
+
assert os.path.exists(params[key]), f"{params[key]} does not exist"
|
| 57 |
+
|
| 58 |
+
return self.line_template % tuple(params[key] for key in self.keys)
|
| 59 |
+
|
| 60 |
+
def _new_line(self, line):
|
| 61 |
+
return DeferredCudaKernelLine(self.kernel_name, line, self.keys)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class DeferredCudaDefaultGrid:
|
| 65 |
+
"""
|
| 66 |
+
A container for the default grid, which may be used by DeferredCudaGridLine
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
kernel_name: str,
|
| 72 |
+
grid,
|
| 73 |
+
grid_callable: Optional[Callable[..., Any]] = None,
|
| 74 |
+
**grid_extra_kwargs,
|
| 75 |
+
):
|
| 76 |
+
self.kernel_name = kernel_name
|
| 77 |
+
self.grid = grid
|
| 78 |
+
self.grid_callable = grid_callable
|
| 79 |
+
self.grid_extra_kwargs = grid_extra_kwargs
|
| 80 |
+
|
| 81 |
+
def _process_grid(self, grid: Union[List[Any], Tuple[Any, ...]]):
|
| 82 |
+
if isinstance(grid, (list, tuple)):
|
| 83 |
+
return [self._process_grid(e) for e in grid]
|
| 84 |
+
else:
|
| 85 |
+
return grid.inner_expr if isinstance(grid, SymbolicCallArg) else grid
|
| 86 |
+
|
| 87 |
+
def __call__(self):
|
| 88 |
+
grid = self.grid
|
| 89 |
+
assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list"
|
| 90 |
+
grid = self._process_grid(grid)
|
| 91 |
+
grid_callable = self.grid_callable or default_grid
|
| 92 |
+
if not self.grid_extra_kwargs:
|
| 93 |
+
grid_fn = grid_callable(*grid)
|
| 94 |
+
else:
|
| 95 |
+
grid_fn = grid_callable(*grid, **self.grid_extra_kwargs)
|
| 96 |
+
|
| 97 |
+
params = CudaKernelParamCache.get(self.kernel_name)
|
| 98 |
+
assert (
|
| 99 |
+
params is not None
|
| 100 |
+
), f"{self.kernel_name} not found in CudaKernelParamCache"
|
| 101 |
+
block_cfg = {
|
| 102 |
+
"XBLOCK": params["x_block"],
|
| 103 |
+
"YBLOCK": params["y_block"],
|
| 104 |
+
"ZBLOCK": params["z_block"],
|
| 105 |
+
}
|
| 106 |
+
return grid_fn(block_cfg)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class DeferredCudaGridLine(DeferredLineBase):
|
| 110 |
+
"""
|
| 111 |
+
When using cpp wrapper, CUDA kernel load and launch needs to wait for Triton kernels
|
| 112 |
+
to be tuned and stored as cubin files, so use a deferred line to backfill those information
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
kernel_name: str,
|
| 118 |
+
grid_var: str,
|
| 119 |
+
grid,
|
| 120 |
+
autotune_configs,
|
| 121 |
+
):
|
| 122 |
+
super().__init__("")
|
| 123 |
+
self.kernel_name = kernel_name
|
| 124 |
+
self.grid_var = grid_var
|
| 125 |
+
self.grid = grid
|
| 126 |
+
self.autotune_configs = autotune_configs
|
| 127 |
+
|
| 128 |
+
def __call__(self):
|
| 129 |
+
params = CudaKernelParamCache.get(self.kernel_name)
|
| 130 |
+
assert (
|
| 131 |
+
params is not None
|
| 132 |
+
), f"{self.kernel_name} not found in CudaKernelParamCache"
|
| 133 |
+
|
| 134 |
+
if self.autotune_configs is not None:
|
| 135 |
+
# This indicates the Triton kernel is a user-defined one.
|
| 136 |
+
grid = None
|
| 137 |
+
if len(self.grid) == 1:
|
| 138 |
+
grid = self.grid[0]
|
| 139 |
+
else:
|
| 140 |
+
for i, c in enumerate(self.autotune_configs):
|
| 141 |
+
if all(arg == params["meta"][key] for key, arg in c.kwargs.items()):
|
| 142 |
+
grid = self.grid[i]
|
| 143 |
+
break
|
| 144 |
+
assert grid is not None
|
| 145 |
+
elif isinstance(self.grid, DeferredCudaDefaultGrid):
|
| 146 |
+
grid = self.grid()
|
| 147 |
+
else:
|
| 148 |
+
grid = self.grid
|
| 149 |
+
|
| 150 |
+
assert len(grid) != 0, "Grid can't be empty"
|
| 151 |
+
grid_args_str = ", ".join(
|
| 152 |
+
[cexpr(V.graph.sizevars.simplify(item)) for item in grid]
|
| 153 |
+
)
|
| 154 |
+
return f" Grid {self.grid_var} = Grid({grid_args_str});"
|
| 155 |
+
|
| 156 |
+
def _new_line(self, line):
|
| 157 |
+
return DeferredCudaGridLine(
|
| 158 |
+
self.kernel_name, self.grid_var, self.grid, self.autotune_configs
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class CppWrapperCuda(CppWrapperCpu):
|
| 163 |
+
"""
|
| 164 |
+
Generates cpp wrapper for running on GPU and calls CUDA kernels
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
def __init__(self) -> None:
|
| 168 |
+
self.device = "cuda"
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.grid_id = count()
|
| 171 |
+
self.cuda = True
|
| 172 |
+
|
| 173 |
+
def write_header(self):
|
| 174 |
+
if V.graph.is_const_graph:
|
| 175 |
+
# We do not write header for constant graph, it will be written by main module.
|
| 176 |
+
return
|
| 177 |
+
|
| 178 |
+
super().write_header()
|
| 179 |
+
|
| 180 |
+
self.header.splice("#include <filesystem>")
|
| 181 |
+
if config.abi_compatible:
|
| 182 |
+
self.header.splice(
|
| 183 |
+
"#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>"
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
self.header.splice(maybe_hipify_code_wrapper(cuda_kernel_header()))
|
| 187 |
+
self.header.splice(maybe_hipify_code_wrapper(cuda_kernel_driver()))
|
| 188 |
+
|
| 189 |
+
def write_get_raw_stream(self, index, graph=None):
|
| 190 |
+
name = f"stream{index}"
|
| 191 |
+
self.writeline(maybe_hipify_code_wrapper(f"cudaStream_t {name};"))
|
| 192 |
+
self.writeline(
|
| 193 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream({index}, (void**)&{name}));"
|
| 194 |
+
)
|
| 195 |
+
return name
|
| 196 |
+
|
| 197 |
+
def define_kernel(
|
| 198 |
+
self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True
|
| 199 |
+
):
|
| 200 |
+
if not cuda:
|
| 201 |
+
return super().define_kernel(name, kernel, metadata, cuda)
|
| 202 |
+
|
| 203 |
+
def generate(self, is_inference):
|
| 204 |
+
self.prefix.writeline("\n")
|
| 205 |
+
if not V.graph.aot_mode:
|
| 206 |
+
for kernel in chain(
|
| 207 |
+
sorted(self.src_to_kernel.values()),
|
| 208 |
+
sorted([entry[0] for entry in self.user_defined_kernel_cache.values()]),
|
| 209 |
+
):
|
| 210 |
+
self.prefix.writeline(
|
| 211 |
+
maybe_hipify_code_wrapper(f"static CUfunction {kernel} = nullptr;")
|
| 212 |
+
)
|
| 213 |
+
self.prefix.writeline("\n")
|
| 214 |
+
return super().generate(is_inference)
|
| 215 |
+
|
| 216 |
+
def generate_user_defined_triton_kernel(
|
| 217 |
+
self,
|
| 218 |
+
kernel_name: str,
|
| 219 |
+
raw_args: List[Any],
|
| 220 |
+
grid: List[Any],
|
| 221 |
+
configs,
|
| 222 |
+
triton_meta,
|
| 223 |
+
constexprs,
|
| 224 |
+
):
|
| 225 |
+
# in C++ wrapper, we don't pass constexpr args, as they don't
|
| 226 |
+
# get added as parameters to the PTX code compiled from the
|
| 227 |
+
# user-defined Triton kernel (only non-constexpr args do)
|
| 228 |
+
raw_args = [
|
| 229 |
+
raw_arg for i, raw_arg in enumerate(raw_args) if i not in constexprs
|
| 230 |
+
]
|
| 231 |
+
args = [self.val_to_arg_str(v) for v in raw_args]
|
| 232 |
+
arg_types = [
|
| 233 |
+
arg.get_dtype() if hasattr(arg, "get_dtype") else type(arg)
|
| 234 |
+
for arg in raw_args
|
| 235 |
+
]
|
| 236 |
+
self.generate_kernel_call(
|
| 237 |
+
kernel_name,
|
| 238 |
+
args,
|
| 239 |
+
arg_types=arg_types,
|
| 240 |
+
raw_args=raw_args,
|
| 241 |
+
grid=grid,
|
| 242 |
+
cuda=True,
|
| 243 |
+
triton=True,
|
| 244 |
+
triton_meta=triton_meta,
|
| 245 |
+
autotune_configs=configs,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
@functools.lru_cache(None) # noqa: B019
|
| 249 |
+
def generate_load_kernel_once(
|
| 250 |
+
self,
|
| 251 |
+
kernel_name: str,
|
| 252 |
+
graph: "GraphLowering", # for per-graph caching
|
| 253 |
+
):
|
| 254 |
+
keys = (get_cpp_wrapper_cubin_path_name(), "mangled_name", "shared_mem")
|
| 255 |
+
kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name
|
| 256 |
+
self.writeline(f"if ({kernel_var_name} == nullptr) {{")
|
| 257 |
+
self.writeline(
|
| 258 |
+
DeferredCudaKernelLine(
|
| 259 |
+
kernel_name,
|
| 260 |
+
""" """
|
| 261 |
+
+ kernel_var_name
|
| 262 |
+
+ """ = loadKernel("%s", "%s", %s, this->cubin_dir_);"""
|
| 263 |
+
if V.graph.aot_mode
|
| 264 |
+
else """ """
|
| 265 |
+
+ kernel_var_name
|
| 266 |
+
+ """ = loadKernel("%s", "%s", %s);""",
|
| 267 |
+
keys,
|
| 268 |
+
)
|
| 269 |
+
)
|
| 270 |
+
self.writeline("}")
|
| 271 |
+
return kernel_var_name
|
| 272 |
+
|
| 273 |
+
def generate_args_decl(self, call_args, arg_types):
|
| 274 |
+
new_args = []
|
| 275 |
+
for arg, arg_type in zip(call_args, arg_types):
|
| 276 |
+
var_name = f"var_{next(self.arg_var_id)}"
|
| 277 |
+
if isinstance(arg_type, torch_dtype):
|
| 278 |
+
if arg.endswith(".item()"):
|
| 279 |
+
# Need to declare a scalar in this case
|
| 280 |
+
ctype = DTYPE_TO_CPP[arg_type]
|
| 281 |
+
arg = arg[:-7]
|
| 282 |
+
if config.abi_compatible:
|
| 283 |
+
self.codegen_tensor_item(
|
| 284 |
+
arg_type,
|
| 285 |
+
arg,
|
| 286 |
+
var_name,
|
| 287 |
+
)
|
| 288 |
+
else:
|
| 289 |
+
from torch import bfloat16, float16
|
| 290 |
+
|
| 291 |
+
if arg_type in (float16, bfloat16):
|
| 292 |
+
var_name_tmp = f"{var_name}_tmp"
|
| 293 |
+
self.writeline(
|
| 294 |
+
f"{ctype} {var_name_tmp} = {arg}.item<{ctype}>();"
|
| 295 |
+
)
|
| 296 |
+
self.writeline(f"float {var_name} = float({var_name_tmp});")
|
| 297 |
+
else:
|
| 298 |
+
self.writeline(
|
| 299 |
+
f"{ctype} {var_name} = {arg}.item<{ctype}>();"
|
| 300 |
+
)
|
| 301 |
+
else:
|
| 302 |
+
if config.abi_compatible:
|
| 303 |
+
self.writeline(
|
| 304 |
+
maybe_hipify_code_wrapper(f"CUdeviceptr {var_name};")
|
| 305 |
+
)
|
| 306 |
+
self.writeline(
|
| 307 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast<void**>(&{var_name})));"
|
| 308 |
+
)
|
| 309 |
+
else:
|
| 310 |
+
self.writeline(
|
| 311 |
+
maybe_hipify_code_wrapper(
|
| 312 |
+
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());"
|
| 313 |
+
)
|
| 314 |
+
)
|
| 315 |
+
elif arg_type in (sympy.Integer, int):
|
| 316 |
+
self.writeline(f"int {var_name} = {self.expr_printer(arg)};")
|
| 317 |
+
elif arg_type in (sympy.Float, float):
|
| 318 |
+
self.writeline(f"float {var_name} = {self.expr_printer(arg)};")
|
| 319 |
+
else:
|
| 320 |
+
self.writeline(f"auto {var_name} = {self.expr_printer(arg)};")
|
| 321 |
+
new_args.append(f"&{var_name}")
|
| 322 |
+
|
| 323 |
+
return ", ".join(new_args)
|
| 324 |
+
|
| 325 |
+
def generate_default_grid(
|
| 326 |
+
self,
|
| 327 |
+
kernel_name: str,
|
| 328 |
+
grid: List[Any],
|
| 329 |
+
cuda: bool = True,
|
| 330 |
+
grid_callable: Optional[Callable[..., Any]] = None,
|
| 331 |
+
**grid_extra_kwargs,
|
| 332 |
+
):
|
| 333 |
+
"""
|
| 334 |
+
Generate grid configs for launching a CUDA kernel using the grid
|
| 335 |
+
function from triton_heuristics. Because its computation needs
|
| 336 |
+
to read kernel config after autotune, it is done in a deferred way
|
| 337 |
+
using DeferredCudaDefaultGrid.
|
| 338 |
+
"""
|
| 339 |
+
if not cuda:
|
| 340 |
+
return grid
|
| 341 |
+
return DeferredCudaDefaultGrid(
|
| 342 |
+
kernel_name, grid, grid_callable, **grid_extra_kwargs
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
def generate_kernel_call(
|
| 346 |
+
self,
|
| 347 |
+
kernel_name: str,
|
| 348 |
+
call_args,
|
| 349 |
+
grid=None,
|
| 350 |
+
device_index=None,
|
| 351 |
+
cuda=True,
|
| 352 |
+
triton=True,
|
| 353 |
+
arg_types=None,
|
| 354 |
+
raw_args=None,
|
| 355 |
+
grid_fn: str = "grid",
|
| 356 |
+
triton_meta=None,
|
| 357 |
+
autotune_configs=None,
|
| 358 |
+
grid_extra_kwargs="",
|
| 359 |
+
):
|
| 360 |
+
assert arg_types is not None and len(call_args) == len(
|
| 361 |
+
arg_types
|
| 362 |
+
), "call_args and arg_types do not match"
|
| 363 |
+
|
| 364 |
+
if not cuda:
|
| 365 |
+
# Even in CppWrapperCuda, we may see cpp kernels
|
| 366 |
+
return super().generate_kernel_call(
|
| 367 |
+
kernel_name,
|
| 368 |
+
call_args,
|
| 369 |
+
grid,
|
| 370 |
+
device_index,
|
| 371 |
+
cuda,
|
| 372 |
+
triton,
|
| 373 |
+
arg_types,
|
| 374 |
+
raw_args,
|
| 375 |
+
grid_fn,
|
| 376 |
+
triton_meta,
|
| 377 |
+
autotune_configs,
|
| 378 |
+
grid_extra_kwargs,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
device_index, call_args = self.prepare_triton_kernel_call(
|
| 382 |
+
device_index, call_args
|
| 383 |
+
)
|
| 384 |
+
kernel_var_name = self.generate_load_kernel_once(kernel_name, V.graph)
|
| 385 |
+
|
| 386 |
+
# args with value 1 are added into equal_to_1 and constants
|
| 387 |
+
# in triton_meta (in the Python codegen) which makes them
|
| 388 |
+
# inlined in the PTX and compiled CUBIN
|
| 389 |
+
if (
|
| 390 |
+
triton_meta is not None
|
| 391 |
+
and "configs" in triton_meta
|
| 392 |
+
and triton_meta["configs"]
|
| 393 |
+
):
|
| 394 |
+
equal_to_1 = triton_meta["configs"][0].equal_to_1
|
| 395 |
+
call_args = [arg for i, arg in enumerate(call_args) if i not in equal_to_1]
|
| 396 |
+
arg_types = [t for i, t in enumerate(arg_types) if i not in equal_to_1]
|
| 397 |
+
|
| 398 |
+
call_args_str = self.generate_args_decl(call_args, arg_types)
|
| 399 |
+
kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
|
| 400 |
+
self.writeline(f"void* {kernel_args_var}[] = {{{call_args_str}}};")
|
| 401 |
+
stream = (
|
| 402 |
+
"stream"
|
| 403 |
+
if V.graph.aot_mode
|
| 404 |
+
else self.write_get_raw_stream(device_index, V.graph)
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
grid_var = f"{kernel_name}_grid_{next(self.grid_id)}"
|
| 408 |
+
self.writeline(
|
| 409 |
+
DeferredCudaGridLine(kernel_name, grid_var, grid, autotune_configs)
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name
|
| 413 |
+
# add debug printer code for all triton kernel related calls
|
| 414 |
+
debug_printer_manager = V.graph.wrapper_code.debug_printer
|
| 415 |
+
debug_printer_manager.set_printer_args(call_args, kernel_name, arg_types, None)
|
| 416 |
+
with debug_printer_manager:
|
| 417 |
+
self.writeline(f"if ({grid_var}.is_non_zero()) {{")
|
| 418 |
+
self.writeline(
|
| 419 |
+
DeferredCudaKernelLine(
|
| 420 |
+
kernel_name,
|
| 421 |
+
r" launchKernel({}, {}, {}, {}, %s, %s, {}, {});".format(
|
| 422 |
+
kernel_var_name,
|
| 423 |
+
f"{grid_var}.grid_x",
|
| 424 |
+
f"{grid_var}.grid_y",
|
| 425 |
+
f"{grid_var}.grid_z",
|
| 426 |
+
kernel_args_var,
|
| 427 |
+
stream,
|
| 428 |
+
),
|
| 429 |
+
("num_warps", "shared_mem"),
|
| 430 |
+
),
|
| 431 |
+
)
|
| 432 |
+
self.writeline("}")
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (201 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-311.pyc
ADDED
|
Binary file (7.61 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-311.pyc
ADDED
|
Binary file (2.27 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-311.pyc
ADDED
|
Binary file (20.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-311.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_epilogue_gen.cpython-311.pyc
ADDED
|
Binary file (20.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-311.pyc
ADDED
|
Binary file (17.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-311.pyc
ADDED
|
Binary file (1.49 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-311.pyc
ADDED
|
Binary file (72.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import logging
|
| 3 |
+
from typing import cast, Sequence
|
| 4 |
+
|
| 5 |
+
from ...._dynamo.utils import counters
|
| 6 |
+
from ... import config
|
| 7 |
+
from ...codecache import code_hash, get_path
|
| 8 |
+
from ...ir import CUDATemplateBuffer
|
| 9 |
+
from ...scheduler import BaseSchedulerNode, BaseScheduling, Scheduler, SchedulerNode
|
| 10 |
+
from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product
|
| 11 |
+
from ...virtualized import V
|
| 12 |
+
from ..common import IndentedBuffer
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
log = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class CUDACPPScheduling(BaseScheduling):
|
| 19 |
+
"""
|
| 20 |
+
Partial Scheduling implementation for CUDA C++ Kernels.
|
| 21 |
+
This class is intended to be used in combination with TritonScheduling,
|
| 22 |
+
and delegated to by CUDACombinedScheduling.
|
| 23 |
+
|
| 24 |
+
It handles fusion decisions and CUDA C++ specific template code generation.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, scheduler: Scheduler) -> None:
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.scheduler = scheduler
|
| 30 |
+
|
| 31 |
+
@classmethod
|
| 32 |
+
def get_backend_features(cls, device):
|
| 33 |
+
return {}
|
| 34 |
+
|
| 35 |
+
def group_fn(self, sizes):
|
| 36 |
+
return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes)
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def is_cuda_cpp_template(node: BaseSchedulerNode) -> bool:
|
| 40 |
+
return isinstance(node, SchedulerNode) and isinstance(
|
| 41 |
+
node.node, CUDATemplateBuffer
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def can_fuse_vertical(
|
| 45 |
+
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
| 46 |
+
) -> bool:
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
def define_kernel(self, src_code: str, node_schedule) -> str:
|
| 50 |
+
wrapper = V.graph.wrapper_code
|
| 51 |
+
if src_code in wrapper.src_to_kernel:
|
| 52 |
+
kernel_name = wrapper.src_to_kernel[src_code]
|
| 53 |
+
else:
|
| 54 |
+
fused_name = (
|
| 55 |
+
get_fused_kernel_name(node_schedule, config.triton.descriptive_names)
|
| 56 |
+
if config.triton.descriptive_names
|
| 57 |
+
else ""
|
| 58 |
+
)
|
| 59 |
+
kernel_name = "_".join(["cuda", fused_name, wrapper.next_kernel_suffix()])
|
| 60 |
+
# use the original src_code as the key
|
| 61 |
+
wrapper.src_to_kernel[src_code] = kernel_name
|
| 62 |
+
src_code = src_code.replace("KERNEL_NAME", kernel_name)
|
| 63 |
+
|
| 64 |
+
_, _, kernel_path = get_path(code_hash(src_code), "py")
|
| 65 |
+
|
| 66 |
+
compile_wrapper = IndentedBuffer()
|
| 67 |
+
compile_wrapper.writeline("async_compile.cuda(r'''")
|
| 68 |
+
compile_wrapper.splice(src_code, strip=True)
|
| 69 |
+
compile_wrapper.writeline("''', 'so')")
|
| 70 |
+
|
| 71 |
+
metadata_comment = f"# kernel path: {kernel_path}"
|
| 72 |
+
origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
|
| 73 |
+
metadata_comment += "\n" + origins + "\n" + detailed_origins
|
| 74 |
+
wrapper.define_kernel(
|
| 75 |
+
kernel_name, compile_wrapper.getvalue(), metadata_comment
|
| 76 |
+
)
|
| 77 |
+
return kernel_name
|
| 78 |
+
|
| 79 |
+
def codegen_template(
|
| 80 |
+
self,
|
| 81 |
+
template_node: BaseSchedulerNode,
|
| 82 |
+
epilogue_nodes: Sequence[BaseSchedulerNode],
|
| 83 |
+
):
|
| 84 |
+
"""
|
| 85 |
+
Codegen a CUDA template, possibly with fused epilogues
|
| 86 |
+
"""
|
| 87 |
+
counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes)
|
| 88 |
+
assert self.is_cuda_cpp_template(
|
| 89 |
+
template_node
|
| 90 |
+
), "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer"
|
| 91 |
+
template_node = cast(SchedulerNode, template_node)
|
| 92 |
+
_, (numel, rnumel) = template_node.group
|
| 93 |
+
assert rnumel == 1
|
| 94 |
+
ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node)
|
| 95 |
+
kernel, render = ctb.make_kernel_render(ctb)
|
| 96 |
+
with kernel:
|
| 97 |
+
template_node.mark_run()
|
| 98 |
+
src_code = render()
|
| 99 |
+
|
| 100 |
+
with V.set_kernel_handler(kernel):
|
| 101 |
+
node_schedule = [template_node]
|
| 102 |
+
kernel_name = self.define_kernel(src_code, node_schedule)
|
| 103 |
+
|
| 104 |
+
# debug printing values of intermediate tensors
|
| 105 |
+
_, call_args, arg_signatures, _ = kernel.args.python_argdefs()
|
| 106 |
+
debug_printer_manager = V.graph.wrapper_code.debug_printer
|
| 107 |
+
debug_printer_manager.set_printer_args(
|
| 108 |
+
call_args, kernel_name, arg_signatures, kernel
|
| 109 |
+
)
|
| 110 |
+
with debug_printer_manager:
|
| 111 |
+
kernel.call_kernel(kernel_name, ctb)
|
| 112 |
+
|
| 113 |
+
V.graph.removed_buffers |= kernel.removed_buffers
|
| 114 |
+
self.scheduler.free_buffers()
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_env.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ... import config
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
log = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_cuda_arch() -> Optional[str]:
|
| 14 |
+
try:
|
| 15 |
+
cuda_arch = config.cuda.arch
|
| 16 |
+
if cuda_arch is None:
|
| 17 |
+
# Get Compute Capability of the first Visible device
|
| 18 |
+
major, minor = torch.cuda.get_device_capability(0)
|
| 19 |
+
return str(major * 10 + minor)
|
| 20 |
+
return str(cuda_arch)
|
| 21 |
+
except Exception as e:
|
| 22 |
+
log.error("Error getting cuda arch: %s", e)
|
| 23 |
+
return None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_cuda_version() -> Optional[str]:
|
| 27 |
+
try:
|
| 28 |
+
cuda_version = config.cuda.version
|
| 29 |
+
if cuda_version is None:
|
| 30 |
+
cuda_version = torch.version.cuda
|
| 31 |
+
return cuda_version
|
| 32 |
+
except Exception as e:
|
| 33 |
+
log.error("Error getting cuda version: %s", e)
|
| 34 |
+
return None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@functools.lru_cache(None)
|
| 38 |
+
def nvcc_exist(nvcc_path: str = "nvcc") -> bool:
|
| 39 |
+
if nvcc_path is None:
|
| 40 |
+
return False
|
| 41 |
+
import subprocess
|
| 42 |
+
|
| 43 |
+
res = subprocess.call(
|
| 44 |
+
["which", nvcc_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
|
| 45 |
+
)
|
| 46 |
+
return res == 0
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
|
| 4 |
+
|
| 5 |
+
from ...autotune_process import CUDABenchmarkRequest
|
| 6 |
+
from ...ir import (
|
| 7 |
+
Buffer,
|
| 8 |
+
ChoiceCaller,
|
| 9 |
+
CUDATemplateBuffer,
|
| 10 |
+
IRNode,
|
| 11 |
+
Layout,
|
| 12 |
+
PrimitiveInfoType,
|
| 13 |
+
TensorBox,
|
| 14 |
+
)
|
| 15 |
+
from ...utils import sympy_product
|
| 16 |
+
from ...virtualized import V
|
| 17 |
+
from ..common import IndentedBuffer, Kernel, OpOverrides
|
| 18 |
+
from ..cpp_utils import CppPrinter, DTYPE_TO_CPP
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
if TYPE_CHECKING:
|
| 22 |
+
from torch._inductor.codegen.cuda.cuda_template import CUDATemplate
|
| 23 |
+
|
| 24 |
+
log = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
cexpr = CppPrinter().doprint
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _normalize_idx(index: int, total_length: int) -> int:
|
| 30 |
+
return index if index >= 0 else index + total_length
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class CUDAKernel(Kernel):
|
| 34 |
+
"""
|
| 35 |
+
Baseclass for CUDA / Cutlass based Kernels
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
overrides = OpOverrides # type: ignore[assignment]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class CUDATemplateKernel(CUDAKernel):
|
| 42 |
+
"""
|
| 43 |
+
Template kernels defined by CUDA / Cutlass in C++.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
_EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream"
|
| 47 |
+
|
| 48 |
+
def __init__(self, kernel_name) -> None:
|
| 49 |
+
"""
|
| 50 |
+
Initializes a new instance of the CUDATemplateKernel class.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
kernel_name (str): The name of the kernel.
|
| 54 |
+
"""
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.kernel_name = kernel_name
|
| 57 |
+
# Mapping from arg name to IRNode.
|
| 58 |
+
self.named_nodes: Dict[str, IRNode] = {}
|
| 59 |
+
|
| 60 |
+
def arg_name(self, node: IRNode) -> Optional[str]:
|
| 61 |
+
"""
|
| 62 |
+
Returns arg name of a given input or output node.
|
| 63 |
+
"""
|
| 64 |
+
if node is None:
|
| 65 |
+
return None
|
| 66 |
+
return {**self.args.input_buffers, **self.args.output_buffers}.get(
|
| 67 |
+
node.get_name(), None
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def check_not_null(self, node: IRNode) -> str:
|
| 71 |
+
"""
|
| 72 |
+
Generates code to check that a node is not null.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
if node is None:
|
| 76 |
+
return ""
|
| 77 |
+
|
| 78 |
+
size_str = self.size(node, 0, -1)
|
| 79 |
+
name_str = self.arg_name(node)
|
| 80 |
+
if name_str is None:
|
| 81 |
+
return ""
|
| 82 |
+
|
| 83 |
+
res = IndentedBuffer(initial_indent=2)
|
| 84 |
+
res.tabwidth = 1
|
| 85 |
+
res.splice(
|
| 86 |
+
f"""
|
| 87 |
+
{{
|
| 88 |
+
if (!{name_str}) {{
|
| 89 |
+
int64_t {name_str}_size = {size_str};
|
| 90 |
+
if ({name_str}_size > 0) {{
|
| 91 |
+
throw std::runtime_error("input {name_str} is null but size is not 0!");
|
| 92 |
+
}}
|
| 93 |
+
}}
|
| 94 |
+
}}
|
| 95 |
+
"""
|
| 96 |
+
)
|
| 97 |
+
return res.getvalue()
|
| 98 |
+
|
| 99 |
+
def def_kernel(
|
| 100 |
+
self,
|
| 101 |
+
inputs: List[IRNode],
|
| 102 |
+
outputs: List[IRNode],
|
| 103 |
+
names_str: str = "",
|
| 104 |
+
input_reorder: Optional[List[int]] = None,
|
| 105 |
+
) -> str:
|
| 106 |
+
"""
|
| 107 |
+
Hook called from template code to generate function definition and
|
| 108 |
+
needed args.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
inputs: List of input IRNodes
|
| 112 |
+
outputs: List of output IRNodes
|
| 113 |
+
names_str: Comma separated list of input + output argument names.
|
| 114 |
+
input_reorder: The actual order of input nodes.
|
| 115 |
+
e.g. The template might have input argument defined as [X, W, Bias],
|
| 116 |
+
and the actual input passed into this template could be [Bias, X, W].
|
| 117 |
+
In this case, the `input_reorder` would be [2, 0, 1].
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
names = [x.strip() for x in names_str.strip().split(",")]
|
| 121 |
+
if len(inputs) + len(outputs) != len(names):
|
| 122 |
+
raise RuntimeError(
|
| 123 |
+
f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}"
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
if input_reorder is not None:
|
| 127 |
+
assert len(inputs) == len(input_reorder)
|
| 128 |
+
else:
|
| 129 |
+
input_reorder = list(range(len(inputs)))
|
| 130 |
+
|
| 131 |
+
for idx in input_reorder:
|
| 132 |
+
name = names[idx]
|
| 133 |
+
node = inputs[idx]
|
| 134 |
+
if node is not None:
|
| 135 |
+
self.named_nodes[name] = node
|
| 136 |
+
self.args.input_buffers[node.get_name()] = name
|
| 137 |
+
|
| 138 |
+
for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs):
|
| 139 |
+
if node is not None:
|
| 140 |
+
self.named_nodes[name] = node
|
| 141 |
+
self.args.output_buffers[node.get_name()] = name
|
| 142 |
+
|
| 143 |
+
arg_defs, *_ = self.args.cpp_argdefs()
|
| 144 |
+
return f"PT_EXPORT int {self.kernel_name}({', '.join(arg_defs)}, {self._EXTRA_CPP_ARGS})"
|
| 145 |
+
|
| 146 |
+
def call_kernel(
|
| 147 |
+
self,
|
| 148 |
+
name: str,
|
| 149 |
+
node: "CUDATemplateBuffer", # type: ignore[name-defined]
|
| 150 |
+
) -> None:
|
| 151 |
+
"""
|
| 152 |
+
Generates code to call the kernel through V.graph.wrapper_code.
|
| 153 |
+
used from within torch._inductor.wrapper.WrapperCodeGen
|
| 154 |
+
|
| 155 |
+
name: Name of kernel function.
|
| 156 |
+
node: The CUDATemplateBuffer node which contains information about the kernel, it's fused epilogue nodes
|
| 157 |
+
as well as all required inputs and outputs.
|
| 158 |
+
"""
|
| 159 |
+
wrapper = V.graph.wrapper_code
|
| 160 |
+
_, call_args, _, arg_types = self.args.python_argdefs()
|
| 161 |
+
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
|
| 162 |
+
for i in range(len(call_args)):
|
| 163 |
+
if V.graph.is_unspec_arg(call_args[i]):
|
| 164 |
+
call_args[i] = call_args[i] + ".item()"
|
| 165 |
+
else:
|
| 166 |
+
call_args[i] = f"c_void_p({call_args[i]}.data_ptr())"
|
| 167 |
+
|
| 168 |
+
# workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size.
|
| 169 |
+
# workspace_size should have already been retrieved prior to this call.
|
| 170 |
+
call_args.append("None")
|
| 171 |
+
|
| 172 |
+
if node.get_workspace_size() > 0:
|
| 173 |
+
wrapper.generate_workspace_allocation(
|
| 174 |
+
node.get_workspace_size(), V.graph.scheduler.current_device, False
|
| 175 |
+
)
|
| 176 |
+
call_args.append("c_void_p(workspace.data_ptr())")
|
| 177 |
+
else:
|
| 178 |
+
call_args.append("None")
|
| 179 |
+
|
| 180 |
+
wrapper.generate_kernel_call(
|
| 181 |
+
name,
|
| 182 |
+
call_args,
|
| 183 |
+
cuda=True,
|
| 184 |
+
triton=False,
|
| 185 |
+
arg_types=arg_types,
|
| 186 |
+
)
|
| 187 |
+
if node.get_workspace_size() > 0:
|
| 188 |
+
wrapper.writeline(wrapper.make_free_by_names(["workspace"]))
|
| 189 |
+
|
| 190 |
+
def dtype(self, node: IRNode) -> Optional[str]:
|
| 191 |
+
"""
|
| 192 |
+
Generates code which represents dtype of a given node.
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
if node is None:
|
| 196 |
+
return "void"
|
| 197 |
+
return DTYPE_TO_CPP.get(node.get_layout().dtype)
|
| 198 |
+
|
| 199 |
+
def cutlass_dtype(self, node: IRNode, default_dtype="void") -> Optional[str]:
|
| 200 |
+
# Helper method, called into from CUTLASSGemmTemplate
|
| 201 |
+
if node is None:
|
| 202 |
+
return default_dtype
|
| 203 |
+
from torch._inductor.codegen.cuda.cuda_template import CUTLASSTemplate
|
| 204 |
+
|
| 205 |
+
return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]
|
| 206 |
+
|
| 207 |
+
def max_valid_index(self, node: IRNode, default=-1):
|
| 208 |
+
# Helper method, called into from CUTLASSGemmTemplate
|
| 209 |
+
if node is None:
|
| 210 |
+
return default
|
| 211 |
+
max_valid_offset = 0
|
| 212 |
+
for i in range(len(node.get_size())):
|
| 213 |
+
max_valid_offset += (node.get_size()[i] - 1) * node.get_stride()[i]
|
| 214 |
+
return max_valid_offset
|
| 215 |
+
|
| 216 |
+
def offset(self, node: IRNode) -> str:
|
| 217 |
+
"""
|
| 218 |
+
Generates code which represents offset of a given node.
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
if node is None:
|
| 222 |
+
return "0"
|
| 223 |
+
return str(node.get_layout().offset)
|
| 224 |
+
|
| 225 |
+
def ptr(self, node: IRNode) -> str:
|
| 226 |
+
"""
|
| 227 |
+
Generates code which represents pointer of a given node.
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
if node is None:
|
| 231 |
+
return "nullptr"
|
| 232 |
+
arg_name = self.arg_name(node)
|
| 233 |
+
if arg_name is None:
|
| 234 |
+
return "nullptr"
|
| 235 |
+
offset = self.offset(node)
|
| 236 |
+
return arg_name if offset == "0" else f"{arg_name} + {offset}"
|
| 237 |
+
|
| 238 |
+
def size(
|
| 239 |
+
self,
|
| 240 |
+
node: IRNode,
|
| 241 |
+
start_index: int,
|
| 242 |
+
end_index: Optional[int] = None,
|
| 243 |
+
default_value: int = 0,
|
| 244 |
+
) -> str:
|
| 245 |
+
"""
|
| 246 |
+
Hook called from template code to get the size of an arg.
|
| 247 |
+
Generates code which represents size of a given node in [start_index, end_index).
|
| 248 |
+
If node is None, returns default_value.
|
| 249 |
+
|
| 250 |
+
TODO: Will add needed args to pass it in if it is dynamic.
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
if node is None:
|
| 254 |
+
return str(default_value)
|
| 255 |
+
|
| 256 |
+
start_index = _normalize_idx(start_index, len(node.get_size()))
|
| 257 |
+
if end_index is None:
|
| 258 |
+
end_index = start_index
|
| 259 |
+
end_index = _normalize_idx(end_index, len(node.get_size()))
|
| 260 |
+
|
| 261 |
+
sizes = node.get_size()[start_index : end_index + 1]
|
| 262 |
+
if len(sizes) == 0:
|
| 263 |
+
return str(default_value)
|
| 264 |
+
|
| 265 |
+
val = sympy_product(sizes)
|
| 266 |
+
return cexpr(self.rename_indexing(val))
|
| 267 |
+
|
| 268 |
+
def stride(self, node: IRNode, index: int, default_value: int = 0) -> str:
|
| 269 |
+
"""
|
| 270 |
+
Hook called from template code to get the stride of an arg.
|
| 271 |
+
Generates code which represents stride of a given node at index.
|
| 272 |
+
If node is None, returns default_value.
|
| 273 |
+
|
| 274 |
+
TODO: Will add needed args to pass it in if it is dynamic.
|
| 275 |
+
"""
|
| 276 |
+
|
| 277 |
+
if node is None:
|
| 278 |
+
return str(default_value)
|
| 279 |
+
|
| 280 |
+
index = _normalize_idx(index, len(node.get_size()))
|
| 281 |
+
if index < 0:
|
| 282 |
+
return str(default_value)
|
| 283 |
+
|
| 284 |
+
stride = node.get_stride()[index]
|
| 285 |
+
return cexpr(self.rename_indexing(stride))
|
| 286 |
+
|
| 287 |
+
def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str:
|
| 288 |
+
"""
|
| 289 |
+
Hook called from template code to get the row or column stride of an arg.
|
| 290 |
+
This is required by some CUTLASS 2.X APIs.
|
| 291 |
+
If the node is in row_major, it returns stride[-2].
|
| 292 |
+
If the node is in column_major, it returns stride[-1].
|
| 293 |
+
|
| 294 |
+
TODO: Will add needed args to pass it in if it is dynamic.
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
if node is None or len(node.get_stride()) < 2:
|
| 298 |
+
return str(default_value)
|
| 299 |
+
|
| 300 |
+
stride0 = node.get_stride()[-1]
|
| 301 |
+
stride1 = node.get_stride()[-2]
|
| 302 |
+
if stride0 == 1:
|
| 303 |
+
return cexpr(self.rename_indexing(stride1))
|
| 304 |
+
elif stride1 == 1:
|
| 305 |
+
return cexpr(self.rename_indexing(stride0))
|
| 306 |
+
else:
|
| 307 |
+
raise RuntimeError(
|
| 308 |
+
f"At least 1 stride should be 1. Strides: {node.get_stride()=}"
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class CUDATemplateCaller(ChoiceCaller):
|
| 313 |
+
"""
|
| 314 |
+
CUDATemplateCaller
|
| 315 |
+
|
| 316 |
+
This class represents a caller for CUDA template kernels. It is a subclass of ChoiceCaller.
|
| 317 |
+
Attributes:
|
| 318 |
+
name (str): The name of the caller.
|
| 319 |
+
category (str): The category of the caller.
|
| 320 |
+
bmreq (CUDABenchmarkRequest): The benchmark request for the caller.
|
| 321 |
+
template_buffer (CUDATemplateBuffer): The template buffer for the caller.
|
| 322 |
+
"""
|
| 323 |
+
|
| 324 |
+
def __init__(
|
| 325 |
+
self,
|
| 326 |
+
name: str,
|
| 327 |
+
category: str,
|
| 328 |
+
input_nodes: List[Buffer],
|
| 329 |
+
layout: Layout,
|
| 330 |
+
make_kernel_render: Callable[[CUDATemplateBuffer, Optional[List[IRNode]]], str],
|
| 331 |
+
bmreq: CUDABenchmarkRequest,
|
| 332 |
+
template: "CUDATemplate", # type: ignore[name-defined]
|
| 333 |
+
info_kwargs: Optional[Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]], # type: ignore[type-arg]
|
| 334 |
+
) -> None:
|
| 335 |
+
super().__init__(name, input_nodes, layout)
|
| 336 |
+
self.category = category
|
| 337 |
+
self.make_kernel_render = make_kernel_render
|
| 338 |
+
self.bmreq = bmreq
|
| 339 |
+
self.template = template
|
| 340 |
+
self.info_kwargs = info_kwargs
|
| 341 |
+
|
| 342 |
+
def precompile(self) -> None:
|
| 343 |
+
assert self.bmreq is not None
|
| 344 |
+
self.bmreq.precompile()
|
| 345 |
+
|
| 346 |
+
def benchmark(self, *args, out) -> float:
|
| 347 |
+
assert self.bmreq is not None
|
| 348 |
+
return self.bmreq.benchmark(
|
| 349 |
+
*args, output_tensor=out
|
| 350 |
+
) # @TODO: Hack for ensuring that Cutlass Kernel is preferred
|
| 351 |
+
|
| 352 |
+
def __str__(self) -> str:
|
| 353 |
+
return f"CUDATemplateCaller(source_file={self.bmreq.source_file})"
|
| 354 |
+
|
| 355 |
+
def call_name(self) -> str:
|
| 356 |
+
return f"cuda_template_kernels.{self.name}"
|
| 357 |
+
|
| 358 |
+
def hash_key(self) -> str:
|
| 359 |
+
return "-".join(
|
| 360 |
+
[
|
| 361 |
+
self.category,
|
| 362 |
+
self.bmreq.hash_key,
|
| 363 |
+
]
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
|
| 367 |
+
"""Information returned here is logged to the autotune log file when that is enabled."""
|
| 368 |
+
if self.info_kwargs is not None and "op" in self.info_kwargs:
|
| 369 |
+
op: Any = self.info_kwargs["op"]
|
| 370 |
+
return {
|
| 371 |
+
"backend": "CUDA",
|
| 372 |
+
"op_type": type(op).__name__,
|
| 373 |
+
"op_conf_name": str(op.configuration_name()),
|
| 374 |
+
"op_arch": str(op.arch),
|
| 375 |
+
"tile_shape": str(op.tile_description.tile_shape),
|
| 376 |
+
"epilogue_schedule": str(op.epilogue_schedule),
|
| 377 |
+
"kernel_schedule": str(op.kernel_schedule),
|
| 378 |
+
"element_accumulator": str(op.accumulator_type()),
|
| 379 |
+
"op_name": str(op.procedural_name()),
|
| 380 |
+
"instruction_shape": str(
|
| 381 |
+
op.tile_description.math_instruction.instruction_shape
|
| 382 |
+
),
|
| 383 |
+
}
|
| 384 |
+
else:
|
| 385 |
+
return {"backend": "CUDA", "op_type": "unknown"}
|
| 386 |
+
|
| 387 |
+
def output_node(self) -> TensorBox:
|
| 388 |
+
self.bmreq.update_workspace_size()
|
| 389 |
+
return TensorBox.create(
|
| 390 |
+
CUDATemplateBuffer(
|
| 391 |
+
layout=self.layout,
|
| 392 |
+
inputs=self.input_nodes,
|
| 393 |
+
make_kernel_render=self.make_kernel_render,
|
| 394 |
+
workspace_size=self.bmreq.workspace_size,
|
| 395 |
+
template=self.template,
|
| 396 |
+
)
|
| 397 |
+
)
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_template.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
import itertools
|
| 4 |
+
import logging
|
| 5 |
+
from typing import List, Optional
|
| 6 |
+
from unittest.mock import patch
|
| 7 |
+
|
| 8 |
+
import sympy
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from ...autotune_process import CUDABenchmarkRequest, TensorMeta
|
| 13 |
+
from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout
|
| 14 |
+
from ...utils import IndentedBuffer, unique
|
| 15 |
+
from ...virtualized import V
|
| 16 |
+
from ..common import KernelTemplate
|
| 17 |
+
from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
log = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CUDATemplate(KernelTemplate):
|
| 24 |
+
index_counter = itertools.count()
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
name: str,
|
| 29 |
+
input_nodes: List[Buffer],
|
| 30 |
+
layout: Layout,
|
| 31 |
+
input_reorder: Optional[List[int]] = None,
|
| 32 |
+
) -> None:
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
Baseclass for CUDA C++ Templates, derived from KernelTemplate. Not to be instantiated directly.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
name (str): The name of the CUDATemplate object.
|
| 39 |
+
input_nodes (List[IRNode]): A list of input IRNodes.
|
| 40 |
+
layout (Layout): The layout of the output buffer / tensor.
|
| 41 |
+
input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes.
|
| 42 |
+
|
| 43 |
+
"""
|
| 44 |
+
super().__init__(name)
|
| 45 |
+
self.input_nodes = input_nodes
|
| 46 |
+
self.output_node: Buffer = Buffer("buf_out", layout)
|
| 47 |
+
self.input_reorder = input_reorder
|
| 48 |
+
self.layout = layout
|
| 49 |
+
|
| 50 |
+
def generate( # type: ignore[override]
|
| 51 |
+
self,
|
| 52 |
+
**kwargs,
|
| 53 |
+
) -> CUDATemplateCaller:
|
| 54 |
+
"""
|
| 55 |
+
Generates the CUDA template caller object for the given GEMM template and operation. This CUDATemplateCaller
|
| 56 |
+
may be used to call and benchmark the generated CUDA kernel in a standalone manner to enable Autotuning.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
kwargs: Additional keyword arguments.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
A CUDATemplateCaller object representing the generated CUDA template caller.
|
| 63 |
+
"""
|
| 64 |
+
kernel_name = f"cuda_{self.name}"
|
| 65 |
+
with patch.object(
|
| 66 |
+
V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
|
| 67 |
+
), CUDATemplateKernel(
|
| 68 |
+
kernel_name=kernel_name,
|
| 69 |
+
) as kernel:
|
| 70 |
+
code = self.render(kernel=kernel, **kwargs)
|
| 71 |
+
_, call_args, _, _ = kernel.args.python_argdefs()
|
| 72 |
+
log.debug("Generated Code:\n%s", code)
|
| 73 |
+
log.debug(
|
| 74 |
+
"Args: cpp_argdefs: %s, python_argdefs: %s",
|
| 75 |
+
kernel.args.cpp_argdefs(),
|
| 76 |
+
kernel.args.python_argdefs(),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
input_reorder = (
|
| 80 |
+
self.input_reorder
|
| 81 |
+
if self.input_reorder is not None
|
| 82 |
+
else list(range(len(self.input_nodes)))
|
| 83 |
+
)
|
| 84 |
+
expected_args = list(
|
| 85 |
+
unique(self.input_nodes[idx].get_name() for idx in input_reorder)
|
| 86 |
+
)
|
| 87 |
+
expected_args.extend([self.output_node.get_name()])
|
| 88 |
+
assert list(call_args)[: len(expected_args)] == expected_args, (
|
| 89 |
+
call_args,
|
| 90 |
+
expected_args,
|
| 91 |
+
)
|
| 92 |
+
extra_args = V.graph.sizevars.size_hints(
|
| 93 |
+
map(sympy.expand, call_args[len(expected_args) :])
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
kernel_hash_name = f"cuda_{self.name}_{next(self.index_counter)}"
|
| 97 |
+
|
| 98 |
+
# create the BenchmarkRequest
|
| 99 |
+
bmreq = CUDABenchmarkRequest(
|
| 100 |
+
kernel_name=kernel_name,
|
| 101 |
+
input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
|
| 102 |
+
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
|
| 103 |
+
extra_args=extra_args,
|
| 104 |
+
source_code=code,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def make_kernel_render(
|
| 108 |
+
template_node: CUDATemplateBuffer,
|
| 109 |
+
epilogue_nodes: Optional[List[IRNode]] = None,
|
| 110 |
+
):
|
| 111 |
+
kernel = CUDATemplateKernel(
|
| 112 |
+
kernel_name="KERNEL_NAME",
|
| 113 |
+
)
|
| 114 |
+
render = functools.partial(
|
| 115 |
+
self.render,
|
| 116 |
+
kernel=kernel,
|
| 117 |
+
template_buffer_node=template_node,
|
| 118 |
+
epilogue_nodes=epilogue_nodes,
|
| 119 |
+
**kwargs, # includes "op" argument in case of CUTLASSGemmTemplate
|
| 120 |
+
)
|
| 121 |
+
return kernel, render
|
| 122 |
+
|
| 123 |
+
return CUDATemplateCaller(
|
| 124 |
+
kernel_hash_name,
|
| 125 |
+
self.name,
|
| 126 |
+
self.input_nodes,
|
| 127 |
+
self.output_node.get_layout(),
|
| 128 |
+
make_kernel_render,
|
| 129 |
+
bmreq,
|
| 130 |
+
self,
|
| 131 |
+
kwargs,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def header(self) -> IndentedBuffer:
|
| 135 |
+
res = IndentedBuffer()
|
| 136 |
+
res.splice(
|
| 137 |
+
"""
|
| 138 |
+
#include <exception>
|
| 139 |
+
#include <iostream>
|
| 140 |
+
#include <memory>
|
| 141 |
+
#include <random>
|
| 142 |
+
#include <vector>
|
| 143 |
+
"""
|
| 144 |
+
)
|
| 145 |
+
return res
|
| 146 |
+
|
| 147 |
+
def globals(self) -> IndentedBuffer:
|
| 148 |
+
res = IndentedBuffer()
|
| 149 |
+
res.splice(
|
| 150 |
+
"""
|
| 151 |
+
// We compile all models with -fvisibility=hidden. Any symbols that need to be
|
| 152 |
+
// exposed in the final shared library must be declared with PT_EXPORT to make
|
| 153 |
+
// them visible.
|
| 154 |
+
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
|
| 155 |
+
#define PT_EXPORT __attribute__((__visibility__("default")))
|
| 156 |
+
#else
|
| 157 |
+
#ifdef _WIN32
|
| 158 |
+
#define PT_EXPORT __declspec(dllexport)
|
| 159 |
+
#else
|
| 160 |
+
#define PT_EXPORT
|
| 161 |
+
#endif
|
| 162 |
+
#endif
|
| 163 |
+
using bfloat16 = nv_bfloat16;
|
| 164 |
+
"""
|
| 165 |
+
)
|
| 166 |
+
return res
|
| 167 |
+
|
| 168 |
+
def render(self, **kwargs) -> str:
|
| 169 |
+
raise NotImplementedError
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class CUTLASSTemplate(CUDATemplate):
|
| 173 |
+
"""
|
| 174 |
+
CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the
|
| 175 |
+
CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
def header(self) -> IndentedBuffer:
|
| 179 |
+
res = super().header()
|
| 180 |
+
res.splice(
|
| 181 |
+
"""
|
| 182 |
+
#include "cute/tensor.hpp"
|
| 183 |
+
#include "cutlass/cutlass.h"
|
| 184 |
+
#include "cutlass/numeric_types.h"
|
| 185 |
+
#include "cutlass/tensor_ref.h"
|
| 186 |
+
#include "cutlass/util/host_tensor.h"
|
| 187 |
+
#include "cutlass/util/reference/host/tensor_fill.h"
|
| 188 |
+
#include "cutlass/util/reference/device/tensor_fill.h"
|
| 189 |
+
#include "cutlass/util/device_memory.h"
|
| 190 |
+
"""
|
| 191 |
+
)
|
| 192 |
+
return res
|
| 193 |
+
|
| 194 |
+
def globals(self) -> IndentedBuffer:
|
| 195 |
+
res = super().globals()
|
| 196 |
+
res.splice(
|
| 197 |
+
"""
|
| 198 |
+
using namespace cute;
|
| 199 |
+
#define CUTLASS_CHECK(status) \\
|
| 200 |
+
{ \\
|
| 201 |
+
cutlass::Status error = status; \\
|
| 202 |
+
if (error != cutlass::Status::kSuccess) { \\
|
| 203 |
+
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\
|
| 204 |
+
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\
|
| 205 |
+
throw std::runtime_error(msg); \\
|
| 206 |
+
} \\
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
// Used as pass-through functor in EVT just for type casting / rounding
|
| 210 |
+
template <typename T>
|
| 211 |
+
struct identity_op {
|
| 212 |
+
CUTLASS_HOST_DEVICE
|
| 213 |
+
T operator()(T val) const { return val; }
|
| 214 |
+
};
|
| 215 |
+
|
| 216 |
+
"""
|
| 217 |
+
)
|
| 218 |
+
return res
|
| 219 |
+
|
| 220 |
+
def cute_int(self, int_str: str, var_name: str) -> str:
|
| 221 |
+
res = ""
|
| 222 |
+
if int_str in {"1", "1L"}:
|
| 223 |
+
res = "cute::Int<1>{}"
|
| 224 |
+
else:
|
| 225 |
+
res = int_str
|
| 226 |
+
|
| 227 |
+
return f"{res} /* {var_name} */"
|
| 228 |
+
|
| 229 |
+
_DTYPE_TO_CUTLASS = {
|
| 230 |
+
torch.float32: "float",
|
| 231 |
+
torch.float64: "double",
|
| 232 |
+
torch.float16: "cutlass::half_t",
|
| 233 |
+
torch.int32: "int32_t",
|
| 234 |
+
torch.int16: "int16_t",
|
| 235 |
+
torch.int8: "int8_t",
|
| 236 |
+
torch.uint8: "uint8_t",
|
| 237 |
+
torch.bool: "bool",
|
| 238 |
+
torch.bfloat16: "cutlass::bfloat16_t",
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
_DTYPE_TO_CUTLASS_SPARSE_META = {
|
| 242 |
+
torch.int32: "uint32_t",
|
| 243 |
+
torch.int16: "uint16_t",
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
def cutlass_type_cast(self, node: IRNode, ptr: str) -> str:
|
| 247 |
+
if node is None:
|
| 248 |
+
return ptr
|
| 249 |
+
else:
|
| 250 |
+
return f"({self._DTYPE_TO_CUTLASS.get(node.get_dtype())}*)({ptr})"
|
| 251 |
+
|
| 252 |
+
def cutlass_sparse_meta_type_cast(self, node: IRNode, ptr: str) -> str:
|
| 253 |
+
if node is None:
|
| 254 |
+
return ptr
|
| 255 |
+
else:
|
| 256 |
+
return (
|
| 257 |
+
f"({self._DTYPE_TO_CUTLASS_SPARSE_META.get(node.get_dtype())}*)({ptr})"
|
| 258 |
+
)
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from typing import Dict, List
|
| 3 |
+
from unittest.mock import patch
|
| 4 |
+
|
| 5 |
+
import sympy
|
| 6 |
+
|
| 7 |
+
import torch._inductor.virtualized as virtualized
|
| 8 |
+
from torch._inductor.ir import ComputedBuffer, FlexibleLayout, IRNode, Pointwise
|
| 9 |
+
from torch._inductor.utils import IndentedBuffer, sympy_str
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Used as a magic string to indicate an unsupported sympy expression
|
| 13 |
+
# became part of generated C++ code.
|
| 14 |
+
_MAGIC_SYMPY_ERROR_STRING = "[!sympy: unsupported expr!]"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _arg_str(a):
|
| 18 |
+
if isinstance(a, sympy.Expr):
|
| 19 |
+
# If this return value containing the _MAGIC_SYMPY_ERROR_STRING
|
| 20 |
+
# is used as part of the final generated C++ code,
|
| 21 |
+
# a CUTLASSEVTOpNotImplementedError is raised to indicate that
|
| 22 |
+
# the op could not be converted to a valid EVT expression.
|
| 23 |
+
return f"{_MAGIC_SYMPY_ERROR_STRING}('{sympy_str(a)}')"
|
| 24 |
+
return str(a)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class CUTLASSEVTOpNotImplementedError(NotImplementedError):
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class CutlassEVTEpilogueTypeFormatter:
|
| 32 |
+
"""
|
| 33 |
+
Codegen class, which provides an entry point to generate
|
| 34 |
+
Cutlass "Epilogue Visitor Tree" (EVT) functor declarations.
|
| 35 |
+
|
| 36 |
+
See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder
|
| 37 |
+
for more about EVTs and how they are declared and used to generate.
|
| 38 |
+
|
| 39 |
+
Notes:
|
| 40 |
+
* Used by CUTLASSGemmTemplate.
|
| 41 |
+
* This class should not be instantiated by users, it is intended to be used
|
| 42 |
+
by calling CutlassEVTEpilogueTypeFormatter.ir_to_evt_string(...)
|
| 43 |
+
which instantiates this class as an ops handler for virtualized.V.ops.[op-name]
|
| 44 |
+
* Extend this with more _op_<whatever> nodes to add support for new pointwise operations.
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, accumulator_node_name, evt_type_name):
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
Initialize an instance of CutlassEVTEpilogueTypeFormatter.
|
| 53 |
+
|
| 54 |
+
Parameters:
|
| 55 |
+
- accumulator_node_name (str): The name of the output Buffer for the GEMM operation in the original (unfused)
|
| 56 |
+
IR graph.
|
| 57 |
+
- evt_type_name (str): The output name of the EVT type we are generating.
|
| 58 |
+
|
| 59 |
+
"""
|
| 60 |
+
self.accumulator_node_name = accumulator_node_name
|
| 61 |
+
self.output = IndentedBuffer(0)
|
| 62 |
+
self.var_counter = 0
|
| 63 |
+
self.evt_type_name = evt_type_name
|
| 64 |
+
self.aliases = {}
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def ir_to_evt_string(
|
| 68 |
+
template_output_node_name: str,
|
| 69 |
+
evt_type_name: str,
|
| 70 |
+
epilogue_nodes: List[IRNode],
|
| 71 |
+
):
|
| 72 |
+
"""
|
| 73 |
+
Formats IR nodes into a string representation compatible with Cutlass EVT format.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
template_output_node_name (str): The name of the template output node.
|
| 77 |
+
evt_type_name (str): The name of the EVT type.
|
| 78 |
+
epilogue_nodes (List[IRNode]): A list of IR nodes representing the epilogue nodes. As of now, these must be
|
| 79 |
+
ComputedBuffer nodes wrapping Pointwise nodes.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
A string representation of the IR nodes formatted according to the Cutlass EVT format.
|
| 83 |
+
"""
|
| 84 |
+
formatter = CutlassEVTEpilogueTypeFormatter(
|
| 85 |
+
template_output_node_name, evt_type_name
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
with virtualized.V.set_ops_handler(formatter), patch.object(
|
| 89 |
+
FlexibleLayout, "allow_indexing", True
|
| 90 |
+
):
|
| 91 |
+
for node in epilogue_nodes:
|
| 92 |
+
if isinstance(node, ComputedBuffer):
|
| 93 |
+
pnode = node.data
|
| 94 |
+
else:
|
| 95 |
+
raise RuntimeError(
|
| 96 |
+
"Epilogue nodes must be Pointwise nodes, wrapped in a named ComputedBuffer"
|
| 97 |
+
)
|
| 98 |
+
assert isinstance(pnode, Pointwise)
|
| 99 |
+
index = pnode._index(pnode.ranges)
|
| 100 |
+
result = pnode.inner_fn(index)
|
| 101 |
+
# each epilogue node results in a single "using" statement and may refer to the previous steps by name
|
| 102 |
+
formatter.aliases[node.name] = result
|
| 103 |
+
res = formatter.getvalue(result) # type: ignore[possibly-undefined]
|
| 104 |
+
if _MAGIC_SYMPY_ERROR_STRING in res:
|
| 105 |
+
raise CUTLASSEVTOpNotImplementedError(
|
| 106 |
+
"sympy / indexing expressions not yet supported in EVT fusion"
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
return res
|
| 110 |
+
|
| 111 |
+
def __getattr__(self, name):
|
| 112 |
+
"""
|
| 113 |
+
Resolve V.ops.<whatever> calls, after this instance has been installed as V.ops handler.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
def inner(*args, **kwargs):
|
| 117 |
+
fargs = [_arg_str(a) for a in args]
|
| 118 |
+
fkwargs = {key: _arg_str(a) for key, a in kwargs.items()}
|
| 119 |
+
fn = getattr(self, f"_op_{name}")
|
| 120 |
+
line = fn(*fargs, **fkwargs)
|
| 121 |
+
self.var_counter += 1
|
| 122 |
+
varname = f"EVT_expr_{self.var_counter}"
|
| 123 |
+
# replace line with a new variable name
|
| 124 |
+
self.output.writeline(f"using {varname} = {line};")
|
| 125 |
+
return varname
|
| 126 |
+
|
| 127 |
+
if name.startswith("_"):
|
| 128 |
+
raise CUTLASSEVTOpNotImplementedError(name)
|
| 129 |
+
if hasattr(self, f"_op_{name}"):
|
| 130 |
+
return inner
|
| 131 |
+
else:
|
| 132 |
+
raise CUTLASSEVTOpNotImplementedError(name)
|
| 133 |
+
|
| 134 |
+
def _op_load(self, name, index_expr):
|
| 135 |
+
# Load an input to an operation. Might be the output of the matmul, the result
|
| 136 |
+
# of a previous epilogue node, a constant or (TODO) an auxiliary input.
|
| 137 |
+
if name == self.accumulator_node_name:
|
| 138 |
+
return f"cutlass::epilogue::fusion::Sm90AccFetch /* :={name} (matmul output in accumulator) */"
|
| 139 |
+
elif name in self.aliases:
|
| 140 |
+
return self.aliases[name]
|
| 141 |
+
else:
|
| 142 |
+
# return f"cutlass::epilogue::fusion::Sm90SrcFetch /* :={name} */"
|
| 143 |
+
raise CUTLASSEVTOpNotImplementedError(
|
| 144 |
+
f"Operand {name} not found. Auxiliary inputs not supported yet."
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
def _op_constant(self, value, dtype):
|
| 148 |
+
# Load a constant
|
| 149 |
+
if str(dtype) in ("torch.float16", "torch.float32"):
|
| 150 |
+
return f"cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc> /* value={value}, dtype={dtype} */"
|
| 151 |
+
else:
|
| 152 |
+
raise CUTLASSEVTOpNotImplementedError(
|
| 153 |
+
f"Unsupported dtype for constant: {dtype}"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def _cutlass_binary_functional_op(self, op, a, b):
|
| 157 |
+
# Perform a named operation on two inputs
|
| 158 |
+
# see https://github.com/NVIDIA/cutlass/blob/6407bcdf0a24097b7b016ee105937693c62f9923/include/cutlass/functional.h for ops
|
| 159 |
+
return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::{op}, ElementAcc, ElementAcc, RoundStyle>,{a},{b}>" # noqa: B950
|
| 160 |
+
|
| 161 |
+
def _convert_to_output_dtype(self, a):
|
| 162 |
+
# Convert the final output to the dtype of the output buffer
|
| 163 |
+
return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<identity_op, ElementD, ElementAcc, RoundStyle>,{a}>" # noqa: B950
|
| 164 |
+
|
| 165 |
+
def _op_to_dtype(self, a, *args, **kwargs):
|
| 166 |
+
# no-op in our case, since we convert to the output dtype at the end and convert everything to the accumulator
|
| 167 |
+
# dtype.
|
| 168 |
+
# Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible
|
| 169 |
+
# throughout the fusion chain.
|
| 170 |
+
return a # noqa: B950
|
| 171 |
+
|
| 172 |
+
def _op_mul(self, a, b):
|
| 173 |
+
return self._cutlass_binary_functional_op("multiplies", a, b)
|
| 174 |
+
|
| 175 |
+
def _op_div(self, a, b):
|
| 176 |
+
return self._cutlass_binary_functional_op("divides", a, b)
|
| 177 |
+
|
| 178 |
+
def _op_truediv(self, a, b):
|
| 179 |
+
return self._cutlass_binary_functional_op("divides", a, b)
|
| 180 |
+
|
| 181 |
+
def _op_ge(self, a, b):
|
| 182 |
+
return self._cutlass_binary_functional_op("greater_equal", a, b)
|
| 183 |
+
|
| 184 |
+
def _op_add(self, a, b):
|
| 185 |
+
return self._cutlass_binary_functional_op("plus", a, b)
|
| 186 |
+
|
| 187 |
+
def _op_sub(self, a, b):
|
| 188 |
+
return self._cutlass_binary_functional_op("minus", a, b)
|
| 189 |
+
|
| 190 |
+
def _op_minimum(self, a, b):
|
| 191 |
+
return self._cutlass_binary_functional_op("minimum", a, b)
|
| 192 |
+
|
| 193 |
+
def _op_maximum(self, a, b):
|
| 194 |
+
return self._cutlass_binary_functional_op("maximum", a, b)
|
| 195 |
+
|
| 196 |
+
def _op_relu(self, a):
|
| 197 |
+
const_zero = self._op_constant(0.0, "torch.float32")
|
| 198 |
+
return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::maximum, ElementAcc, ElementAcc, RoundStyle>,{a}, {const_zero}>" # noqa: B950
|
| 199 |
+
|
| 200 |
+
def reduction(self, dtype, src_dtype, reduction_type, value):
|
| 201 |
+
raise CUTLASSEVTOpNotImplementedError
|
| 202 |
+
|
| 203 |
+
# Add more ops here...
|
| 204 |
+
def getvalue(self, result) -> str:
|
| 205 |
+
# Return final result
|
| 206 |
+
dtype_converted_expr = self._convert_to_output_dtype(
|
| 207 |
+
f"EVT_expr_{self.var_counter}"
|
| 208 |
+
)
|
| 209 |
+
self.output.writeline(f"using {self.evt_type_name} = {dtype_converted_expr};")
|
| 210 |
+
return self.output.getvalue()
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class CutlassEVTEpilogueArgumentFormatter:
|
| 214 |
+
"""
|
| 215 |
+
Codegen class, which provides an entry point to generate
|
| 216 |
+
Cutlass "Epilogue Visitor Tree" (EVT) Argument initializers
|
| 217 |
+
|
| 218 |
+
See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder
|
| 219 |
+
for more about EVTs and how they are declared and used to generate.
|
| 220 |
+
|
| 221 |
+
Notes:
|
| 222 |
+
* Used by CUTLASSGemmTemplate.
|
| 223 |
+
* This class should not be instantiated by users, it is intended to be used
|
| 224 |
+
by calling CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string(...)
|
| 225 |
+
which instantiates this class as an ops handler for virtualized.V.ops.[op-name]
|
| 226 |
+
* Extend this with more _op_<whatever> nodes to add support for new pointwise operations.
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
def __init__(self, accumulator_node_name: str):
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
Initializes a CutlassEVTEpilogueArgumentFormatter object. Do not instantiate directly.
|
| 235 |
+
Use the CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string static method.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
accumulator_node_name (str): The name of the accumulator node which should contain
|
| 239 |
+
the Matmul result before fusion according to the IR graph.
|
| 240 |
+
"""
|
| 241 |
+
self.accumulator_node_name: str = accumulator_node_name #
|
| 242 |
+
self.output: IndentedBuffer = IndentedBuffer(0) # The output buffer for codegen
|
| 243 |
+
self.var_counter: int = (
|
| 244 |
+
0 # used to generate variable names, incremented for each new variable
|
| 245 |
+
)
|
| 246 |
+
self.aliases: Dict[str, str] = {} # Aliases for subexpression functors
|
| 247 |
+
|
| 248 |
+
@staticmethod
|
| 249 |
+
def ir_to_evt_argument_string(
|
| 250 |
+
template_output_node_name: str,
|
| 251 |
+
epilogue_nodes: List[IRNode],
|
| 252 |
+
) -> str:
|
| 253 |
+
formatter = CutlassEVTEpilogueArgumentFormatter(
|
| 254 |
+
template_output_node_name,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
with virtualized.V.set_ops_handler(formatter), patch.object(
|
| 258 |
+
FlexibleLayout, "allow_indexing", True
|
| 259 |
+
):
|
| 260 |
+
for node in epilogue_nodes:
|
| 261 |
+
assert isinstance(node, ComputedBuffer)
|
| 262 |
+
pnode = node.data
|
| 263 |
+
assert isinstance(pnode, Pointwise)
|
| 264 |
+
index = pnode._index(pnode.ranges)
|
| 265 |
+
result = pnode.inner_fn(index)
|
| 266 |
+
# each epilogue node results in a single "using" statement and may refer to the previous steps by name
|
| 267 |
+
if node.name is not None:
|
| 268 |
+
formatter.aliases[node.name] = result
|
| 269 |
+
|
| 270 |
+
res: str = formatter.getvalue(result) # type: ignore[possibly-undefined]
|
| 271 |
+
if _MAGIC_SYMPY_ERROR_STRING in res:
|
| 272 |
+
raise CUTLASSEVTOpNotImplementedError(
|
| 273 |
+
"sympy / indexing expressions not yet supported in EVT fusion"
|
| 274 |
+
)
|
| 275 |
+
else:
|
| 276 |
+
return res
|
| 277 |
+
|
| 278 |
+
def __getattr__(self, name):
|
| 279 |
+
def inner(*args, **kwargs):
|
| 280 |
+
fargs = [_arg_str(a) for a in args]
|
| 281 |
+
fkwargs = {key: _arg_str(a) for key, a in kwargs.items()}
|
| 282 |
+
fn = getattr(self, f"_op_{name}")
|
| 283 |
+
line = fn(*fargs, **fkwargs)
|
| 284 |
+
return line
|
| 285 |
+
|
| 286 |
+
if name.startswith("_"):
|
| 287 |
+
raise CUTLASSEVTOpNotImplementedError(name)
|
| 288 |
+
|
| 289 |
+
if hasattr(self, f"_op_{name}"):
|
| 290 |
+
return inner
|
| 291 |
+
else:
|
| 292 |
+
raise CUTLASSEVTOpNotImplementedError(name)
|
| 293 |
+
|
| 294 |
+
def _op_load(self, name, index_expr):
|
| 295 |
+
if name == self.accumulator_node_name:
|
| 296 |
+
return "{}"
|
| 297 |
+
elif name in self.aliases:
|
| 298 |
+
return self.aliases[name]
|
| 299 |
+
else:
|
| 300 |
+
raise CUTLASSEVTOpNotImplementedError(
|
| 301 |
+
f"Operand {name} not found. Auxiliary inputs not supported yet."
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
def _op_constant(self, value, dtype):
|
| 305 |
+
if str(dtype) in ("torch.float16", "torch.float32"):
|
| 306 |
+
return "{ static_cast<ElementAcc>(" + str(value) + ") }"
|
| 307 |
+
else:
|
| 308 |
+
raise CUTLASSEVTOpNotImplementedError(
|
| 309 |
+
f"Unsupported dtype for constant: {dtype}"
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
def _cutlass_binary_functional_op(self, op, a, b):
|
| 313 |
+
return f"{{ /*{op}: */ {a}, {b} }}"
|
| 314 |
+
|
| 315 |
+
def _op_mul(self, a, b):
|
| 316 |
+
return self._cutlass_binary_functional_op("multiplies", a, b)
|
| 317 |
+
|
| 318 |
+
def _op_div(self, a, b):
|
| 319 |
+
return self._cutlass_binary_functional_op("divides", a, b)
|
| 320 |
+
|
| 321 |
+
def _op_truediv(self, a, b):
|
| 322 |
+
return self._cutlass_binary_functional_op("divides", a, b)
|
| 323 |
+
|
| 324 |
+
def _op_ge(self, a, b):
|
| 325 |
+
return self._cutlass_binary_functional_op("greater_equal", a, b)
|
| 326 |
+
|
| 327 |
+
def _op_add(self, a, b):
|
| 328 |
+
return self._cutlass_binary_functional_op("plus", a, b)
|
| 329 |
+
|
| 330 |
+
def _op_sub(self, a, b):
|
| 331 |
+
return self._cutlass_binary_functional_op("minus", a, b)
|
| 332 |
+
|
| 333 |
+
def _op_minimum(self, a, b):
|
| 334 |
+
return self._cutlass_binary_functional_op("minimum", a, b)
|
| 335 |
+
|
| 336 |
+
def _op_maximum(self, a, b):
|
| 337 |
+
return self._cutlass_binary_functional_op("maximum", a, b)
|
| 338 |
+
|
| 339 |
+
def _op_relu(self, a):
|
| 340 |
+
const_zero = self._op_constant(0.0, "torch.float32")
|
| 341 |
+
return "{" + str(a) + ", " + const_zero + "}"
|
| 342 |
+
|
| 343 |
+
def _op_to_dtype(self, a, dtype, src_dtype=None):
|
| 344 |
+
# Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible
|
| 345 |
+
# throughout the fusion chain.
|
| 346 |
+
assert dtype in (
|
| 347 |
+
"torch.float32",
|
| 348 |
+
"torch.float16",
|
| 349 |
+
), f"Unsupported dtype: {dtype}"
|
| 350 |
+
assert src_dtype in (
|
| 351 |
+
None,
|
| 352 |
+
"torch.float32",
|
| 353 |
+
"torch.float16",
|
| 354 |
+
), f"Unsupported source dtype: {src_dtype}"
|
| 355 |
+
return a
|
| 356 |
+
|
| 357 |
+
def reduction(self, dtype, src_dtype, reduction_type, value):
|
| 358 |
+
raise CUTLASSEVTOpNotImplementedError
|
| 359 |
+
|
| 360 |
+
def getvalue(self, result) -> str:
|
| 361 |
+
return "{" + str(result) + "}"
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (224 Bytes). View file
|
|
|