koichi12 commited on
Commit
a2ec7d8
·
verified ·
1 Parent(s): b2f8f15

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +8 -0
  2. .venv/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2 +3 -0
  3. .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/codecache.cpython-311.pyc +3 -0
  4. .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-311.pyc +3 -0
  5. .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/lowering.cpython-311.pyc +3 -0
  6. .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-311.pyc +3 -0
  7. .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/utils.cpython-311.pyc +3 -0
  8. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__init__.py +0 -0
  9. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/aoti_hipify_utils.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/codegen_device_driver.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_gemm_template.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_micro_gemm.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_template.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_template_kernel.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_utils.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cuda.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/debug_utils.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_combo_kernel.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/implementation.cpp +87 -0
  25. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/common.py +2167 -0
  26. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp.py +0 -0
  27. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_gemm_template.py +1043 -0
  28. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py +850 -0
  29. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_template.py +128 -0
  30. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_template_kernel.py +384 -0
  31. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_utils.py +916 -0
  32. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py +0 -0
  33. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cuda.py +432 -0
  34. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__init__.py +0 -0
  35. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_epilogue_gen.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +114 -0
  45. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_env.py +46 -0
  46. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py +397 -0
  47. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_template.py +258 -0
  48. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py +361 -0
  49. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py +0 -0
  50. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -129,3 +129,11 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
129
  .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
130
  .venv/lib/python3.11/site-packages/torch/nn/__pycache__/functional.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
131
  .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/scheduler.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
129
  .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
130
  .venv/lib/python3.11/site-packages/torch/nn/__pycache__/functional.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
131
  .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/scheduler.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
132
+ .venv/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2 filter=lfs diff=lfs merge=lfs -text
133
+ .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
134
+ .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/codecache.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
135
+ .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
136
+ .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/symbolic_shapes.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
137
+ .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
138
+ .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/lowering.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
139
+ .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78df2f31f6db8142ec546a1e5a31cb066f7892d12d2f665b448f8069a08ef807
3
+ size 251616632
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/codecache.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1ba20726a513f57e01fc1fbf9c3744defdeda5d64e6e3a00d7d3911f4f598d2
3
+ size 164293
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:495965e46b513011b3387880294e810069bb3299277002dd35d6e15e1a3d6508
3
+ size 118734
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/lowering.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c863587cdf0f8eef657d2fa0f0ebf9ddddc19a24d5670869719203bc7d877e48
3
+ size 337621
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd3ce0e8ac0de613f90615aaf063bff822e142ca75c5993718647f82d9d0add5
3
+ size 109858
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/utils.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a247c49f4e32c680bff1ed7aa611b6eae0c91d995c632fbc3fa35605649b638b
3
+ size 109445
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/aoti_hipify_utils.cpython-311.pyc ADDED
Binary file (1.2 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/codegen_device_driver.cpython-311.pyc ADDED
Binary file (3.87 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_gemm_template.cpython-311.pyc ADDED
Binary file (49.5 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_micro_gemm.cpython-311.pyc ADDED
Binary file (32.5 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_template.cpython-311.pyc ADDED
Binary file (7.74 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_template_kernel.cpython-311.pyc ADDED
Binary file (26.7 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_utils.cpython-311.pyc ADDED
Binary file (57.9 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cuda.cpython-311.pyc ADDED
Binary file (22.9 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-311.pyc ADDED
Binary file (6.63 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/debug_utils.cpython-311.pyc ADDED
Binary file (7.01 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-311.pyc ADDED
Binary file (44.4 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-311.pyc ADDED
Binary file (21 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_combo_kernel.cpython-311.pyc ADDED
Binary file (65.5 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-311.pyc ADDED
Binary file (9.27 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-311.pyc ADDED
Binary file (8.65 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/implementation.cpp ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // NOTE: Like interface.cpp, this file will be copied into AOTInductor
2
+ // generated output. This file is intended to keep implementation
3
+ // details separate from the implementation of the AOTI public
4
+ // interface. Note also that #includes should go into interface.cpp
5
+ // for simplicity of maintenance.
6
+
7
+ namespace torch {
8
+ namespace aot_inductor {
9
+ template <typename T>
10
+ void convert_output_to_handle(
11
+ const ArrayRefTensor<T>& output,
12
+ AtenTensorHandle& handle) {
13
+ handle = output.expensiveCopyToTensor();
14
+ }
15
+
16
+ template <typename... Ts, std::size_t... Is>
17
+ void convert_outputs_to_handles_helper(
18
+ const std::tuple<ArrayRefTensor<Ts>...>& outputs,
19
+ AtenTensorHandle* output_handles,
20
+ std::index_sequence<Is...>) {
21
+ (convert_output_to_handle(std::get<Is>(outputs), output_handles[Is]), ...);
22
+ }
23
+ template <typename... Ts>
24
+ void convert_outputs_to_handles(
25
+ const std::tuple<ArrayRefTensor<Ts>...>& outputs,
26
+ AtenTensorHandle* output_handles) {
27
+ convert_outputs_to_handles_helper(
28
+ outputs, output_handles, std::make_index_sequence<sizeof...(Ts)>());
29
+ }
30
+
31
+ template <typename T>
32
+ void convert_handle_to_arrayref_tensor(
33
+ AtenTensorHandle handle,
34
+ ArrayRefTensor<T>& input) {
35
+ void* data_ptr;
36
+ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle, &data_ptr));
37
+ int64_t dim;
38
+ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(handle, &dim));
39
+ int64_t numel;
40
+ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(handle, &numel));
41
+ int64_t* sizes;
42
+ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle, &sizes));
43
+ int64_t* strides;
44
+ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle, &strides));
45
+ int32_t dtype;
46
+ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(handle, &dtype));
47
+ int32_t device_type;
48
+ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(handle, &device_type));
49
+ int32_t device_index;
50
+ AOTI_TORCH_ERROR_CODE_CHECK(
51
+ aoti_torch_get_device_index(handle, &device_index));
52
+
53
+ input = ArrayRefTensor<T>(
54
+ MiniArrayRef<T>(reinterpret_cast<T*>(data_ptr), numel),
55
+ MiniArrayRef<const int64_t>(sizes, dim),
56
+ MiniArrayRef<const int64_t>(strides, dim),
57
+ device_type,
58
+ device_index);
59
+ }
60
+
61
+ template <typename... Ts, std::size_t... Is>
62
+ void convert_handles_to_inputs_helper(
63
+ AtenTensorHandle* input_handles,
64
+ std::tuple<ArrayRefTensor<Ts>...>& inputs,
65
+ std::index_sequence<Is...>) {
66
+ (convert_handle_to_arrayref_tensor(input_handles[Is], std::get<Is>(inputs)),
67
+ ...);
68
+ }
69
+
70
+ template <typename... Ts>
71
+ void convert_handles_to_inputs(
72
+ AtenTensorHandle* input_handles,
73
+ std::tuple<ArrayRefTensor<Ts>...>& inputs) {
74
+ convert_handles_to_inputs_helper(
75
+ input_handles, inputs, std::make_index_sequence<sizeof...(Ts)>());
76
+ }
77
+
78
+ template <typename T>
79
+ void assert_numel(const ArrayRefTensor<T>& tensor, uint64_t numel) {
80
+ if (tensor.numel() != numel) {
81
+ std::stringstream err;
82
+ err << "incorrect numel for input tensor. expected " << numel << ", got " << tensor.numel();
83
+ throw std::runtime_error(err.str());
84
+ }
85
+ }
86
+ } // namespace aot_inductor
87
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/common.py ADDED
@@ -0,0 +1,2167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import contextlib
3
+ import dataclasses
4
+ import functools
5
+ import itertools
6
+ import logging
7
+ import math
8
+ import operator
9
+ import re
10
+ from enum import auto, Enum
11
+ from itertools import chain
12
+ from typing import (
13
+ Any,
14
+ Callable,
15
+ ClassVar,
16
+ Dict,
17
+ List,
18
+ NamedTuple,
19
+ Optional,
20
+ Tuple,
21
+ Union,
22
+ )
23
+
24
+ import sympy
25
+ from sympy.printing.printer import Printer
26
+
27
+ import torch
28
+ import torch.fx
29
+ from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
30
+ from torch.utils import _pytree as pytree
31
+ from torch.utils._ordered_set import OrderedSet
32
+ from torch.utils._sympy.numbers import int_oo
33
+ from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
34
+ from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
35
+
36
+ from .. import config, metrics
37
+ from ..utils import (
38
+ DeferredLineBase,
39
+ generate_assert,
40
+ IndentedBuffer,
41
+ sympy_dot,
42
+ sympy_subs,
43
+ unique,
44
+ )
45
+ from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
46
+
47
+
48
+ schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
49
+
50
+
51
+ def data_type_logger(msg):
52
+ if schedule_log.isEnabledFor(logging.DEBUG):
53
+ schedule_log.debug("Data type propagation: %s", msg)
54
+
55
+
56
+ @dataclasses.dataclass
57
+ class WorkspaceArg:
58
+ """A temporary buffer used for a single kernel, then discarded.
59
+
60
+ Not registered as a traditional buffer since there are no users,
61
+ so it would be dead code eliminated.
62
+ """
63
+
64
+ nbytes: sympy.Expr
65
+ zero_fill: bool
66
+
67
+
68
+ @dataclasses.dataclass
69
+ class TensorArg:
70
+ name: str
71
+ buffer: str
72
+ dtype: torch.dtype
73
+ offset: sympy.Expr = sympy.Integer(0) # c++ only
74
+ alias_of: Optional[str] = None # halide only
75
+
76
+
77
+ @dataclasses.dataclass
78
+ class SizeArg:
79
+ name: str
80
+ expr: sympy.Expr
81
+
82
+ @property
83
+ def alias_of(self):
84
+ return None
85
+
86
+
87
+ @dataclasses.dataclass
88
+ class DeviceCodegen:
89
+ scheduling: Any
90
+ wrapper_codegen: type
91
+ cpp_wrapper_codegen: type = type(None)
92
+
93
+
94
+ KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg]
95
+
96
+ device_codegens: Dict[str, DeviceCodegen] = {}
97
+
98
+
99
+ class DeviceOpOverrides:
100
+ def import_get_raw_stream_as(self, name):
101
+ raise NotImplementedError
102
+
103
+ def set_device(self, device_idx):
104
+ raise NotImplementedError
105
+
106
+ def synchronize(self):
107
+ raise NotImplementedError
108
+
109
+ def device_guard(self, device_idx):
110
+ raise NotImplementedError
111
+
112
+
113
+ device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}
114
+
115
+
116
+ # The code generated by Inductor consists of two main parts: kernel code and wrapper code.
117
+ # For any new backend looking to integrate with Inductor, customization of these two main
118
+ # parts are necessary to generate its specific code.
119
+ #
120
+ # Kernel code generation is determined by different Scheduling. Consequently, a new
121
+ # backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
122
+ # CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
123
+ #
124
+ # For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
125
+ # that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
126
+ # and override specific member functions to create backend-specific Python wrapper code.
127
+ #
128
+ # Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
129
+ # of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
130
+ # provide flexibility to the backend. A backend can choose to implement these classes from scratch,
131
+ # or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
132
+ # register_backend_for_device, to equip a new backend at runtime.
133
+ #
134
+ # Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
135
+ # This backend can be used as a reference:
136
+ # https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
137
+ def register_backend_for_device(
138
+ device: str,
139
+ device_scheduling: Any,
140
+ device_wrapper_codegen: type,
141
+ device_cpp_wrapper_codegen: type = type(None),
142
+ ):
143
+ device_codegens[device] = DeviceCodegen(
144
+ device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
145
+ )
146
+
147
+
148
+ class BackendFeature(Enum):
149
+ FOREACH = auto()
150
+ BUCKETIZE = auto()
151
+ INPLACE_BUFFERS = auto()
152
+ MASKED_SCATTER_WITH_INDEX = auto()
153
+ SCAN = auto()
154
+ SORT = auto()
155
+ TUPLE_REDUCTION = auto()
156
+ PREFER_STORE_LOOP_ORDER = auto()
157
+ TRITON_TEMPLATES = auto()
158
+ REDUCE_TO_SINGLE_ELEMENT = auto()
159
+
160
+
161
+ def get_backend_features(device: Union[torch.device, str]):
162
+ init_backend_registration()
163
+ if isinstance(device, torch.device):
164
+ device_type = device.type
165
+ else:
166
+ assert isinstance(device, str)
167
+ device_type = device
168
+ device = torch.device(device_type)
169
+ scheduling = get_scheduling_for_device(device_type)
170
+ return scheduling(None).get_backend_features(device)
171
+
172
+
173
+ def has_backend_feature(device, feature):
174
+ """See also V.graph.has_feature"""
175
+ assert isinstance(feature, BackendFeature)
176
+ return feature in get_backend_features(device)
177
+
178
+
179
+ def get_scheduling_for_device(device: str):
180
+ return device_codegens[device].scheduling if device in device_codegens else None
181
+
182
+
183
+ def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False):
184
+ if device in device_codegens:
185
+ wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
186
+ return (
187
+ wrapper_codegen_obj.cpp_wrapper_codegen
188
+ if cpp_wrapper
189
+ else wrapper_codegen_obj.wrapper_codegen
190
+ )
191
+ else:
192
+ return None
193
+
194
+
195
+ @functools.lru_cache(None)
196
+ def init_backend_registration():
197
+ from .cpp import CppScheduling
198
+ from .cpp_wrapper_cpu import CppWrapperCpu
199
+ from .cpp_wrapper_cuda import CppWrapperCuda
200
+ from .cuda_combined_scheduling import CUDACombinedScheduling
201
+ from .halide import HalideScheduling
202
+ from .triton import TritonScheduling
203
+ from .wrapper import WrapperCodeGen
204
+
205
+ if get_scheduling_for_device("cpu") is None:
206
+ cpu_backends = {"cpp": CppScheduling, "halide": HalideScheduling}
207
+ register_backend_for_device(
208
+ "cpu",
209
+ lambda *args, **kwargs: cpu_backends[config.cpu_backend](*args, **kwargs),
210
+ WrapperCodeGen,
211
+ CppWrapperCpu,
212
+ )
213
+
214
+ if get_scheduling_for_device("cuda") is None:
215
+ # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
216
+ cuda_backends = {"triton": CUDACombinedScheduling, "halide": HalideScheduling}
217
+ register_backend_for_device(
218
+ "cuda",
219
+ lambda *args, **kwargs: cuda_backends[config.cuda_backend](*args, **kwargs),
220
+ WrapperCodeGen,
221
+ CppWrapperCuda,
222
+ )
223
+
224
+ if get_scheduling_for_device("xpu") is None:
225
+ register_backend_for_device("xpu", TritonScheduling, WrapperCodeGen)
226
+
227
+ private_backend = torch._C._get_privateuse1_backend_name()
228
+ if (
229
+ private_backend != "privateuseone"
230
+ and get_scheduling_for_device(private_backend) is None
231
+ ):
232
+ from torch.utils.backend_registration import _get_custom_mod_func
233
+
234
+ try:
235
+ device_scheduling = _get_custom_mod_func("Scheduling")
236
+ wrapper_codegen = _get_custom_mod_func("WrapperCodeGen")
237
+ cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodeGen")
238
+ if device_scheduling and wrapper_codegen and cpp_wrapper_codegen:
239
+ register_backend_for_device(
240
+ private_backend,
241
+ device_scheduling,
242
+ wrapper_codegen,
243
+ cpp_wrapper_codegen,
244
+ )
245
+ except RuntimeError:
246
+ pass
247
+
248
+
249
+ def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
250
+ from ..ir import FlexibleLayout
251
+
252
+ # added contiguous index prevents reordering
253
+ return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
254
+
255
+
256
+ def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides):
257
+ device_op_overrides_dict[device] = device_op_overrides
258
+
259
+
260
+ def get_device_op_overrides(device: str):
261
+ assert isinstance(device, str)
262
+
263
+ if not device_op_overrides_dict.keys():
264
+ from .cuda import device_op_overrides # noqa: F401
265
+ from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401
266
+
267
+ if device in device_op_overrides_dict.keys():
268
+ return device_op_overrides_dict[device]
269
+
270
+
271
+ @functools.lru_cache(None)
272
+ def boolean_ops():
273
+ return (
274
+ "isinf",
275
+ "isnan",
276
+ "logical_not",
277
+ "signbit",
278
+ "le",
279
+ "lt",
280
+ "ge",
281
+ "gt",
282
+ "eq",
283
+ "ne",
284
+ )
285
+
286
+
287
+ DTYPE_TO_COMPUTATION_DTYPE = {
288
+ torch.bfloat16: torch.float,
289
+ torch.float16: torch.float,
290
+ **{
291
+ dtype: dtype
292
+ for dtype in [
293
+ torch.bool,
294
+ torch.float32,
295
+ torch.float64,
296
+ torch.int8,
297
+ torch.int16,
298
+ torch.int32,
299
+ torch.int64,
300
+ torch.uint8,
301
+ torch.uint16,
302
+ torch.uint32,
303
+ torch.uint64,
304
+ ]
305
+ },
306
+ }
307
+
308
+
309
+ def deduce_output_dtype_by_name(
310
+ op_name: str,
311
+ *args,
312
+ **kwargs,
313
+ ) -> Optional[torch.dtype]:
314
+ """
315
+ Given op name and a list of input dtypes, deduce the output dtype
316
+ """
317
+ if op_name in boolean_ops():
318
+ return torch.bool
319
+ elif op_name in (
320
+ "to_dtype",
321
+ "index_expr",
322
+ ):
323
+ return kwargs["dtype"] if "dtype" in kwargs else args[-1]
324
+ elif op_name in (
325
+ "rand",
326
+ "randn",
327
+ ):
328
+ return torch.float
329
+ elif op_name in (
330
+ "get_index",
331
+ "randint64",
332
+ "load_seed",
333
+ ):
334
+ return torch.int64
335
+ elif op_name == "reduction":
336
+ return kwargs["dtype"] if "dtype" in kwargs else args[1]
337
+ elif op_name == "constant":
338
+ dtype = kwargs["dtype"] if "dtype" in kwargs else args[-1]
339
+ return DTYPE_TO_COMPUTATION_DTYPE[dtype] # type: ignore[index]
340
+ elif op_name in (
341
+ "load",
342
+ "store",
343
+ "store_reduction",
344
+ ):
345
+ buf_name = args[1]
346
+ return V.graph.get_dtype(buf_name) # type: ignore[arg-type]
347
+ elif op_name == "to_dtype_bitcast":
348
+ return kwargs["dtype"] if "dtype" in kwargs else args[-2]
349
+ return None
350
+
351
+
352
+ class DataTypePropagation:
353
+ def __init__(self, body) -> None:
354
+ self.body = body
355
+ self.graphs: Dict[Union[Callable[..., Any], str], Any] = {
356
+ "root": body.root_block.graph
357
+ }
358
+ for k, v in body.subblocks.items():
359
+ self.graphs[k] = v.graph
360
+
361
+ def deduce_node_dtype_by_inputs(self, node: torch.fx.Node):
362
+ inputs = node.all_input_nodes
363
+ input_nodes = [
364
+ n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
365
+ ]
366
+ if len(input_nodes) == 0:
367
+ return None
368
+
369
+ all_input_nodes_propagated = all(
370
+ OptimizationContext.key in n.meta
371
+ and n.meta[OptimizationContext.key].dtype is not None
372
+ for n in input_nodes
373
+ )
374
+ if not all_input_nodes_propagated:
375
+ return None
376
+
377
+ return functools.reduce(
378
+ torch.promote_types,
379
+ [n.meta[OptimizationContext.key].dtype for n in input_nodes],
380
+ )
381
+
382
+ def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node):
383
+ sub_graph = self.graphs[node.target]
384
+ dtype = self.propagate_graph(sub_graph)
385
+ assert dtype
386
+ return dtype
387
+
388
+ def deduce_node_dtype(self, node: torch.fx.Node):
389
+ if node.op == "placeholder":
390
+ return None
391
+
392
+ if node.target == "output" and len(node.args) != 1:
393
+ # we can infer output node if it only have 1 arg
394
+ return None
395
+
396
+ if node.target == operator.getitem:
397
+ return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
398
+
399
+ assert isinstance(node.target, str)
400
+
401
+ if node.target.startswith("masked_subblock"):
402
+ return self.deduce_node_dtype_by_subgraph(node)
403
+
404
+ if (
405
+ output_dtype := deduce_output_dtype_by_name(
406
+ node.target,
407
+ *node.args,
408
+ **node.kwargs,
409
+ )
410
+ ) is not None:
411
+ return output_dtype
412
+
413
+ return self.deduce_node_dtype_by_inputs(node)
414
+
415
+ def propagate_graph(self, graph: torch.fx.Graph):
416
+ assert graph.nodes
417
+ graph_dtype = None
418
+ # For masked_subblock, we use output's dtype to represent
419
+ # the dtype of this subgraph. For other cases, graph_dtype
420
+ # might be None
421
+ for node in graph.nodes:
422
+ if OptimizationContext.key in node.meta:
423
+ opt_ctx = node.meta[OptimizationContext.key]
424
+ else:
425
+ opt_ctx = OptimizationContext()
426
+
427
+ opt_ctx.dtype = self.deduce_node_dtype(node)
428
+ node.meta[OptimizationContext.key] = opt_ctx
429
+ if node.target == "output":
430
+ graph_dtype = opt_ctx.dtype
431
+ return graph_dtype
432
+
433
+ def propagate(self):
434
+ self.propagate_graph(self.graphs["root"])
435
+
436
+ @classmethod
437
+ def propagate_loopbody(cls, body):
438
+ return cls(body).propagate()
439
+
440
+ @classmethod
441
+ def propagate_scheduler_node(cls, node):
442
+ from ..loop_body import LoopBody
443
+ from ..scheduler import SchedulerNode
444
+
445
+ assert isinstance(node, SchedulerNode)
446
+ assert isinstance(node._body, LoopBody)
447
+ DataTypePropagation.propagate_loopbody(node._body)
448
+
449
+
450
+ # This printer contains rules that are supposed to be generic for both C/C++ and
451
+ # Python
452
+ class ExprPrinter(Printer):
453
+ @staticmethod
454
+ def paren(string):
455
+ def all_in_parens(string):
456
+ if string[0] != "(" or len(string) < 2:
457
+ return False
458
+ count = 1
459
+ for i, char in enumerate(string[1:]):
460
+ if char == "(":
461
+ count += 1
462
+ elif char == ")":
463
+ count -= 1
464
+ if count == 0 and i != len(string) - 2:
465
+ return False
466
+ assert count == 0
467
+ return True
468
+
469
+ if (
470
+ isinstance(string, CSEVariable)
471
+ or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE)
472
+ or re.match(r"^\([^)]*\)$", string, re.IGNORECASE)
473
+ or string == ""
474
+ ):
475
+ return string
476
+ # don't put extra parens for strings that are already wrapped in parens
477
+ if all_in_parens(string):
478
+ return string
479
+ return f"({string})"
480
+
481
+ def _print_Relational(self, expr):
482
+ return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
483
+
484
+ def _print_Mul(self, expr):
485
+ return "*".join(map(self.paren, map(self._print, expr.args)))
486
+
487
+ def _print_Add(self, expr):
488
+ return " + ".join(map(self.paren, map(self._print, expr.args)))
489
+
490
+ # NB: this is OK to put here, because Mod is only defined for positive
491
+ # numbers, and so across C/Python its behavior is consistent
492
+ def _print_Mod(self, expr):
493
+ return " % ".join(map(self.paren, map(self._print, expr.args)))
494
+
495
+ def _print_FloatTrueDiv(self, expr):
496
+ lhs, rhs = expr.args
497
+ return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
498
+
499
+ def _print_CleanDiv(self, expr):
500
+ return self._print_FloorDiv(expr)
501
+
502
+ def _print_Identity(self, expr):
503
+ return self._print(expr.args[0])
504
+
505
+ def _print_GreaterThan(self, expr):
506
+ # GreaterThan: >=
507
+ # StrictlyGreaterThan: >
508
+ # Go figure...
509
+ return " >= ".join(map(self.paren, map(self._print, expr.args)))
510
+
511
+ # NB: The C implementation is injected into codegen at
512
+ # torch/_inductor/codegen/wrapper.py
513
+ def _print_align(self, expr):
514
+ assert len(expr.args) == 1
515
+ return f"align({self._print(expr.args[0])})"
516
+
517
+ # This must be implemented because sympy will collect x * x into Pow(x, 2), without
518
+ # any explicit intervention. We print it just like x * x, notably, we
519
+ # never generate sympy.Pow with floats.
520
+ #
521
+ # NB: this pow by natural, you should never have used builtin sympy.pow
522
+ # for FloatPow, and a symbolic exponent should be PowByNatural. These
523
+ # means exp is guaranteed to be integer.
524
+ def _print_Pow(self, expr):
525
+ base, exp = expr.args
526
+ base = self._print(base)
527
+ assert exp == int(exp), exp
528
+ exp = int(exp)
529
+ assert exp >= 0
530
+ if exp > 0:
531
+ return "*".join([self.paren(base)] * exp)
532
+ else: # exp == 0
533
+ return "1"
534
+
535
+ # Explicit NotImplemented functions are to prevent default sympy printing
536
+ # behavior, which will just barf out ToFloat(...) to your IR. The error
537
+ # message is better here because it tells you which printer class it needs
538
+ # to go in.
539
+
540
+ def _print_ToFloat(self, expr):
541
+ raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
542
+
543
+ def _print_Infinity(self, expr):
544
+ raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
545
+
546
+ def _print_NegativeInfinity(self, expr):
547
+ raise NotImplementedError(
548
+ f"_print_NegativeInfinity not implemented for {type(self)}"
549
+ )
550
+
551
+ def _print_FloorDiv(self, expr):
552
+ raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
553
+
554
+ def _print_PythonMod(self, expr):
555
+ raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
556
+
557
+ def _print_IntTrueDiv(self, expr):
558
+ raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
559
+
560
+ def _print_PowByNatural(self, expr):
561
+ raise NotImplementedError(
562
+ f"_print_PowByNatural not implemented for {type(self)}"
563
+ )
564
+
565
+ def _print_FloatPow(self, expr):
566
+ raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
567
+
568
+ def _print_TruncToInt(self, expr):
569
+ raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
570
+
571
+ def _print_RoundToInt(self, expr):
572
+ raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
573
+
574
+ def _print_RoundDecimal(self, expr):
575
+ raise NotImplementedError(
576
+ f"_print_RoundDecimal not implemented for {type(self)}"
577
+ )
578
+
579
+ # NB: Some float operations are INTENTIONALLY not implemented for
580
+ # printers. You can implement them as a quick unblock, but it is better
581
+ # to ask yourself why we haven't done this computation in the Tensor
582
+ # universe instead
583
+
584
+ def _print_TruncToFloat(self, expr):
585
+ raise NotImplementedError(
586
+ f"_print_TruncToFloat not implemented for {type(self)}"
587
+ )
588
+
589
+ def doprint(self, expr, *, simplify: bool = True):
590
+ # TODO: why are people passing strings to the printer here :think:
591
+ if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
592
+ expr = V.graph.sizevars.simplify(expr)
593
+ return super().doprint(expr)
594
+
595
+
596
+ class PythonPrinter(ExprPrinter):
597
+ def _print_ToFloat(self, expr):
598
+ assert len(expr.args) == 1
599
+ return f"float({self._print(expr.args[0])})"
600
+
601
+ def _print_ModularIndexing(self, expr):
602
+ x, div, mod = expr.args
603
+ x = self.paren(self.doprint(x))
604
+ div = self.paren(self.doprint(div))
605
+ mod = self.paren(self.doprint(mod))
606
+ if div != "1":
607
+ x = f"({x} // {div})"
608
+ return f"{x} % {mod}"
609
+
610
+ def _print_Infinity(self, expr):
611
+ return "math.inf"
612
+
613
+ def _print_NegativeInfinity(self, expr):
614
+ return "-math.inf"
615
+
616
+ # WARNING: this is dangerous for Triton, which has C-style modulus
617
+ def _print_PythonMod(self, expr):
618
+ return " % ".join(map(self.paren, map(self._print, expr.args)))
619
+
620
+ # WARNING: this is dangerous for Triton, which has C-style modulus
621
+ def _print_FloorDiv(self, expr):
622
+ x, div = expr.args
623
+ x = self.paren(self.doprint(x))
624
+ div = self.paren(self.doprint(div))
625
+ return f"({x} // {div})"
626
+
627
+ # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
628
+ # does a special algorithm
629
+ def _print_IntTrueDiv(self, expr):
630
+ lhs, rhs = expr.args
631
+ return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
632
+
633
+ def _helper_sqrt(self, expr):
634
+ return f"math.sqrt({self._print(expr)})"
635
+
636
+ def _print_OpaqueUnaryFn_sqrt(self, expr):
637
+ return self._helper_sqrt(expr.args[0])
638
+
639
+ def _print_FloatPow(self, expr):
640
+ base, exp = expr.args
641
+ return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
642
+
643
+ # TODO: Not sure this works with Triton, even when base/exp are integral
644
+ def _print_PowByNatural(self, expr):
645
+ base, exp = expr.args
646
+ return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
647
+
648
+ def _print_floor(self, expr):
649
+ assert len(expr.args) == 1
650
+ return f"math.floor({self._print(expr.args[0])})"
651
+
652
+ def _print_FloorToInt(self, expr):
653
+ assert len(expr.args) == 1
654
+ return f"math.floor({self._print(expr.args[0])})"
655
+
656
+ def _print_TruncToInt(self, expr):
657
+ assert len(expr.args) == 1
658
+ # This also could have been int(), they'll do the same thing for float
659
+ return f"math.trunc({self._print(expr.args[0])})"
660
+
661
+ def _print_ceiling(self, expr):
662
+ assert len(expr.args) == 1
663
+ return f"math.ceil({self._print(expr.args[0])})"
664
+
665
+ def _print_CeilToInt(self, expr):
666
+ assert len(expr.args) == 1
667
+ return f"math.ceil({self._print(expr.args[0])})"
668
+
669
+ def _print_Abs(self, expr):
670
+ assert len(expr.args) == 1
671
+ return f"abs({self._print(expr.args[0])})"
672
+
673
+ # NB: It's expected that we've made explicit any promotion in the sympy
674
+ # expression, so it doesn't matter that Python max/min doesn't perform
675
+ # promotion
676
+ def _print_Max(self, expr):
677
+ assert len(expr.args) >= 2
678
+ return f"max({', '.join(map(self._print, expr.args))})"
679
+
680
+ def _print_Min(self, expr):
681
+ assert len(expr.args) >= 2
682
+ return f"min({', '.join(map(self._print, expr.args))})"
683
+
684
+ def _print_OpaqueUnaryFn_cos(self, expr):
685
+ assert len(expr.args) == 1
686
+ return f"math.cos({self._print(expr.args[0])})"
687
+
688
+ def _print_OpaqueUnaryFn_cosh(self, expr):
689
+ assert len(expr.args) == 1
690
+ return f"math.cosh({self._print(expr.args[0])})"
691
+
692
+ def _print_OpaqueUnaryFn_acos(self, expr):
693
+ assert len(expr.args) == 1
694
+ return f"math.acos({self._print(expr.args[0])})"
695
+
696
+ def _print_OpaqueUnaryFn_sin(self, expr):
697
+ assert len(expr.args) == 1
698
+ return f"math.sin({self._print(expr.args[0])})"
699
+
700
+ def _print_OpaqueUnaryFn_sinh(self, expr):
701
+ assert len(expr.args) == 1
702
+ return f"math.sinh({self._print(expr.args[0])})"
703
+
704
+ def _print_OpaqueUnaryFn_asin(self, expr):
705
+ assert len(expr.args) == 1
706
+ return f"math.asin({self._print(expr.args[0])})"
707
+
708
+ def _print_OpaqueUnaryFn_tan(self, expr):
709
+ assert len(expr.args) == 1
710
+ return f"math.tan({self._print(expr.args[0])})"
711
+
712
+ def _print_OpaqueUnaryFn_tanh(self, expr):
713
+ assert len(expr.args) == 1
714
+ return f"math.tanh({self._print(expr.args[0])})"
715
+
716
+ def _print_OpaqueUnaryFn_atan(self, expr):
717
+ assert len(expr.args) == 1
718
+ return f"math.atan({self._print(expr.args[0])})"
719
+
720
+ def _print_RoundToInt(self, expr):
721
+ assert len(expr.args) == 1
722
+ return f"round({self._print(expr.args[0])})"
723
+
724
+ def _print_RoundDecimal(self, expr):
725
+ assert len(expr.args) == 2
726
+ number, ndigits = expr.args
727
+ assert isinstance(ndigits, sympy.Integer)
728
+ return f"round({self._print(number)}, {ndigits})"
729
+
730
+
731
+ class OpOverrides:
732
+ def __init__(self, parent):
733
+ super().__init__()
734
+ self._parent = parent
735
+
736
+ def __getattr__(self, item):
737
+ return getattr(self._parent, item)
738
+
739
+ @staticmethod
740
+ def identity(value):
741
+ # used to trigger cse
742
+ return value
743
+
744
+ @staticmethod
745
+ def constant(value, dtype):
746
+ return repr(value)
747
+
748
+ @staticmethod
749
+ def reciprocal(x):
750
+ return ops.truediv(ops.constant(1, torch.int32), x)
751
+
752
+ @staticmethod
753
+ def square(x):
754
+ return ops.mul(x, x)
755
+
756
+ @staticmethod
757
+ def erfc(x):
758
+ return ops.sub(ops.constant(1, torch.float32), ops.erf(x))
759
+
760
+ @staticmethod
761
+ def erfcx(x):
762
+ return ops.mul(ops.exp(ops.square(x)), ops.erfc(x))
763
+
764
+ @staticmethod
765
+ def expm1(x):
766
+ return ops.sub(ops.exp(x), ops.constant(1, torch.float32))
767
+
768
+ @staticmethod
769
+ def log10(x):
770
+ return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32))
771
+
772
+ @staticmethod
773
+ def log2(x):
774
+ return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32))
775
+
776
+ @staticmethod
777
+ def exp2(x):
778
+ return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32)))
779
+
780
+ @staticmethod
781
+ def log1p(x):
782
+ return ops.log(ops.add(x, ops.constant(1, torch.int32)))
783
+
784
+ @staticmethod
785
+ def sigmoid(x):
786
+ one = ops.constant(1, torch.int32)
787
+ return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x))))
788
+
789
+ @staticmethod
790
+ def libdevice_sigmoid(x):
791
+ one = ops.constant(1, torch.int32)
792
+ return ops.truediv(one, ops.add(one, ops.libdevice_exp(ops.neg(x))))
793
+
794
+ @staticmethod
795
+ def relu(x):
796
+ return ops.maximum(x, ops.constant(0, torch.int32))
797
+
798
+ @staticmethod
799
+ def libdevice_abs(x):
800
+ return ops.abs(x)
801
+
802
+ @staticmethod
803
+ def libdevice_sqrt(x):
804
+ return ops.sqrt(x)
805
+
806
+ @staticmethod
807
+ def libdevice_cos(x):
808
+ return ops.cos(x)
809
+
810
+ @staticmethod
811
+ def libdevice_sin(x):
812
+ return ops.sin(x)
813
+
814
+ @staticmethod
815
+ def libdevice_log(x):
816
+ return ops.log(x)
817
+
818
+ @staticmethod
819
+ def libdevice_exp(x):
820
+ return ops.exp(x)
821
+
822
+ @staticmethod
823
+ def bitwise_not(x):
824
+ return f"~{ExprPrinter.paren(x)}"
825
+
826
+ @staticmethod
827
+ def logical_not(a):
828
+ return f"{ExprPrinter.paren(a)} == 0"
829
+
830
+ @staticmethod
831
+ def bitwise_and(x, y):
832
+ return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
833
+
834
+ @staticmethod
835
+ def bitwise_or(x, y):
836
+ return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
837
+
838
+ @staticmethod
839
+ def bitwise_xor(x, y):
840
+ return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
841
+
842
+ @staticmethod
843
+ def bitwise_left_shift(x, y):
844
+ return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
845
+
846
+ @staticmethod
847
+ def bitwise_right_shift(x, y):
848
+ return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
849
+
850
+ @staticmethod
851
+ def remainder(a, b):
852
+ r = ops.mod(a, b)
853
+ cond = ops.and_(
854
+ ops.ne(r, ops.constant(0, torch.int32)),
855
+ ops.ne(ops.signbit(r), ops.signbit(b)),
856
+ )
857
+ return ops.where(cond, ops.add(r, b), r)
858
+
859
+ @staticmethod
860
+ def trunc_to_int(a, dtype):
861
+ return ops.to_dtype(ops.trunc(a), dtype)
862
+
863
+ @staticmethod
864
+ def floor_to_int(a, dtype):
865
+ return ops.to_dtype(ops.floor(a), dtype)
866
+
867
+ @staticmethod
868
+ def ceil_to_int(a, dtype):
869
+ return ops.to_dtype(ops.ceil(a), dtype)
870
+
871
+ @staticmethod
872
+ def round_to_int(a, dtype):
873
+ return ops.to_dtype(ops.round(a), dtype)
874
+
875
+ @staticmethod
876
+ def int_truediv(a, b):
877
+ # TODO: this is wrong
878
+ # TODO: an easy bandaid is to generate runtime asserts that it's
879
+ # <= 2**53, which is when this equation is correct
880
+ return ops.truediv(a, b)
881
+
882
+ @staticmethod
883
+ def load_seed(name, offset):
884
+ return ops.load(name, sympy.Integer(offset))
885
+
886
+ @classmethod
887
+ def _initialize_pointwise_overrides(cls, target):
888
+ assert target in {"triton", "cpp", "cppvec"}, target
889
+
890
+ for funcname, data in pointwise_overrides_data.items():
891
+ impl = getattr(data, target)
892
+ if impl is None:
893
+ continue
894
+ setattr(cls, funcname, staticmethod(impl))
895
+
896
+
897
+ @dataclasses.dataclass
898
+ class OverridesData:
899
+ name: str
900
+ cpp: Callable[..., str]
901
+ # None when not impl in libdevice/triton
902
+ triton: Optional[Callable[..., str]] = None
903
+ # None when not impl in aten/.../vec
904
+ cppvec: Optional[Callable[..., str]] = None
905
+ type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
906
+ ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
907
+ )
908
+
909
+
910
+ # NB: if you add a new special function, don't forget to update
911
+ # torch._inductor.ops_handler too
912
+ pointwise_overrides_data: Dict[str, OverridesData] = dict(
913
+ airy_ai=OverridesData(
914
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
915
+ cpp=lambda x: f"airy_ai_forward({x})",
916
+ name="special_airy_ai",
917
+ ),
918
+ bessel_j0=OverridesData(
919
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
920
+ cpp=lambda x: f"bessel_j0_forward({x})",
921
+ triton=lambda x: f"libdevice.j0({x})",
922
+ name="special_bessel_j0",
923
+ ),
924
+ bessel_j1=OverridesData(
925
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
926
+ cpp=lambda x: f"bessel_j1_forward({x})",
927
+ triton=lambda x: f"libdevice.j1({x})",
928
+ name="special_bessel_j1",
929
+ ),
930
+ bessel_y0=OverridesData(
931
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
932
+ cpp=lambda x: f"bessel_y0_forward({x})",
933
+ triton=lambda x: f"libdevice.y0({x})",
934
+ name="special_bessel_y0",
935
+ ),
936
+ bessel_y1=OverridesData(
937
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
938
+ cpp=lambda x: f"bessel_y1_forward({x})",
939
+ triton=lambda x: f"libdevice.y1({x})",
940
+ name="special_bessel_y1",
941
+ ),
942
+ digamma=OverridesData(
943
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
944
+ cpp=lambda x: f"calc_digamma({x})",
945
+ cppvec=lambda x: f"{x}.digamma()",
946
+ name="digamma",
947
+ ),
948
+ # no cpp nor triton implementation for entr, it is defined as decomposition
949
+ # erf, erfc
950
+ erfcx=OverridesData(
951
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
952
+ cpp=lambda x: f"calc_erfcx({x})",
953
+ triton=lambda x: f"libdevice.erfcx({x})",
954
+ name="special_erfcx",
955
+ ),
956
+ fma=OverridesData(
957
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
958
+ cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})",
959
+ cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})",
960
+ triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})",
961
+ name="fma",
962
+ ),
963
+ # erfinv, exp2, expit, gammaln
964
+ igamma=OverridesData(
965
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
966
+ cpp=lambda x, y: f"calc_igamma({x}, {y})",
967
+ name="igamma",
968
+ ),
969
+ igammac=OverridesData(
970
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
971
+ cpp=lambda x, y: f"calc_igammac({x}, {y})",
972
+ name="igammac",
973
+ ),
974
+ gammainc=OverridesData(
975
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
976
+ cpp=lambda x, y: f"calc_igamma({x}, {y})",
977
+ name="special_gammainc",
978
+ ),
979
+ gammaincc=OverridesData(
980
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
981
+ cpp=lambda x, y: f"calc_igammac({x}, {y})",
982
+ name="special_gammaincc",
983
+ ),
984
+ i0=OverridesData(
985
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
986
+ cpp=lambda x: f"calc_i0({x})",
987
+ triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
988
+ cppvec=lambda x: f"{x}.i0()",
989
+ name="i0",
990
+ ),
991
+ i0e=OverridesData(
992
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
993
+ cpp=lambda x: f"calc_i0e({x})",
994
+ cppvec=lambda x: f"{x}.i0e()",
995
+ name="special_i0e",
996
+ ),
997
+ i1=OverridesData(
998
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
999
+ cpp=lambda x: f"calc_i1({x})",
1000
+ triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
1001
+ name="special_i1",
1002
+ ),
1003
+ i1e=OverridesData(
1004
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1005
+ cpp=lambda x: f"calc_i1e({x})",
1006
+ name="special_i1e",
1007
+ ),
1008
+ log_ndtr=OverridesData(
1009
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1010
+ cpp=lambda x: f"calc_log_ndtr({x})",
1011
+ name="special_log_ndtr",
1012
+ ),
1013
+ # logit
1014
+ modified_bessel_i0=OverridesData(
1015
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1016
+ cpp=lambda x: f"modified_bessel_i0_forward({x})",
1017
+ triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
1018
+ name="special_modified_bessel_i0",
1019
+ ),
1020
+ modified_bessel_i1=OverridesData(
1021
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1022
+ cpp=lambda x: f"modified_bessel_i1_forward({x})",
1023
+ triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
1024
+ name="special_modified_bessel_i1",
1025
+ ),
1026
+ modified_bessel_k0=OverridesData(
1027
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1028
+ cpp=lambda x: f"modified_bessel_k0_forward({x})",
1029
+ name="special_modified_bessel_k0",
1030
+ ),
1031
+ modified_bessel_k1=OverridesData(
1032
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1033
+ cpp=lambda x: f"modified_bessel_k1_forward({x})",
1034
+ name="special_modified_bessel_k1",
1035
+ ),
1036
+ # multigamma
1037
+ ndtr=OverridesData(
1038
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1039
+ cpp=lambda x: f"calc_ndtr({x})",
1040
+ name="special_ndtr",
1041
+ ),
1042
+ ndtri=OverridesData(
1043
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1044
+ cpp=lambda x: f"calc_ndtri({x})",
1045
+ name="special_ndtri",
1046
+ ),
1047
+ polygamma=OverridesData(
1048
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1049
+ cpp=lambda x, y: f"calc_polygamma({y}, {x})",
1050
+ name="polygamma",
1051
+ ),
1052
+ # psi - alias to digamma
1053
+ # round
1054
+ scaled_modified_bessel_k0=OverridesData(
1055
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1056
+ cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})",
1057
+ name="special_scaled_modified_bessel_k0",
1058
+ ),
1059
+ scaled_modified_bessel_k1=OverridesData(
1060
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1061
+ cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})",
1062
+ name="special_scaled_modified_bessel_k1",
1063
+ ),
1064
+ # sinc
1065
+ spherical_bessel_j0=OverridesData(
1066
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1067
+ cpp=lambda x: f"spherical_bessel_j0_forward({x})",
1068
+ name="special_spherical_bessel_j0",
1069
+ ),
1070
+ zeta=OverridesData(
1071
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1072
+ cpp=lambda x, y: f"zeta({x}, {y})",
1073
+ name="special_zeta",
1074
+ ),
1075
+ chebyshev_polynomial_t=OverridesData(
1076
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1077
+ cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})",
1078
+ name="special_chebyshev_polynomial_t",
1079
+ ),
1080
+ chebyshev_polynomial_u=OverridesData(
1081
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1082
+ cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})",
1083
+ name="special_chebyshev_polynomial_u",
1084
+ ),
1085
+ chebyshev_polynomial_v=OverridesData(
1086
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1087
+ cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})",
1088
+ name="special_chebyshev_polynomial_v",
1089
+ ),
1090
+ chebyshev_polynomial_w=OverridesData(
1091
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1092
+ cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})",
1093
+ name="special_chebyshev_polynomial_w",
1094
+ ),
1095
+ legendre_polynomial_p=OverridesData(
1096
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1097
+ cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})",
1098
+ name="special_legendre_polynomial_p",
1099
+ ),
1100
+ shifted_chebyshev_polynomial_t=OverridesData(
1101
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1102
+ cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})",
1103
+ name="special_shifted_chebyshev_polynomial_t",
1104
+ ),
1105
+ shifted_chebyshev_polynomial_u=OverridesData(
1106
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1107
+ cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})",
1108
+ name="special_shifted_chebyshev_polynomial_u",
1109
+ ),
1110
+ shifted_chebyshev_polynomial_v=OverridesData(
1111
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1112
+ cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})",
1113
+ name="special_shifted_chebyshev_polynomial_v",
1114
+ ),
1115
+ shifted_chebyshev_polynomial_w=OverridesData(
1116
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1117
+ cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})",
1118
+ name="special_shifted_chebyshev_polynomial_w",
1119
+ ),
1120
+ hermite_polynomial_h=OverridesData(
1121
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1122
+ cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})",
1123
+ name="special_hermite_polynomial_h",
1124
+ ),
1125
+ hermite_polynomial_he=OverridesData(
1126
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1127
+ cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})",
1128
+ name="special_hermite_polynomial_he",
1129
+ ),
1130
+ laguerre_polynomial_l=OverridesData(
1131
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1132
+ cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})",
1133
+ name="special_laguerre_polynomial_l",
1134
+ ),
1135
+ )
1136
+
1137
+
1138
+ # Use mypy to check protocol implemented correctly
1139
+ def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
1140
+ return h
1141
+
1142
+
1143
+ class DeferredLine(DeferredLineBase):
1144
+ """A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
1145
+
1146
+ def __init__(self, name, line):
1147
+ super().__init__(line)
1148
+ self.name = name
1149
+ assert not isinstance(line, DeferredLineBase)
1150
+
1151
+ def __call__(self):
1152
+ if all(
1153
+ self.name not in x
1154
+ for x in (
1155
+ V.graph.removed_buffers,
1156
+ V.kernel.removed_buffers,
1157
+ V.graph.inplaced_to_remove,
1158
+ V.kernel.inplaced_to_remove,
1159
+ )
1160
+ ):
1161
+ return self.line
1162
+ return None
1163
+
1164
+ def _new_line(self, line):
1165
+ return DeferredLine(self.name, line)
1166
+
1167
+
1168
+ class BracesBuffer(IndentedBuffer):
1169
+ def indent(self, offset=1):
1170
+ @contextlib.contextmanager
1171
+ def ctx():
1172
+ for _ in range(offset):
1173
+ self.writeline("{")
1174
+ self._indent += 1
1175
+ for _ in range(-offset):
1176
+ self._indent -= 1
1177
+ self.writeline("}")
1178
+ yield
1179
+ for _ in range(-offset):
1180
+ self.writeline("{")
1181
+ self._indent += 1
1182
+ for _ in range(offset):
1183
+ self._indent -= 1
1184
+ self.writeline("}")
1185
+
1186
+ return ctx()
1187
+
1188
+
1189
+ class InplacedBuffer(NamedTuple):
1190
+ inner_name: str
1191
+ other_names: List[str]
1192
+
1193
+
1194
+ class KernelArgs:
1195
+ @staticmethod
1196
+ def _lookup(prefix, odict, name):
1197
+ assert isinstance(name, (str, sympy.Symbol))
1198
+ if name not in odict:
1199
+ odict[name] = f"{prefix}{len(odict)}"
1200
+ return odict[name]
1201
+
1202
+ def __init__(self, sizevars=None):
1203
+ self.input_buffers = {}
1204
+ self.output_buffers = {}
1205
+ self.inplace_buffers = {}
1206
+ self.sizevars = sizevars or {}
1207
+ self.workspace_arg = None
1208
+
1209
+ def __repr__(self):
1210
+ return "KernelArgs({})".format(
1211
+ ", ".join(
1212
+ map(
1213
+ repr,
1214
+ [
1215
+ self.input_buffers,
1216
+ self.output_buffers,
1217
+ self.inplace_buffers,
1218
+ self.sizevars,
1219
+ ],
1220
+ )
1221
+ )
1222
+ )
1223
+
1224
+ def _buffer_is_marked_removed(self, name):
1225
+ return isinstance(name, str) and name.startswith("REMOVED")
1226
+
1227
+ def input(self, name):
1228
+ if V.graph.scheduler:
1229
+ name = V.graph.scheduler.mutation_real_name.get(name, name)
1230
+ assert name not in V.graph.removed_buffers, name
1231
+ if name in self.output_buffers:
1232
+ return self.output_buffers[name]
1233
+ if name in self.inplace_buffers:
1234
+ return self.inplace_buffers[name].inner_name
1235
+ if name.startswith("seed"):
1236
+ return self._lookup("seed", self.input_buffers, name)
1237
+ return self._lookup("in_ptr", self.input_buffers, name)
1238
+
1239
+ def output(self, name):
1240
+ if V.graph.scheduler:
1241
+ name = V.graph.scheduler.mutation_real_name.get(name, name)
1242
+ assert name not in V.graph.removed_buffers, name
1243
+ if name in self.inplace_buffers:
1244
+ return self.inplace_buffers[name].inner_name
1245
+ return self._lookup("out_ptr", self.output_buffers, name)
1246
+
1247
+ def make_inplace(self, input_name, output_name):
1248
+ assert output_name not in self.inplace_buffers
1249
+ if input_name in self.inplace_buffers:
1250
+ buf = self.inplace_buffers[input_name]
1251
+ buf.other_names.append(output_name)
1252
+ self.inplace_buffers[output_name] = buf
1253
+ else:
1254
+ buf = InplacedBuffer(
1255
+ f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
1256
+ [input_name, output_name],
1257
+ )
1258
+ self.inplace_buffers[input_name] = buf
1259
+ self.inplace_buffers[output_name] = buf
1260
+
1261
+ def workspace(self, nbytes: sympy.Expr, zero_fill: bool):
1262
+ if self.workspace_arg is None:
1263
+ self.workspace_arg = WorkspaceArg(nbytes, zero_fill)
1264
+ return "ws_ptr", 0
1265
+
1266
+ offset = self.workspace_arg.nbytes
1267
+ zero_fill = zero_fill or self.workspace_arg.zero_fill
1268
+ self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill)
1269
+ return "ws_ptr", offset
1270
+
1271
+ def seed_offset(self, name, value):
1272
+ if value in self.sizevars:
1273
+ return self.sizevars[value]
1274
+ if name in self.sizevars.values():
1275
+ name = (
1276
+ f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
1277
+ )
1278
+ self.sizevars[value] = name
1279
+ return name
1280
+
1281
+ def size(self, name):
1282
+ if str(name) == "seed":
1283
+ self.sizevars["seed"] = "seed"
1284
+ return "seed"
1285
+ return self._lookup("ks", self.sizevars, name)
1286
+
1287
+ def call_names(self):
1288
+ return chain(
1289
+ self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
1290
+ )
1291
+
1292
+ def wrap_ptr_arg(self, buf, dtype):
1293
+ return buf
1294
+
1295
+ def wrap_size_arg(self, size):
1296
+ return str(size)
1297
+
1298
+ def cpp_argdefs(self):
1299
+ from .cpp_utils import DTYPE_TO_CPP, INDEX_TYPE
1300
+
1301
+ call_args = []
1302
+ arg_defs = []
1303
+ arg_types = []
1304
+ for inplaced in unique(self.inplace_buffers.values()):
1305
+ if self._buffer_is_marked_removed(inplaced):
1306
+ continue
1307
+ outer = inplaced.other_names[-1]
1308
+ inner = inplaced.inner_name
1309
+ dtype = V.graph.get_dtype(outer)
1310
+ cpp_dtype = DTYPE_TO_CPP[dtype]
1311
+ arg_defs.append(f"{cpp_dtype}* {inner}")
1312
+ call_args.append(self.wrap_ptr_arg(outer, dtype))
1313
+ arg_types.append(f"{cpp_dtype}*")
1314
+ for outer, inner in self.input_buffers.items():
1315
+ if outer in self.inplace_buffers:
1316
+ continue
1317
+ dtype = V.graph.get_dtype(outer)
1318
+ cpp_dtype = DTYPE_TO_CPP[dtype]
1319
+ arg_defs.append(f"const {cpp_dtype}* {inner}")
1320
+ call_args.append(self.wrap_ptr_arg(outer, dtype))
1321
+ arg_types.append(f"const {cpp_dtype}*")
1322
+ for outer, inner in self.output_buffers.items():
1323
+ if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
1324
+ continue
1325
+ dtype = V.graph.get_dtype(outer)
1326
+ cpp_dtype = DTYPE_TO_CPP[dtype]
1327
+ arg_defs.append(f"{cpp_dtype}* {inner}")
1328
+ call_args.append(self.wrap_ptr_arg(outer, dtype))
1329
+ arg_types.append(f"{cpp_dtype}*")
1330
+ for outer, inner in self.sizevars.items():
1331
+ arg_defs.append(f"const {INDEX_TYPE} {inner}")
1332
+ call_args.append(self.wrap_size_arg(outer))
1333
+ arg_types.append(f"const {INDEX_TYPE}")
1334
+ if V.graph.wrapper_code:
1335
+ V.graph.wrapper_code.ensure_size_computed(outer)
1336
+ assert self.workspace_arg is None, "Workspace not supported on CPU "
1337
+ return arg_defs, call_args, arg_types
1338
+
1339
+ def python_argdefs(self):
1340
+ arg_defs: List[str] = []
1341
+ call_args: List[str] = []
1342
+ arg_types: List[torch.dtype] = []
1343
+ precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = []
1344
+ for inplaced in unique(self.inplace_buffers.values()):
1345
+ if self._buffer_is_marked_removed(inplaced):
1346
+ continue
1347
+ arg_defs.append(inplaced.inner_name)
1348
+ call_args.append(inplaced.other_names[-1])
1349
+ arg_types.append(V.graph.get_dtype(inplaced.other_names[-1]))
1350
+ precompile_args.append(
1351
+ TensorArg(
1352
+ name=inplaced.inner_name,
1353
+ buffer=inplaced.other_names[-1],
1354
+ dtype=V.graph.get_dtype(inplaced.other_names[-1]),
1355
+ )
1356
+ )
1357
+ for outer, inner in chain(
1358
+ self.input_buffers.items(), self.output_buffers.items()
1359
+ ):
1360
+ if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
1361
+ continue
1362
+ arg_defs.append(inner)
1363
+ call_args.append(outer)
1364
+ arg_types.append(V.graph.get_dtype(outer))
1365
+ precompile_args.append(
1366
+ TensorArg(
1367
+ name=inner,
1368
+ buffer=outer,
1369
+ dtype=V.graph.get_dtype(outer),
1370
+ )
1371
+ )
1372
+ for outer, inner in self.sizevars.items():
1373
+ arg_defs.append(inner)
1374
+ call_args.append(outer)
1375
+ arg_types.append(type(outer)) # type: ignore[arg-type]
1376
+ precompile_args.append(SizeArg(inner, outer))
1377
+ if V.graph.wrapper_code:
1378
+ V.graph.wrapper_code.ensure_size_computed(outer)
1379
+ if self.workspace_arg is not None:
1380
+ arg_defs.append("ws_ptr")
1381
+ call_args.append("workspace")
1382
+ precompile_args.append(self.workspace_arg)
1383
+ return arg_defs, call_args, precompile_args, arg_types
1384
+
1385
+ def aliases(self):
1386
+ for inplaced in unique(self.inplace_buffers.values()):
1387
+ if self._buffer_is_marked_removed(inplaced):
1388
+ continue
1389
+ for other in inplaced.other_names:
1390
+ if (
1391
+ other in V.graph.inplaced_to_remove
1392
+ or other in V.kernel.inplaced_to_remove
1393
+ ):
1394
+ continue
1395
+ if other in self.input_buffers:
1396
+ yield self.input_buffers[other], inplaced.inner_name
1397
+ if other in self.output_buffers:
1398
+ yield self.output_buffers[other], inplaced.inner_name
1399
+
1400
+ def is_removed(self, name):
1401
+ def _is_removed(name, buffers):
1402
+ return name not in buffers or self._buffer_is_marked_removed(buffers[name])
1403
+
1404
+ return _is_removed(name, self.output_buffers) and _is_removed(
1405
+ name, self.inplace_buffers
1406
+ )
1407
+
1408
+ # Includes inplace buffers, excludes removed buffers. Essentially,
1409
+ # after you do a call into this kernel, which buffers actually contain
1410
+ # updated data? Modeled off of python_argdefs.
1411
+ def live_output_buffers(self):
1412
+ live_outs = OrderedSet() # type: ignore[var-annotated]
1413
+ for inplaced in unique(self.inplace_buffers.values()):
1414
+ if self._buffer_is_marked_removed(inplaced):
1415
+ continue
1416
+ live_outs.add(inplaced.other_names[-1])
1417
+ for outer, inner in self.output_buffers.items():
1418
+ if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
1419
+ continue
1420
+ live_outs.add(outer)
1421
+ return live_outs
1422
+
1423
+
1424
+ class CSEVariable:
1425
+ """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
1426
+ To do so, the backends can simply overload `Kernel.create_cse_var`
1427
+ The "CSEVariable.update_on_args" method gives you a hook for annotations
1428
+ See example of TritonCSEVariable in triton.py
1429
+ """
1430
+
1431
+ def __init__(self, name, bounds: ValueRanges[Any]):
1432
+ assert isinstance(bounds, ValueRanges)
1433
+ self.name = name
1434
+ self.bounds = bounds
1435
+ self.use_count = 1 # track how many tims this expression is used
1436
+
1437
+ def __str__(self):
1438
+ return self.name
1439
+
1440
+ def __hash__(self) -> int:
1441
+ return hash(self.name)
1442
+
1443
+ def __eq__(self, other) -> bool:
1444
+ return type(other) == type(self) and other.name == self.name
1445
+
1446
+ def update_on_args(self, name, args, kwargs):
1447
+ pass
1448
+
1449
+ def __repr__(self):
1450
+ return f"{self.__class__.__name__}({self.name!r})"
1451
+
1452
+
1453
+ class CppWrapperKernelArgs(KernelArgs):
1454
+ def wrap_ptr_arg(self, buf, dtype):
1455
+ from .cpp_utils import DTYPE_TO_CPP
1456
+
1457
+ if config.abi_compatible:
1458
+ # In the abi_compatible model, we just return the buf here.
1459
+ # We will form correct call args later in wrapper.generate_kernel_all.
1460
+ return buf
1461
+ else:
1462
+ return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"
1463
+
1464
+ def wrap_size_arg(self, size):
1465
+ return f"{size}"
1466
+
1467
+
1468
+ class CSE:
1469
+ """Common subexpression elimination"""
1470
+
1471
+ def __init__(
1472
+ self,
1473
+ prefix="",
1474
+ suffix="",
1475
+ name_prefix="tmp",
1476
+ iter_buffers=None,
1477
+ store_cache=None,
1478
+ reduction_cache=None,
1479
+ varname_map=None,
1480
+ ):
1481
+ self.prefix = prefix
1482
+ self.suffix = suffix
1483
+ self.cache = {}
1484
+ self.name_prefix = name_prefix
1485
+ self.store_cache = store_cache or {}
1486
+ self.reduction_cache = reduction_cache or {}
1487
+ self.iter_buffer_ids = iter_buffers or itertools.count()
1488
+ self.invalidated_stores = OrderedSet() # type: ignore[var-annotated]
1489
+ self.varname_map = varname_map or {}
1490
+
1491
+ def invalidate(self, keep_vars: OrderedSet[str]):
1492
+ for name, tmp in list(self.store_cache.items()):
1493
+ if tmp not in keep_vars:
1494
+ del self.store_cache[name]
1495
+ self.invalidated_stores.add(name)
1496
+ self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}
1497
+
1498
+ def clone(self):
1499
+ # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
1500
+ return CSE(
1501
+ prefix=self.prefix,
1502
+ suffix=self.suffix,
1503
+ name_prefix=self.name_prefix,
1504
+ iter_buffers=self.iter_buffer_ids,
1505
+ store_cache=self.store_cache,
1506
+ varname_map=self.varname_map,
1507
+ )
1508
+
1509
+ def generate(
1510
+ self,
1511
+ buffer: IndentedBuffer,
1512
+ expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
1513
+ *,
1514
+ bounds: ValueRanges[Any] = ValueRanges.unknown(),
1515
+ write=True,
1516
+ assignment=True,
1517
+ ) -> CSEVariable:
1518
+ if isinstance(expr, OpsValue):
1519
+ expr = expr.value
1520
+
1521
+ assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr)
1522
+ assert write or assignment
1523
+ if isinstance(expr, CSEVariable):
1524
+ # If the expressions were always created with all the information, we could
1525
+ # assert expr.bounds == bounds, but sometimes the expression is created
1526
+ # with the loose ValueRanges.unknown(), so we need to tighten the bounds
1527
+ expr.bounds = expr.bounds.tighten(bounds)
1528
+ expr.use_count += 1
1529
+ return expr
1530
+ cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr
1531
+ var = self.cache.get(cache_key, None)
1532
+ if not var:
1533
+ var = self.newvar(bounds)
1534
+ self.cache[cache_key] = var
1535
+ if write:
1536
+ if V.kernel.current_node:
1537
+ V.kernel.current_node.codegen_originating_info(
1538
+ buffer, only_once=True
1539
+ )
1540
+ if isinstance(expr, IndentedBuffer):
1541
+ if assignment:
1542
+ buffer.writeline(f"{self.prefix}{var} =")
1543
+ buffer.splice(expr)
1544
+ buffer.writeline(self.suffix)
1545
+ else:
1546
+ if assignment:
1547
+ line = f"{self.prefix}{var} = {expr}{self.suffix}"
1548
+ else:
1549
+ line = f"{expr}{self.suffix}"
1550
+ buffer.writeline(line)
1551
+ else:
1552
+ var.bounds = var.bounds.tighten(bounds)
1553
+ var.use_count += 1
1554
+
1555
+ return var
1556
+
1557
+ def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable:
1558
+ var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
1559
+ var = V.kernel.create_cse_var(var_name, bounds)
1560
+ self.varname_map[var_name] = var
1561
+ return var
1562
+
1563
+
1564
+ class CodeGen:
1565
+ def __init__(self) -> None:
1566
+ super().__init__()
1567
+ self.exit_stack = contextlib.ExitStack()
1568
+
1569
+ def __enter__(self):
1570
+ self.exit_stack.__enter__()
1571
+ return self
1572
+
1573
+ def __exit__(self, exc_type, exc_val, exc_tb):
1574
+ self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
1575
+
1576
+
1577
+ class ScopedDict:
1578
+ def __init__(self, original_dict):
1579
+ self.original_dict = original_dict
1580
+ self.new_items = {}
1581
+
1582
+ def __getitem__(self, key):
1583
+ if key in self.new_items:
1584
+ return self.new_items[key]
1585
+ return self.original_dict[key]
1586
+
1587
+ def __setitem__(self, key, value):
1588
+ self.new_items[key] = value
1589
+
1590
+ def __contains__(self, key):
1591
+ return key in self.new_items or key in self.original_dict
1592
+
1593
+ def get(self, key, default=None):
1594
+ if key in self.new_items:
1595
+ return self.new_items[key]
1596
+ return self.original_dict.get(key, default)
1597
+
1598
+
1599
+ class Kernel(CodeGen):
1600
+ newvar_prefix = ""
1601
+ suffix = ""
1602
+ overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
1603
+ # TODO: these look dead, but with all the getattr it's hard to tell...
1604
+ load_format: None = None
1605
+ store_format: None = None
1606
+
1607
+ def __init__(self, args=None, increase_kernel_count=True):
1608
+ super().__init__()
1609
+ if increase_kernel_count:
1610
+ metrics.generated_kernel_count += 1
1611
+ self.args = args or KernelArgs()
1612
+ self.loads = IndentedBuffer()
1613
+ self.compute = IndentedBuffer()
1614
+ self.stores = IndentedBuffer()
1615
+
1616
+ self.num_load = 0
1617
+ self.num_reduction = 0
1618
+
1619
+ self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
1620
+ self.must_keep_buffers = OrderedSet() # type: ignore[var-annotated]
1621
+ self.store_buffer_names = OrderedSet() # type: ignore[var-annotated]
1622
+ self._load_mask = None
1623
+ self._load_other = None
1624
+ # OrderedSet in set_current_node
1625
+ self.current_node = None
1626
+ self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None
1627
+
1628
+ self.removed_buffers = OrderedSet() # type: ignore[var-annotated]
1629
+ self.inplaced_to_remove = OrderedSet() # type: ignore[var-annotated]
1630
+
1631
+ # key: the buffer to write
1632
+ # value: the buffer to read and whose memory can be reused for
1633
+ # the buffer specified by key
1634
+ self.inplace_update_buffers = {}
1635
+ # Set minimum number of elements processed per thread.
1636
+ self.min_elem_per_thread = 1
1637
+ self.kernel_name = None
1638
+
1639
+ @contextlib.contextmanager
1640
+ def set_current_node(self, node):
1641
+ prior = self.current_node
1642
+ self.current_node = node
1643
+ self.node_to_bounds = node._body.bounds().get_bounds()
1644
+ try:
1645
+ yield
1646
+ finally:
1647
+ self.current_node = prior
1648
+
1649
+ @contextlib.contextmanager
1650
+ def swap_buffers(self, lb, cb=None, sb=None):
1651
+ def scope_cse(cse):
1652
+ new_cse = cse.clone()
1653
+ new_cse.cache = ScopedDict(cse.cache)
1654
+ new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
1655
+ new_cse.store_cache = ScopedDict(cse.store_cache)
1656
+ return new_cse
1657
+
1658
+ if cb is None:
1659
+ cb = lb
1660
+ loads = self.loads
1661
+ compute = self.compute
1662
+ stores = self.stores
1663
+ cse = self.cse
1664
+ self.loads = lb
1665
+ self.compute = cb
1666
+ self.stores = sb
1667
+ self.cse = scope_cse(cse)
1668
+ try:
1669
+ yield
1670
+ finally:
1671
+ self.loads = loads
1672
+ self.compute = compute
1673
+ self.stores = stores
1674
+ self.cse = cse
1675
+
1676
+ def load(self, name: str, index: sympy.Expr) -> CSEVariable:
1677
+ raise NotImplementedError
1678
+
1679
+ def indirect_load(self, name: str, index: sympy.Expr):
1680
+ """A load the depends on an index we have read"""
1681
+ prior = self.loads
1682
+ try:
1683
+ # put the load in the compute section as it might have deps
1684
+ self.loads = self.compute
1685
+ return self.load(name, index)
1686
+ finally:
1687
+ self.loads = prior
1688
+
1689
+ def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
1690
+ raise NotImplementedError
1691
+
1692
+ def store(
1693
+ self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
1694
+ ) -> None:
1695
+ raise NotImplementedError
1696
+
1697
+ def reduction(
1698
+ self,
1699
+ dtype: torch.dtype,
1700
+ src_dtype: torch.dtype,
1701
+ reduction_type: ReductionType,
1702
+ value: Union[CSEVariable, Tuple[CSEVariable, ...]],
1703
+ ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
1704
+ raise NotImplementedError
1705
+
1706
+ def scan(
1707
+ self,
1708
+ dtypes: Tuple[torch.dtype, ...],
1709
+ combine_fn: Callable[
1710
+ [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...]
1711
+ ],
1712
+ values: Tuple[CSEVariable, ...],
1713
+ ) -> Tuple[CSEVariable, ...]:
1714
+ raise NotImplementedError
1715
+
1716
+ def sort(
1717
+ self,
1718
+ dtypes: Tuple[torch.dtype, ...],
1719
+ values: Tuple[CSEVariable, ...],
1720
+ stable: bool,
1721
+ descending: bool,
1722
+ ) -> Tuple[CSEVariable, ...]:
1723
+ raise NotImplementedError
1724
+
1725
+ def var_ranges(self):
1726
+ raise NotImplementedError
1727
+
1728
+ def bucketize(
1729
+ self,
1730
+ values: CSEVariable,
1731
+ offsets_name: str,
1732
+ offsets_size: sympy.Expr,
1733
+ indexing_dtype: torch.dtype,
1734
+ right: bool,
1735
+ ) -> CSEVariable:
1736
+ """
1737
+ See [Note: Inductor bucketize op]
1738
+ """
1739
+ raise NotImplementedError
1740
+
1741
+ @property
1742
+ def assert_function(self) -> str:
1743
+ raise NotImplementedError
1744
+
1745
+ def indirect_assert(
1746
+ self,
1747
+ var: Union[CSEVariable, str],
1748
+ lower: Optional[str],
1749
+ upper: Optional[str],
1750
+ mask: Optional[Union[CSEVariable, str]] = None,
1751
+ ) -> str:
1752
+ if isinstance(var, CSEVariable):
1753
+ var = str(var)
1754
+ assert isinstance(var, str)
1755
+ assert lower is None or isinstance(lower, str)
1756
+ assert upper is None or isinstance(upper, str)
1757
+ if lower and upper:
1758
+ # The conditions need to be in parens because of Python's operator precedence.
1759
+ # It'd be less error-prone to use and/or/not, which is suported by triton
1760
+ cond = f"({lower} <= {var}) & ({var} < {upper})"
1761
+ cond_print = f"{lower} <= {var} < {upper}"
1762
+ elif lower:
1763
+ cond = f"{lower} <= {var}"
1764
+ cond_print = cond
1765
+ else:
1766
+ assert upper
1767
+ cond = f"{var} < {upper}"
1768
+ cond_print = cond
1769
+
1770
+ if mask:
1771
+ cond = f"({cond}) | ~({mask})"
1772
+
1773
+ return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'
1774
+
1775
+ def check_bounds(
1776
+ self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
1777
+ ):
1778
+ raise NotImplementedError
1779
+
1780
+ def index_to_str(self, index: sympy.Expr) -> str:
1781
+ raise NotImplementedError
1782
+
1783
+ def __enter__(self):
1784
+ # TODO: hoist this to top level
1785
+ class CSEProxy:
1786
+ self.name = "CSEProxy"
1787
+ vr_analysis = ValueRangeAnalysis()
1788
+
1789
+ @staticmethod
1790
+ def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
1791
+ def inner(*args, **kwargs):
1792
+ bounds = CSEProxy._bound_variable(name, *args, **kwargs)
1793
+
1794
+ value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
1795
+
1796
+ def do_cse(v):
1797
+ csevar = V.kernel.cse.generate(
1798
+ V.kernel.compute, v, bounds=bounds
1799
+ )
1800
+ csevar.update_on_args(name, args, kwargs)
1801
+ return csevar
1802
+
1803
+ return pytree.tree_map(do_cse, value)
1804
+
1805
+ return inner
1806
+
1807
+ @staticmethod
1808
+ def _bound_variable(name, *args, **kwargs):
1809
+ """
1810
+ If the variable comes from an FX node, we forward the bound we have already computed
1811
+ Else, if the variable when codegen'ing another op, we try to compute its bounds
1812
+ """
1813
+ from ..select_algorithm import TritonTemplateKernel
1814
+
1815
+ if isinstance(V.kernel, TritonTemplateKernel):
1816
+ return ValueRanges.unknown()
1817
+
1818
+ fx_node = V.interpreter.current_node
1819
+ if fx_node.target == name and self.node_to_bounds is not None:
1820
+ assert isinstance(self.node_to_bounds, dict)
1821
+ return self.node_to_bounds.get(fx_node, ValueRanges.unknown())
1822
+ elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
1823
+ # These create lots of inner strings. We would need to compute the bounds at the ops
1824
+ # We will also likely not get much from computing VRs on these nodes
1825
+ if any(
1826
+ s in fx_node.target
1827
+ for s in ("set_indirect", "reduction", "scan")
1828
+ ):
1829
+ return ValueRanges.unknown()
1830
+
1831
+ # We assume that the inputs come from `ops.` and are not strings. If you want to generate
1832
+ # intermediary strings, wrap them in CSE variables with properly initialised bounds.
1833
+
1834
+ # If there is no FX bound but we know how to compute one we do so
1835
+ assert not kwargs
1836
+
1837
+ def arg_to_bound(x):
1838
+ if isinstance(x, CSEVariable):
1839
+ return x.bounds
1840
+ elif isinstance(x, sympy.Expr):
1841
+ return bound_sympy(x)
1842
+ else:
1843
+ return x
1844
+
1845
+ arg_bounds = list(map(arg_to_bound, args))
1846
+ return getattr(CSEProxy.vr_analysis, name)(*arg_bounds)
1847
+ else:
1848
+ return ValueRanges.unknown()
1849
+
1850
+ @staticmethod
1851
+ def indirect_indexing(
1852
+ var: CSEVariable,
1853
+ size: Union[sympy.Expr, int],
1854
+ check: bool = True,
1855
+ wrap_neg=True,
1856
+ ):
1857
+ if isinstance(size, int):
1858
+ size = sympy.Integer(size)
1859
+ assert isinstance(size, sympy.Expr), size
1860
+ # Skip CSE since this doesn't return an expression
1861
+
1862
+ if var.bounds.lower < 0: # type: ignore[operator]
1863
+ if wrap_neg:
1864
+ stm = ops.add(var, ops.index_expr(size, torch.long))
1865
+ # Mixed negative and non-negative
1866
+ if var.bounds.upper >= 0: # type: ignore[operator]
1867
+ lt = ops.lt(var, 0)
1868
+ stm = ops.where(lt, stm, var)
1869
+ else:
1870
+ stm = var
1871
+
1872
+ # Propagate bounds as we know how to compute them properly
1873
+ new_bounds = ValueRanges.unknown()
1874
+ if var.bounds != ValueRanges.unknown() and isinstance(
1875
+ size, sympy.Number
1876
+ ):
1877
+ # Take the negative part of the bound and add size to it
1878
+ # Then take union of that and the positive part
1879
+ # This is a tighter bound than that of a generic ops.where, as we have info on the cond
1880
+ neg_bounds = var.bounds & ValueRanges(-int_oo, -1)
1881
+ new_bounds = ValueRanges(
1882
+ neg_bounds.lower + size, neg_bounds.upper + size
1883
+ )
1884
+ # We don't have a good way of representing the empty range
1885
+ if var.bounds.upper >= 0: # type: ignore[operator]
1886
+ pos = var.bounds & ValueRanges(0, int_oo)
1887
+ new_bounds = new_bounds | pos
1888
+
1889
+ var = self.cse.generate(self.compute, stm, bounds=new_bounds)
1890
+
1891
+ sympy_var = parent_handler.indirect_indexing(var, size, check)
1892
+ if generate_assert(check):
1893
+ assert_lower = not (var.bounds.lower >= 0)
1894
+ # value ranges cannot x < s when x and s are symbols
1895
+ assert_upper = not isinstance(size, sympy.Number) or not (
1896
+ var.bounds.upper < size
1897
+ )
1898
+ self.check_bounds(sympy_var, size, assert_lower, assert_upper)
1899
+ return sympy_var
1900
+
1901
+ @staticmethod
1902
+ def check_bounds(
1903
+ expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
1904
+ ):
1905
+ return self.check_bounds(expr, size, lower, upper)
1906
+
1907
+ @staticmethod
1908
+ def load(name: str, index: sympy.Expr) -> CSEVariable:
1909
+ if name in self.cse.invalidated_stores:
1910
+ # A load from an invalidated store requires us to
1911
+ # keep the actual buffer around
1912
+ V.kernel.must_keep_buffers.add(name)
1913
+ if free_symbol_is_type(index, SymT.TMP):
1914
+ return self.indirect_load(name, index)
1915
+ store_cache = self.cse.store_cache
1916
+ if name in store_cache:
1917
+ return store_cache[name]
1918
+ out = self.load(name, index)
1919
+ # count load that is not in the store_cache, and also not in the
1920
+ # cse cache.
1921
+ if out.use_count == 1:
1922
+ self.num_load += 1
1923
+ return out
1924
+
1925
+ @staticmethod
1926
+ def _update_store_cache(name: str, value: CSEVariable):
1927
+ self.cse.store_cache[name] = value
1928
+ if self.current_node and name in V.graph.name_to_buffer:
1929
+ buf = self.current_node.get_output(name)
1930
+ for other_name in buf.get_mutations():
1931
+ self.cse.store_cache[other_name] = value
1932
+
1933
+ @staticmethod
1934
+ def store(
1935
+ name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
1936
+ ) -> None:
1937
+ self.store_buffer_names.add(name)
1938
+ if mode is None:
1939
+ CSEProxy._update_store_cache(name, value)
1940
+ if name not in V.graph.removed_buffers:
1941
+ return self.store(name, index, value, mode=mode)
1942
+ else:
1943
+ return None # type: ignore[return-value]
1944
+
1945
+ @staticmethod
1946
+ def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
1947
+ self.store_buffer_names.add(name)
1948
+ CSEProxy._update_store_cache(name, value)
1949
+
1950
+ if name not in V.graph.removed_buffers:
1951
+ return self.store_reduction(name, index, value)
1952
+
1953
+ @staticmethod
1954
+ def reduction(
1955
+ dtype: torch.dtype,
1956
+ src_dtype: torch.dtype,
1957
+ reduction_type: ReductionType,
1958
+ value: Union[CSEVariable, Tuple[CSEVariable, ...]],
1959
+ ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
1960
+ self.num_reduction += 1
1961
+ return self.reduction(dtype, src_dtype, reduction_type, value)
1962
+
1963
+ @staticmethod
1964
+ def scan(
1965
+ dtypes: Tuple[torch.dtype, ...],
1966
+ combine_fn: Callable[
1967
+ [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]],
1968
+ Tuple[CSEVariable, ...],
1969
+ ],
1970
+ values: Tuple[CSEVariable, ...],
1971
+ ) -> Tuple[CSEVariable, ...]:
1972
+ return self.scan(dtypes, combine_fn, values)
1973
+
1974
+ @staticmethod
1975
+ def sort(
1976
+ dtypes: Tuple[torch.dtype, ...],
1977
+ values: Tuple[CSEVariable, ...],
1978
+ stable: bool,
1979
+ descending: bool,
1980
+ ) -> Tuple[CSEVariable, ...]:
1981
+ return self.sort(dtypes, values, stable, descending)
1982
+
1983
+ @staticmethod
1984
+ def bucketize(
1985
+ values: CSEVariable,
1986
+ offsets_name: str,
1987
+ offsets_size: sympy.Expr,
1988
+ indexing_dtype: torch.dtype,
1989
+ right: bool,
1990
+ ) -> CSEVariable:
1991
+ """
1992
+ [Note: Inductor bucketize op]
1993
+
1994
+ Given values (tensor) and offsets_name (reference to the name of a 1D
1995
+ tensor), calculate the bucket that each value belongs to.
1996
+
1997
+ e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
1998
+ return = [ 0, 1, 1, 1, 1, 3, 3, 4].
1999
+
2000
+ When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
2001
+ When right == True, bucket i refers to range [offsets[i], offsets[i+1]).
2002
+
2003
+ Offsets must be non-decreasing or the result is undefined.
2004
+ """
2005
+ return self.bucketize(
2006
+ values, offsets_name, offsets_size, indexing_dtype, right
2007
+ )
2008
+
2009
+ # Use mypy to check protocol implemented correctly
2010
+ def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
2011
+ return h
2012
+
2013
+ super().__enter__()
2014
+ assert self.overrides
2015
+ parent_handler = self.overrides(V.get_ops_handler())
2016
+ self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
2017
+ self.exit_stack.enter_context(V.set_kernel_handler(self))
2018
+ return self
2019
+
2020
+ def __exit__(self, exc_type, exc_val, exc_tb):
2021
+ """
2022
+ Note that V.graph.scheduler can be None when codegening triton template
2023
+ kernels.
2024
+ """
2025
+ if V.graph.scheduler:
2026
+ V.graph.scheduler.remove_kernel_local_buffers()
2027
+ super().__exit__(exc_type, exc_val, exc_tb)
2028
+
2029
+ def rename_indexing(self, index) -> sympy.Expr:
2030
+ # adds the necessary kernel args for index expressions
2031
+ # and renames variables in index expressions to kernel arg names
2032
+ if isinstance(index, (list, tuple)):
2033
+ return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
2034
+ index = V.graph.sizevars.simplify(index)
2035
+ sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
2036
+ replacements = {
2037
+ x: self.args.size(x)
2038
+ for x in sorted_symbols
2039
+ if symbol_is_type(
2040
+ x,
2041
+ (
2042
+ SymT.UNBACKED_INT,
2043
+ SymT.SIZE,
2044
+ SymT.PRECOMPUTED_SIZE,
2045
+ ),
2046
+ )
2047
+ }
2048
+ return sympy_subs(index, replacements)
2049
+
2050
+ def create_cse_var(self, *args, **kwargs):
2051
+ return CSEVariable(*args, **kwargs)
2052
+
2053
+
2054
+ @dataclasses.dataclass
2055
+ class OptimizationContext:
2056
+ key: ClassVar[str] = "opt_ctx"
2057
+
2058
+ dtype: Optional[torch.dtype] = None
2059
+ ops_name: str = ""
2060
+
2061
+
2062
+ @functools.lru_cache(None)
2063
+ def jinja2_env():
2064
+ try:
2065
+ import jinja2
2066
+
2067
+ return jinja2.Environment(
2068
+ undefined=jinja2.StrictUndefined,
2069
+ )
2070
+ except ImportError:
2071
+ return None
2072
+
2073
+
2074
+ class KernelTemplate:
2075
+ """
2076
+ Base class for defining kernel templates.
2077
+
2078
+ Children classes: TritonTemplate, CUDATemplate
2079
+ """
2080
+
2081
+ @staticmethod
2082
+ def indent_except_first(source: str, num_indents: int, indents_spacing=4):
2083
+ lines = source.splitlines(True)
2084
+ if len(lines) > 1:
2085
+ lines[1:] = [
2086
+ (" " * indents_spacing * num_indents) + line for line in lines[1:]
2087
+ ]
2088
+ return "".join(lines)
2089
+
2090
+ @staticmethod
2091
+ def _template_from_string(source):
2092
+ env = jinja2_env()
2093
+ if env is not None:
2094
+ env.filters["indent_except_first"] = KernelTemplate.indent_except_first
2095
+ from jinja2 import TemplateSyntaxError
2096
+
2097
+ class DetailedTemplateSyntaxError(TemplateSyntaxError):
2098
+ def __init__(self, original_error):
2099
+ super().__init__(
2100
+ original_error.message,
2101
+ original_error.lineno,
2102
+ original_error.name,
2103
+ original_error.filename,
2104
+ )
2105
+ self.original_error = original_error
2106
+
2107
+ def __str__(self):
2108
+ error_info = f"Error in template at line {self.lineno}\n"
2109
+ error_info += f"Error message: {self.message}\n"
2110
+ if hasattr(self.original_error, "source"):
2111
+ lines = self.original_error.source.split("\n")
2112
+ error_info += "Context:\n"
2113
+ start = max(0, self.lineno - 2)
2114
+ end = min(len(lines), self.lineno + 2)
2115
+ for i in range(start, end):
2116
+ if i == self.lineno - 1:
2117
+ error_info += f"{i+1}: --> {lines[i]}\n"
2118
+ if hasattr(self.original_error, "column"):
2119
+ error_info += (
2120
+ " "
2121
+ + " " * (self.original_error.column - 1)
2122
+ + "^\n"
2123
+ )
2124
+ else:
2125
+ error_info += f"{i+1}: {lines[i]}\n"
2126
+ return error_info
2127
+
2128
+ try:
2129
+ return env.from_string(source)
2130
+ except TemplateSyntaxError as e:
2131
+ raise DetailedTemplateSyntaxError(e) from e
2132
+
2133
+ return None
2134
+
2135
+ @staticmethod
2136
+ def _fake_get_dtype(fake_out):
2137
+ _get_dtype_real = V.graph.get_dtype
2138
+
2139
+ def get_dtype(name):
2140
+ if name == fake_out.get_name():
2141
+ return fake_out.get_dtype()
2142
+ return _get_dtype_real(name)
2143
+
2144
+ return get_dtype
2145
+
2146
+ def __init__(self, name: str):
2147
+ self.name = name
2148
+
2149
+ def maybe_append_choice(self, choices, **kwargs):
2150
+ """
2151
+ Maybe generates a new ChoiceCaller and appends it into existing choices.
2152
+
2153
+ choices: A list of ChoiceCallers.
2154
+ kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
2155
+ """
2156
+
2157
+ try:
2158
+ choices.append(self.generate(**kwargs))
2159
+ except NotImplementedError as e:
2160
+ pass
2161
+
2162
+ def generate(self, **kwargs) -> "torch._inductor.ir.ChoiceCaller":
2163
+ """
2164
+ Generates a ChoiceCaller instance from the given arguments.
2165
+ """
2166
+
2167
+ raise NotImplementedError
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_gemm_template.py ADDED
@@ -0,0 +1,1043 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import contextlib
3
+ import logging
4
+ import math
5
+ from functools import lru_cache
6
+ from typing import Any, Callable, cast, List, Optional, Set, Union
7
+ from unittest.mock import patch
8
+
9
+ import torch
10
+ import torch.utils
11
+
12
+ from ..._dynamo.utils import counters
13
+ from .. import config, ir, lowering as L
14
+ from ..kernel.mm_common import mm_args
15
+ from ..select_algorithm import DataProcessorTemplateWrapper
16
+ from ..utils import cache_on_self, has_free_symbols, parallel_num_threads
17
+ from ..virtualized import ops, V
18
+ from .cpp import get_export_declaration
19
+ from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm, LayoutType
20
+ from .cpp_template import CppTemplate
21
+ from .cpp_template_kernel import CppTemplateKernel
22
+ from .cpp_utils import (
23
+ create_epilogue_with_attr,
24
+ DTYPE_TO_CPP,
25
+ GemmBlocking,
26
+ get_gemm_template_output_and_compute_dtype,
27
+ )
28
+
29
+
30
+ log = logging.getLogger(__name__)
31
+
32
+ GEMM_TEMPLATE = r"""
33
+ {{template.header().getvalue()}}
34
+
35
+ {{micro_gemm.codegen_define(kernel)}}
36
+
37
+ {%- if x_scale is not none %}
38
+ {%- set kernel_args = {"X": X, "W": W, "inp": inp, "x_scale": x_scale, "x_zp": x_zp, "w_scale": w_scale, "w_zp": w_zp,} %}
39
+ {%- else %}
40
+ {%- set kernel_args = {"X": X, "W": W, "inp": inp} %}
41
+ {%- endif %}
42
+
43
+ extern "C" {{export_declaration}}
44
+ {{kernel.def_kernel(inputs=kernel_args, outputs={"Y": Y}, aliases=aliases)}}
45
+ {
46
+ {{kernel.maybe_codegen_profile()}}
47
+ constexpr int64_t num_threads = {{num_threads}};
48
+ constexpr int64_t N = {{N}};
49
+ constexpr int64_t K = {{K}};
50
+ constexpr int64_t Mr = {{micro_gemm.register_blocking.block_m}};
51
+ constexpr int64_t Nr = {{micro_gemm.register_blocking.block_n}};
52
+ constexpr int64_t Kr = {{micro_gemm.register_blocking.block_k}};
53
+ constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
54
+ constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
55
+
56
+ {%- if is_dynamic_M %}
57
+ const int64_t M = {{kernel.size(GemmOut, 0)}};
58
+ const int64_t Mr_blocks = (M + Mr - 1) / Mr;
59
+ {%- if num_threads > 1 %}
60
+ int64_t Mt_blocks, Nt_blocks, Kt_blocks;
61
+ mm_get_thread_blocking(num_threads, {{config.cpp.gemm_max_k_slices}}, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks);
62
+ {%- else %}
63
+ const auto Mt_blocks = Mr_blocks;
64
+ const auto Nt_blocks = Nr_blocks;
65
+ const auto Kt_blocks = Kr_blocks;
66
+ {%- endif %}
67
+ int64_t Mc_blocks, Nc_blocks, Kc_blocks;
68
+ uint32_t L1_cache_size = {{L1_cache_size}};
69
+ uint32_t L2_cache_size = {{L2_cache_size}};
70
+ mm_get_cache_blocking<{{kernel.dtype(X)}}, {{kernel.dtype(W)}}>(
71
+ num_threads,
72
+ M,
73
+ N,
74
+ K,
75
+ Mr,
76
+ Nr,
77
+ Kr,
78
+ Mt_blocks,
79
+ Nt_blocks,
80
+ Kt_blocks,
81
+ Mc_blocks,
82
+ Nc_blocks,
83
+ Kc_blocks,
84
+ L1_cache_size,
85
+ L2_cache_size
86
+ );
87
+ const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
88
+ const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
89
+ const int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
90
+ {%- else %}
91
+ constexpr int64_t M = {{kernel.size(GemmOut, 0)}};
92
+ constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
93
+ constexpr int64_t Mt_blocks = {{template.thread_blocking().block_m}};
94
+ constexpr int64_t Nt_blocks = {{template.thread_blocking().block_n}};
95
+ constexpr int64_t Kt_blocks = {{template.thread_blocking().block_k}};
96
+ constexpr int64_t Mc_blocks = {{template.cache_blocking().block_m}};
97
+ constexpr int64_t Nc_blocks = {{template.cache_blocking().block_n}};
98
+ constexpr int64_t Kc_blocks = {{template.cache_blocking().block_k}};
99
+ constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
100
+ constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
101
+ constexpr int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
102
+ {%- endif %}
103
+
104
+ // make sure all partitions are assigned
105
+ {{kernel.assert_function}}(
106
+ Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= Mr_blocks * Nr_blocks * Kr_blocks,
107
+ "Not all partitions are assigned."
108
+ );
109
+
110
+ {%- if maybe_k_slicing %}
111
+ std::unique_ptr<std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[]> local_buf_ptrs;
112
+ if (num_k_slices > 1) {
113
+ local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_k_slices]);
114
+ }
115
+ {%- endif %}
116
+
117
+ {%- if num_threads > 1 %}
118
+ #pragma omp parallel num_threads({{num_threads}})
119
+ {
120
+ const int tid = omp_get_thread_num();
121
+ int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end;
122
+ mm_get_thread_blocks(
123
+ tid, Mr_blocks, Nr_blocks, Kr_blocks, Mt_blocks, Nt_blocks, Kt_blocks,
124
+ m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end);
125
+ {%- if maybe_k_slicing %}
126
+ const int64_t k_group_id = tid / num_k_slices;
127
+ const int64_t k_slice_id = tid % num_k_slices;
128
+ {%- endif %}
129
+ {%- else %}
130
+ {
131
+ const int tid = 0;
132
+ const int64_t m_block_start = 0;
133
+ const int64_t m_block_end = Mr_blocks;
134
+ const int64_t n_block_start = 0;
135
+ const int64_t n_block_end = Nr_blocks;
136
+ const int64_t k_block_start = 0;
137
+ const int64_t k_block_end = Kr_blocks;
138
+ {%- endif %}
139
+ {{ micro_gemm.codegen_init(kernel) }}
140
+ {%- if use_local_acc %}
141
+ {%- set acc_buf_name = "local_acc_buf" %}
142
+ {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }}
143
+ {%- endif %}
144
+ for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
145
+ const int64_t m_start = mc * Mr;
146
+ const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
147
+ const int64_t m_size = m_end - m_start;
148
+ for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
149
+ const int64_t n_start = nc * Nr;
150
+ const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
151
+ const int64_t n_size = n_end - n_start;
152
+ // NB: assume we pad N, nc_block_end won't exceed padded N here.
153
+ const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
154
+ {%- if use_local_acc %}
155
+ {%- set acc = kernel.local_buffers[acc_buf_name] %}
156
+ {{ kernel.reinit_buffer_if_null(acc_buf_name) }}
157
+ {%- else %}
158
+ {%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_end")]) %}
159
+ {%- endif %}
160
+ for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
161
+ int64_t k_start = kc * Kr;
162
+ int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
163
+ {%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %}
164
+ for (int64_t nci = nc; nci < nc_block_end; nci++) {
165
+ {%- set acc_slice = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %}
166
+ {%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %}
167
+ {%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
168
+ if (kc == k_block_start) {
169
+ {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=False)|indent(28, false) }}
170
+ } else {
171
+ {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=True)|indent(28, false) }}
172
+ }
173
+ }
174
+ }
175
+ {%- if maybe_k_slicing %}
176
+ if (num_k_slices > 1) {
177
+ const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
178
+ local_buf_ptrs[mxn_cache_block_id * num_k_slices + k_slice_id].reset({{ kernel.release_buffer(acc_buf_name) }});
179
+ } else
180
+ {%- endif %}
181
+ {
182
+ {%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_end")]) %}
183
+ {%- set tile_acc = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("0", "n_end - n_start")]) %}
184
+ {{ kernel.store_output(
185
+ tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
186
+ )|indent(20, false)
187
+ }}
188
+ }
189
+ }
190
+ }
191
+ {%- if maybe_k_slicing %}
192
+ if (num_k_slices > 1) {
193
+ #pragma omp barrier
194
+ for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
195
+ // We slice M-dim and each thread in the k-slicing group works on a slice
196
+ const int64_t m_start_unsliced = mc * Mr;
197
+ const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
198
+ const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced;
199
+ const int64_t m_slice_size = (m_size_unsliced + num_k_slices - 1) / num_k_slices;
200
+ const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced);
201
+ const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced);
202
+ const int64_t m_size = m_end - m_start;
203
+ const int64_t m_offset = m_start - m_start_unsliced;
204
+ for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
205
+ const int64_t n_start = nc * Nr;
206
+ const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
207
+ const int64_t n_size = n_end - n_start;
208
+ const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
209
+ auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_k_slices].get();
210
+ for (int64_t other_slice = 1; other_slice < num_k_slices; other_slice++) {
211
+ auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_k_slices + other_slice].get();
212
+ for (int64_t m = m_offset; m < m_offset + m_size; m++) {
213
+ #pragma omp simd
214
+ for (int64_t n = 0; n < n_size; n++) {
215
+ {{acc_buf_name}}[m*Nr + n] += other_acc[m*Nr + n];
216
+ }
217
+ }
218
+ }
219
+ {%- set tile_acc_m_slice = kernel.slice_nd(tile_acc, [("m_offset", "m_offset + m_end - m_start"), ()]) %}
220
+ {{ kernel.store_output(
221
+ tile_Y, tile_acc_m_slice, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
222
+ )|indent(20, false)
223
+ }}
224
+ }
225
+ }
226
+ }
227
+ {%- endif %}
228
+ {{ micro_gemm.codegen_finalize(kernel) }}
229
+ }
230
+ }
231
+ """
232
+
233
+
234
+ def get_padded_n(n, block_n):
235
+ return (n + block_n - 1) // block_n * block_n
236
+
237
+
238
+ class CppPackedGemmTemplate(CppTemplate):
239
+ def __init__(
240
+ self,
241
+ input_nodes,
242
+ layout: ir.Layout,
243
+ num_threads: int,
244
+ register_blocking: GemmBlocking,
245
+ beta=1,
246
+ alpha=1,
247
+ has_bias=False,
248
+ epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
249
+ ) -> None:
250
+ assert layout.dtype in [torch.float, torch.bfloat16, torch.half, torch.uint8]
251
+ super().__init__(
252
+ "packed_gemm",
253
+ input_nodes,
254
+ layout,
255
+ num_threads,
256
+ epilogue_creator=epilogue_creator,
257
+ )
258
+ self.beta = beta
259
+ self.alpha = alpha
260
+ self.has_bias = has_bias
261
+ self.register_blocking = register_blocking
262
+ m, n = layout.size
263
+ _, k = input_nodes[0].get_size()
264
+ self.m, self.n, self.k = m, n, k
265
+ self.padded_n = get_padded_n(n, self.register_blocking.block_n)
266
+ self.is_dynamic_M = has_free_symbols((m,))
267
+
268
+ @cache_on_self
269
+ def thread_blocking(self) -> GemmBlocking:
270
+ """
271
+ NOTE [Thread blocking in Cpp GEMM]
272
+ We use simple heuristics to decide the thread blocking:
273
+ 1. Make sure all threads are occupied as much as possible.
274
+ 2. For (m, n) blocks, favor more square-sized thread blocks for better data reuse.
275
+ 3. If (m, n) blocks cannot occupy all the threads, we consider k-slicing.
276
+ TODO(jgong5): allow tuning various blocking options
277
+ """
278
+
279
+ @lru_cache(maxsize=100)
280
+ def get_factors(number):
281
+ factors = []
282
+ for i in range(int(number**0.5), 0, -1):
283
+ if number % i == 0:
284
+ factors.append(number // i)
285
+ factors.append(i)
286
+ return factors
287
+
288
+ def get_blocking(m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks):
289
+ thread_block_k = math.ceil(k_blocks / k_factor)
290
+ thread_block_n = math.ceil(n_blocks / n_factor)
291
+ thread_block_m = math.ceil(m_blocks / m_factor)
292
+ return GemmBlocking(thread_block_m, thread_block_n, thread_block_k)
293
+
294
+ assert (
295
+ not self.is_dynamic_M
296
+ ), "Unable to determine thread blocking for dynamic M."
297
+ register_blocking = self.register_blocking
298
+ m_blocks = math.ceil(self.m / register_blocking.block_m)
299
+ n_blocks = math.ceil(self.n / register_blocking.block_n)
300
+ k_blocks = math.ceil(self.k / register_blocking.block_k)
301
+ factors = get_factors(self.num_threads)
302
+ assert len(factors) > 0
303
+
304
+ if config.cpp.gemm_thread_factors is not None:
305
+ factors = [int(i) for i in config.cpp.gemm_thread_factors.split(",")]
306
+ assert len(factors) == 3
307
+ assert math.prod(factors) == self.num_threads
308
+ return get_blocking(
309
+ factors[0], factors[1], factors[2], m_blocks, n_blocks, k_blocks
310
+ )
311
+
312
+ # we favor square-sized thread blocks for good data reuse
313
+ def get_better_blocking(blocking, best_blocking):
314
+ if best_blocking is None:
315
+ best_blocking = blocking
316
+ else:
317
+ block_m_size = blocking.block_m * register_blocking.block_m
318
+ block_n_size = blocking.block_n * register_blocking.block_n
319
+ best_block_m_size = best_blocking.block_m * register_blocking.block_m
320
+ best_block_n_size = best_blocking.block_n * register_blocking.block_n
321
+ if blocking.block_k > best_blocking.block_k:
322
+ best_blocking = blocking
323
+ elif (
324
+ blocking.block_k == best_blocking.block_k
325
+ and block_m_size + block_n_size
326
+ < best_block_m_size + best_block_n_size
327
+ ):
328
+ best_blocking = blocking
329
+ return best_blocking
330
+
331
+ best_blocking = None
332
+ # check if we can have a thread-blocking to occupy all threads without k-slicing
333
+ for n_factor in factors:
334
+ m_factor = self.num_threads // n_factor
335
+ if n_blocks >= n_factor and m_blocks >= m_factor:
336
+ blocking = get_blocking(
337
+ m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks
338
+ )
339
+ best_blocking = get_better_blocking(blocking, best_blocking)
340
+
341
+ if best_blocking is None:
342
+ for k_factor in factors:
343
+ if k_blocks >= k_factor and (
344
+ config.cpp.gemm_max_k_slices == 0
345
+ or k_factor <= config.cpp.gemm_max_k_slices
346
+ ):
347
+ n_factors = get_factors(self.num_threads // k_factor)
348
+ for n_factor in n_factors:
349
+ m_factor = (self.num_threads // k_factor) // n_factor
350
+ if n_blocks >= n_factor and m_blocks >= m_factor:
351
+ blocking = get_blocking(
352
+ m_factor,
353
+ n_factor,
354
+ k_factor,
355
+ m_blocks,
356
+ n_blocks,
357
+ k_blocks,
358
+ )
359
+ best_blocking = get_better_blocking(blocking, best_blocking)
360
+
361
+ if best_blocking is None:
362
+ for n_factor in factors:
363
+ m_factor = self.num_threads // n_factor
364
+ if n_blocks >= n_factor or m_blocks >= m_factor:
365
+ blocking = get_blocking(
366
+ m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks
367
+ )
368
+ best_blocking = get_better_blocking(blocking, best_blocking)
369
+
370
+ assert best_blocking is not None
371
+ return best_blocking
372
+
373
+ @cache_on_self
374
+ def cache_blocking(self) -> GemmBlocking:
375
+ def get_cache_blocking(register_blocking, thread_blocking):
376
+ Mr = register_blocking.block_m
377
+ Nr = register_blocking.block_n
378
+ Kr = register_blocking.block_k
379
+
380
+ Mt_blocks = thread_blocking.block_m
381
+ Nt_blocks = thread_blocking.block_n
382
+ Kt_blocks = thread_blocking.block_k
383
+
384
+ if config.cpp.gemm_cache_blocking is not None:
385
+ blockings = [int(i) for i in config.cpp.gemm_cache_blocking.split(",")]
386
+ assert len(blockings) == 3
387
+ Mc_blocks, Nc_blocks, Kc_blocks = blockings
388
+ return (
389
+ min(Mc_blocks, Mt_blocks),
390
+ min(Nc_blocks, Nt_blocks),
391
+ min(Kc_blocks, Kt_blocks),
392
+ )
393
+
394
+ # The ratios below are empirically determined to decide
395
+ # the effective sizes of L1 and L2.
396
+ # TODO: tune the factor here
397
+ L1_limit_factor = 0.8
398
+ L2_limit_factor = 0.5
399
+
400
+ L1_cache_size = (
401
+ torch._C._cpu._L1d_cache_size()
402
+ ) # per core cache size in Bytes
403
+ assert (
404
+ L1_cache_size > 0
405
+ ), f"Expect L1_cache_size > 0 but got {L1_cache_size}"
406
+ L1 = L1_cache_size * L1_limit_factor
407
+
408
+ L2_cache_size = (
409
+ torch._C._cpu._L2_cache_size()
410
+ ) # per core cache size in Bytes
411
+ assert (
412
+ L2_cache_size > 0
413
+ ), f"Expect L2_cache_size > 0 but got {L2_cache_size}"
414
+ L2 = L2_cache_size * L2_limit_factor
415
+
416
+ def get_num_byte(dtype):
417
+ return torch.tensor([], dtype=dtype).element_size()
418
+
419
+ num_byte_A = get_num_byte(self.input_nodes[0].get_dtype())
420
+ num_byte_B = get_num_byte(self.input_nodes[1].get_dtype())
421
+
422
+ # NOTE [CPP GEMM Cache Blocking Algorithm]
423
+ # Our overall strategy is to
424
+ # 1) Make cache blocks of B L1-reside and reused by multiple rows of A, i.e. Mc.
425
+ # Here, B is Kc x Nr where Nr is a single register block. We use L1 size to
426
+ # decide Kc. We want to make Mc large enough to better reuse B.
427
+ # 2) Make cache blocks of A L2-reside, which would limit Mc. We want to reuse A
428
+ # along N, where we have two sub-strategies (see notes below) to decide Mc and Nc.
429
+
430
+ # Step 1: Decide Kc assuming B block is L1-reside.
431
+ size_cache_B = Kr * Kt_blocks * Nr * num_byte_B
432
+ Kc_blocks = Kt_blocks
433
+ if size_cache_B > L1:
434
+ Kc_blocks = math.floor(L1 / (Kr * Nr * num_byte_B))
435
+
436
+ # Step 2: Decide Mc assuming A block is L2-reside.
437
+ min_Mc_ratio = 2 # TODO(jgong5): something to tune?
438
+ min_Mc_blocks = math.ceil(min_Mc_ratio * Mr / Nr)
439
+ assert min_Mc_blocks >= 1
440
+ Kt_bytes = Kt_blocks * Kr * num_byte_A
441
+ if min_Mc_blocks * Mr * Kt_bytes < L2:
442
+ # Strategy 1: A (Mc x Kt) resides in L2 and reused by all Nt
443
+ # when Nc_blocks is kept 1. Mc should be large enough (>= min_Mc_blocks)
444
+ # to reuse B (Kc x Nr) in L1. This makes C (Mc x Nr) small enough to reside
445
+ # in L1.
446
+ Mc_blocks = min(Mt_blocks, math.floor(L2 / (Mr * Kt_bytes)))
447
+ Nc_blocks = 1
448
+ else:
449
+ # Strategy 2: Kt is too large to hold A (Mc x Kt) in L2, we reuse
450
+ # A (Mc x Kc) in L2 by B (Kc x Nc). C (Mc x Nc) resides in L2.
451
+ Mc_blocks = Mt_blocks
452
+ Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks)
453
+ Nc_bytes = Nc_blocks * Nr * 4 # assume C or acc is float32/int32
454
+ Kc_bytes = Kc_blocks * Kr * num_byte_A
455
+ if Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2:
456
+ # The following is the solution for 4*Mc*Nc + Mc*Kc_bytes = L2,
457
+ # assuming Mc == Nc for good data reuse.
458
+ M_max = (math.sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8
459
+ if M_max < Mc_blocks * Mr:
460
+ Mc_blocks = math.floor(M_max / Mr)
461
+ Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks)
462
+
463
+ return Mc_blocks, Nc_blocks, Kc_blocks
464
+
465
+ assert (
466
+ not self.is_dynamic_M
467
+ ), "Unable to determine cache blocking for dynamic M."
468
+ register_blocking = self.register_blocking
469
+ thread_blocking = self.thread_blocking()
470
+
471
+ return GemmBlocking(*get_cache_blocking(register_blocking, thread_blocking))
472
+
473
+ def log_blockings(self):
474
+ log.debug(f"Register blocking: {self.register_blocking}") # noqa: G004
475
+ if self.is_dynamic_M:
476
+ # thread and cache blockings are determined at runtime for dynamic shapes
477
+ return
478
+ log.debug(f"Cache blocking: {self.cache_blocking()}") # noqa: G004
479
+ thread_blocking = self.thread_blocking()
480
+ log.debug(f"Thread blocking: {thread_blocking}") # noqa: G004
481
+
482
+ def get_occupancy():
483
+ m_blocks = math.ceil(self.m / self.register_blocking.block_m)
484
+ n_blocks = math.ceil(self.n / self.register_blocking.block_n)
485
+ k_blocks = math.ceil(self.k / self.register_blocking.block_k)
486
+ m = math.ceil(m_blocks / thread_blocking.block_m)
487
+ n = math.ceil(n_blocks / thread_blocking.block_n)
488
+ k = math.ceil(k_blocks / thread_blocking.block_k)
489
+ return (m, n, k)
490
+
491
+ log.debug(
492
+ f"Number of threads: {self.num_threads}, occupancy: {get_occupancy()}" # noqa: G004
493
+ )
494
+
495
+ def maybe_k_slicing(self):
496
+ if self.num_threads == 1:
497
+ return False
498
+ if self.is_dynamic_M:
499
+ # TODO(jgong5): perhaps use size hint to decide?
500
+ return True
501
+ register_blocking = self.register_blocking
502
+ k_blocks = math.ceil(self.k / register_blocking.block_k)
503
+ thread_blocking = self.thread_blocking()
504
+ return k_blocks > thread_blocking.block_k
505
+
506
+ @staticmethod
507
+ def add_choices(
508
+ choices,
509
+ layout,
510
+ input_nodes,
511
+ beta=1,
512
+ alpha=1,
513
+ has_bias=False,
514
+ trans_w=False,
515
+ input_indices=None,
516
+ epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
517
+ ):
518
+ if input_indices is None:
519
+ input_indices = list(range(len(input_nodes)))
520
+
521
+ def reorder_and_filter(inputs, layout_or_out):
522
+ if has_bias:
523
+ assert len(input_indices) >= 3
524
+ # Assume the input order is [inp, x, w] and we reorder it to [x, w, inp]
525
+ inp_idx = input_indices[0]
526
+ x_idx = input_indices[1]
527
+ w_idx = input_indices[2]
528
+ return [
529
+ inputs[x_idx],
530
+ inputs[w_idx],
531
+ inputs[inp_idx],
532
+ *[inputs[idx] for idx in input_indices[3:]],
533
+ ], layout_or_out
534
+ else:
535
+ assert len(input_indices) >= 2
536
+ return [inputs[idx] for idx in input_indices], layout_or_out
537
+
538
+ def maybe_to_dense(inputs, layout_or_out):
539
+ new_inputs = list(inputs)
540
+ if isinstance(inputs[1], torch.Tensor):
541
+ W = inputs[1]
542
+ new_inputs[1] = W.to_dense() if W.is_mkldnn else W
543
+ return new_inputs, layout_or_out
544
+
545
+ def normalize_shapes(inputs, layout_or_out):
546
+ if not trans_w:
547
+ return inputs, layout_or_out
548
+ new_inputs = list(inputs)
549
+ X = inputs[0]
550
+ W = inputs[1]
551
+ B = inputs[2] if has_bias else None
552
+ if isinstance(W, ir.IRNode):
553
+ if trans_w:
554
+ if not isinstance(W, ir.TensorBox):
555
+ W = ir.TensorBox(W)
556
+ W = L.permute(W, [1, 0])
557
+ else:
558
+ if trans_w:
559
+ assert isinstance(W, torch.Tensor)
560
+ W = W.transpose(0, 1)
561
+ if B is not None:
562
+ if isinstance(B, ir.IRNode):
563
+ if not isinstance(B, ir.TensorBox):
564
+ B = ir.TensorBox(B)
565
+ B = L.expand(B, (X.get_size()[0], B.get_size()[-1]))
566
+ else:
567
+ assert isinstance(B, torch.Tensor)
568
+ B = B.expand(X.shape[0], B.shape[-1])
569
+ new_inputs[1] = W
570
+ if B is not None:
571
+ new_inputs[2] = B
572
+ return new_inputs, layout_or_out
573
+
574
+ # TODO(jgong5): decide proper number of threads per problem size
575
+ num_threads = parallel_num_threads()
576
+ new_inputs, _ = normalize_shapes(
577
+ *maybe_to_dense(*reorder_and_filter(input_nodes, layout))
578
+ )
579
+ m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1])
580
+ output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
581
+ new_inputs[0].get_dtype()
582
+ )
583
+ micro_gemm = create_micro_gemm(
584
+ "micro_gemm",
585
+ m,
586
+ n,
587
+ k,
588
+ input_dtype=new_inputs[0].get_dtype(),
589
+ input2_dtype=new_inputs[1].get_dtype(),
590
+ output_dtype=output_dtype,
591
+ compute_dtype=compute_dtype,
592
+ alpha=alpha,
593
+ num_threads=num_threads,
594
+ )
595
+ assert micro_gemm is not None
596
+ _, block_n, _ = micro_gemm.register_blocking
597
+ padded_n = get_padded_n(n, block_n)
598
+
599
+ def pack_weight(inputs, layout_or_out):
600
+ W = inputs[1]
601
+ new_inputs = list(inputs)
602
+ blocked_w: Union[ir.IRNode, torch.Tensor] = W
603
+ if isinstance(W, ir.IRNode):
604
+ new_size = [padded_n // block_n, k, block_n]
605
+ blocked_w = ir.Buffer(
606
+ W.get_name(), # Borrow the registered buffer name
607
+ ir.FixedLayout(
608
+ W.get_device(),
609
+ W.get_dtype(),
610
+ new_size,
611
+ ir.FlexibleLayout.contiguous_strides(new_size),
612
+ 0,
613
+ ),
614
+ )
615
+ else:
616
+ blocked_w = (
617
+ torch.nn.functional.pad(W, (0, padded_n - n))
618
+ .reshape(k, padded_n // block_n, block_n)
619
+ .transpose(0, 1)
620
+ .contiguous()
621
+ )
622
+ if micro_gemm.get_b_layout() != LayoutType.NORMAL:
623
+ layout_str = (
624
+ "VNNI4"
625
+ if micro_gemm.get_b_layout() == LayoutType.VNNI4
626
+ else "VNNI2"
627
+ )
628
+ assert micro_gemm.get_b_layout() in [
629
+ LayoutType.VNNI2,
630
+ LayoutType.VNNI4,
631
+ ], f"We only support {layout_str} for now"
632
+ vnni_size = (
633
+ 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2
634
+ )
635
+ assert (
636
+ k % vnni_size == 0
637
+ ), f"k should be divisible by vnni_size for {layout_str} layout"
638
+ blocked_w = (
639
+ blocked_w.view(
640
+ padded_n // block_n, k // vnni_size, vnni_size, block_n
641
+ )
642
+ .transpose(-1, -2)
643
+ .contiguous()
644
+ .view(padded_n // block_n, k, block_n)
645
+ )
646
+ # normalize stride to be "contiguous_strides" per size
647
+ # this avoids the problems in L.view during template codegen
648
+ new_stride = [1]
649
+ for sz in reversed(blocked_w.shape[1:]):
650
+ new_stride.insert(0, new_stride[0] * sz)
651
+ blocked_w = blocked_w.as_strided(blocked_w.shape, new_stride)
652
+ new_inputs[1] = blocked_w
653
+
654
+ def _is_int8_gemm(inputs):
655
+ return (
656
+ isinstance(inputs[0], ir.IRNode)
657
+ and inputs[0].get_dtype() == torch.uint8
658
+ ) or (
659
+ isinstance(inputs[0], torch.Tensor)
660
+ and inputs[0].dtype == torch.uint8
661
+ )
662
+
663
+ if _is_int8_gemm(new_inputs):
664
+ BCompensate = None
665
+ if isinstance(W, ir.IRNode):
666
+ BCompensate = V.graph.add_tensor_constant(
667
+ V.graph.constants[W.get_name() + "_BMatrixCompens"],
668
+ W.get_name() + "_BMatrixCompens",
669
+ )
670
+ else:
671
+ BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0) # type: ignore[assignment]
672
+ new_inputs.append(BCompensate)
673
+ return new_inputs, layout_or_out
674
+
675
+ def preprocessor(inputs, layout):
676
+ return pack_weight(
677
+ *normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout)))
678
+ )
679
+
680
+ def postprocessor(output):
681
+ if isinstance(output, ir.TensorBox):
682
+ # prepack the weight as input to the template buffer
683
+ template_buffer = ir.InputsKernel.unwrap_storage_for_input(output)
684
+ assert isinstance(template_buffer, ir.CppTemplateBuffer)
685
+ new_input_nodes, _ = reorder_and_filter(input_nodes, layout)
686
+
687
+ W_node = new_input_nodes[1]
688
+ assert W_node.get_name() in V.graph.constants
689
+ W = V.graph.constants[W_node.get_name()]
690
+ new_input_nodes[1] = W
691
+ new_input_nodes, _ = pack_weight(
692
+ *normalize_shapes(*maybe_to_dense(new_input_nodes, layout))
693
+ )
694
+
695
+ # By using the new packed weight for the GEMM template, we can prune the
696
+ # old weight if it has no other users. This saves memory but makes the FX graph
697
+ # non-retraceable. To support retracing, we can add a repack node to the
698
+ # FX graph. For example:
699
+ # mkldnn._linear_pointwise <- repack_linear_wgt <- packed_wgt_for_template
700
+ W_tensor_users = 0
701
+ for node in reversed(V.graph.graph.nodes):
702
+ # Case may happen when the wgt tensor is used by more than 1 get_attr node
703
+ # https://github.com/pytorch/pytorch/issues/134998
704
+ if node.op == "get_attr" and hasattr(
705
+ V.graph.module, node.name
706
+ ): # wgt might already be deleted
707
+ comp_tensor = getattr(V.graph.module, node.name)
708
+ if (
709
+ W.is_mkldnn == comp_tensor.is_mkldnn
710
+ and W.dtype == comp_tensor.dtype
711
+ and W.device == comp_tensor.device
712
+ and (
713
+ (
714
+ not W.is_mkldnn
715
+ and (
716
+ W.untyped_storage().data_ptr()
717
+ == comp_tensor.untyped_storage().data_ptr()
718
+ )
719
+ )
720
+ or (
721
+ W.is_mkldnn
722
+ and (
723
+ torch.ops.mkldnn.data_ptr(W)
724
+ == torch.ops.mkldnn.data_ptr(comp_tensor)
725
+ )
726
+ )
727
+ )
728
+ ):
729
+ W_tensor_users += 1
730
+
731
+ for node in reversed(V.graph.graph.nodes):
732
+ # The wgt tensor has been used by only 1 get_attr node
733
+ # The get_attr node has only 1 user fx node
734
+ if (
735
+ node.name == W_node.get_name()
736
+ and len(node.users) == 1
737
+ and W_tensor_users == 1
738
+ ):
739
+ del V.graph.constants[node.name]
740
+ delattr(V.graph.module, node.name)
741
+ delattr(V.graph.graph.owning_module, node.name)
742
+
743
+ W_packed = new_input_nodes[1]
744
+ W_packed_constant = V.graph.add_tensor_constant(W_packed)
745
+ template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input(
746
+ W_packed_constant
747
+ )
748
+ return output
749
+
750
+ template = DataProcessorTemplateWrapper(
751
+ CppPackedGemmTemplate,
752
+ preprocessor,
753
+ postprocessor,
754
+ input_nodes=input_nodes,
755
+ layout=layout,
756
+ num_threads=num_threads,
757
+ register_blocking=micro_gemm.register_blocking,
758
+ beta=beta,
759
+ alpha=alpha,
760
+ has_bias=has_bias,
761
+ epilogue_creator=epilogue_creator,
762
+ )
763
+ template.maybe_append_choice(choices)
764
+ return template
765
+
766
+ def render( # type: ignore[override,return]
767
+ self,
768
+ kernel: CppTemplateKernel,
769
+ template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
770
+ flag_template_buffer_has_other_users: Optional[bool] = None,
771
+ epilogue_nodes: Optional[List[ir.IRNode]] = None,
772
+ **kwargs,
773
+ ) -> str:
774
+ assert len(self.input_nodes) >= 2
775
+
776
+ int8_gemm = self.input_nodes[0].get_dtype() == torch.uint8
777
+ x_scale = None
778
+ x_zp = None
779
+ w_scale = None
780
+ w_zp = None
781
+ if int8_gemm:
782
+ X, W = self.input_nodes[0], self.input_nodes[1]
783
+ bias_idx = 2 if self.has_bias else 1
784
+ inp = self.input_nodes[bias_idx] if self.has_bias else None
785
+ x_scale = self.input_nodes[bias_idx + 1]
786
+ x_zp = self.input_nodes[bias_idx + 2]
787
+ w_scale = self.input_nodes[bias_idx + 3]
788
+ w_zp = self.input_nodes[bias_idx + 4]
789
+ Y = self.output_node
790
+ else:
791
+ X, W = self.input_nodes[0], self.input_nodes[1]
792
+ Y = self.output_node
793
+ inp = self.input_nodes[2] if self.has_bias else None
794
+
795
+ template_buffer_has_other_users = None
796
+
797
+ if template_buffer_node is not None:
798
+ # Use the updated prepacked weight buffer
799
+ W = template_buffer_node.inputs[1]
800
+ Y = template_buffer_node
801
+
802
+ assert flag_template_buffer_has_other_users is not None
803
+ template_buffer_has_other_users = flag_template_buffer_has_other_users
804
+
805
+ template_buffer = Y
806
+ gemm_output_buffer = template_buffer
807
+
808
+ epilogues: List[ir.IRNode] = []
809
+ reindexers: List[Optional[Callable[[List[Any]], List[Any]]]] = []
810
+ epilogue_creators: List[Callable[[ir.Buffer], ir.Pointwise]] = []
811
+ fake_buffers: List[ir.Buffer] = []
812
+ Y_aliases: Set[str] = set()
813
+
814
+ use_local_acc = (
815
+ self.layout.dtype != torch.float
816
+ or template_buffer_has_other_users
817
+ or int8_gemm
818
+ or self.padded_n != self.n
819
+ or self.maybe_k_slicing()
820
+ )
821
+
822
+ # TODO(jgong5): for int8 gemm, bias-add is handled outside of gemm template,
823
+ # but we'd better move it here to align with fp.
824
+ if inp is not None and self.beta != 0 and not int8_gemm:
825
+ # add an epilogue for bias add
826
+ def _bias_add_epilogue(buf):
827
+ return create_epilogue_with_attr(
828
+ buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype
829
+ )
830
+
831
+ epilogue_creators.append(_bias_add_epilogue)
832
+
833
+ if self.epilogue_creator is not None:
834
+ epilogue_creators.append(self.epilogue_creator)
835
+
836
+ # When the GEMM output buffer is localized but it has users other than the epilogue nodes,
837
+ # we need to copy the value in the GEMM output local buffer to a global buffer.
838
+ def need_copy_from_local_to_global_buffer_epilogue(
839
+ use_local_acc, template_buffer_has_other_users, epilogue_creators
840
+ ):
841
+ # The GEMM output buffer is a global buffer, thus copy is not needed.
842
+ if not use_local_acc:
843
+ return False
844
+
845
+ # The possible value of template_buffer_has_other_users is (None, False, True)
846
+ # It is None when generating the gemm template during autotune and it will have value during scheduler codegen.
847
+ # extra copy_from_local_to_global_buffer_epilogue is not needed in either of the below two cases:
848
+ # 1. template_buffer_has_other_users is None (i.e. when doing the codegen during autotune)
849
+ # 2. template_buffer_has_other_users is False, which means it's safe to keep the value in the
850
+ # GEMM output buffer in local buffer only (no users outside of the epilogues will use its value).
851
+ if not template_buffer_has_other_users:
852
+ return False
853
+
854
+ # When bias is not None or self.epilogue_creator is not None,
855
+ # there will be epilogue_creators after the GEMM.
856
+ # The GEMM output buffer is localized while
857
+ # the output buffer of the epilogue_creators is a global buffer.
858
+ if epilogue_creators:
859
+ return False
860
+
861
+ return True
862
+
863
+ if need_copy_from_local_to_global_buffer_epilogue(
864
+ use_local_acc, template_buffer_has_other_users, epilogue_creators
865
+ ):
866
+
867
+ def copy_from_local_to_global_buffer_epilogue(input_buffer: ir.Buffer):
868
+ dtype = self.layout.dtype
869
+ input_loader = input_buffer.make_loader()
870
+
871
+ def copy_inner(index):
872
+ input = input_loader(index)
873
+ result = ops.to_dtype(input, dtype)
874
+ return result
875
+
876
+ return ir.Pointwise(
877
+ device=input_buffer.get_device(),
878
+ dtype=self.layout.dtype,
879
+ inner_fn=copy_inner,
880
+ ranges=input_buffer.get_size(),
881
+ )
882
+
883
+ epilogue_creators.append(copy_from_local_to_global_buffer_epilogue)
884
+
885
+ # NOTE [How CPP GEMM template epilogues are organized]
886
+ # gemm_output_buffer
887
+ # --> zero or more in-template epilogues (created by `epilogue_creators`) -->
888
+ # template_buffer
889
+ # --> zero or more out-of-template epilogues (`epilogue_nodes`) -->
890
+ # Y
891
+ if epilogue_creators:
892
+ gemm_output_name = "buf_GemmOut"
893
+ gemm_output_buffer = ir.Buffer(gemm_output_name, template_buffer.layout)
894
+ current_input_buffer = gemm_output_buffer
895
+ for i, creator in enumerate(epilogue_creators):
896
+ if i == len(epilogue_creators) - 1:
897
+ buffer_name = template_buffer.get_name()
898
+ else:
899
+ buffer_name = f"buf_GemmOut_epilogue_{i}"
900
+ epilogues.append(
901
+ ir.ComputedBuffer(
902
+ name=buffer_name,
903
+ layout=template_buffer.layout,
904
+ data=creator(current_input_buffer),
905
+ )
906
+ )
907
+ fake_buffers.append(current_input_buffer)
908
+ Y_aliases.add(current_input_buffer.get_name())
909
+ reindexers.append(None)
910
+ if i < len(epilogue_creators) - 1:
911
+ current_input_buffer = ir.Buffer(
912
+ buffer_name, template_buffer.layout
913
+ )
914
+
915
+ Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y
916
+
917
+ if epilogue_nodes:
918
+ epilogues.extend(epilogue_nodes)
919
+ assert Y.get_numel() == epilogues[-1].get_numel()
920
+ Y = cast(ir.Buffer, epilogues[-1])
921
+
922
+ if not template_buffer_has_other_users:
923
+ Y_aliases.add(template_buffer.get_name())
924
+
925
+ if (
926
+ Y.get_size() == template_buffer.get_size()
927
+ and Y.get_stride() == template_buffer.get_stride()
928
+ ):
929
+ reindexers.extend([None] * len(epilogue_nodes))
930
+ Y_2d = Y
931
+ else:
932
+
933
+ def get_reindexer(epilogue_node):
934
+ # From template_buffer to epilogue_node_ordered (ordered by stride decreasingly, in dense format), for example:
935
+ # template_buffer:
936
+ # size (324, 512), stride (512, 1)
937
+ # epilogue_node_ordered (ordered by stride decreasingly, in dense format):
938
+ # size (1, 18, 18, 512), stride (165888, 9216, 512, 1)
939
+ stride_order = list(
940
+ ir.get_stride_order(
941
+ V.graph.sizevars.size_hints(epilogue_node.get_stride())
942
+ )
943
+ )
944
+ fill_order = ir.stride_order2fill_order(stride_order)
945
+ reversed_fill_order = list(reversed(fill_order))
946
+ size_with_stride_ordered_decreasingly = [
947
+ epilogue_node.get_size()[i] for i in reversed_fill_order
948
+ ]
949
+ reshape_reindex = ir.View.dynamic_reshape_indexer(
950
+ size_with_stride_ordered_decreasingly,
951
+ template_buffer.get_size(),
952
+ )
953
+
954
+ # From epilogue_node_ordered (ordered by stride decreasingly, in dense format) to epilogue_node, for example:
955
+ # epilogue_node_ordered (ordered by stride decreasingly, in dense format):
956
+ # size (1, 18, 18, 512), stride (165888, 9216, 512, 1)
957
+ # epilogue_node:
958
+ # size (1, 18, 18, 512), stride (165888, 1, 9216, 512)
959
+ from_stride_ordered_decreasingly_to_epilogue_node_order = [
960
+ (len(stride_order) - 1) - stride_order[i]
961
+ for i in range(len(stride_order))
962
+ ]
963
+ stride_reindex = ir.same_reorder(
964
+ from_stride_ordered_decreasingly_to_epilogue_node_order
965
+ )
966
+
967
+ reindexer = ir.fuse_reindexing(stride_reindex, reshape_reindex)
968
+ return reindexer
969
+
970
+ reindexers.extend([get_reindexer(epilogue_node) for epilogue_node in epilogue_nodes]) # type: ignore[list-item]
971
+ if isinstance(Y, ir.BaseView):
972
+ storage = ir.StorageBox(Y.unwrap_view())
973
+ else:
974
+ assert isinstance(Y, ir.Buffer)
975
+ storage = ir.StorageBox(Y)
976
+ Y_2d = ir.ReinterpretView(storage, template_buffer.get_layout())
977
+
978
+ output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
979
+ X.get_dtype()
980
+ )
981
+ micro_gemm = create_micro_gemm(
982
+ f"{kernel.kernel_name}_micro_gemm",
983
+ self.m,
984
+ self.n,
985
+ self.k,
986
+ input_dtype=X.get_dtype(),
987
+ input2_dtype=W.get_dtype(),
988
+ output_dtype=output_dtype,
989
+ compute_dtype=compute_dtype,
990
+ alpha=self.alpha,
991
+ num_threads=self.num_threads,
992
+ )
993
+ assert micro_gemm is not None
994
+ assert self.register_blocking == micro_gemm.register_blocking
995
+ self.log_blockings()
996
+ if isinstance(micro_gemm, CppMicroGemmAMX):
997
+ counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1
998
+
999
+ L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes
1000
+ assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}"
1001
+
1002
+ L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes
1003
+ assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}"
1004
+
1005
+ options = dict(
1006
+ X=X,
1007
+ W=W,
1008
+ inp=inp,
1009
+ Y=Y,
1010
+ N=self.n,
1011
+ K=self.k,
1012
+ PADDED_N=self.padded_n,
1013
+ GemmOut=gemm_output_buffer,
1014
+ aliases={alias: Y.get_name() for alias in Y_aliases},
1015
+ beta=self.beta,
1016
+ alpha=self.alpha,
1017
+ num_threads=self.num_threads,
1018
+ micro_gemm=micro_gemm,
1019
+ is_dynamic_M=self.is_dynamic_M,
1020
+ template=self,
1021
+ kernel=kernel,
1022
+ export_declaration=get_export_declaration(),
1023
+ epilogue_nodes=epilogues,
1024
+ reindexers=reindexers,
1025
+ Y_2d=Y_2d,
1026
+ use_local_acc=use_local_acc,
1027
+ maybe_k_slicing=self.maybe_k_slicing(),
1028
+ x_scale=x_scale,
1029
+ x_zp=x_zp,
1030
+ w_scale=w_scale,
1031
+ w_zp=w_zp,
1032
+ acc_buf_dtype=torch.int32 if int8_gemm else torch.float,
1033
+ DTYPE_TO_CPP=DTYPE_TO_CPP,
1034
+ L1_cache_size=L1_cache_size,
1035
+ L2_cache_size=L2_cache_size,
1036
+ config=config,
1037
+ )
1038
+ with contextlib.ExitStack() as stack:
1039
+ for buf in fake_buffers:
1040
+ stack.enter_context(
1041
+ patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
1042
+ )
1043
+ return self._template_from_string(GEMM_TEMPLATE).render(**options)
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py ADDED
@@ -0,0 +1,850 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import dataclasses
3
+ import sys
4
+ from enum import Enum
5
+ from typing import Callable, Dict, List, Optional, Type
6
+
7
+ import sympy
8
+
9
+ import torch
10
+
11
+ from .. import ir
12
+ from ..cpu_vec_isa import pick_vec_isa, VecAMX, VecAVX2, VecAVX512, VecISA
13
+ from ..utils import IndentedBuffer, parallel_num_threads
14
+ from ..virtualized import V
15
+ from .common import KernelTemplate
16
+ from .cpp_template_kernel import CppTemplateKernel
17
+ from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp
18
+
19
+
20
+ class LayoutType(Enum):
21
+ NORMAL = 0
22
+ VNNI2 = 1
23
+ VNNI4 = 2
24
+
25
+
26
+ _IS_WINDOWS = sys.platform == "win32"
27
+
28
+
29
+ def get_restrict_keyword() -> str:
30
+ if _IS_WINDOWS:
31
+ # https://learn.microsoft.com/en-us/cpp/cpp/extension-restrict?view=msvc-170
32
+ return "__restrict"
33
+ else:
34
+ return "__restrict__"
35
+
36
+
37
+ class CppMicroGemm:
38
+ """
39
+ A class that codegens a kernel that computes small-sized matrix multiplication.
40
+
41
+ A micro GEMM kernel is responsible for register blocking, instruction selection,
42
+ and other CPU architecture-specific optimizations.
43
+
44
+ The subclasses need to override `codegen_define` to define the kernel function
45
+ that is called by the code generated by `codegen_call`.
46
+ """
47
+
48
+ # TODO(jgong5): support constant shapes and lds as template args.
49
+ DECLARE_KERNEL = r"""
50
+ template <bool accum>
51
+ inline void {{kernel_name}}(
52
+ {%- if kernel_extra_args_declare %}
53
+ {{kernel_extra_args_declare}}
54
+ {%- endif %}
55
+ const {{input_t}}* {{restrict_keyword}} A,
56
+ const {{input2_t}}* {{restrict_keyword}} B,
57
+ {{output_t}}* {{restrict_keyword}} C,
58
+ int64_t M,
59
+ int64_t N,
60
+ int64_t K,
61
+ int64_t lda,
62
+ int64_t ldb,
63
+ int64_t ldc
64
+ )
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ name,
70
+ input_dtype,
71
+ input2_dtype,
72
+ output_dtype,
73
+ compute_dtype,
74
+ register_blocking,
75
+ alpha=1,
76
+ ) -> None:
77
+ self.name = name
78
+ self.input_dtype = input_dtype
79
+ assert input2_dtype is not None
80
+ self.input2_dtype = input2_dtype
81
+ self.output_dtype = output_dtype
82
+ self.compute_dtype = compute_dtype
83
+ self.register_blocking = register_blocking
84
+ self.alpha = alpha
85
+
86
+ def get_common_options(self):
87
+ if self.input_dtype == torch.uint8:
88
+ assert self.compute_dtype == torch.int32
89
+ assert self.output_dtype == torch.int32
90
+ assert self.input2_dtype == torch.int8
91
+ return {
92
+ "torch": torch,
93
+ "kernel_name": self.name,
94
+ "input_dtype": self.input_dtype,
95
+ "input2_dtype": self.input2_dtype,
96
+ "output_dtype": self.output_dtype,
97
+ "compute_dtype": self.compute_dtype,
98
+ "input_t": DTYPE_TO_CPP[self.input_dtype],
99
+ "input2_t": DTYPE_TO_CPP[self.input2_dtype],
100
+ "output_t": DTYPE_TO_CPP[self.output_dtype],
101
+ "compute_t": DTYPE_TO_CPP[self.compute_dtype],
102
+ "alpha": self.alpha,
103
+ "kernel_extra_args_declare": self.get_kernel_extra_args_declare(),
104
+ "int8_gemm": self.input_dtype == torch.uint8,
105
+ "vnni_size": 4 if self.input_dtype == torch.uint8 else 2,
106
+ "restrict_keyword": get_restrict_keyword(),
107
+ }
108
+
109
+ def get_kernel_declaration(self):
110
+ options = self.get_common_options()
111
+ return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options)
112
+
113
+ def get_kernel_extra_args_declare(self) -> str:
114
+ return ""
115
+
116
+ def get_kernel_extra_args(self) -> str:
117
+ return ""
118
+
119
+ def codegen_define(self, kernel: CppTemplateKernel) -> str:
120
+ raise NotImplementedError
121
+
122
+ def codegen_call(
123
+ self,
124
+ kernel: CppTemplateKernel,
125
+ A: ir.Buffer,
126
+ B: ir.Buffer,
127
+ C: ir.Buffer,
128
+ accum: bool,
129
+ ) -> str:
130
+ """
131
+ Generate the code for calling the templated kernel that computes
132
+ `C += alpha * A @ B` if `accum` is True, or `C = alpha * A @ B` otherwise.
133
+ """
134
+ A_ptr = f"&({kernel.index(A, [0, 0])})"
135
+ B_ptr = f"&({kernel.index(B, [0, 0])})"
136
+ C_ptr = f"&({kernel.index(C, [0, 0])})"
137
+ M = kernel.size(C, 0)
138
+ N = kernel.size(C, 1)
139
+ K = kernel.size(A, 1)
140
+ lda = kernel.stride(A, 0)
141
+ ldb = kernel.stride(B, 0)
142
+ ldc = kernel.stride(C, 0)
143
+ res = IndentedBuffer()
144
+ res.writeline(f"{self.name}<{value_to_cpp(accum, 'bool')}>(")
145
+ with res.indent():
146
+ extra_args = self.get_kernel_extra_args()
147
+ if extra_args:
148
+ res.writeline(extra_args)
149
+ res.writeline(f"{A_ptr},")
150
+ res.writeline(f"{B_ptr},")
151
+ res.writeline(f"{C_ptr},")
152
+ res.writeline(f"{M},")
153
+ res.writeline(f"{N},")
154
+ res.writeline(f"{K},")
155
+ res.writeline(f"{lda},")
156
+ res.writeline(f"{ldb},")
157
+ res.writeline(f"{ldc}")
158
+ res.writeline(");")
159
+ return res.getvalue()
160
+
161
+ def codegen_init(
162
+ self,
163
+ kernel: CppTemplateKernel,
164
+ ) -> str:
165
+ return ""
166
+
167
+ def codegen_finalize(
168
+ self,
169
+ kernel: CppTemplateKernel,
170
+ ) -> str:
171
+ return ""
172
+
173
+ def get_b_layout(self) -> LayoutType:
174
+ return LayoutType.NORMAL
175
+
176
+
177
+ @dataclasses.dataclass
178
+ class CppMicroGemmConfig:
179
+ input_dtype: torch.dtype
180
+ input2_dtype: torch.dtype
181
+ output_dtype: torch.dtype
182
+ compute_dtype: torch.dtype
183
+ vec_isa_cls: Type[VecISA]
184
+ register_blocking: GemmBlocking
185
+ extra_check: Optional[Callable[..., bool]] = None
186
+
187
+
188
+ micro_gemm_configs: Dict[Type[CppMicroGemm], List[CppMicroGemmConfig]] = {}
189
+
190
+
191
+ def register_micro_gemm(*configs):
192
+ def inner(cls):
193
+ assert (
194
+ cls not in micro_gemm_configs
195
+ ), f"Duplicate micro_gemm registration for {cls}"
196
+ assert len(configs) > 0, f"No micro_gemm configs provided for {cls}"
197
+ micro_gemm_configs[cls] = list(configs)
198
+ return cls
199
+
200
+ return inner
201
+
202
+
203
+ def generate_gemm_config(
204
+ vec_isa_cls,
205
+ register_blockings,
206
+ input_dtype=torch.float,
207
+ input2_dtype=None,
208
+ output_dtype=None,
209
+ compute_dtype=None,
210
+ extra_check=None,
211
+ ):
212
+ if output_dtype is None:
213
+ output_dtype = input_dtype
214
+ if compute_dtype is None:
215
+ compute_dtype = output_dtype
216
+ if input2_dtype is None:
217
+ input2_dtype = input_dtype
218
+ return [
219
+ CppMicroGemmConfig(
220
+ input_dtype,
221
+ input2_dtype,
222
+ output_dtype,
223
+ compute_dtype,
224
+ vec_isa_cls,
225
+ GemmBlocking(*blocking),
226
+ extra_check,
227
+ )
228
+ for blocking in register_blockings
229
+ ]
230
+
231
+
232
+ class CppMicroGemmRef(CppMicroGemm):
233
+ """
234
+ A reference implementation of the CppMicroGemm class with naive C++ code.
235
+ It is used for correctness debugging.
236
+ """
237
+
238
+ TEMPLATE_ENTRY = r"""
239
+ {{declare_kernel}} {
240
+ for (int64_t m = 0; m < M; ++m) {
241
+ for (int64_t n = 0; n < N; ++n) {
242
+ {{compute_t}} result = accum ? C[m * ldc + n] : 0;
243
+ for (int64_t k = 0; k < K; ++k) {
244
+ result += ({{compute_t}})A[m * lda + k] * ({{compute_t}})B[k * ldb + n] * {{alpha}};
245
+ }
246
+ C[m * ldc + n] = result;
247
+ }
248
+ }
249
+ }
250
+ """
251
+
252
+ def __init__(
253
+ self, name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha
254
+ ) -> None:
255
+ super().__init__(
256
+ name,
257
+ input_dtype,
258
+ input2_dtype,
259
+ output_dtype,
260
+ compute_dtype,
261
+ GemmBlocking(1, 1, 1),
262
+ alpha,
263
+ )
264
+
265
+ def codegen_define(self, kernel: CppTemplateKernel) -> str:
266
+ options = {
267
+ "declare_kernel": self.get_kernel_declaration(),
268
+ **self.get_common_options(),
269
+ }
270
+ return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options)
271
+
272
+
273
+ @register_micro_gemm(
274
+ *generate_gemm_config(
275
+ VecAVX512,
276
+ [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
277
+ input_dtype=torch.float,
278
+ ),
279
+ *generate_gemm_config(
280
+ VecAVX512,
281
+ [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
282
+ input_dtype=torch.bfloat16,
283
+ output_dtype=torch.float,
284
+ ),
285
+ *generate_gemm_config(
286
+ VecAVX512,
287
+ [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
288
+ input_dtype=torch.half,
289
+ output_dtype=torch.float,
290
+ ),
291
+ *generate_gemm_config(
292
+ VecAVX512,
293
+ [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
294
+ input_dtype=torch.bfloat16,
295
+ input2_dtype=torch.int8,
296
+ output_dtype=torch.float,
297
+ compute_dtype=torch.float,
298
+ ),
299
+ *generate_gemm_config(
300
+ VecAVX2,
301
+ [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
302
+ input_dtype=torch.float,
303
+ ),
304
+ *generate_gemm_config(
305
+ VecAVX2,
306
+ [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
307
+ input_dtype=torch.bfloat16,
308
+ output_dtype=torch.float,
309
+ ),
310
+ *generate_gemm_config(
311
+ VecAVX2,
312
+ [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
313
+ input_dtype=torch.half,
314
+ output_dtype=torch.float,
315
+ ),
316
+ *generate_gemm_config(
317
+ VecAVX2,
318
+ [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
319
+ input_dtype=torch.bfloat16,
320
+ input2_dtype=torch.int8,
321
+ output_dtype=torch.float,
322
+ compute_dtype=torch.float,
323
+ ),
324
+ )
325
+ class CppMicroGemmFP32Vec(CppMicroGemm):
326
+ """
327
+ This class generates the code for micro gemm using fp32 vec instructions for compute.
328
+ It supports input types of torch.float, torch.bfloat16, and torch.half with fp32 output.
329
+ The output of the microkernel is in FP32, but it would be converted to BF16/FP16 in the template,
330
+ if the desired output is BF16/FP16.
331
+ """
332
+
333
+ TEMPLATE_ENTRY = r"""
334
+ {{declare_kernel}} {
335
+ TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
336
+ TORCH_CHECK(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}");
337
+ // TODO(jgong5): loop unroll for M and N
338
+ for (int64_t m = 0; m < M; m += {{block_m}}) {
339
+ int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
340
+ for (int64_t n = 0; n < N; n += {{block_n}}) {
341
+ if (block_m == {{block_m}}) {
342
+ {{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>(
343
+ A + m * lda,
344
+ B + n,
345
+ C + m * ldc + n,
346
+ K,
347
+ lda,
348
+ ldb,
349
+ ldc
350
+ );
351
+ } else {
352
+ switch (block_m) {
353
+ {%- for b in range(block_m - 1, 0, -1) %}
354
+ case {{b}}:
355
+ {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>(
356
+ A + m * lda,
357
+ B + n,
358
+ C + m * ldc + n,
359
+ K,
360
+ lda,
361
+ ldb,
362
+ ldc
363
+ );
364
+ break;
365
+ {%- endfor %}
366
+ default:
367
+ {{kernel.assert_function}}(false, "Unsupported block_m: ", block_m);
368
+ }
369
+ }
370
+ }
371
+ }
372
+ }
373
+ """
374
+
375
+ TEMPLATE_KERNEL = r"""
376
+ template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum>
377
+ inline void {{kernel_name}}_kernel(
378
+ const {{input_t}}* {{restrict_keyword}} A,
379
+ const {{input2_t}}* {{restrict_keyword}} B,
380
+ {{output_t}}* {{restrict_keyword}} C,
381
+ int64_t K,
382
+ int64_t lda,
383
+ int64_t ldb,
384
+ int64_t ldc
385
+ ) {
386
+ using Vectorized = at::vec::Vectorized<{{compute_t}}>;
387
+ using VectorizedIn = at::vec::Vectorized<{{input_t}}>;
388
+ constexpr auto VLEN = Vectorized::size();
389
+ constexpr auto ROWS = BLOCK_M;
390
+ constexpr auto COLS = BLOCK_N / VLEN;
391
+
392
+ Vectorized va;
393
+ at::vec::VectorizedN<{{compute_t}}, COLS> vb;
394
+ at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc;
395
+
396
+ auto loadc = [&](auto i) {
397
+ if constexpr (accum) {
398
+ constexpr int row = i / COLS;
399
+ constexpr int col = i % COLS;
400
+ vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN);
401
+ } else {
402
+ vc[i] = Vectorized(0.0f);
403
+ }
404
+ };
405
+ c10::ForcedUnroll<ROWS * COLS>{}(loadc);
406
+
407
+ auto compute = [&, COLS](auto i, int k) {
408
+ constexpr int row = i / COLS;
409
+ constexpr int col = i % COLS;
410
+
411
+ if constexpr (col == 0) {
412
+ {%- if alpha != 1 %}
413
+ va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]) * {{alpha}});
414
+ {%- else %}
415
+ va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]));
416
+ {%- endif %}
417
+ }
418
+
419
+ if constexpr (row == 0) {
420
+ {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
421
+ auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN);
422
+ vb[col] = at::vec::convert<{{compute_t}}>(b);
423
+ {%- elif input2_dtype == torch.int8 %}
424
+ // Convert VLEN int8 elements to int32, and then fp32
425
+ auto b32 = at::vec::convert_to_int32<int8_t>(B + k * ldb + col * VLEN);
426
+ vb[col] = at::vec::convert<float>(b32);
427
+ {%- else %}
428
+ vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN);
429
+ {%- endif %}
430
+ }
431
+
432
+ constexpr int idx = row * COLS + col;
433
+ vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]);
434
+ };
435
+
436
+ for (int k = 0; k < K; ++k) {
437
+ c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
438
+ }
439
+
440
+ // store to C
441
+ auto storec = [&](auto i) {
442
+ constexpr int row = i / COLS;
443
+ constexpr int col = i % COLS;
444
+ vc[i].store(C + row * ldc + col * VLEN);
445
+ };
446
+ c10::ForcedUnroll<ROWS * COLS>{}(storec);
447
+ }
448
+ """
449
+
450
+ def codegen_define(self, kernel: CppTemplateKernel) -> str:
451
+ options = {
452
+ "declare_kernel": self.get_kernel_declaration(),
453
+ "kernel": kernel,
454
+ "block_m": self.register_blocking.block_m,
455
+ "block_n": self.register_blocking.block_n,
456
+ "block_k": self.register_blocking.block_k,
457
+ "restrict_keyword": get_restrict_keyword(),
458
+ **self.get_common_options(),
459
+ }
460
+ result = KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
461
+ options
462
+ )
463
+ result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
464
+ options
465
+ )
466
+ return result
467
+
468
+
469
+ # extra check for CppMicroGemmAMX
470
+ def check_amx_extra(config, m, n, k, alpha, num_threads):
471
+ vnni_size = 4 if config.input_dtype == torch.uint8 else 2
472
+ return k % vnni_size == 0 and alpha == 1
473
+
474
+
475
+ @register_micro_gemm(
476
+ *generate_gemm_config(
477
+ VecAMX,
478
+ [(32, 32, 32), (48, 16, 32), (16, 48, 32)],
479
+ input_dtype=torch.bfloat16,
480
+ input2_dtype=torch.int8,
481
+ output_dtype=torch.float,
482
+ compute_dtype=torch.float,
483
+ extra_check=check_amx_extra,
484
+ ),
485
+ *generate_gemm_config(
486
+ VecAMX,
487
+ [(32, 32, 32), (48, 16, 32), (16, 48, 32)],
488
+ input_dtype=torch.bfloat16,
489
+ output_dtype=torch.float,
490
+ extra_check=check_amx_extra,
491
+ ),
492
+ *generate_gemm_config(
493
+ VecAMX,
494
+ [(32, 32, 64), (48, 16, 64)],
495
+ input_dtype=torch.uint8,
496
+ input2_dtype=torch.int8,
497
+ output_dtype=torch.int32,
498
+ compute_dtype=torch.int32,
499
+ extra_check=check_amx_extra,
500
+ ),
501
+ )
502
+ class CppMicroGemmAMX(CppMicroGemm):
503
+ """
504
+ This class generates the code for micro gemm using Advanced Matrix eXtention (AMX)
505
+ instructions available in 4th generation Intel Xeon for compute.
506
+ It supports input types of torch.bfloat16 with fp32 output.
507
+ TODO(jgong5): support int8 data type.
508
+ """
509
+
510
+ TEMPLATE_ENTRY = r"""
511
+ {{declare_kernel}} {
512
+ TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
513
+ TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2");
514
+ // TODO(jgong5): loop unroll for M and N
515
+ for (int64_t m = 0; m < M; m += {{block_m}}) {
516
+ int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
517
+ int64_t m_tail = m;
518
+ for (int64_t n = 0; n < N; n += {{block_n}}) {
519
+ {%- for num_rows in range(block_m, 0, -16) %}
520
+ {%- if num_rows != block_m %}
521
+ else
522
+ {%- endif %}
523
+ if (block_m >= {{num_rows}}) {
524
+ {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}<accum>(
525
+ amx_state,
526
+ A + m * lda,
527
+ B + n,
528
+ C + m * ldc + n,
529
+ K,
530
+ lda,
531
+ ldb,
532
+ ldc,
533
+ 16
534
+ );
535
+ block_m -= {{num_rows}};
536
+ m_tail += {{num_rows}};
537
+ }
538
+ {%- endfor %}
539
+ if (block_m > 0) {
540
+ {{kernel_name}}_amx_kernel_16_{{num_columns}}<accum>(
541
+ amx_state,
542
+ A + m_tail * lda,
543
+ B + n,
544
+ C + m_tail * ldc + n,
545
+ K,
546
+ lda,
547
+ ldb,
548
+ ldc,
549
+ block_m
550
+ );
551
+ }
552
+ }
553
+ }
554
+ }
555
+ """
556
+
557
+ TEMPLATE_KERNEL = r"""
558
+ template <bool accum>
559
+ inline void {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}(
560
+ AMXState& amx_state,
561
+ const {{input_t}}* {{restrict_keyword}} A,
562
+ const {{input2_t}}* {{restrict_keyword}} B,
563
+ {{output_t}}* {{restrict_keyword}} C,
564
+ int64_t K,
565
+ int64_t lda,
566
+ int64_t ldb,
567
+ int64_t ldc,
568
+ uint8_t tilecfg_rows
569
+ ) {
570
+ // TODO(jgong5): add prefetch hint for A, B, C
571
+ auto loadconfig = [](const amx_tilecfg& cfg) {
572
+ _tile_loadconfig(&cfg);
573
+ };
574
+ const auto last_k_offset = K / {{block_k}} * {{block_k}};
575
+ const auto tail_k_size = K - last_k_offset;
576
+ if C10_LIKELY (last_k_offset > 0) {
577
+ amx_state.configure(tilecfg_rows, 64, {{num_rows}} / 16, {{num_columns}}, loadconfig);
578
+ } else {
579
+ amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig);
580
+ }
581
+ auto load_c = [&]() {
582
+ {%- for tile_row in range(num_rows // 16) %}
583
+ {%- for tile_col in range(num_columns) %}
584
+ {%- set tile_idx = tile_row * num_columns + tile_col %}
585
+ _tile_loadd({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}}));
586
+ {%- endfor %}
587
+ {%- endfor %}
588
+ };
589
+ auto zero_c = [&]() {
590
+ {%- for tile_row in range(num_rows // 16) %}
591
+ {%- for tile_col in range(num_columns) %}
592
+ {%- set tile_idx = tile_row * num_columns + tile_col %}
593
+ _tile_zero({{tile_idx}});
594
+ {%- endfor %}
595
+ {%- endfor %}
596
+ };
597
+
598
+ if constexpr (accum) {
599
+ load_c();
600
+ } else {
601
+ zero_c();
602
+ }
603
+
604
+ {%- if input_dtype == torch.bfloat16 and input2_dtype == torch.int8 %}
605
+ // create a buffer for tiles of B.
606
+ alignas(64) {{input_t}} bf16_weights_buf[512];
607
+
608
+ int num_b_rows = (last_k_offset > 0) ? 16 : (tail_k_size * sizeof({{input_t}})) / 4;
609
+ int b_tile_ptr_stride = ldb * {{vnni_size}};
610
+
611
+ auto load_B_row = [&]({{input2_t}}* src, {{input_t}}* dst) {
612
+ {{kernel.unroll_pragma(2)}}
613
+ for (int i = 0; i < 2; i++) {
614
+ // int8 -> int32 -> fp32 -> bf16
615
+ auto b32 = at::vec::convert_to_int32<int8_t>(src + i * 16);
616
+ auto b_bf16 = at::vec::convert<{{input_t}}>(b32);
617
+ b_bf16.store(dst + i * 16);
618
+ }
619
+ };
620
+
621
+ auto load_B_in_buf = [&]({{input2_t}}* B_ptr) {
622
+ {{kernel.unroll_pragma(8)}}
623
+ for (int i = 0; i < num_b_rows; i++) {
624
+ load_B_row(
625
+ B_ptr + i * b_tile_ptr_stride,
626
+ bf16_weights_buf + i * 32
627
+ );
628
+ }
629
+ };
630
+ {%- endif %}
631
+
632
+ auto compute = [&](int k) {
633
+ {%- set tile_offset_a = num_rows // 16 * num_columns %}
634
+ {%- set tile_offset_b = tile_offset_a + num_rows // 16 %}
635
+ {%- for tile_row in range(num_rows // 16) %}
636
+ {%- for tile_col in range(num_columns) %}
637
+ {%- set tile_idx_a = tile_offset_a + tile_row %}
638
+ {%- set tile_idx_b = tile_offset_b + tile_col %}
639
+ {%- set tile_idx_c = tile_row * num_columns + tile_col %}
640
+ {%- if tile_col == 0 %}
641
+ _tile_stream_loadd({{tile_idx_a}}, A + {{tile_row * 16}} * lda + k, lda * sizeof({{input_t}}));
642
+ {%- endif %}
643
+ {%- if tile_row == 0 %}
644
+ {%- if input_dtype == torch.bfloat16 and input2_dtype == torch.int8 %}
645
+ load_B_in_buf(const_cast<{{input2_t}}*>(B) + k * ldb + {{tile_col * 16 * vnni_size}});
646
+ _tile_loadd({{tile_idx_b}}, bf16_weights_buf, 64);
647
+ {%- else %}
648
+ _tile_loadd({{tile_idx_b}}, B + k * ldb + {{tile_col * 16 * vnni_size}}, ldb * {{vnni_size}} * sizeof({{input_t}}));
649
+ {%- endif %}
650
+ {%- endif %}
651
+ {%- if int8_gemm %}
652
+ _tile_dpbusd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
653
+ {%- else %}
654
+ _tile_dpbf16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
655
+ {%- endif %}
656
+ {%- endfor %}
657
+ {%- endfor %}
658
+ };
659
+
660
+ {{kernel.unroll_pragma(4)}}
661
+ for (int k = 0; k < last_k_offset; k += {{block_k}}) {
662
+ compute(k);
663
+ }
664
+
665
+ auto store_c = [&]() {
666
+ // store to C
667
+ {%- for tile_row in range(num_rows // 16) %}
668
+ {%- for tile_col in range(num_columns) %}
669
+ {%- set tile_idx = tile_row * num_columns + tile_col %}
670
+ _tile_stored({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}}));
671
+ {%- endfor %}
672
+ {%- endfor %}
673
+ };
674
+
675
+ // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
676
+ if C10_UNLIKELY (tail_k_size > 0) {
677
+ if C10_LIKELY (last_k_offset > 0) {
678
+ store_c();
679
+ amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig);
680
+ load_c();
681
+ }
682
+ compute(last_k_offset);
683
+ }
684
+
685
+ store_c();
686
+ }
687
+ """
688
+
689
+ def codegen_define(self, kernel: CppTemplateKernel) -> str:
690
+ block_m, block_n, block_k = self.register_blocking
691
+ assert block_m % 16 == 0, "Only support block_m % 16 == 0 for AMX"
692
+ assert block_n % 16 == 0, "Only support block_n % 16 == 0 for AMX"
693
+ if self.input_dtype == torch.uint8:
694
+ assert block_k == 64, "Only support block_k = 64 for AMX INT8"
695
+ else:
696
+ assert block_k == 32, "Only support block_k = 32 for AMX Bfloat16/Float16"
697
+ num_columns = block_n // 16
698
+ options = {
699
+ "declare_kernel": self.get_kernel_declaration(),
700
+ "kernel": kernel,
701
+ "block_m": block_m,
702
+ "block_n": block_n,
703
+ "block_k": block_k,
704
+ "num_columns": num_columns,
705
+ "restrict_keyword": get_restrict_keyword(),
706
+ **self.get_common_options(),
707
+ }
708
+ result = ""
709
+ for num_rows in range(block_m, 0, -16):
710
+ amx_kernel_options = {**options, "num_rows": num_rows}
711
+ result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
712
+ amx_kernel_options
713
+ )
714
+ result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
715
+ options
716
+ )
717
+ return result
718
+
719
+ def codegen_init(
720
+ self,
721
+ kernel: CppTemplateKernel,
722
+ ) -> str:
723
+ return "AMXState amx_state;"
724
+
725
+ def codegen_finalize(
726
+ self,
727
+ kernel: CppTemplateKernel,
728
+ ) -> str:
729
+ return "amx_state.release([]() { _tile_release(); });"
730
+
731
+ def get_kernel_extra_args_declare(self) -> str:
732
+ return "AMXState& amx_state,"
733
+
734
+ def get_kernel_extra_args(self) -> str:
735
+ return "amx_state,"
736
+
737
+ def get_b_layout(self):
738
+ if self.input_dtype == torch.uint8:
739
+ return LayoutType.VNNI4
740
+ else:
741
+ return LayoutType.VNNI2
742
+
743
+
744
+ def create_micro_gemm(
745
+ name,
746
+ m,
747
+ n,
748
+ k,
749
+ input_dtype,
750
+ input2_dtype,
751
+ output_dtype=None,
752
+ compute_dtype=None,
753
+ alpha=1,
754
+ num_threads=-1,
755
+ use_ref=True,
756
+ ) -> Optional[CppMicroGemm]:
757
+ def create_from_config(cls, config: CppMicroGemmConfig):
758
+ return cls(
759
+ name,
760
+ config.input_dtype,
761
+ config.input2_dtype,
762
+ config.output_dtype,
763
+ config.compute_dtype,
764
+ config.register_blocking,
765
+ alpha,
766
+ )
767
+
768
+ assert isinstance(n, int) or n.is_number, n
769
+ assert isinstance(k, int) or k.is_number, k
770
+ m = V.graph.sizevars.size_hint(m, fallback=1) if isinstance(m, sympy.Expr) else m
771
+ assert isinstance(m, int), m
772
+ if output_dtype is None:
773
+ output_dtype = input_dtype
774
+ if compute_dtype is None:
775
+ compute_dtype = output_dtype
776
+ if num_threads < 0:
777
+ num_threads = parallel_num_threads()
778
+ vec_isa = pick_vec_isa()
779
+ matched_configs = []
780
+ for cls, configs in micro_gemm_configs.items():
781
+ for config in configs:
782
+ if not issubclass(vec_isa.__class__, config.vec_isa_cls):
783
+ continue
784
+ if (
785
+ config.input_dtype == input_dtype
786
+ and config.compute_dtype == compute_dtype
787
+ and config.input2_dtype == input2_dtype
788
+ and config.output_dtype == output_dtype
789
+ # The output_dtype here is the output dtype of the micro-kernel.
790
+ # In some cases, the actual output dtype of the op for which the micro-kernel
791
+ # is being created would be same as that of the activation, but the micro-kernels
792
+ # compute output in Float/int32, which is converted in the GEMM template. This is
793
+ # subject to change in the future.
794
+ ):
795
+ if config.extra_check is not None and not config.extra_check(
796
+ config, m, n, k, alpha, num_threads
797
+ ):
798
+ continue
799
+ block_m, block_n, block_k = config.register_blocking
800
+ if (
801
+ config.vec_isa_cls == VecAMX
802
+ and m < block_m
803
+ and input_dtype == torch.bfloat16
804
+ and input2_dtype == torch.int8
805
+ ):
806
+ # For int8 WoQ GEMM, AMX micro-kernel may not perform well if m < block_m
807
+ continue
808
+ # Criteria on the ranking of configurations
809
+ # 1. ISA: AMX > VEC
810
+ # 2. Dividable by block sizes (block_m, block_n, block_k)
811
+ # 3. Number of mxn blocks is large enough to occupy all the threads
812
+ # 4. Register blocks are larger
813
+ isa_score = 0
814
+ if config.vec_isa_cls == VecAMX:
815
+ isa_score += 1
816
+ dividable_score = 0
817
+ if m % block_m == 0:
818
+ dividable_score += 1
819
+ if n % block_n == 0:
820
+ dividable_score += 1
821
+ if k % block_k == 0:
822
+ dividable_score += 1
823
+ occupancy_score = 0
824
+ n_blocks = (n + block_n - 1) // block_n
825
+ total_mxn_blocks = n_blocks * ((m + block_m - 1) // block_m)
826
+ if n_blocks >= num_threads:
827
+ occupancy_score += 1
828
+ if total_mxn_blocks >= num_threads:
829
+ occupancy_score += 1
830
+ register_bytes = (
831
+ block_m * block_n * config.compute_dtype.itemsize
832
+ + (block_m * block_k + block_k * block_n)
833
+ * config.input_dtype.itemsize
834
+ )
835
+ matched_configs.append(
836
+ (
837
+ (isa_score, dividable_score, occupancy_score, register_bytes),
838
+ cls,
839
+ config,
840
+ )
841
+ )
842
+ if len(matched_configs) == 0:
843
+ if use_ref:
844
+ return CppMicroGemmRef(
845
+ name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha
846
+ )
847
+ else:
848
+ return None
849
+ # TODO(jgong5): allow autotuning on choices of configs
850
+ return create_from_config(*max(matched_configs, key=lambda x: x[0])[1:])
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_template.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import ctypes
3
+ import functools
4
+ import itertools
5
+ import logging
6
+ import sys
7
+ from typing import Callable, List, Optional
8
+ from unittest.mock import patch
9
+
10
+ import sympy
11
+
12
+ from .. import codecache, config, ir
13
+ from ..autotune_process import CppBenchmarkRequest, TensorMeta
14
+ from ..utils import IndentedBuffer, Placeholder, unique
15
+ from ..virtualized import V
16
+ from .common import KernelTemplate
17
+ from .cpp_template_kernel import CppTemplateCaller, CppTemplateKernel
18
+
19
+
20
+ log = logging.getLogger(__name__)
21
+
22
+
23
+ class CppTemplate(KernelTemplate):
24
+ index_counter = itertools.count()
25
+
26
+ def __init__(
27
+ self,
28
+ name: str,
29
+ input_nodes,
30
+ layout: ir.Layout,
31
+ num_threads: int,
32
+ epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
33
+ ) -> None:
34
+ super().__init__(name)
35
+ self.input_nodes = input_nodes
36
+ self.output_node: ir.Buffer = ir.Buffer("buf_out", layout)
37
+ self.layout = layout
38
+ self.num_threads = num_threads
39
+ self.epilogue_creator = epilogue_creator
40
+
41
+ def generate(self, **kwargs):
42
+ kernel_name = f"cpp_{self.name}"
43
+ with patch.object(
44
+ V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
45
+ ), patch.object(ir.FlexibleLayout, "allow_indexing", True), CppTemplateKernel(
46
+ kernel_name=kernel_name, num_threads=self.num_threads
47
+ ) as kernel:
48
+ code = kernel.render(self, **kwargs)
49
+ _, call_args, _, _ = kernel.args.python_argdefs()
50
+ log.debug("Generated Code:\n%s", code)
51
+ log.debug(
52
+ "Args: cpp_argdefs: %s, python_argdefs: %s",
53
+ kernel.args.cpp_argdefs(),
54
+ kernel.args.python_argdefs(),
55
+ )
56
+
57
+ expected_args = list(
58
+ unique(input_node.get_name() for input_node in self.input_nodes)
59
+ )
60
+ expected_args.extend([self.output_node.get_name()])
61
+ assert list(call_args)[: len(expected_args)] == expected_args, (
62
+ call_args,
63
+ expected_args,
64
+ )
65
+ extra_args = V.graph.sizevars.size_hints(
66
+ map(sympy.expand, call_args[len(expected_args) :])
67
+ )
68
+ # Cast the size hint from int to ctypes.c_ulonglong explicitly
69
+ # since in cpp kernel, we bind it to C long
70
+ extra_args = tuple(ctypes.c_ulonglong(x) for x in extra_args)
71
+
72
+ kernel_hash_name = f"cpp_{self.name}_{next(self.index_counter)}"
73
+
74
+ # Create the BenchmarkRequest for CPP
75
+ bmreq = CppBenchmarkRequest(
76
+ kernel_name=kernel_name,
77
+ input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
78
+ output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
79
+ extra_args=extra_args,
80
+ source_code=code,
81
+ )
82
+
83
+ def make_kernel_render(
84
+ template_node: ir.CppTemplateBuffer,
85
+ flag_template_buffer_has_other_users: bool,
86
+ epilogue_nodes: Optional[List[ir.IRNode]] = None,
87
+ ):
88
+ kernel = CppTemplateKernel(
89
+ kernel_name=str(Placeholder.KERNEL_NAME), num_threads=self.num_threads
90
+ )
91
+ render = functools.partial(
92
+ kernel.render,
93
+ self,
94
+ template_buffer_node=template_node,
95
+ flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
96
+ epilogue_nodes=epilogue_nodes,
97
+ **kwargs,
98
+ )
99
+ return kernel, render
100
+
101
+ return CppTemplateCaller(
102
+ kernel_hash_name,
103
+ self.name,
104
+ self.input_nodes,
105
+ self.output_node.get_layout(),
106
+ make_kernel_render,
107
+ bmreq,
108
+ self,
109
+ )
110
+
111
+ def header(self) -> IndentedBuffer:
112
+ res = IndentedBuffer()
113
+ res.writeline(codecache.cpp_prefix())
114
+ res.splice(
115
+ """
116
+ #include "c10/util/Unroll.h"
117
+ """
118
+ )
119
+ enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
120
+ "linux",
121
+ "win32",
122
+ ]
123
+ if enable_kernel_profile:
124
+ res.writelines(["#include <ATen/record_function.h>"])
125
+ return res
126
+
127
+ def render(self, **kwargs) -> str:
128
+ raise NotImplementedError
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_template_kernel.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import itertools
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import sympy
6
+ from sympy.parsing.sympy_parser import parse_expr
7
+
8
+ import torch
9
+ from torch.utils._sympy.symbol import SymT
10
+
11
+ from .. import config, cpp_builder, ir, lowering as L
12
+ from ..autotune_process import CppBenchmarkRequest
13
+ from ..loop_body import LoopBody
14
+ from ..select_algorithm import PartialRender
15
+ from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix
16
+ from ..virtualized import V
17
+ from .common import CppWrapperKernelArgs
18
+ from .cpp import CppKernel, CppKernelProxy, KernelGroup
19
+ from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferContext
20
+ from .cpp_wrapper_cpu import CppWrapperCpu
21
+
22
+
23
+ def parse_expr_with_index_symbols(expr):
24
+ if isinstance(expr, sympy.Expr):
25
+ return expr
26
+ elif isinstance(expr, (list, tuple)):
27
+ return [parse_expr_with_index_symbols(e) for e in expr]
28
+ else:
29
+ expr = parse_expr(str(expr))
30
+ int_symbols = {sym: sympy_index_symbol(sym.name) for sym in expr.free_symbols}
31
+ return expr.subs(int_symbols)
32
+
33
+
34
+ def wrap_with_tensorbox(node) -> ir.TensorBox:
35
+ return (
36
+ ir.TensorBox.create(node) if isinstance(node, ir.Buffer) else ir.TensorBox(node)
37
+ )
38
+
39
+
40
+ class CppTemplateKernel(CppKernel):
41
+ def __init__(self, kernel_name, num_threads):
42
+ super().__init__(None, num_threads)
43
+ self.kernel_name = kernel_name
44
+ self.render_hooks = {}
45
+ self.local_buffers = {}
46
+ if isinstance(V.graph.wrapper_code, CppWrapperCpu):
47
+ self.args = CppWrapperKernelArgs()
48
+
49
+ def render(self, template, **kwargs):
50
+ return PartialRender(
51
+ template.render(kernel=self, **kwargs), self.render_hooks
52
+ ).finalize_all()
53
+
54
+ def def_kernel(
55
+ self,
56
+ inputs: Dict[str, ir.Buffer],
57
+ outputs: Dict[str, ir.Buffer],
58
+ aliases: Optional[Dict[str, str]] = None,
59
+ ) -> str:
60
+ for name, inp in inputs.items():
61
+ if inp is not None:
62
+ self.args.input_buffers[inp.get_name()] = name
63
+ for name, out in outputs.items():
64
+ self.args.output_buffers[out.get_name()] = name
65
+ if aliases is not None:
66
+ for alias, orig in aliases.items():
67
+ if orig in self.args.input_buffers:
68
+ self.args.input_buffers[alias] = self.args.input_buffers[orig]
69
+ if orig in self.args.output_buffers:
70
+ self.args.output_buffers[alias] = self.args.output_buffers[orig]
71
+
72
+ unique_sizevars = {
73
+ s
74
+ for input in inputs.values()
75
+ if input is not None
76
+ for sym in itertools.chain(input.get_size(), input.get_stride())
77
+ if isinstance(sym, sympy.Expr)
78
+ for s in sym.free_symbols
79
+ }
80
+ unique_sizevars |= {
81
+ s
82
+ for output in outputs.values()
83
+ for sym in itertools.chain(output.get_size(), output.get_stride())
84
+ if isinstance(sym, sympy.Expr)
85
+ for s in sym.free_symbols
86
+ }
87
+ sizevars = sorted(unique_sizevars, key=str)
88
+ for sizevar in sizevars:
89
+ self.args.sizevars[sizevar] = f"k{sizevar}"
90
+
91
+ def hook():
92
+ # remove all aliases before generate function definition
93
+ if aliases is not None:
94
+ for alias in aliases:
95
+ if alias in self.args.input_buffers:
96
+ self.args.input_buffers[alias] = "REMOVED"
97
+ if alias in self.args.output_buffers:
98
+ self.args.output_buffers[alias] = "REMOVED"
99
+ cpp_argdefs, _, _ = self.args.cpp_argdefs()
100
+ return f"void {self.kernel_name}({', '.join(cpp_argdefs)})"
101
+
102
+ placeholder = "<DEF_KERNEL>"
103
+ assert placeholder not in self.render_hooks
104
+ self.render_hooks[placeholder] = hook
105
+ return placeholder
106
+
107
+ def call_kernel(self, name: str, node: ir.CppTemplateBuffer):
108
+ wrapper = V.graph.wrapper_code
109
+ _, call_args, arg_types = self.args.cpp_argdefs()
110
+ wrapper.generate_kernel_call(name, call_args, cuda=False, arg_types=arg_types)
111
+
112
+ def dtype(self, node: ir.Buffer) -> str:
113
+ return DTYPE_TO_CPP[node.get_dtype()]
114
+
115
+ def acc_dtype(self, node: ir.Buffer) -> str:
116
+ if node.get_dtype() in [torch.float32, torch.bfloat16, torch.half]:
117
+ return "float"
118
+ else:
119
+ raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}")
120
+
121
+ def size(self, node: ir.Buffer, dim: int) -> str:
122
+ return cexpr_index(self.rename_indexing(node.get_size()[dim]))
123
+
124
+ def stride(self, node: ir.Buffer, dim: int) -> str:
125
+ return cexpr_index(self.rename_indexing(node.get_stride()[dim]))
126
+
127
+ def index(self, node: ir.Buffer, indices: List[Any]) -> str:
128
+ indexer = node.layout.as_fixed().make_indexer()
129
+ index = indexer(parse_expr_with_index_symbols(indices))
130
+ index = self.rename_indexing(index)
131
+ outer_name = node.get_name()
132
+ inner_name = (
133
+ outer_name
134
+ if outer_name in self.local_buffers
135
+ else self.args.input(node.get_name())
136
+ )
137
+ return f"{inner_name}[{cexpr_index(index)}]"
138
+
139
+ def slice_nd(self, node, ranges: List[Tuple[Any, Any]]) -> ir.ReinterpretView:
140
+ """
141
+ Slice the given node with a list of ranges (start and end) corresponding to its dims.
142
+ The dim is not sliced if the corresponding range is empty.
143
+ """
144
+ assert len(ranges) == len(node.get_size()), f"{ranges=}, {node=}"
145
+ sliced = wrap_with_tensorbox(node)
146
+ for dim, _range in enumerate(ranges):
147
+ if len(_range) == 0:
148
+ continue
149
+ assert len(_range) == 2
150
+ start, end = parse_expr_with_index_symbols(_range)
151
+ sliced = L.slice_(sliced, dim, start, end, clamp=False)
152
+ assert isinstance(sliced.data, ir.ReinterpretView), sliced.data
153
+ return sliced.data
154
+
155
+ def view(self, node, sizes: List[Any]) -> ir.View:
156
+ node = wrap_with_tensorbox(node)
157
+ sizes = parse_expr_with_index_symbols(sizes)
158
+ return L.view(node, sizes).data
159
+
160
+ def permute(self, node, dims):
161
+ node = wrap_with_tensorbox(node)
162
+ permuted = L.permute(node, dims).data
163
+ assert isinstance(permuted, ir.ReinterpretView)
164
+ return permuted
165
+
166
+ def maybe_codegen_profile(self) -> str:
167
+ if config.cpp.enable_kernel_profile:
168
+ graph_id = V.graph.graph_id
169
+ prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else ""
170
+ return f'RECORD_FUNCTION("{prefix}{self.kernel_name}", c10::ArrayRef<c10::IValue>({{}}));'
171
+ else:
172
+ return ""
173
+
174
+ def unroll_pragma(self, unroll):
175
+ if cpp_builder.is_gcc():
176
+ return f"#pragma GCC unroll {unroll}"
177
+ else:
178
+ return f"#pragma unroll {unroll}"
179
+
180
+ def define_buffer(self, name, sizes: List[Any], dtype=torch.float) -> str:
181
+ """Define kernel local buffer"""
182
+ sizes = parse_expr_with_index_symbols(sizes)
183
+ buf = ir.Buffer(name, ir.FixedLayout(torch.device("cpu"), dtype, sizes))
184
+ self.local_buffers[name] = buf
185
+ ctype = f"{DTYPE_TO_CPP[dtype]}"
186
+ numel = f"{cexpr_index(buf.get_numel())}"
187
+ return f"auto _{name} = std::make_unique<{ctype}[]>({numel}); auto {name} = _{name}.get();"
188
+
189
+ def reinit_buffer_if_null(self, name):
190
+ """Reinit the previously defined local buffer if it is null"""
191
+ assert name in self.local_buffers
192
+ buf = self.local_buffers[name]
193
+ ctype = f"{DTYPE_TO_CPP[buf.layout.dtype]}"
194
+ numel = f"{cexpr_index(buf.get_numel())}"
195
+ return f"if (_{name} == nullptr) {{ _{name} = std::make_unique<{ctype}[]>({numel}); {name} = _{name}.get(); }}"
196
+
197
+ def release_buffer(self, name):
198
+ """Codegen the code to release the ownership of a local buffer to others"""
199
+ assert name in self.local_buffers
200
+ return f"_{name}.release()"
201
+
202
+ def store_pointwise_nodes(
203
+ self,
204
+ dst: ir.Buffer,
205
+ nodes: List[ir.IRNode],
206
+ offsets: Optional[List[sympy.Expr]] = None,
207
+ reindexers: Optional[List[Optional[Callable[[List[Any]], List[Any]]]]] = None,
208
+ ) -> str:
209
+ var_sizes = (tuple(dst.get_size()), ())
210
+ var_ranges = {
211
+ sympy_index_symbol_with_prefix(SymT.INDEX, i): sz
212
+ for i, sz in enumerate(var_sizes[0])
213
+ }
214
+ if not offsets:
215
+ offsets = [sympy.Integer(0)] * len(var_sizes[0])
216
+ if not reindexers:
217
+ reindexers = [None] * len(nodes)
218
+ assert len(offsets) == len(var_sizes[0])
219
+ output_index = dst.get_layout().make_indexer()(var_ranges.keys())
220
+ kernel_group = KernelGroup()
221
+ kernel_group.args = self.args
222
+ cpp_kernel_proxy = CppKernelProxy(kernel_group)
223
+ bodies = []
224
+ var_sizes_list = []
225
+ for i, node in enumerate(nodes):
226
+ output_name = node.get_name() if i < len(nodes) - 1 else dst.get_name()
227
+ node = node.data if isinstance(node, ir.ComputedBuffer) else node
228
+ assert isinstance(node, ir.Pointwise), node
229
+
230
+ def fn(*args):
231
+ assert len(args) == 2
232
+ assert len(args[0]) == len(var_sizes[0])
233
+ assert len(args[1]) == 0
234
+ new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type]
235
+ if reindexers[i] is not None:
236
+ new_args = reindexers[i](new_args) # type: ignore[misc]
237
+ V.ops.store(
238
+ output_name,
239
+ output_index,
240
+ node.make_loader()(new_args).value,
241
+ )
242
+
243
+ body = LoopBody(
244
+ fn,
245
+ (list(var_ranges.keys()), ()),
246
+ var_ranges,
247
+ list(var_ranges.keys()),
248
+ tuple(),
249
+ )
250
+ bodies.append(body)
251
+ var_sizes_list.append(var_sizes)
252
+
253
+ cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list)
254
+ kernel_group.finalize_kernel(cpp_kernel_proxy, [])
255
+ return kernel_group.loops_code.getvalue()
256
+
257
+ def store_output(
258
+ self,
259
+ dst: ir.Buffer,
260
+ src: ir.Buffer,
261
+ orig_src: Optional[ir.Buffer] = None,
262
+ epilogue_nodes: Optional[List[ir.IRNode]] = None,
263
+ offsets: Optional[List[Any]] = None,
264
+ reindexers: Optional[List[Optional[Callable[[List[Any]], List[Any]]]]] = None,
265
+ ):
266
+ """
267
+ Store the `src` buffer to the `dst` buffer. The size of `src` and `dst` should match.
268
+ If `epilogue_nodes` is provided, the `src` buffer is firstly computed with the epilogues
269
+ before stored to `dst`. The `epilogues_nodes` are all pointwise.
270
+
271
+ Notes:
272
+ 1. `src` and `dst` buffer could be the same buffer in which case we are doing in-place compute
273
+ and stores. In case `epilogue_nodes` are not provided, we do nothing.
274
+ 2. The `epilogue_nodes`, if exist, have computations on `src` before storing to `dst` but since
275
+ they come form the original Inductor IR, they might need to be adjusted before working with
276
+ `src` and `dst` as outlined below:
277
+ a) `src` or `dst` buffer could be a sub-slice of the ranges the `epilogue_nodes`work on.
278
+ In this case, the `offsets` could be provided to adjust the indices passed to
279
+ `epilogue_nodes` during codegen and the data ranges are also configured according to
280
+ the sizes of `src` and `dst`.
281
+ b) `dst` might be indexed in a different way as the `epilogue_nodes`, hence a `reindexer` is
282
+ needed on the indices to `epilogue_nodes` to match the indexing of `dst`.
283
+ c) If `src` is local, we need to add a local buffer for it and localize the `orig_src` buffer
284
+ in `epilogue_nodes` with `src`.
285
+ """
286
+ assert dst.get_size() == src.get_size(), f"{dst=}, {src=}"
287
+ if offsets:
288
+ offsets = parse_expr_with_index_symbols(offsets)
289
+ if epilogue_nodes:
290
+ with LocalBufferContext(self.args) as scope:
291
+ assert orig_src is not None
292
+ if orig_src.get_name() != src.get_name():
293
+ scope.add_local_buffer(
294
+ src,
295
+ [
296
+ orig_src,
297
+ ],
298
+ )
299
+ epilogue_nodes = scope.localize_nodes(epilogue_nodes)
300
+ return self.store_pointwise_nodes(
301
+ dst, epilogue_nodes, offsets, reindexers # type: ignore[arg-type]
302
+ )
303
+ else:
304
+ if dst.get_name() != src.get_name():
305
+ # src is local
306
+ copy = L.copy(dst, src).data.data
307
+ with LocalBufferContext(self.args) as scope:
308
+ scope.add_local_buffer(src)
309
+ return self.store_pointwise_nodes(dst, [copy])
310
+ else:
311
+ assert dst.layout == src.layout, f"{dst=}, {src=}"
312
+ return ""
313
+
314
+
315
+ class CppTemplateCaller(ir.ChoiceCaller):
316
+ """
317
+ CppTemplateCaller
318
+
319
+ This class represents a caller for CPP template kernels. It is a subclass of ir.ChoiceCaller.
320
+ Attributes:
321
+ name (str): The name of the caller.
322
+ category (str): The category of the caller.
323
+ bmreq (CppBenchmarkRequest): The benchmark request for the caller.
324
+ template_buffer (ir.CppTemplateBuffer): The template buffer for the caller.
325
+ """
326
+
327
+ def __init__(
328
+ self,
329
+ name: str,
330
+ category: str,
331
+ input_nodes: List[ir.Buffer],
332
+ layout: ir.Layout,
333
+ make_kernel_render: Callable[
334
+ [
335
+ ir.CppTemplateBuffer,
336
+ bool,
337
+ Optional[List[ir.IRNode]],
338
+ ],
339
+ str,
340
+ ],
341
+ bmreq: CppBenchmarkRequest,
342
+ template: "CppTemplate", # type: ignore[name-defined] # noqa: F821
343
+ info_kwargs: Optional[
344
+ Dict[str, Union[ir.PrimitiveInfoType, List[ir.PrimitiveInfoType]]]
345
+ ] = None,
346
+ ):
347
+ super().__init__(name, input_nodes, layout)
348
+ self.category = category
349
+ self.make_kernel_render = make_kernel_render
350
+ self.bmreq = bmreq
351
+ self.template = template
352
+ self.info_kwargs = info_kwargs
353
+
354
+ def precompile(self) -> None:
355
+ assert self.bmreq is not None
356
+ self.bmreq.precompile()
357
+
358
+ def benchmark(self, *args, out) -> float:
359
+ assert self.bmreq is not None
360
+ return self.bmreq.benchmark(*args, output_tensor=out)
361
+
362
+ def hash_key(self) -> str:
363
+ return "-".join(
364
+ [
365
+ self.category,
366
+ self.bmreq.hash_key,
367
+ ]
368
+ )
369
+
370
+ def info_dict(
371
+ self,
372
+ ) -> Dict[str, Union[ir.PrimitiveInfoType, List[ir.PrimitiveInfoType]]]:
373
+ return {"backend": "CPP", "op_type": "unknown"}
374
+
375
+ def output_node(self) -> ir.TensorBox:
376
+ return ir.TensorBox.create(
377
+ ir.CppTemplateBuffer(
378
+ layout=self.layout,
379
+ inputs=self.input_nodes,
380
+ make_kernel_render=self.make_kernel_render,
381
+ template=self.template,
382
+ choice=self,
383
+ )
384
+ )
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_utils.py ADDED
@@ -0,0 +1,916 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import contextlib
3
+ import copy
4
+ import functools
5
+ import math
6
+ import sys
7
+ from collections import namedtuple
8
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple
9
+ from unittest.mock import patch
10
+
11
+ import sympy
12
+
13
+ import torch
14
+ from torch._prims_common import is_integer_dtype
15
+ from torch.utils._sympy.symbol import symbol_is_type, SymT
16
+ from torch.utils._sympy.value_ranges import ValueRanges
17
+
18
+ from .. import ir
19
+ from ..loop_body import LoopBody
20
+ from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
21
+ from ..virtualized import ops, OpsValue, V
22
+ from .common import (
23
+ CSEVariable,
24
+ deduce_output_dtype_by_name,
25
+ ExprPrinter,
26
+ Kernel,
27
+ KernelArgs,
28
+ OptimizationContext,
29
+ )
30
+
31
+
32
+ DTYPE_TO_CPP = {
33
+ torch.float32: "float",
34
+ torch.float64: "double",
35
+ torch.float16: "half",
36
+ torch.int64: "int64_t",
37
+ torch.int32: "int32_t",
38
+ torch.int16: "int16_t",
39
+ torch.int8: "int8_t",
40
+ torch.uint64: "uint64_t",
41
+ torch.uint32: "uint32_t",
42
+ torch.uint16: "uint16_t",
43
+ torch.uint8: "uint8_t",
44
+ torch.bool: "bool",
45
+ torch.bfloat16: "bfloat16",
46
+ torch.complex64: "c10::complex<float>",
47
+ torch.float8_e4m3fn: "float8_e4m3fn",
48
+ torch.float8_e5m2: "float8_e5m2",
49
+ }
50
+
51
+ DTYPE_TO_ATEN = {
52
+ torch.float32: "at::kFloat",
53
+ torch.float64: "at::kDouble",
54
+ torch.float16: "at::kHalf",
55
+ torch.int64: "at::kLong",
56
+ torch.int32: "at::kInt",
57
+ torch.int16: "at::kShort",
58
+ torch.int8: "at::kChar",
59
+ torch.uint64: "at::kUInt64",
60
+ torch.uint32: "at::kUInt32",
61
+ torch.uint16: "at::kUInt16",
62
+ torch.uint8: "at::kByte",
63
+ torch.uint32: "at::kUInt32",
64
+ torch.uint64: "at::kUInt64",
65
+ torch.bool: "at::kBool",
66
+ torch.bfloat16: "at::kBFloat16",
67
+ torch.complex32: "at::kComplexHalf",
68
+ torch.complex64: "at::kComplexFloat",
69
+ torch.complex128: "at::kComplexDouble",
70
+ torch.float8_e4m3fn: "at::kFloat8_e4m3fn",
71
+ torch.float8_e5m2: "at::kFloat8_e5m2",
72
+ torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz",
73
+ torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz",
74
+ }
75
+
76
+ DEVICE_TO_ATEN = {
77
+ "cpu": "at::kCPU",
78
+ "cuda": "at::kCUDA",
79
+ }
80
+
81
+ LAYOUT_TO_ATEN = {
82
+ torch.strided: "at::kStrided",
83
+ torch._mkldnn: "at::kMkldnn", # type: ignore[attr-defined]
84
+ }
85
+
86
+ _IS_WINDOWS = sys.platform == "win32"
87
+
88
+ INDEX_TYPE = "int64_t"
89
+
90
+ GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"])
91
+
92
+
93
+ def get_promote_dtype(args):
94
+ return (
95
+ functools.reduce(
96
+ torch.promote_types, # type: ignore[arg-type]
97
+ [n.dtype for n in args if isinstance(n, CppCSEVariable)],
98
+ )
99
+ if all(n.dtype is not None for n in args if isinstance(n, CppCSEVariable))
100
+ else None # not enough info to calculate the promote dtype
101
+ )
102
+
103
+
104
+ def promote_args(new_args):
105
+ def promote_arg(arg, promote_type):
106
+ if (
107
+ isinstance(arg, CppCSEVariable)
108
+ and arg.dtype
109
+ and promote_type
110
+ and arg.dtype != promote_type
111
+ ):
112
+ arg = ops.to_dtype(arg, promote_type)
113
+ arg = arg.value if isinstance(arg, OpsValue) else arg
114
+ arg.dtype = promote_type
115
+ return arg
116
+
117
+ promote_type = get_promote_dtype(new_args)
118
+ promote_fn = functools.partial(
119
+ promote_arg,
120
+ promote_type=promote_type,
121
+ )
122
+ if (
123
+ all(
124
+ new_arg.dtype is not None
125
+ for new_arg in new_args
126
+ if isinstance(new_arg, CppCSEVariable)
127
+ )
128
+ and promote_type
129
+ ):
130
+ new_args = list(map(promote_fn, new_args))
131
+ return new_args
132
+
133
+
134
+ def get_opt_ctx(node: torch.fx.Node) -> OptimizationContext:
135
+ return node.meta.get(OptimizationContext.key, None)
136
+
137
+
138
+ def get_current_node_opt_ctx() -> OptimizationContext:
139
+ assert V.interpreter.current_node
140
+ return get_opt_ctx(V.interpreter.current_node)
141
+
142
+
143
+ def deduce_dtype_for_cpp_cse_variable(name, *args, **kwargs):
144
+ if (
145
+ output_dtype := deduce_output_dtype_by_name(
146
+ name,
147
+ *args,
148
+ **kwargs,
149
+ )
150
+ ) is not None:
151
+ return output_dtype
152
+ elif name == "masked":
153
+ # <TODO> Leslie: perhaps we can also deduce the masked dtype by
154
+ # inputs' CppCseVariable like other. Let's check it if any
155
+ # unexpected failures.
156
+ assert (
157
+ hasattr(V.interpreter, "current_node")
158
+ and V.interpreter.current_node.target.startswith("masked_subblock")
159
+ and get_current_node_opt_ctx() is not None
160
+ )
161
+ return get_current_node_opt_ctx().dtype
162
+ else:
163
+ # deduce output dtype by inputs' dtype
164
+ assert all(
165
+ arg.dtype is not None for arg in args if isinstance(arg, CppCSEVariable)
166
+ )
167
+ return functools.reduce(
168
+ torch.promote_types, # type: ignore[arg-type]
169
+ [arg.dtype for arg in args if isinstance(arg, CppCSEVariable)],
170
+ )
171
+
172
+
173
+ class CppCSEVariable(CSEVariable):
174
+ def __init__(self, name, bounds: ValueRanges[Any]) -> None:
175
+ super().__init__(name, bounds)
176
+ self.is_vec = False
177
+ self.dtype: Optional[torch.dtype] = None
178
+ self.dependent_itervars: Set[sympy.Symbol] = set()
179
+
180
+ def __repr__(self) -> str:
181
+ return (
182
+ f"CppCSEVariable(name: {self.name}, bounds: {self.bounds}, is_vec: {self.is_vec}, dtype: {self.dtype}, "
183
+ f"dependent_itervars: {self.dependent_itervars})"
184
+ )
185
+
186
+ def update_on_args(self, name, args, kwargs):
187
+ if name == "load":
188
+ # args[2] is index
189
+ self._set_dependent_itervars(args[2])
190
+ else:
191
+ # propagate relevant itervars and is_vec from args
192
+ self.dependent_itervars.update(
193
+ *[
194
+ arg.dependent_itervars
195
+ for arg in args
196
+ if isinstance(arg, CppCSEVariable)
197
+ ]
198
+ )
199
+ if name == "index_expr":
200
+ self._set_dependent_itervars(args[0])
201
+ if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)):
202
+ self.is_vec = True
203
+ # NOTE [Deduce dtype of CppCSEVariable at runtime]
204
+ self.dtype = deduce_dtype_for_cpp_cse_variable(name, *args, **kwargs)
205
+ assert self.dtype is not None
206
+
207
+ def _set_dependent_itervars(self, index: sympy.Expr):
208
+ """
209
+ Set the relevant itervars for this variable based on the `index` expression.
210
+ This includes the itervars directly used in the `index` as well as relevant itervars
211
+ of other cse variables used in the `index`.
212
+ """
213
+ for s in index.free_symbols:
214
+ if s in V.kernel.itervars:
215
+ self.dependent_itervars.add(s) # type: ignore[arg-type]
216
+ elif s.name in V.kernel.cse.varname_map: # type: ignore[attr-defined]
217
+ self.dependent_itervars.update(
218
+ V.kernel.cse.varname_map[s.name].dependent_itervars # type: ignore[attr-defined]
219
+ )
220
+
221
+ def depends_on(self, itervar: sympy.Symbol):
222
+ return itervar in self.dependent_itervars
223
+
224
+
225
+ class CppPrinter(ExprPrinter):
226
+ def _print_Integer(self, expr):
227
+ return (
228
+ f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L"
229
+ )
230
+
231
+ def _print_Where(self, expr):
232
+ c = self.paren(self.doprint(expr.args[0]))
233
+ p = self.paren(self.doprint(expr.args[1]))
234
+ q = self.paren(self.doprint(expr.args[2]))
235
+ return f"{c} ? {p} : {q}"
236
+
237
+ def _print_ModularIndexing(self, expr):
238
+ x, div, mod = expr.args
239
+ x = self.paren(self.doprint(x))
240
+ if div != 1:
241
+ div = self.paren(self.doprint(div))
242
+ if expr.is_integer:
243
+ x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
244
+ else:
245
+ x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
246
+ mod = self.paren(self.doprint(mod))
247
+ return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})"
248
+
249
+ def _print_FloorDiv(self, expr):
250
+ x, div = expr.args
251
+ x = self.paren(self.doprint(x))
252
+ div = self.paren(self.doprint(div))
253
+ if expr.is_integer:
254
+ return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
255
+ return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
256
+
257
+ def _print_floor(self, expr):
258
+ assert len(expr.args) == 1
259
+ r = f"std::floor({self._print(expr.args[0])})"
260
+ return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
261
+
262
+ def _print_FloorToInt(self, expr):
263
+ assert len(expr.args) == 1
264
+ r = f"std::floor({self._print(expr.args[0])})"
265
+ return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
266
+
267
+ def _print_TruncToInt(self, expr):
268
+ assert len(expr.args) == 1
269
+ r = f"std::trunc({self._print(expr.args[0])})"
270
+ return f"static_cast<{INDEX_TYPE}>({r})"
271
+
272
+ def _print_TruncToFloat(self, expr):
273
+ assert len(expr.args) == 1
274
+ return f"std::trunc({self._print(expr.args[0])})"
275
+
276
+ def _print_ToFloat(self, expr):
277
+ assert len(expr.args) == 1
278
+ return f"static_cast<double>({self._print(expr.args[0])})"
279
+
280
+ # TODO: This is wrong if one of the inputs is negative. This is hard to
281
+ # tickle though, as the inputs are typically positive (and if we can prove
282
+ # they are positive, we will have used Mod instead, for which this codegen
283
+ # is right).
284
+ def _print_PythonMod(self, expr):
285
+ return " % ".join(map(self.paren, map(self._print, expr.args)))
286
+
287
+ def _print_CMod(self, expr):
288
+ return " % ".join(map(self.paren, map(self._print, expr.args)))
289
+
290
+ def _print_IntTrueDiv(self, expr):
291
+ lhs, rhs = expr.args
292
+ # TODO: This is only accurate up to 2**53
293
+ return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
294
+
295
+ # TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
296
+ # use std::pow, that operates on floats
297
+ def _print_PowByNatural(self, expr):
298
+ raise NotImplementedError(
299
+ f"_print_PowByNatural not implemented for {type(self)}"
300
+ )
301
+
302
+ def _print_FloatTrueDiv(self, expr):
303
+ lhs, rhs = expr.args
304
+ return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
305
+
306
+ def _print_FloatPow(self, expr):
307
+ base, exp = expr.args
308
+ return f"std::pow({self._print(base)}, {self._print(exp)})"
309
+
310
+ def _print_Pow(self, expr):
311
+ # Uses float constants to perform FP div
312
+ base, exp = expr.args
313
+ base = self._print(base)
314
+
315
+ if exp == 0.5 or exp == -0.5:
316
+ return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
317
+ if exp.is_integer:
318
+ exp = int(exp)
319
+ if exp > 0:
320
+ r = "*".join([self.paren(base)] * exp)
321
+ elif exp < 0:
322
+ r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp)))
323
+ else: # exp == 0
324
+ r = "1.0"
325
+
326
+ return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
327
+ else:
328
+ # TODO: float vs double
329
+ return f"std::pow({base}, {float(exp)})"
330
+
331
+ def _print_Rational(self, expr):
332
+ # Uses float constants to perform FP div
333
+ if expr.q == 1:
334
+ r = f"{expr.p}"
335
+ else:
336
+ r = f"{expr.p}.0/{expr.q}.0"
337
+ return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
338
+
339
+ def _print_ceiling(self, expr):
340
+ assert len(expr.args) == 1
341
+ r = f"std::ceil({self._print(expr.args[0])})"
342
+ return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
343
+
344
+ def _print_CeilToInt(self, expr):
345
+ assert len(expr.args) == 1
346
+ r = f"std::ceil({self._print(expr.args[0])})"
347
+ return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
348
+
349
+ def _print_Min(self, expr):
350
+ args = [self._print(a) for a in expr.args]
351
+ if len(args) == 2:
352
+ return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
353
+ else:
354
+ # Initializer list overload
355
+ il = "{" + ", ".join(args) + "}"
356
+ return f"std::min({il})"
357
+
358
+ def _print_Max(self, expr):
359
+ args = [self._print(a) for a in expr.args]
360
+ if len(args) == 2:
361
+ return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
362
+ else:
363
+ # Initializer list overload
364
+ il = "{" + ", ".join(args) + "}"
365
+ return f"std::max({il})"
366
+
367
+ def _print_Abs(self, expr):
368
+ assert len(expr.args) == 1
369
+ return f"std::abs({self._print(expr.args[0])})"
370
+
371
+ def _print_OpaqueUnaryFn_cos(self, expr):
372
+ assert len(expr.args) == 1
373
+ return f"std::cos({self._print(expr.args[0])})"
374
+
375
+ def _print_OpaqueUnaryFn_cosh(self, expr):
376
+ assert len(expr.args) == 1
377
+ return f"std::cosh({self._print(expr.args[0])})"
378
+
379
+ def _print_OpaqueUnaryFn_acos(self, expr):
380
+ assert len(expr.args) == 1
381
+ return f"std::acos({self._print(expr.args[0])})"
382
+
383
+ def _print_OpaqueUnaryFn_sin(self, expr):
384
+ assert len(expr.args) == 1
385
+ return f"std::sin({self._print(expr.args[0])})"
386
+
387
+ def _print_OpaqueUnaryFn_sinh(self, expr):
388
+ assert len(expr.args) == 1
389
+ return f"std::sinh({self._print(expr.args[0])})"
390
+
391
+ def _print_OpaqueUnaryFn_asin(self, expr):
392
+ assert len(expr.args) == 1
393
+ return f"std::asin({self._print(expr.args[0])})"
394
+
395
+ def _print_OpaqueUnaryFn_tan(self, expr):
396
+ assert len(expr.args) == 1
397
+ return f"std::tan({self._print(expr.args[0])})"
398
+
399
+ def _print_OpaqueUnaryFn_tanh(self, expr):
400
+ assert len(expr.args) == 1
401
+ return f"std::tanh({self._print(expr.args[0])})"
402
+
403
+ def _print_OpaqueUnaryFn_atan(self, expr):
404
+ assert len(expr.args) == 1
405
+ return f"std::atan({self._print(expr.args[0])})"
406
+
407
+ def _print_OpaqueUnaryFn_sqrt(self, expr):
408
+ return f"std::sqrt({self._print(expr.args[0])})"
409
+
410
+ def _print_RoundToInt(self, expr):
411
+ assert len(expr.args) == 1
412
+ # TODO: dispatch to llrint depending on index type
413
+ return f"std::lrint({self._print(expr.args[0])})"
414
+
415
+ def _print_RoundDecimal(self, expr):
416
+ assert len(expr.args) == 2
417
+ number, ndigits = expr.args
418
+ if number.is_integer:
419
+ # ndigits < 0 should have been filtered by the sympy function
420
+ assert ndigits < 0
421
+ raise ValueError(
422
+ f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
423
+ )
424
+ return f"static_cast<double>(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})"
425
+
426
+ def _print_BooleanTrue(self, expr):
427
+ return "true"
428
+
429
+ def _print_BooleanFalse(self, expr):
430
+ return "false"
431
+
432
+
433
+ # A function to print, useful for printing sympy symbols.
434
+ cexpr = CppPrinter().doprint
435
+
436
+
437
+ def cexpr_index(index):
438
+ return f"static_cast<{INDEX_TYPE}>({cexpr(index)})"
439
+
440
+
441
+ def value_to_cpp(value, cpp_type):
442
+ if value == float("-inf"):
443
+ return f"-std::numeric_limits<{cpp_type}>::infinity()"
444
+ elif value == float("inf"):
445
+ return f"std::numeric_limits<{cpp_type}>::infinity()"
446
+ elif isinstance(value, bool):
447
+ return f"static_cast<{cpp_type}>({str(value).lower()})"
448
+ elif math.isnan(value):
449
+ return f"std::numeric_limits<{cpp_type}>::quiet_NaN()"
450
+ else:
451
+ return f"static_cast<{cpp_type}>({repr(value)})"
452
+
453
+
454
+ def rewrite_index_for_function(
455
+ localize_buffer_handler: "LocalizeBufferHandler",
456
+ index: sympy.Expr,
457
+ global_buf_name: str,
458
+ ):
459
+ # Local buffer at the inner dimensions
460
+ snode = V.graph.scheduler.name_to_buf[global_buf_name].defining_op
461
+ local_buf = localize_buffer_handler.global_to_local[global_buf_name]
462
+ scheduler_nodes = snode.get_nodes()
463
+ _, (group, reduction_group) = max(
464
+ scheduler_nodes, key=lambda x: int(x.is_reduction())
465
+ ).group
466
+ call_ranges = tuple(group) + tuple(reduction_group)
467
+ indices_to_keep = [
468
+ f"x{len(call_ranges) - (idx + 1)}"
469
+ for idx in range(len(local_buf.get_layout().size))
470
+ ]
471
+ sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) # type: ignore[attr-defined]
472
+ replacements = {}
473
+ for x in sorted_symbols:
474
+ if x.name.startswith("x") and x.name not in indices_to_keep: # type: ignore[attr-defined]
475
+ # Only keep index used by local buffer
476
+ replacements[x] = sympy.core.numbers.Zero()
477
+ index = sympy_subs(index, replacements) # type: ignore[arg-type]
478
+ return index
479
+
480
+
481
+ def rewrite_index_for_nodes(
482
+ localize_buffer_handler: "LocalizeBufferHandler",
483
+ index: sympy.Expr,
484
+ global_buf_name: str,
485
+ ):
486
+ used_vars = {s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)}
487
+ index_vars = []
488
+ local_buf = localize_buffer_handler.global_to_local[global_buf_name]
489
+ for i in range(len(local_buf.get_size())):
490
+ var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
491
+ index_vars.append(var if var in used_vars else 0)
492
+ index = local_buf.layout.make_indexer()(index_vars)
493
+ return index
494
+
495
+
496
+ class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
497
+ def __init__(
498
+ self,
499
+ inner,
500
+ global_to_local: Dict[str, ir.Buffer],
501
+ rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr],
502
+ ) -> None:
503
+ super().__init__(inner)
504
+ self.global_to_local = global_to_local
505
+ self.rewrite_index = rewrite_index
506
+
507
+ def localize(self, name: str, index: sympy.Expr):
508
+ if self.global_to_local and name in self.global_to_local:
509
+ assert self.rewrite_index is not None
510
+ index = self.rewrite_index(self, index, name)
511
+ name = self.global_to_local[name].get_name()
512
+ return name, index
513
+
514
+ def load(self, name: str, index: sympy.Expr):
515
+ return self._inner.load(*self.localize(name, index))
516
+
517
+ def store(self, name, index, value, mode=None):
518
+ local_buffer_name, local_buffer_index = self.localize(name, index)
519
+ res = self._inner.store(local_buffer_name, local_buffer_index, value, mode)
520
+ if (
521
+ self.global_to_local
522
+ and name in self.global_to_local
523
+ and isinstance(V.kernel, Kernel)
524
+ ):
525
+ # Remove name of local buffer from Kernel.store_buffer_names
526
+ # local_buffer_name is added to Kernel.store_buffer_names in Kernel.CSEProxy.store.
527
+ V.kernel.store_buffer_names.discard(local_buffer_name)
528
+ return res
529
+
530
+ def store_reduction(self, name, index, value):
531
+ return self._inner.store_reduction(*self.localize(name, index), value)
532
+
533
+
534
+ class LocalBufferContext:
535
+ """
536
+ This class creates a context that helps to generate code involving Inductor IR with
537
+ function local buffers. These buffers are constructed during the codegen process and
538
+ are used to store intermediate results such as local accumulators. We do not want to
539
+ add them to `V.graph` since they are not global and we do not want to add them as
540
+ function arguments either. So we patch the codegen processes under this scope to support
541
+ these buffers without exposure to the outside world.
542
+ """
543
+
544
+ def __init__(self, kernel_args: KernelArgs) -> None:
545
+ self.kernel_args = kernel_args
546
+ self.exit_stack = contextlib.ExitStack()
547
+ # map local buffer name to local buffer
548
+ self.local_buffers: Dict[str, ir.Buffer] = {}
549
+ # map global buffer name to global buffer
550
+ self.global_buffers: Dict[str, ir.Buffer] = {}
551
+ # map global buffer name to local buffer
552
+ self.global_to_local: Dict[str, ir.Buffer] = {}
553
+
554
+ def __enter__(self):
555
+ self.exit_stack.__enter__()
556
+ original_get_dtype = V.graph.get_dtype
557
+
558
+ def get_dtype(name):
559
+ if name in self.local_buffers:
560
+ return self.local_buffers[name].get_dtype()
561
+ return original_get_dtype(name)
562
+
563
+ self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype))
564
+
565
+ original_input = self.kernel_args.input
566
+
567
+ def input(name):
568
+ if name in self.local_buffers:
569
+ return name
570
+ return original_input(name)
571
+
572
+ self.exit_stack.enter_context(patch.object(self.kernel_args, "input", input))
573
+
574
+ original_output = self.kernel_args.output
575
+
576
+ def output(name):
577
+ if name in self.local_buffers:
578
+ return name
579
+ return original_output(name)
580
+
581
+ self.exit_stack.enter_context(patch.object(self.kernel_args, "output", output))
582
+
583
+ # Set current LocalBufferContext into V
584
+ self.exit_stack.enter_context(V.set_local_buffer_context(self))
585
+
586
+ return self
587
+
588
+ def __exit__(self, exc_type, exc_val, exc_tb):
589
+ self.local_buffers.clear()
590
+ self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
591
+
592
+ def add_local_buffer(
593
+ self, local_buffer: ir.Buffer, global_buffers: Optional[List[ir.Buffer]] = None
594
+ ):
595
+ assert local_buffer.get_name() not in self.local_buffers
596
+ self.local_buffers[local_buffer.get_name()] = local_buffer
597
+ if global_buffers:
598
+ for global_buffer in global_buffers:
599
+ global_buffer_name = global_buffer.get_name()
600
+ assert (
601
+ global_buffer_name not in self.global_buffers
602
+ and global_buffer_name not in self.global_to_local
603
+ )
604
+ self.global_buffers[global_buffer_name] = global_buffer
605
+ self.global_to_local[global_buffer_name] = local_buffer
606
+ V.graph.removed_buffers.add(global_buffer_name)
607
+
608
+ def localize_function(
609
+ self,
610
+ fn: Callable[..., Any],
611
+ rewrite_index: Callable[
612
+ ["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr
613
+ ] = rewrite_index_for_function,
614
+ ):
615
+ def inner(*args, **kwargs):
616
+ with V.set_ops_handler(
617
+ LocalizeBufferHandler(
618
+ V.get_ops_handler(),
619
+ global_to_local=self.global_to_local,
620
+ rewrite_index=rewrite_index,
621
+ )
622
+ ):
623
+ return fn(*args, **kwargs)
624
+
625
+ return inner
626
+
627
+ def localize_nodes(
628
+ self,
629
+ nodes: List[ir.IRNode],
630
+ rewrite_index: Callable[
631
+ ["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr
632
+ ] = rewrite_index_for_nodes,
633
+ ) -> List[ir.IRNode]:
634
+ """
635
+ Given `local_buf` and `global_buf` registered in current `LocalBufferContext`
636
+ though the method of `add_local_buffer`, localizes the `global_buf` to `local_buf`
637
+ for the given `nodes` and returns a new list of IR nodes that work on `local_buf`
638
+ instead of `global_buf`, i.e., all the loads and stores are redirected to
639
+ `local_buf`. This helps the fused loops to work on smaller-sized local buffers
640
+ for better data locality.
641
+
642
+ The the data access of `local_buf` is assumed to be contiguous with the
643
+ same order as the `global_buf`.
644
+ """
645
+ assert len(nodes) > 0
646
+
647
+ def wrap_inner_fn_for_node(node: ir.IRNode):
648
+ loops = node.data if isinstance(node, ir.ComputedBuffer) else node
649
+ assert isinstance(loops, ir.Loops)
650
+ new_loops = copy.copy(loops)
651
+ if isinstance(node, ir.ComputedBuffer):
652
+ new_node = ir.ComputedBuffer(
653
+ node.get_name(), node.get_layout(), new_loops
654
+ )
655
+ else:
656
+ new_node = new_loops # type: ignore[assignment]
657
+
658
+ new_loops.inner_fn = self.localize_function(
659
+ new_loops.inner_fn,
660
+ rewrite_index,
661
+ )
662
+ return new_node
663
+
664
+ return [wrap_inner_fn_for_node(node) for node in nodes]
665
+
666
+
667
+ def unify_mask_base_type(
668
+ buffer: IndentedBuffer,
669
+ vars: Tuple[CSEVariable, ...],
670
+ dtype=torch.float,
671
+ ):
672
+ """
673
+ Given list of cse variables,
674
+ Cast each to new mask base dtype and return casted cse variable.
675
+ """
676
+ new_vars = (
677
+ V.kernel.cse.generate(
678
+ buffer,
679
+ f"{V.kernel._get_mask_cast(var, dtype)}",
680
+ )
681
+ for var in vars
682
+ )
683
+ return new_vars
684
+
685
+
686
+ def codegen_rand(offset, code, rand_function, dst_dtype=torch.float32):
687
+ assert is_integer_dtype(offset.dtype)
688
+ code.writeline("[&]()")
689
+ with code.indent():
690
+ code.writeline(
691
+ f"{DTYPE_TO_CPP[offset.dtype]} offset[{V.kernel.tiling_factor}];"
692
+ )
693
+ code.writeline(f"{DTYPE_TO_CPP[dst_dtype]} result[{V.kernel.tiling_factor}];")
694
+ code.writeline(f"{offset}.store(offset);")
695
+ code.writeline(
696
+ f"for( {DTYPE_TO_CPP[offset.dtype]} offset_idx = 0; offset_idx < {V.kernel.tiling_factor}; offset_idx++ )"
697
+ )
698
+ with code.indent():
699
+ code.writeline(rand_function)
700
+ num_vectors = V.kernel._get_num_vectors(dtype=dst_dtype)
701
+ if num_vectors == 1:
702
+ code.writeline(
703
+ f"return at::vec::Vectorized<{DTYPE_TO_CPP[dst_dtype]}>::loadu(result);"
704
+ )
705
+ else:
706
+ code.writeline(
707
+ f"return at::vec::VectorizedN<{DTYPE_TO_CPP[dst_dtype]}, {num_vectors}>::loadu(result);"
708
+ )
709
+ code.writeline("()")
710
+ return code
711
+
712
+
713
+ def get_gemm_template_output_and_compute_dtype(input_dtype):
714
+ if input_dtype == torch.uint8:
715
+ return (torch.int32, torch.int32)
716
+ else:
717
+ return (torch.float32, torch.float32)
718
+
719
+
720
+ def create_epilogue_with_attr(input_buffer, attr, **kwargs):
721
+ input_loader = input_buffer.make_loader()
722
+ dtype = input_buffer.get_dtype()
723
+ if attr == "relu":
724
+
725
+ def inner_fn(index):
726
+ input = input_loader(index)
727
+ zero = ops.constant(0, dtype)
728
+ return ops.maximum(input, zero)
729
+
730
+ elif attr == "gelu":
731
+ assert "algorithm" in kwargs
732
+ if kwargs["algorithm"] == "none":
733
+
734
+ def inner_fn(index):
735
+ input = input_loader(index)
736
+ if dtype != torch.float:
737
+ input = ops.to_dtype(input, torch.float)
738
+ half = ops.constant(0.5, torch.float)
739
+ one = ops.constant(1.0, torch.float)
740
+ const = ops.constant(0.7071067811865476, torch.float)
741
+ result = input * half * (ops.erf(input * const) + one)
742
+ if dtype != torch.float:
743
+ result = ops.to_dtype(result, dtype)
744
+ return result
745
+
746
+ else:
747
+ assert kwargs["algorithm"] == "tanh"
748
+
749
+ def inner_fn(index):
750
+ input = input_loader(index)
751
+ if dtype != torch.float:
752
+ input = ops.to_dtype(input, torch.float)
753
+ half = ops.constant(0.5, torch.float)
754
+ one = ops.constant(1.0, torch.float)
755
+ const1 = ops.constant(0.7978845608028654, torch.float)
756
+ const2 = ops.constant(0.044715, torch.float)
757
+ result = (
758
+ half
759
+ * input
760
+ * (
761
+ one
762
+ + ops.tanh(const1 * (input + const2 * input * input * input))
763
+ )
764
+ )
765
+ if dtype != torch.float:
766
+ result = ops.to_dtype(result, dtype)
767
+ return result
768
+
769
+ elif attr == "swish":
770
+
771
+ def inner_fn(index):
772
+ input = input_loader(index)
773
+ result = input * ops.sigmoid(input)
774
+ return result
775
+
776
+ elif attr == "sigmoid":
777
+
778
+ def inner_fn(index):
779
+ return ops.sigmoid(input_loader(index))
780
+
781
+ elif attr == "tanh":
782
+
783
+ def inner_fn(index):
784
+ return ops.tanh(input_loader(index))
785
+
786
+ elif attr == "hardswish" or attr == "hardsigmoid":
787
+
788
+ def hardsigmoid_float(input):
789
+ zero = ops.constant(0, torch.float)
790
+ six = ops.constant(6, torch.float)
791
+ three = ops.constant(3, torch.float)
792
+ one_over_six = ops.constant(0.16666666666666666, torch.float)
793
+ max = ops.maximum(input + three, zero)
794
+ min = ops.minimum(max, six)
795
+ return min * one_over_six
796
+
797
+ def inner_fn(index):
798
+ input = input_loader(index)
799
+ if dtype != torch.float:
800
+ input = ops.to_dtype(input, torch.float)
801
+ result = hardsigmoid_float(input)
802
+ if attr == "hardswish":
803
+ result = input * result
804
+ if dtype != torch.float:
805
+ result = ops.to_dtype(result, dtype)
806
+ return result
807
+
808
+ elif attr == "leaky_relu":
809
+ assert "scalars" in kwargs
810
+ assert len(kwargs["scalars"]) == 1
811
+ negative_slope = kwargs["scalars"][0]
812
+
813
+ def inner_fn(index):
814
+ input = input_loader(index)
815
+ if dtype != torch.float:
816
+ input = ops.to_dtype(input, torch.float)
817
+ zero = ops.constant(0, torch.float)
818
+ result = ops.where(
819
+ input > zero, input, input * ops.constant(negative_slope, torch.float)
820
+ )
821
+ if dtype != torch.float:
822
+ result = ops.to_dtype(result, dtype)
823
+ return result
824
+
825
+ elif attr == "hardtanh":
826
+ assert "scalars" in kwargs
827
+ assert len(kwargs["scalars"]) == 2
828
+ min_value = kwargs["scalars"][0]
829
+ max_value = kwargs["scalars"][1]
830
+
831
+ def inner_fn(index):
832
+ input = input_loader(index)
833
+ if dtype != torch.float:
834
+ input = ops.to_dtype(input, torch.float)
835
+ result = ops.minimum(
836
+ ops.maximum(input, ops.constant(min_value, torch.float)),
837
+ ops.constant(max_value, torch.float),
838
+ )
839
+ if dtype != torch.float:
840
+ result = ops.to_dtype(result, dtype)
841
+ return result
842
+
843
+ elif attr in ["add", "sub", "mul"]:
844
+ assert "other" in kwargs
845
+ other = kwargs["other"]
846
+ num_input_dims = len(input_buffer.get_size())
847
+ num_other_dims = len(other.get_size())
848
+ dims_diff = num_input_dims - num_other_dims
849
+ other_loader = other.make_loader()
850
+
851
+ def inner_fn(index):
852
+ op = getattr(ops, attr)
853
+ if dims_diff != 0:
854
+ return op(input_loader(index), other_loader(index[dims_diff:]))
855
+ else:
856
+ return op(input_loader(index), other_loader(index))
857
+
858
+ elif attr == "bias_add":
859
+ assert "other" in kwargs
860
+ assert "beta" in kwargs
861
+ assert "dtype" in kwargs
862
+ beta = kwargs["beta"]
863
+ other = kwargs["other"]
864
+ dtype = kwargs["dtype"]
865
+ bias_loader = other.make_loader()
866
+
867
+ def inner_fn(index):
868
+ bias = bias_loader(index)
869
+ input = input_loader(index)
870
+ if beta != 1:
871
+ result = ops.constant(beta, torch.float) * bias + input
872
+ else:
873
+ result = bias + input
874
+ return result
875
+
876
+ else:
877
+ raise ValueError(f"Unsupported epilogue attribute: {attr}")
878
+ return ir.Pointwise(
879
+ device=input_buffer.get_device(),
880
+ dtype=dtype,
881
+ inner_fn=inner_fn,
882
+ ranges=input_buffer.get_size(),
883
+ )
884
+
885
+
886
+ def _get_loop_body(fn_list):
887
+ if all(isinstance(fn, LoopBody) for fn in fn_list):
888
+ loop_bodies = fn_list
889
+ else:
890
+ if hasattr(fn_list[0], "original_fn"):
891
+ # For the case of local buffer, we wrap the fn with localize_function
892
+ assert all(hasattr(fn, "original_fn") for fn in fn_list)
893
+ assert all(
894
+ isinstance(fn.original_fn.args[0]._body, LoopBody) for fn in fn_list
895
+ )
896
+ loop_bodies = [fn.original_fn.args[0]._body for fn in fn_list]
897
+ else:
898
+ assert all(isinstance(fn, functools.partial) for fn in fn_list)
899
+ assert all(isinstance(fn.args[0]._body, LoopBody) for fn in fn_list)
900
+ loop_bodies = [fn.args[0]._body for fn in fn_list]
901
+ assert loop_bodies is not None
902
+ return loop_bodies
903
+
904
+
905
+ def _get_dtype_from_loopbodies(loop_bodies):
906
+ dtypes = set()
907
+ for loop_body in loop_bodies:
908
+ graphs = [loop_body.root_block.graph] + [
909
+ body.graph for body in list(loop_body.subblocks.values())
910
+ ]
911
+ for graph in graphs:
912
+ for node in graph.nodes:
913
+ if node.op != "call_method":
914
+ continue
915
+ dtypes.add(node.meta[OptimizationContext.key].dtype)
916
+ return dtypes
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cuda.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+ import os
4
+ from itertools import chain, count
5
+ from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, Union
6
+
7
+ import sympy
8
+
9
+ from torch import dtype as torch_dtype
10
+ from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
11
+ from torch._inductor.runtime.triton_heuristics import grid as default_grid
12
+
13
+ from .. import config
14
+ from ..codecache import CudaKernelParamCache
15
+ from ..utils import DeferredLineBase
16
+ from ..virtualized import V
17
+ from .aoti_hipify_utils import maybe_hipify_code_wrapper
18
+ from .codegen_device_driver import cuda_kernel_driver, cuda_kernel_header
19
+ from .cpp_utils import cexpr, DTYPE_TO_CPP
20
+ from .cpp_wrapper_cpu import CppWrapperCpu
21
+ from .wrapper import SymbolicCallArg
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ from ..graph import GraphLowering
26
+
27
+
28
+ class DeferredCudaKernelLine(DeferredLineBase):
29
+ """
30
+ When using cpp wrapper, CUDA kernel load and launch needs to wait for Triton kernels
31
+ to be tuned and stored as cubin files, so use a deferred line to backfill those information
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ kernel_name: str,
37
+ line_template: str,
38
+ keys: Tuple[str, ...],
39
+ ):
40
+ super().__init__(line_template)
41
+ assert not isinstance(line_template, DeferredLineBase)
42
+ self.kernel_name = kernel_name
43
+ self.line_template = line_template
44
+ self.keys = keys
45
+
46
+ def __call__(self):
47
+ params = CudaKernelParamCache.get(self.kernel_name)
48
+ assert (
49
+ params is not None
50
+ ), f"{self.kernel_name} not found in CudaKernelParamCache"
51
+ for key in self.keys:
52
+ assert (
53
+ key in params
54
+ ), f"{key} not found in CudaKernelParamCache[{self.kernel_name}]"
55
+ if key == get_cpp_wrapper_cubin_path_name():
56
+ assert os.path.exists(params[key]), f"{params[key]} does not exist"
57
+
58
+ return self.line_template % tuple(params[key] for key in self.keys)
59
+
60
+ def _new_line(self, line):
61
+ return DeferredCudaKernelLine(self.kernel_name, line, self.keys)
62
+
63
+
64
+ class DeferredCudaDefaultGrid:
65
+ """
66
+ A container for the default grid, which may be used by DeferredCudaGridLine
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ kernel_name: str,
72
+ grid,
73
+ grid_callable: Optional[Callable[..., Any]] = None,
74
+ **grid_extra_kwargs,
75
+ ):
76
+ self.kernel_name = kernel_name
77
+ self.grid = grid
78
+ self.grid_callable = grid_callable
79
+ self.grid_extra_kwargs = grid_extra_kwargs
80
+
81
+ def _process_grid(self, grid: Union[List[Any], Tuple[Any, ...]]):
82
+ if isinstance(grid, (list, tuple)):
83
+ return [self._process_grid(e) for e in grid]
84
+ else:
85
+ return grid.inner_expr if isinstance(grid, SymbolicCallArg) else grid
86
+
87
+ def __call__(self):
88
+ grid = self.grid
89
+ assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list"
90
+ grid = self._process_grid(grid)
91
+ grid_callable = self.grid_callable or default_grid
92
+ if not self.grid_extra_kwargs:
93
+ grid_fn = grid_callable(*grid)
94
+ else:
95
+ grid_fn = grid_callable(*grid, **self.grid_extra_kwargs)
96
+
97
+ params = CudaKernelParamCache.get(self.kernel_name)
98
+ assert (
99
+ params is not None
100
+ ), f"{self.kernel_name} not found in CudaKernelParamCache"
101
+ block_cfg = {
102
+ "XBLOCK": params["x_block"],
103
+ "YBLOCK": params["y_block"],
104
+ "ZBLOCK": params["z_block"],
105
+ }
106
+ return grid_fn(block_cfg)
107
+
108
+
109
+ class DeferredCudaGridLine(DeferredLineBase):
110
+ """
111
+ When using cpp wrapper, CUDA kernel load and launch needs to wait for Triton kernels
112
+ to be tuned and stored as cubin files, so use a deferred line to backfill those information
113
+ """
114
+
115
+ def __init__(
116
+ self,
117
+ kernel_name: str,
118
+ grid_var: str,
119
+ grid,
120
+ autotune_configs,
121
+ ):
122
+ super().__init__("")
123
+ self.kernel_name = kernel_name
124
+ self.grid_var = grid_var
125
+ self.grid = grid
126
+ self.autotune_configs = autotune_configs
127
+
128
+ def __call__(self):
129
+ params = CudaKernelParamCache.get(self.kernel_name)
130
+ assert (
131
+ params is not None
132
+ ), f"{self.kernel_name} not found in CudaKernelParamCache"
133
+
134
+ if self.autotune_configs is not None:
135
+ # This indicates the Triton kernel is a user-defined one.
136
+ grid = None
137
+ if len(self.grid) == 1:
138
+ grid = self.grid[0]
139
+ else:
140
+ for i, c in enumerate(self.autotune_configs):
141
+ if all(arg == params["meta"][key] for key, arg in c.kwargs.items()):
142
+ grid = self.grid[i]
143
+ break
144
+ assert grid is not None
145
+ elif isinstance(self.grid, DeferredCudaDefaultGrid):
146
+ grid = self.grid()
147
+ else:
148
+ grid = self.grid
149
+
150
+ assert len(grid) != 0, "Grid can't be empty"
151
+ grid_args_str = ", ".join(
152
+ [cexpr(V.graph.sizevars.simplify(item)) for item in grid]
153
+ )
154
+ return f" Grid {self.grid_var} = Grid({grid_args_str});"
155
+
156
+ def _new_line(self, line):
157
+ return DeferredCudaGridLine(
158
+ self.kernel_name, self.grid_var, self.grid, self.autotune_configs
159
+ )
160
+
161
+
162
+ class CppWrapperCuda(CppWrapperCpu):
163
+ """
164
+ Generates cpp wrapper for running on GPU and calls CUDA kernels
165
+ """
166
+
167
+ def __init__(self) -> None:
168
+ self.device = "cuda"
169
+ super().__init__()
170
+ self.grid_id = count()
171
+ self.cuda = True
172
+
173
+ def write_header(self):
174
+ if V.graph.is_const_graph:
175
+ # We do not write header for constant graph, it will be written by main module.
176
+ return
177
+
178
+ super().write_header()
179
+
180
+ self.header.splice("#include <filesystem>")
181
+ if config.abi_compatible:
182
+ self.header.splice(
183
+ "#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>"
184
+ )
185
+ else:
186
+ self.header.splice(maybe_hipify_code_wrapper(cuda_kernel_header()))
187
+ self.header.splice(maybe_hipify_code_wrapper(cuda_kernel_driver()))
188
+
189
+ def write_get_raw_stream(self, index, graph=None):
190
+ name = f"stream{index}"
191
+ self.writeline(maybe_hipify_code_wrapper(f"cudaStream_t {name};"))
192
+ self.writeline(
193
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream({index}, (void**)&{name}));"
194
+ )
195
+ return name
196
+
197
+ def define_kernel(
198
+ self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True
199
+ ):
200
+ if not cuda:
201
+ return super().define_kernel(name, kernel, metadata, cuda)
202
+
203
+ def generate(self, is_inference):
204
+ self.prefix.writeline("\n")
205
+ if not V.graph.aot_mode:
206
+ for kernel in chain(
207
+ sorted(self.src_to_kernel.values()),
208
+ sorted([entry[0] for entry in self.user_defined_kernel_cache.values()]),
209
+ ):
210
+ self.prefix.writeline(
211
+ maybe_hipify_code_wrapper(f"static CUfunction {kernel} = nullptr;")
212
+ )
213
+ self.prefix.writeline("\n")
214
+ return super().generate(is_inference)
215
+
216
+ def generate_user_defined_triton_kernel(
217
+ self,
218
+ kernel_name: str,
219
+ raw_args: List[Any],
220
+ grid: List[Any],
221
+ configs,
222
+ triton_meta,
223
+ constexprs,
224
+ ):
225
+ # in C++ wrapper, we don't pass constexpr args, as they don't
226
+ # get added as parameters to the PTX code compiled from the
227
+ # user-defined Triton kernel (only non-constexpr args do)
228
+ raw_args = [
229
+ raw_arg for i, raw_arg in enumerate(raw_args) if i not in constexprs
230
+ ]
231
+ args = [self.val_to_arg_str(v) for v in raw_args]
232
+ arg_types = [
233
+ arg.get_dtype() if hasattr(arg, "get_dtype") else type(arg)
234
+ for arg in raw_args
235
+ ]
236
+ self.generate_kernel_call(
237
+ kernel_name,
238
+ args,
239
+ arg_types=arg_types,
240
+ raw_args=raw_args,
241
+ grid=grid,
242
+ cuda=True,
243
+ triton=True,
244
+ triton_meta=triton_meta,
245
+ autotune_configs=configs,
246
+ )
247
+
248
+ @functools.lru_cache(None) # noqa: B019
249
+ def generate_load_kernel_once(
250
+ self,
251
+ kernel_name: str,
252
+ graph: "GraphLowering", # for per-graph caching
253
+ ):
254
+ keys = (get_cpp_wrapper_cubin_path_name(), "mangled_name", "shared_mem")
255
+ kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name
256
+ self.writeline(f"if ({kernel_var_name} == nullptr) {{")
257
+ self.writeline(
258
+ DeferredCudaKernelLine(
259
+ kernel_name,
260
+ """ """
261
+ + kernel_var_name
262
+ + """ = loadKernel("%s", "%s", %s, this->cubin_dir_);"""
263
+ if V.graph.aot_mode
264
+ else """ """
265
+ + kernel_var_name
266
+ + """ = loadKernel("%s", "%s", %s);""",
267
+ keys,
268
+ )
269
+ )
270
+ self.writeline("}")
271
+ return kernel_var_name
272
+
273
+ def generate_args_decl(self, call_args, arg_types):
274
+ new_args = []
275
+ for arg, arg_type in zip(call_args, arg_types):
276
+ var_name = f"var_{next(self.arg_var_id)}"
277
+ if isinstance(arg_type, torch_dtype):
278
+ if arg.endswith(".item()"):
279
+ # Need to declare a scalar in this case
280
+ ctype = DTYPE_TO_CPP[arg_type]
281
+ arg = arg[:-7]
282
+ if config.abi_compatible:
283
+ self.codegen_tensor_item(
284
+ arg_type,
285
+ arg,
286
+ var_name,
287
+ )
288
+ else:
289
+ from torch import bfloat16, float16
290
+
291
+ if arg_type in (float16, bfloat16):
292
+ var_name_tmp = f"{var_name}_tmp"
293
+ self.writeline(
294
+ f"{ctype} {var_name_tmp} = {arg}.item<{ctype}>();"
295
+ )
296
+ self.writeline(f"float {var_name} = float({var_name_tmp});")
297
+ else:
298
+ self.writeline(
299
+ f"{ctype} {var_name} = {arg}.item<{ctype}>();"
300
+ )
301
+ else:
302
+ if config.abi_compatible:
303
+ self.writeline(
304
+ maybe_hipify_code_wrapper(f"CUdeviceptr {var_name};")
305
+ )
306
+ self.writeline(
307
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast<void**>(&{var_name})));"
308
+ )
309
+ else:
310
+ self.writeline(
311
+ maybe_hipify_code_wrapper(
312
+ f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());"
313
+ )
314
+ )
315
+ elif arg_type in (sympy.Integer, int):
316
+ self.writeline(f"int {var_name} = {self.expr_printer(arg)};")
317
+ elif arg_type in (sympy.Float, float):
318
+ self.writeline(f"float {var_name} = {self.expr_printer(arg)};")
319
+ else:
320
+ self.writeline(f"auto {var_name} = {self.expr_printer(arg)};")
321
+ new_args.append(f"&{var_name}")
322
+
323
+ return ", ".join(new_args)
324
+
325
+ def generate_default_grid(
326
+ self,
327
+ kernel_name: str,
328
+ grid: List[Any],
329
+ cuda: bool = True,
330
+ grid_callable: Optional[Callable[..., Any]] = None,
331
+ **grid_extra_kwargs,
332
+ ):
333
+ """
334
+ Generate grid configs for launching a CUDA kernel using the grid
335
+ function from triton_heuristics. Because its computation needs
336
+ to read kernel config after autotune, it is done in a deferred way
337
+ using DeferredCudaDefaultGrid.
338
+ """
339
+ if not cuda:
340
+ return grid
341
+ return DeferredCudaDefaultGrid(
342
+ kernel_name, grid, grid_callable, **grid_extra_kwargs
343
+ )
344
+
345
+ def generate_kernel_call(
346
+ self,
347
+ kernel_name: str,
348
+ call_args,
349
+ grid=None,
350
+ device_index=None,
351
+ cuda=True,
352
+ triton=True,
353
+ arg_types=None,
354
+ raw_args=None,
355
+ grid_fn: str = "grid",
356
+ triton_meta=None,
357
+ autotune_configs=None,
358
+ grid_extra_kwargs="",
359
+ ):
360
+ assert arg_types is not None and len(call_args) == len(
361
+ arg_types
362
+ ), "call_args and arg_types do not match"
363
+
364
+ if not cuda:
365
+ # Even in CppWrapperCuda, we may see cpp kernels
366
+ return super().generate_kernel_call(
367
+ kernel_name,
368
+ call_args,
369
+ grid,
370
+ device_index,
371
+ cuda,
372
+ triton,
373
+ arg_types,
374
+ raw_args,
375
+ grid_fn,
376
+ triton_meta,
377
+ autotune_configs,
378
+ grid_extra_kwargs,
379
+ )
380
+
381
+ device_index, call_args = self.prepare_triton_kernel_call(
382
+ device_index, call_args
383
+ )
384
+ kernel_var_name = self.generate_load_kernel_once(kernel_name, V.graph)
385
+
386
+ # args with value 1 are added into equal_to_1 and constants
387
+ # in triton_meta (in the Python codegen) which makes them
388
+ # inlined in the PTX and compiled CUBIN
389
+ if (
390
+ triton_meta is not None
391
+ and "configs" in triton_meta
392
+ and triton_meta["configs"]
393
+ ):
394
+ equal_to_1 = triton_meta["configs"][0].equal_to_1
395
+ call_args = [arg for i, arg in enumerate(call_args) if i not in equal_to_1]
396
+ arg_types = [t for i, t in enumerate(arg_types) if i not in equal_to_1]
397
+
398
+ call_args_str = self.generate_args_decl(call_args, arg_types)
399
+ kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
400
+ self.writeline(f"void* {kernel_args_var}[] = {{{call_args_str}}};")
401
+ stream = (
402
+ "stream"
403
+ if V.graph.aot_mode
404
+ else self.write_get_raw_stream(device_index, V.graph)
405
+ )
406
+
407
+ grid_var = f"{kernel_name}_grid_{next(self.grid_id)}"
408
+ self.writeline(
409
+ DeferredCudaGridLine(kernel_name, grid_var, grid, autotune_configs)
410
+ )
411
+
412
+ kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name
413
+ # add debug printer code for all triton kernel related calls
414
+ debug_printer_manager = V.graph.wrapper_code.debug_printer
415
+ debug_printer_manager.set_printer_args(call_args, kernel_name, arg_types, None)
416
+ with debug_printer_manager:
417
+ self.writeline(f"if ({grid_var}.is_non_zero()) {{")
418
+ self.writeline(
419
+ DeferredCudaKernelLine(
420
+ kernel_name,
421
+ r" launchKernel({}, {}, {}, {}, %s, %s, {}, {});".format(
422
+ kernel_var_name,
423
+ f"{grid_var}.grid_x",
424
+ f"{grid_var}.grid_y",
425
+ f"{grid_var}.grid_z",
426
+ kernel_args_var,
427
+ stream,
428
+ ),
429
+ ("num_warps", "shared_mem"),
430
+ ),
431
+ )
432
+ self.writeline("}")
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (201 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-311.pyc ADDED
Binary file (7.61 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-311.pyc ADDED
Binary file (2.27 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-311.pyc ADDED
Binary file (20.4 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-311.pyc ADDED
Binary file (13 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_epilogue_gen.cpython-311.pyc ADDED
Binary file (20.7 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-311.pyc ADDED
Binary file (17.8 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-311.pyc ADDED
Binary file (1.49 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-311.pyc ADDED
Binary file (72.7 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import logging
3
+ from typing import cast, Sequence
4
+
5
+ from ...._dynamo.utils import counters
6
+ from ... import config
7
+ from ...codecache import code_hash, get_path
8
+ from ...ir import CUDATemplateBuffer
9
+ from ...scheduler import BaseSchedulerNode, BaseScheduling, Scheduler, SchedulerNode
10
+ from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product
11
+ from ...virtualized import V
12
+ from ..common import IndentedBuffer
13
+
14
+
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ class CUDACPPScheduling(BaseScheduling):
19
+ """
20
+ Partial Scheduling implementation for CUDA C++ Kernels.
21
+ This class is intended to be used in combination with TritonScheduling,
22
+ and delegated to by CUDACombinedScheduling.
23
+
24
+ It handles fusion decisions and CUDA C++ specific template code generation.
25
+ """
26
+
27
+ def __init__(self, scheduler: Scheduler) -> None:
28
+ super().__init__()
29
+ self.scheduler = scheduler
30
+
31
+ @classmethod
32
+ def get_backend_features(cls, device):
33
+ return {}
34
+
35
+ def group_fn(self, sizes):
36
+ return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes)
37
+
38
+ @staticmethod
39
+ def is_cuda_cpp_template(node: BaseSchedulerNode) -> bool:
40
+ return isinstance(node, SchedulerNode) and isinstance(
41
+ node.node, CUDATemplateBuffer
42
+ )
43
+
44
+ def can_fuse_vertical(
45
+ self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
46
+ ) -> bool:
47
+ return False
48
+
49
+ def define_kernel(self, src_code: str, node_schedule) -> str:
50
+ wrapper = V.graph.wrapper_code
51
+ if src_code in wrapper.src_to_kernel:
52
+ kernel_name = wrapper.src_to_kernel[src_code]
53
+ else:
54
+ fused_name = (
55
+ get_fused_kernel_name(node_schedule, config.triton.descriptive_names)
56
+ if config.triton.descriptive_names
57
+ else ""
58
+ )
59
+ kernel_name = "_".join(["cuda", fused_name, wrapper.next_kernel_suffix()])
60
+ # use the original src_code as the key
61
+ wrapper.src_to_kernel[src_code] = kernel_name
62
+ src_code = src_code.replace("KERNEL_NAME", kernel_name)
63
+
64
+ _, _, kernel_path = get_path(code_hash(src_code), "py")
65
+
66
+ compile_wrapper = IndentedBuffer()
67
+ compile_wrapper.writeline("async_compile.cuda(r'''")
68
+ compile_wrapper.splice(src_code, strip=True)
69
+ compile_wrapper.writeline("''', 'so')")
70
+
71
+ metadata_comment = f"# kernel path: {kernel_path}"
72
+ origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
73
+ metadata_comment += "\n" + origins + "\n" + detailed_origins
74
+ wrapper.define_kernel(
75
+ kernel_name, compile_wrapper.getvalue(), metadata_comment
76
+ )
77
+ return kernel_name
78
+
79
+ def codegen_template(
80
+ self,
81
+ template_node: BaseSchedulerNode,
82
+ epilogue_nodes: Sequence[BaseSchedulerNode],
83
+ ):
84
+ """
85
+ Codegen a CUDA template, possibly with fused epilogues
86
+ """
87
+ counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes)
88
+ assert self.is_cuda_cpp_template(
89
+ template_node
90
+ ), "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer"
91
+ template_node = cast(SchedulerNode, template_node)
92
+ _, (numel, rnumel) = template_node.group
93
+ assert rnumel == 1
94
+ ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node)
95
+ kernel, render = ctb.make_kernel_render(ctb)
96
+ with kernel:
97
+ template_node.mark_run()
98
+ src_code = render()
99
+
100
+ with V.set_kernel_handler(kernel):
101
+ node_schedule = [template_node]
102
+ kernel_name = self.define_kernel(src_code, node_schedule)
103
+
104
+ # debug printing values of intermediate tensors
105
+ _, call_args, arg_signatures, _ = kernel.args.python_argdefs()
106
+ debug_printer_manager = V.graph.wrapper_code.debug_printer
107
+ debug_printer_manager.set_printer_args(
108
+ call_args, kernel_name, arg_signatures, kernel
109
+ )
110
+ with debug_printer_manager:
111
+ kernel.call_kernel(kernel_name, ctb)
112
+
113
+ V.graph.removed_buffers |= kernel.removed_buffers
114
+ self.scheduler.free_buffers()
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_env.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import logging
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+ from ... import config
8
+
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+
13
+ def get_cuda_arch() -> Optional[str]:
14
+ try:
15
+ cuda_arch = config.cuda.arch
16
+ if cuda_arch is None:
17
+ # Get Compute Capability of the first Visible device
18
+ major, minor = torch.cuda.get_device_capability(0)
19
+ return str(major * 10 + minor)
20
+ return str(cuda_arch)
21
+ except Exception as e:
22
+ log.error("Error getting cuda arch: %s", e)
23
+ return None
24
+
25
+
26
+ def get_cuda_version() -> Optional[str]:
27
+ try:
28
+ cuda_version = config.cuda.version
29
+ if cuda_version is None:
30
+ cuda_version = torch.version.cuda
31
+ return cuda_version
32
+ except Exception as e:
33
+ log.error("Error getting cuda version: %s", e)
34
+ return None
35
+
36
+
37
+ @functools.lru_cache(None)
38
+ def nvcc_exist(nvcc_path: str = "nvcc") -> bool:
39
+ if nvcc_path is None:
40
+ return False
41
+ import subprocess
42
+
43
+ res = subprocess.call(
44
+ ["which", nvcc_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
45
+ )
46
+ return res == 0
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import logging
3
+ from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
4
+
5
+ from ...autotune_process import CUDABenchmarkRequest
6
+ from ...ir import (
7
+ Buffer,
8
+ ChoiceCaller,
9
+ CUDATemplateBuffer,
10
+ IRNode,
11
+ Layout,
12
+ PrimitiveInfoType,
13
+ TensorBox,
14
+ )
15
+ from ...utils import sympy_product
16
+ from ...virtualized import V
17
+ from ..common import IndentedBuffer, Kernel, OpOverrides
18
+ from ..cpp_utils import CppPrinter, DTYPE_TO_CPP
19
+
20
+
21
+ if TYPE_CHECKING:
22
+ from torch._inductor.codegen.cuda.cuda_template import CUDATemplate
23
+
24
+ log = logging.getLogger(__name__)
25
+
26
+ cexpr = CppPrinter().doprint
27
+
28
+
29
+ def _normalize_idx(index: int, total_length: int) -> int:
30
+ return index if index >= 0 else index + total_length
31
+
32
+
33
+ class CUDAKernel(Kernel):
34
+ """
35
+ Baseclass for CUDA / Cutlass based Kernels
36
+ """
37
+
38
+ overrides = OpOverrides # type: ignore[assignment]
39
+
40
+
41
+ class CUDATemplateKernel(CUDAKernel):
42
+ """
43
+ Template kernels defined by CUDA / Cutlass in C++.
44
+ """
45
+
46
+ _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream"
47
+
48
+ def __init__(self, kernel_name) -> None:
49
+ """
50
+ Initializes a new instance of the CUDATemplateKernel class.
51
+
52
+ Args:
53
+ kernel_name (str): The name of the kernel.
54
+ """
55
+ super().__init__()
56
+ self.kernel_name = kernel_name
57
+ # Mapping from arg name to IRNode.
58
+ self.named_nodes: Dict[str, IRNode] = {}
59
+
60
+ def arg_name(self, node: IRNode) -> Optional[str]:
61
+ """
62
+ Returns arg name of a given input or output node.
63
+ """
64
+ if node is None:
65
+ return None
66
+ return {**self.args.input_buffers, **self.args.output_buffers}.get(
67
+ node.get_name(), None
68
+ )
69
+
70
+ def check_not_null(self, node: IRNode) -> str:
71
+ """
72
+ Generates code to check that a node is not null.
73
+ """
74
+
75
+ if node is None:
76
+ return ""
77
+
78
+ size_str = self.size(node, 0, -1)
79
+ name_str = self.arg_name(node)
80
+ if name_str is None:
81
+ return ""
82
+
83
+ res = IndentedBuffer(initial_indent=2)
84
+ res.tabwidth = 1
85
+ res.splice(
86
+ f"""
87
+ {{
88
+ if (!{name_str}) {{
89
+ int64_t {name_str}_size = {size_str};
90
+ if ({name_str}_size > 0) {{
91
+ throw std::runtime_error("input {name_str} is null but size is not 0!");
92
+ }}
93
+ }}
94
+ }}
95
+ """
96
+ )
97
+ return res.getvalue()
98
+
99
+ def def_kernel(
100
+ self,
101
+ inputs: List[IRNode],
102
+ outputs: List[IRNode],
103
+ names_str: str = "",
104
+ input_reorder: Optional[List[int]] = None,
105
+ ) -> str:
106
+ """
107
+ Hook called from template code to generate function definition and
108
+ needed args.
109
+
110
+ Args:
111
+ inputs: List of input IRNodes
112
+ outputs: List of output IRNodes
113
+ names_str: Comma separated list of input + output argument names.
114
+ input_reorder: The actual order of input nodes.
115
+ e.g. The template might have input argument defined as [X, W, Bias],
116
+ and the actual input passed into this template could be [Bias, X, W].
117
+ In this case, the `input_reorder` would be [2, 0, 1].
118
+ """
119
+
120
+ names = [x.strip() for x in names_str.strip().split(",")]
121
+ if len(inputs) + len(outputs) != len(names):
122
+ raise RuntimeError(
123
+ f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}"
124
+ )
125
+
126
+ if input_reorder is not None:
127
+ assert len(inputs) == len(input_reorder)
128
+ else:
129
+ input_reorder = list(range(len(inputs)))
130
+
131
+ for idx in input_reorder:
132
+ name = names[idx]
133
+ node = inputs[idx]
134
+ if node is not None:
135
+ self.named_nodes[name] = node
136
+ self.args.input_buffers[node.get_name()] = name
137
+
138
+ for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs):
139
+ if node is not None:
140
+ self.named_nodes[name] = node
141
+ self.args.output_buffers[node.get_name()] = name
142
+
143
+ arg_defs, *_ = self.args.cpp_argdefs()
144
+ return f"PT_EXPORT int {self.kernel_name}({', '.join(arg_defs)}, {self._EXTRA_CPP_ARGS})"
145
+
146
+ def call_kernel(
147
+ self,
148
+ name: str,
149
+ node: "CUDATemplateBuffer", # type: ignore[name-defined]
150
+ ) -> None:
151
+ """
152
+ Generates code to call the kernel through V.graph.wrapper_code.
153
+ used from within torch._inductor.wrapper.WrapperCodeGen
154
+
155
+ name: Name of kernel function.
156
+ node: The CUDATemplateBuffer node which contains information about the kernel, it's fused epilogue nodes
157
+ as well as all required inputs and outputs.
158
+ """
159
+ wrapper = V.graph.wrapper_code
160
+ _, call_args, _, arg_types = self.args.python_argdefs()
161
+ # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
162
+ for i in range(len(call_args)):
163
+ if V.graph.is_unspec_arg(call_args[i]):
164
+ call_args[i] = call_args[i] + ".item()"
165
+ else:
166
+ call_args[i] = f"c_void_p({call_args[i]}.data_ptr())"
167
+
168
+ # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size.
169
+ # workspace_size should have already been retrieved prior to this call.
170
+ call_args.append("None")
171
+
172
+ if node.get_workspace_size() > 0:
173
+ wrapper.generate_workspace_allocation(
174
+ node.get_workspace_size(), V.graph.scheduler.current_device, False
175
+ )
176
+ call_args.append("c_void_p(workspace.data_ptr())")
177
+ else:
178
+ call_args.append("None")
179
+
180
+ wrapper.generate_kernel_call(
181
+ name,
182
+ call_args,
183
+ cuda=True,
184
+ triton=False,
185
+ arg_types=arg_types,
186
+ )
187
+ if node.get_workspace_size() > 0:
188
+ wrapper.writeline(wrapper.make_free_by_names(["workspace"]))
189
+
190
+ def dtype(self, node: IRNode) -> Optional[str]:
191
+ """
192
+ Generates code which represents dtype of a given node.
193
+ """
194
+
195
+ if node is None:
196
+ return "void"
197
+ return DTYPE_TO_CPP.get(node.get_layout().dtype)
198
+
199
+ def cutlass_dtype(self, node: IRNode, default_dtype="void") -> Optional[str]:
200
+ # Helper method, called into from CUTLASSGemmTemplate
201
+ if node is None:
202
+ return default_dtype
203
+ from torch._inductor.codegen.cuda.cuda_template import CUTLASSTemplate
204
+
205
+ return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]
206
+
207
+ def max_valid_index(self, node: IRNode, default=-1):
208
+ # Helper method, called into from CUTLASSGemmTemplate
209
+ if node is None:
210
+ return default
211
+ max_valid_offset = 0
212
+ for i in range(len(node.get_size())):
213
+ max_valid_offset += (node.get_size()[i] - 1) * node.get_stride()[i]
214
+ return max_valid_offset
215
+
216
+ def offset(self, node: IRNode) -> str:
217
+ """
218
+ Generates code which represents offset of a given node.
219
+ """
220
+
221
+ if node is None:
222
+ return "0"
223
+ return str(node.get_layout().offset)
224
+
225
+ def ptr(self, node: IRNode) -> str:
226
+ """
227
+ Generates code which represents pointer of a given node.
228
+ """
229
+
230
+ if node is None:
231
+ return "nullptr"
232
+ arg_name = self.arg_name(node)
233
+ if arg_name is None:
234
+ return "nullptr"
235
+ offset = self.offset(node)
236
+ return arg_name if offset == "0" else f"{arg_name} + {offset}"
237
+
238
+ def size(
239
+ self,
240
+ node: IRNode,
241
+ start_index: int,
242
+ end_index: Optional[int] = None,
243
+ default_value: int = 0,
244
+ ) -> str:
245
+ """
246
+ Hook called from template code to get the size of an arg.
247
+ Generates code which represents size of a given node in [start_index, end_index).
248
+ If node is None, returns default_value.
249
+
250
+ TODO: Will add needed args to pass it in if it is dynamic.
251
+ """
252
+
253
+ if node is None:
254
+ return str(default_value)
255
+
256
+ start_index = _normalize_idx(start_index, len(node.get_size()))
257
+ if end_index is None:
258
+ end_index = start_index
259
+ end_index = _normalize_idx(end_index, len(node.get_size()))
260
+
261
+ sizes = node.get_size()[start_index : end_index + 1]
262
+ if len(sizes) == 0:
263
+ return str(default_value)
264
+
265
+ val = sympy_product(sizes)
266
+ return cexpr(self.rename_indexing(val))
267
+
268
+ def stride(self, node: IRNode, index: int, default_value: int = 0) -> str:
269
+ """
270
+ Hook called from template code to get the stride of an arg.
271
+ Generates code which represents stride of a given node at index.
272
+ If node is None, returns default_value.
273
+
274
+ TODO: Will add needed args to pass it in if it is dynamic.
275
+ """
276
+
277
+ if node is None:
278
+ return str(default_value)
279
+
280
+ index = _normalize_idx(index, len(node.get_size()))
281
+ if index < 0:
282
+ return str(default_value)
283
+
284
+ stride = node.get_stride()[index]
285
+ return cexpr(self.rename_indexing(stride))
286
+
287
+ def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str:
288
+ """
289
+ Hook called from template code to get the row or column stride of an arg.
290
+ This is required by some CUTLASS 2.X APIs.
291
+ If the node is in row_major, it returns stride[-2].
292
+ If the node is in column_major, it returns stride[-1].
293
+
294
+ TODO: Will add needed args to pass it in if it is dynamic.
295
+ """
296
+
297
+ if node is None or len(node.get_stride()) < 2:
298
+ return str(default_value)
299
+
300
+ stride0 = node.get_stride()[-1]
301
+ stride1 = node.get_stride()[-2]
302
+ if stride0 == 1:
303
+ return cexpr(self.rename_indexing(stride1))
304
+ elif stride1 == 1:
305
+ return cexpr(self.rename_indexing(stride0))
306
+ else:
307
+ raise RuntimeError(
308
+ f"At least 1 stride should be 1. Strides: {node.get_stride()=}"
309
+ )
310
+
311
+
312
+ class CUDATemplateCaller(ChoiceCaller):
313
+ """
314
+ CUDATemplateCaller
315
+
316
+ This class represents a caller for CUDA template kernels. It is a subclass of ChoiceCaller.
317
+ Attributes:
318
+ name (str): The name of the caller.
319
+ category (str): The category of the caller.
320
+ bmreq (CUDABenchmarkRequest): The benchmark request for the caller.
321
+ template_buffer (CUDATemplateBuffer): The template buffer for the caller.
322
+ """
323
+
324
+ def __init__(
325
+ self,
326
+ name: str,
327
+ category: str,
328
+ input_nodes: List[Buffer],
329
+ layout: Layout,
330
+ make_kernel_render: Callable[[CUDATemplateBuffer, Optional[List[IRNode]]], str],
331
+ bmreq: CUDABenchmarkRequest,
332
+ template: "CUDATemplate", # type: ignore[name-defined]
333
+ info_kwargs: Optional[Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]], # type: ignore[type-arg]
334
+ ) -> None:
335
+ super().__init__(name, input_nodes, layout)
336
+ self.category = category
337
+ self.make_kernel_render = make_kernel_render
338
+ self.bmreq = bmreq
339
+ self.template = template
340
+ self.info_kwargs = info_kwargs
341
+
342
+ def precompile(self) -> None:
343
+ assert self.bmreq is not None
344
+ self.bmreq.precompile()
345
+
346
+ def benchmark(self, *args, out) -> float:
347
+ assert self.bmreq is not None
348
+ return self.bmreq.benchmark(
349
+ *args, output_tensor=out
350
+ ) # @TODO: Hack for ensuring that Cutlass Kernel is preferred
351
+
352
+ def __str__(self) -> str:
353
+ return f"CUDATemplateCaller(source_file={self.bmreq.source_file})"
354
+
355
+ def call_name(self) -> str:
356
+ return f"cuda_template_kernels.{self.name}"
357
+
358
+ def hash_key(self) -> str:
359
+ return "-".join(
360
+ [
361
+ self.category,
362
+ self.bmreq.hash_key,
363
+ ]
364
+ )
365
+
366
+ def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
367
+ """Information returned here is logged to the autotune log file when that is enabled."""
368
+ if self.info_kwargs is not None and "op" in self.info_kwargs:
369
+ op: Any = self.info_kwargs["op"]
370
+ return {
371
+ "backend": "CUDA",
372
+ "op_type": type(op).__name__,
373
+ "op_conf_name": str(op.configuration_name()),
374
+ "op_arch": str(op.arch),
375
+ "tile_shape": str(op.tile_description.tile_shape),
376
+ "epilogue_schedule": str(op.epilogue_schedule),
377
+ "kernel_schedule": str(op.kernel_schedule),
378
+ "element_accumulator": str(op.accumulator_type()),
379
+ "op_name": str(op.procedural_name()),
380
+ "instruction_shape": str(
381
+ op.tile_description.math_instruction.instruction_shape
382
+ ),
383
+ }
384
+ else:
385
+ return {"backend": "CUDA", "op_type": "unknown"}
386
+
387
+ def output_node(self) -> TensorBox:
388
+ self.bmreq.update_workspace_size()
389
+ return TensorBox.create(
390
+ CUDATemplateBuffer(
391
+ layout=self.layout,
392
+ inputs=self.input_nodes,
393
+ make_kernel_render=self.make_kernel_render,
394
+ workspace_size=self.bmreq.workspace_size,
395
+ template=self.template,
396
+ )
397
+ )
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_template.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+ import itertools
4
+ import logging
5
+ from typing import List, Optional
6
+ from unittest.mock import patch
7
+
8
+ import sympy
9
+
10
+ import torch
11
+
12
+ from ...autotune_process import CUDABenchmarkRequest, TensorMeta
13
+ from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout
14
+ from ...utils import IndentedBuffer, unique
15
+ from ...virtualized import V
16
+ from ..common import KernelTemplate
17
+ from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel
18
+
19
+
20
+ log = logging.getLogger(__name__)
21
+
22
+
23
+ class CUDATemplate(KernelTemplate):
24
+ index_counter = itertools.count()
25
+
26
+ def __init__(
27
+ self,
28
+ name: str,
29
+ input_nodes: List[Buffer],
30
+ layout: Layout,
31
+ input_reorder: Optional[List[int]] = None,
32
+ ) -> None:
33
+ """
34
+
35
+ Baseclass for CUDA C++ Templates, derived from KernelTemplate. Not to be instantiated directly.
36
+
37
+ Args:
38
+ name (str): The name of the CUDATemplate object.
39
+ input_nodes (List[IRNode]): A list of input IRNodes.
40
+ layout (Layout): The layout of the output buffer / tensor.
41
+ input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes.
42
+
43
+ """
44
+ super().__init__(name)
45
+ self.input_nodes = input_nodes
46
+ self.output_node: Buffer = Buffer("buf_out", layout)
47
+ self.input_reorder = input_reorder
48
+ self.layout = layout
49
+
50
+ def generate( # type: ignore[override]
51
+ self,
52
+ **kwargs,
53
+ ) -> CUDATemplateCaller:
54
+ """
55
+ Generates the CUDA template caller object for the given GEMM template and operation. This CUDATemplateCaller
56
+ may be used to call and benchmark the generated CUDA kernel in a standalone manner to enable Autotuning.
57
+
58
+ Args:
59
+ kwargs: Additional keyword arguments.
60
+
61
+ Returns:
62
+ A CUDATemplateCaller object representing the generated CUDA template caller.
63
+ """
64
+ kernel_name = f"cuda_{self.name}"
65
+ with patch.object(
66
+ V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
67
+ ), CUDATemplateKernel(
68
+ kernel_name=kernel_name,
69
+ ) as kernel:
70
+ code = self.render(kernel=kernel, **kwargs)
71
+ _, call_args, _, _ = kernel.args.python_argdefs()
72
+ log.debug("Generated Code:\n%s", code)
73
+ log.debug(
74
+ "Args: cpp_argdefs: %s, python_argdefs: %s",
75
+ kernel.args.cpp_argdefs(),
76
+ kernel.args.python_argdefs(),
77
+ )
78
+
79
+ input_reorder = (
80
+ self.input_reorder
81
+ if self.input_reorder is not None
82
+ else list(range(len(self.input_nodes)))
83
+ )
84
+ expected_args = list(
85
+ unique(self.input_nodes[idx].get_name() for idx in input_reorder)
86
+ )
87
+ expected_args.extend([self.output_node.get_name()])
88
+ assert list(call_args)[: len(expected_args)] == expected_args, (
89
+ call_args,
90
+ expected_args,
91
+ )
92
+ extra_args = V.graph.sizevars.size_hints(
93
+ map(sympy.expand, call_args[len(expected_args) :])
94
+ )
95
+
96
+ kernel_hash_name = f"cuda_{self.name}_{next(self.index_counter)}"
97
+
98
+ # create the BenchmarkRequest
99
+ bmreq = CUDABenchmarkRequest(
100
+ kernel_name=kernel_name,
101
+ input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
102
+ output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
103
+ extra_args=extra_args,
104
+ source_code=code,
105
+ )
106
+
107
+ def make_kernel_render(
108
+ template_node: CUDATemplateBuffer,
109
+ epilogue_nodes: Optional[List[IRNode]] = None,
110
+ ):
111
+ kernel = CUDATemplateKernel(
112
+ kernel_name="KERNEL_NAME",
113
+ )
114
+ render = functools.partial(
115
+ self.render,
116
+ kernel=kernel,
117
+ template_buffer_node=template_node,
118
+ epilogue_nodes=epilogue_nodes,
119
+ **kwargs, # includes "op" argument in case of CUTLASSGemmTemplate
120
+ )
121
+ return kernel, render
122
+
123
+ return CUDATemplateCaller(
124
+ kernel_hash_name,
125
+ self.name,
126
+ self.input_nodes,
127
+ self.output_node.get_layout(),
128
+ make_kernel_render,
129
+ bmreq,
130
+ self,
131
+ kwargs,
132
+ )
133
+
134
+ def header(self) -> IndentedBuffer:
135
+ res = IndentedBuffer()
136
+ res.splice(
137
+ """
138
+ #include <exception>
139
+ #include <iostream>
140
+ #include <memory>
141
+ #include <random>
142
+ #include <vector>
143
+ """
144
+ )
145
+ return res
146
+
147
+ def globals(self) -> IndentedBuffer:
148
+ res = IndentedBuffer()
149
+ res.splice(
150
+ """
151
+ // We compile all models with -fvisibility=hidden. Any symbols that need to be
152
+ // exposed in the final shared library must be declared with PT_EXPORT to make
153
+ // them visible.
154
+ #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
155
+ #define PT_EXPORT __attribute__((__visibility__("default")))
156
+ #else
157
+ #ifdef _WIN32
158
+ #define PT_EXPORT __declspec(dllexport)
159
+ #else
160
+ #define PT_EXPORT
161
+ #endif
162
+ #endif
163
+ using bfloat16 = nv_bfloat16;
164
+ """
165
+ )
166
+ return res
167
+
168
+ def render(self, **kwargs) -> str:
169
+ raise NotImplementedError
170
+
171
+
172
+ class CUTLASSTemplate(CUDATemplate):
173
+ """
174
+ CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the
175
+ CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels.
176
+ """
177
+
178
+ def header(self) -> IndentedBuffer:
179
+ res = super().header()
180
+ res.splice(
181
+ """
182
+ #include "cute/tensor.hpp"
183
+ #include "cutlass/cutlass.h"
184
+ #include "cutlass/numeric_types.h"
185
+ #include "cutlass/tensor_ref.h"
186
+ #include "cutlass/util/host_tensor.h"
187
+ #include "cutlass/util/reference/host/tensor_fill.h"
188
+ #include "cutlass/util/reference/device/tensor_fill.h"
189
+ #include "cutlass/util/device_memory.h"
190
+ """
191
+ )
192
+ return res
193
+
194
+ def globals(self) -> IndentedBuffer:
195
+ res = super().globals()
196
+ res.splice(
197
+ """
198
+ using namespace cute;
199
+ #define CUTLASS_CHECK(status) \\
200
+ { \\
201
+ cutlass::Status error = status; \\
202
+ if (error != cutlass::Status::kSuccess) { \\
203
+ auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\
204
+ cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\
205
+ throw std::runtime_error(msg); \\
206
+ } \\
207
+ }
208
+
209
+ // Used as pass-through functor in EVT just for type casting / rounding
210
+ template <typename T>
211
+ struct identity_op {
212
+ CUTLASS_HOST_DEVICE
213
+ T operator()(T val) const { return val; }
214
+ };
215
+
216
+ """
217
+ )
218
+ return res
219
+
220
+ def cute_int(self, int_str: str, var_name: str) -> str:
221
+ res = ""
222
+ if int_str in {"1", "1L"}:
223
+ res = "cute::Int<1>{}"
224
+ else:
225
+ res = int_str
226
+
227
+ return f"{res} /* {var_name} */"
228
+
229
+ _DTYPE_TO_CUTLASS = {
230
+ torch.float32: "float",
231
+ torch.float64: "double",
232
+ torch.float16: "cutlass::half_t",
233
+ torch.int32: "int32_t",
234
+ torch.int16: "int16_t",
235
+ torch.int8: "int8_t",
236
+ torch.uint8: "uint8_t",
237
+ torch.bool: "bool",
238
+ torch.bfloat16: "cutlass::bfloat16_t",
239
+ }
240
+
241
+ _DTYPE_TO_CUTLASS_SPARSE_META = {
242
+ torch.int32: "uint32_t",
243
+ torch.int16: "uint16_t",
244
+ }
245
+
246
+ def cutlass_type_cast(self, node: IRNode, ptr: str) -> str:
247
+ if node is None:
248
+ return ptr
249
+ else:
250
+ return f"({self._DTYPE_TO_CUTLASS.get(node.get_dtype())}*)({ptr})"
251
+
252
+ def cutlass_sparse_meta_type_cast(self, node: IRNode, ptr: str) -> str:
253
+ if node is None:
254
+ return ptr
255
+ else:
256
+ return (
257
+ f"({self._DTYPE_TO_CUTLASS_SPARSE_META.get(node.get_dtype())}*)({ptr})"
258
+ )
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from typing import Dict, List
3
+ from unittest.mock import patch
4
+
5
+ import sympy
6
+
7
+ import torch._inductor.virtualized as virtualized
8
+ from torch._inductor.ir import ComputedBuffer, FlexibleLayout, IRNode, Pointwise
9
+ from torch._inductor.utils import IndentedBuffer, sympy_str
10
+
11
+
12
+ # Used as a magic string to indicate an unsupported sympy expression
13
+ # became part of generated C++ code.
14
+ _MAGIC_SYMPY_ERROR_STRING = "[!sympy: unsupported expr!]"
15
+
16
+
17
+ def _arg_str(a):
18
+ if isinstance(a, sympy.Expr):
19
+ # If this return value containing the _MAGIC_SYMPY_ERROR_STRING
20
+ # is used as part of the final generated C++ code,
21
+ # a CUTLASSEVTOpNotImplementedError is raised to indicate that
22
+ # the op could not be converted to a valid EVT expression.
23
+ return f"{_MAGIC_SYMPY_ERROR_STRING}('{sympy_str(a)}')"
24
+ return str(a)
25
+
26
+
27
+ class CUTLASSEVTOpNotImplementedError(NotImplementedError):
28
+ pass
29
+
30
+
31
+ class CutlassEVTEpilogueTypeFormatter:
32
+ """
33
+ Codegen class, which provides an entry point to generate
34
+ Cutlass "Epilogue Visitor Tree" (EVT) functor declarations.
35
+
36
+ See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder
37
+ for more about EVTs and how they are declared and used to generate.
38
+
39
+ Notes:
40
+ * Used by CUTLASSGemmTemplate.
41
+ * This class should not be instantiated by users, it is intended to be used
42
+ by calling CutlassEVTEpilogueTypeFormatter.ir_to_evt_string(...)
43
+ which instantiates this class as an ops handler for virtualized.V.ops.[op-name]
44
+ * Extend this with more _op_<whatever> nodes to add support for new pointwise operations.
45
+
46
+
47
+ """
48
+
49
+ def __init__(self, accumulator_node_name, evt_type_name):
50
+ """
51
+
52
+ Initialize an instance of CutlassEVTEpilogueTypeFormatter.
53
+
54
+ Parameters:
55
+ - accumulator_node_name (str): The name of the output Buffer for the GEMM operation in the original (unfused)
56
+ IR graph.
57
+ - evt_type_name (str): The output name of the EVT type we are generating.
58
+
59
+ """
60
+ self.accumulator_node_name = accumulator_node_name
61
+ self.output = IndentedBuffer(0)
62
+ self.var_counter = 0
63
+ self.evt_type_name = evt_type_name
64
+ self.aliases = {}
65
+
66
+ @staticmethod
67
+ def ir_to_evt_string(
68
+ template_output_node_name: str,
69
+ evt_type_name: str,
70
+ epilogue_nodes: List[IRNode],
71
+ ):
72
+ """
73
+ Formats IR nodes into a string representation compatible with Cutlass EVT format.
74
+
75
+ Args:
76
+ template_output_node_name (str): The name of the template output node.
77
+ evt_type_name (str): The name of the EVT type.
78
+ epilogue_nodes (List[IRNode]): A list of IR nodes representing the epilogue nodes. As of now, these must be
79
+ ComputedBuffer nodes wrapping Pointwise nodes.
80
+
81
+ Returns:
82
+ A string representation of the IR nodes formatted according to the Cutlass EVT format.
83
+ """
84
+ formatter = CutlassEVTEpilogueTypeFormatter(
85
+ template_output_node_name, evt_type_name
86
+ )
87
+
88
+ with virtualized.V.set_ops_handler(formatter), patch.object(
89
+ FlexibleLayout, "allow_indexing", True
90
+ ):
91
+ for node in epilogue_nodes:
92
+ if isinstance(node, ComputedBuffer):
93
+ pnode = node.data
94
+ else:
95
+ raise RuntimeError(
96
+ "Epilogue nodes must be Pointwise nodes, wrapped in a named ComputedBuffer"
97
+ )
98
+ assert isinstance(pnode, Pointwise)
99
+ index = pnode._index(pnode.ranges)
100
+ result = pnode.inner_fn(index)
101
+ # each epilogue node results in a single "using" statement and may refer to the previous steps by name
102
+ formatter.aliases[node.name] = result
103
+ res = formatter.getvalue(result) # type: ignore[possibly-undefined]
104
+ if _MAGIC_SYMPY_ERROR_STRING in res:
105
+ raise CUTLASSEVTOpNotImplementedError(
106
+ "sympy / indexing expressions not yet supported in EVT fusion"
107
+ )
108
+ else:
109
+ return res
110
+
111
+ def __getattr__(self, name):
112
+ """
113
+ Resolve V.ops.<whatever> calls, after this instance has been installed as V.ops handler.
114
+ """
115
+
116
+ def inner(*args, **kwargs):
117
+ fargs = [_arg_str(a) for a in args]
118
+ fkwargs = {key: _arg_str(a) for key, a in kwargs.items()}
119
+ fn = getattr(self, f"_op_{name}")
120
+ line = fn(*fargs, **fkwargs)
121
+ self.var_counter += 1
122
+ varname = f"EVT_expr_{self.var_counter}"
123
+ # replace line with a new variable name
124
+ self.output.writeline(f"using {varname} = {line};")
125
+ return varname
126
+
127
+ if name.startswith("_"):
128
+ raise CUTLASSEVTOpNotImplementedError(name)
129
+ if hasattr(self, f"_op_{name}"):
130
+ return inner
131
+ else:
132
+ raise CUTLASSEVTOpNotImplementedError(name)
133
+
134
+ def _op_load(self, name, index_expr):
135
+ # Load an input to an operation. Might be the output of the matmul, the result
136
+ # of a previous epilogue node, a constant or (TODO) an auxiliary input.
137
+ if name == self.accumulator_node_name:
138
+ return f"cutlass::epilogue::fusion::Sm90AccFetch /* :={name} (matmul output in accumulator) */"
139
+ elif name in self.aliases:
140
+ return self.aliases[name]
141
+ else:
142
+ # return f"cutlass::epilogue::fusion::Sm90SrcFetch /* :={name} */"
143
+ raise CUTLASSEVTOpNotImplementedError(
144
+ f"Operand {name} not found. Auxiliary inputs not supported yet."
145
+ )
146
+
147
+ def _op_constant(self, value, dtype):
148
+ # Load a constant
149
+ if str(dtype) in ("torch.float16", "torch.float32"):
150
+ return f"cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc> /* value={value}, dtype={dtype} */"
151
+ else:
152
+ raise CUTLASSEVTOpNotImplementedError(
153
+ f"Unsupported dtype for constant: {dtype}"
154
+ )
155
+
156
+ def _cutlass_binary_functional_op(self, op, a, b):
157
+ # Perform a named operation on two inputs
158
+ # see https://github.com/NVIDIA/cutlass/blob/6407bcdf0a24097b7b016ee105937693c62f9923/include/cutlass/functional.h for ops
159
+ return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::{op}, ElementAcc, ElementAcc, RoundStyle>,{a},{b}>" # noqa: B950
160
+
161
+ def _convert_to_output_dtype(self, a):
162
+ # Convert the final output to the dtype of the output buffer
163
+ return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<identity_op, ElementD, ElementAcc, RoundStyle>,{a}>" # noqa: B950
164
+
165
+ def _op_to_dtype(self, a, *args, **kwargs):
166
+ # no-op in our case, since we convert to the output dtype at the end and convert everything to the accumulator
167
+ # dtype.
168
+ # Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible
169
+ # throughout the fusion chain.
170
+ return a # noqa: B950
171
+
172
+ def _op_mul(self, a, b):
173
+ return self._cutlass_binary_functional_op("multiplies", a, b)
174
+
175
+ def _op_div(self, a, b):
176
+ return self._cutlass_binary_functional_op("divides", a, b)
177
+
178
+ def _op_truediv(self, a, b):
179
+ return self._cutlass_binary_functional_op("divides", a, b)
180
+
181
+ def _op_ge(self, a, b):
182
+ return self._cutlass_binary_functional_op("greater_equal", a, b)
183
+
184
+ def _op_add(self, a, b):
185
+ return self._cutlass_binary_functional_op("plus", a, b)
186
+
187
+ def _op_sub(self, a, b):
188
+ return self._cutlass_binary_functional_op("minus", a, b)
189
+
190
+ def _op_minimum(self, a, b):
191
+ return self._cutlass_binary_functional_op("minimum", a, b)
192
+
193
+ def _op_maximum(self, a, b):
194
+ return self._cutlass_binary_functional_op("maximum", a, b)
195
+
196
+ def _op_relu(self, a):
197
+ const_zero = self._op_constant(0.0, "torch.float32")
198
+ return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::maximum, ElementAcc, ElementAcc, RoundStyle>,{a}, {const_zero}>" # noqa: B950
199
+
200
+ def reduction(self, dtype, src_dtype, reduction_type, value):
201
+ raise CUTLASSEVTOpNotImplementedError
202
+
203
+ # Add more ops here...
204
+ def getvalue(self, result) -> str:
205
+ # Return final result
206
+ dtype_converted_expr = self._convert_to_output_dtype(
207
+ f"EVT_expr_{self.var_counter}"
208
+ )
209
+ self.output.writeline(f"using {self.evt_type_name} = {dtype_converted_expr};")
210
+ return self.output.getvalue()
211
+
212
+
213
+ class CutlassEVTEpilogueArgumentFormatter:
214
+ """
215
+ Codegen class, which provides an entry point to generate
216
+ Cutlass "Epilogue Visitor Tree" (EVT) Argument initializers
217
+
218
+ See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder
219
+ for more about EVTs and how they are declared and used to generate.
220
+
221
+ Notes:
222
+ * Used by CUTLASSGemmTemplate.
223
+ * This class should not be instantiated by users, it is intended to be used
224
+ by calling CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string(...)
225
+ which instantiates this class as an ops handler for virtualized.V.ops.[op-name]
226
+ * Extend this with more _op_<whatever> nodes to add support for new pointwise operations.
227
+
228
+
229
+ """
230
+
231
+ def __init__(self, accumulator_node_name: str):
232
+ """
233
+
234
+ Initializes a CutlassEVTEpilogueArgumentFormatter object. Do not instantiate directly.
235
+ Use the CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string static method.
236
+
237
+ Args:
238
+ accumulator_node_name (str): The name of the accumulator node which should contain
239
+ the Matmul result before fusion according to the IR graph.
240
+ """
241
+ self.accumulator_node_name: str = accumulator_node_name #
242
+ self.output: IndentedBuffer = IndentedBuffer(0) # The output buffer for codegen
243
+ self.var_counter: int = (
244
+ 0 # used to generate variable names, incremented for each new variable
245
+ )
246
+ self.aliases: Dict[str, str] = {} # Aliases for subexpression functors
247
+
248
+ @staticmethod
249
+ def ir_to_evt_argument_string(
250
+ template_output_node_name: str,
251
+ epilogue_nodes: List[IRNode],
252
+ ) -> str:
253
+ formatter = CutlassEVTEpilogueArgumentFormatter(
254
+ template_output_node_name,
255
+ )
256
+
257
+ with virtualized.V.set_ops_handler(formatter), patch.object(
258
+ FlexibleLayout, "allow_indexing", True
259
+ ):
260
+ for node in epilogue_nodes:
261
+ assert isinstance(node, ComputedBuffer)
262
+ pnode = node.data
263
+ assert isinstance(pnode, Pointwise)
264
+ index = pnode._index(pnode.ranges)
265
+ result = pnode.inner_fn(index)
266
+ # each epilogue node results in a single "using" statement and may refer to the previous steps by name
267
+ if node.name is not None:
268
+ formatter.aliases[node.name] = result
269
+
270
+ res: str = formatter.getvalue(result) # type: ignore[possibly-undefined]
271
+ if _MAGIC_SYMPY_ERROR_STRING in res:
272
+ raise CUTLASSEVTOpNotImplementedError(
273
+ "sympy / indexing expressions not yet supported in EVT fusion"
274
+ )
275
+ else:
276
+ return res
277
+
278
+ def __getattr__(self, name):
279
+ def inner(*args, **kwargs):
280
+ fargs = [_arg_str(a) for a in args]
281
+ fkwargs = {key: _arg_str(a) for key, a in kwargs.items()}
282
+ fn = getattr(self, f"_op_{name}")
283
+ line = fn(*fargs, **fkwargs)
284
+ return line
285
+
286
+ if name.startswith("_"):
287
+ raise CUTLASSEVTOpNotImplementedError(name)
288
+
289
+ if hasattr(self, f"_op_{name}"):
290
+ return inner
291
+ else:
292
+ raise CUTLASSEVTOpNotImplementedError(name)
293
+
294
+ def _op_load(self, name, index_expr):
295
+ if name == self.accumulator_node_name:
296
+ return "{}"
297
+ elif name in self.aliases:
298
+ return self.aliases[name]
299
+ else:
300
+ raise CUTLASSEVTOpNotImplementedError(
301
+ f"Operand {name} not found. Auxiliary inputs not supported yet."
302
+ )
303
+
304
+ def _op_constant(self, value, dtype):
305
+ if str(dtype) in ("torch.float16", "torch.float32"):
306
+ return "{ static_cast<ElementAcc>(" + str(value) + ") }"
307
+ else:
308
+ raise CUTLASSEVTOpNotImplementedError(
309
+ f"Unsupported dtype for constant: {dtype}"
310
+ )
311
+
312
+ def _cutlass_binary_functional_op(self, op, a, b):
313
+ return f"{{ /*{op}: */ {a}, {b} }}"
314
+
315
+ def _op_mul(self, a, b):
316
+ return self._cutlass_binary_functional_op("multiplies", a, b)
317
+
318
+ def _op_div(self, a, b):
319
+ return self._cutlass_binary_functional_op("divides", a, b)
320
+
321
+ def _op_truediv(self, a, b):
322
+ return self._cutlass_binary_functional_op("divides", a, b)
323
+
324
+ def _op_ge(self, a, b):
325
+ return self._cutlass_binary_functional_op("greater_equal", a, b)
326
+
327
+ def _op_add(self, a, b):
328
+ return self._cutlass_binary_functional_op("plus", a, b)
329
+
330
+ def _op_sub(self, a, b):
331
+ return self._cutlass_binary_functional_op("minus", a, b)
332
+
333
+ def _op_minimum(self, a, b):
334
+ return self._cutlass_binary_functional_op("minimum", a, b)
335
+
336
+ def _op_maximum(self, a, b):
337
+ return self._cutlass_binary_functional_op("maximum", a, b)
338
+
339
+ def _op_relu(self, a):
340
+ const_zero = self._op_constant(0.0, "torch.float32")
341
+ return "{" + str(a) + ", " + const_zero + "}"
342
+
343
+ def _op_to_dtype(self, a, dtype, src_dtype=None):
344
+ # Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible
345
+ # throughout the fusion chain.
346
+ assert dtype in (
347
+ "torch.float32",
348
+ "torch.float16",
349
+ ), f"Unsupported dtype: {dtype}"
350
+ assert src_dtype in (
351
+ None,
352
+ "torch.float32",
353
+ "torch.float16",
354
+ ), f"Unsupported source dtype: {src_dtype}"
355
+ return a
356
+
357
+ def reduction(self, dtype, src_dtype, reduction_type, value):
358
+ raise CUTLASSEVTOpNotImplementedError
359
+
360
+ def getvalue(self, result) -> str:
361
+ return "{" + str(result) + "}"
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (224 Bytes). View file