koichi12 commited on
Commit
70fbf20
·
verified ·
1 Parent(s): a2ec7d8

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ir.cpython-311.pyc +3 -0
  3. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-311.pyc +3 -0
  5. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/simd.cpython-311.pyc +3 -0
  6. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton.cpython-311.pyc +3 -0
  7. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc +3 -0
  8. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp +354 -0
  9. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__init__.py +0 -0
  10. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__pycache__/__init__.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__pycache__/device_op_overrides.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/device_op_overrides.py +19 -0
  13. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__init__.py +0 -0
  14. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/ddp_fusion.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/b2b_gemm.py +746 -0
  16. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/binary_folding.py +276 -0
  17. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/ddp_fusion.py +599 -0
  18. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py +153 -0
  19. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py +80 -0
  20. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py +406 -0
  21. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/freezing_patterns.py +227 -0
  22. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/fuse_attention.py +909 -0
  23. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py +1317 -0
  24. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/joint_graph.py +694 -0
  25. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/micro_pipeline_tp.py +854 -0
  26. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/misc_patterns.py +131 -0
  27. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py +1266 -0
  28. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/numeric_utils.py +212 -0
  29. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pad_mm.py +881 -0
  30. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/post_grad.py +1318 -0
  31. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pre_grad.py +800 -0
  32. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/quantization.py +2589 -0
  33. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/reinplace.py +688 -0
  34. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/replace_random.py +145 -0
  35. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py +0 -0
  36. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_18.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_19.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -137,3 +137,8 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
137
  .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
138
  .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/lowering.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
139
  .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
137
  .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
138
  .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/lowering.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
139
  .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
140
+ .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ir.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
141
+ .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
142
+ .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/simd.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
143
+ .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
144
+ .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ir.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:092aa1e8b674926d96609f7d70e837e88bb2433dce56bfb3b265696082850bf7
3
+ size 361762
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (196 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:088ceca24b4ba43a80ac34a889d96ba95ca4779739aea2e41a34600e2f7fd8ae
3
+ size 103679
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/simd.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9dde4c92669d913e3e0a7bf7bbc82b533f3e5492ce09ea3322607c4c9cec549
3
+ size 106873
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:454c0716087aee149fe6ab1aaf1a2f50703e82c70c9edf9bd21bb618627510ad
3
+ size 176338
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d043ca807b2f2cb8a6fa93a9c0fe15f2091627b58bf0206323eb3d52b7c26cc
3
+ size 122305
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h>
2
+ #include <torch/csrc/inductor/aoti_runtime/interface.h>
3
+ #include <torch/csrc/inductor/aoti_runtime/model_container.h>
4
+ #include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h>
5
+ #include <torch/csrc/inductor/aoti_runtime/thread_local.h>
6
+
7
+ #include <iostream>
8
+ #include <sstream>
9
+ #include <stdexcept>
10
+ #include <vector>
11
+
12
+ #define CONVERT_EXCEPTION_TO_ERROR_CODE(...) \
13
+ try { \
14
+ __VA_ARGS__ \
15
+ } catch (const std::exception& e) { \
16
+ std::cerr << "Error: " << e.what() << std::endl; \
17
+ return AOTI_RUNTIME_FAILURE; \
18
+ } catch (...) { \
19
+ std::cerr << "Unknown exception occurred." << std::endl; \
20
+ return AOTI_RUNTIME_FAILURE; \
21
+ } \
22
+ return AOTI_RUNTIME_SUCCESS;
23
+
24
+ #define AOTI_VECTOR_SIZE_CHECK(actual_size, expected_size, name) \
25
+ do { \
26
+ AOTI_RUNTIME_CHECK( \
27
+ actual_size == expected_size, \
28
+ "expected " + std::string(name) + " vector size to be " + \
29
+ std::to_string(expected_size) + ", but got " + \
30
+ std::to_string(actual_size)); \
31
+ } while (0)
32
+
33
+ // AOTInductor uses at::addmm_out, which doesn't supports
34
+ // arguments that requires gradient. For this reason, we
35
+ // enforce no_grad context for run APIs.
36
+ //
37
+ // A RAII, thread local (!) guard that enables or disables grad mode upon
38
+ // construction, and sets it back to the original value upon destruction.
39
+ struct AOTINoGradGuard {
40
+ AOTINoGradGuard() : prev_mode(aoti_torch_grad_mode_is_enabled()) {
41
+ aoti_torch_grad_mode_set_enabled(false);
42
+ }
43
+ ~AOTINoGradGuard() {
44
+ aoti_torch_grad_mode_set_enabled(prev_mode);
45
+ }
46
+ bool prev_mode;
47
+ };
48
+
49
+ extern "C" {
50
+
51
+ AOTIRuntimeError AOTInductorModelContainerCreate(
52
+ AOTInductorModelContainerHandle* container_handle,
53
+ size_t num_models,
54
+ bool is_cpu,
55
+ const char* cubin_dir) {
56
+ return AOTInductorModelContainerCreateWithDevice(
57
+ container_handle,
58
+ num_models,
59
+ is_cpu ? "cpu" : "cuda",
60
+ cubin_dir);
61
+ }
62
+
63
+ AOTIRuntimeError AOTInductorModelContainerCreateWithDevice(
64
+ AOTInductorModelContainerHandle* container_handle,
65
+ size_t num_models,
66
+ const char* device_str,
67
+ const char* cubin_dir) {
68
+ if (num_models == 0) {
69
+ std::cerr << "Error: num_models must be positive, but got 0" << std::endl;
70
+ return AOTI_RUNTIME_FAILURE;
71
+ }
72
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
73
+ std::optional<std::string> cubin_dir_opt;
74
+ if (cubin_dir != nullptr) {
75
+ cubin_dir_opt.emplace(cubin_dir);
76
+ }
77
+ auto* container = new torch::aot_inductor::AOTInductorModelContainer(
78
+ num_models, std::string(device_str), cubin_dir_opt);
79
+ *container_handle =
80
+ reinterpret_cast<AOTInductorModelContainerHandle>(container);
81
+ })
82
+ }
83
+
84
+ AOTIRuntimeError AOTInductorModelContainerDelete(
85
+ AOTInductorModelContainerHandle container_handle) {
86
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
87
+ auto* container =
88
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
89
+ container_handle);
90
+ delete container;
91
+ });
92
+ }
93
+
94
+ AOTIRuntimeError AOTInductorModelContainerRun(
95
+ AOTInductorModelContainerHandle container_handle,
96
+ AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles
97
+ // are stolen; the array itself is borrowed
98
+ size_t num_inputs,
99
+ AtenTensorHandle*
100
+ output_handles, // array for writing output AtenTensorHandle; handles
101
+ // will be stolen by the caller; the array itself is
102
+ // borrowed
103
+ size_t num_outputs,
104
+ AOTInductorStreamHandle stream_handle,
105
+ AOTIProxyExecutorHandle proxy_executor_handle) {
106
+ auto* container =
107
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
108
+ container_handle);
109
+ AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs");
110
+ AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs");
111
+
112
+ auto stream =
113
+ reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
114
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
115
+ AOTINoGradGuard guard;
116
+ container->run(
117
+ input_handles, output_handles, stream, proxy_executor_handle);
118
+ })
119
+ }
120
+
121
+ AOTIRuntimeError AOTInductorModelContainerGetNumConstants(
122
+ AOTInductorModelContainerHandle container_handle,
123
+ size_t* num_constants) {
124
+ auto* container =
125
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
126
+ container_handle);
127
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
128
+ { *num_constants = container->num_constants(); })
129
+ }
130
+
131
+ AOTIRuntimeError AOTInductorModelContainerGetConstantName(
132
+ AOTInductorModelContainerHandle container_handle,
133
+ size_t idx,
134
+ const char** name) {
135
+ auto* container =
136
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
137
+ container_handle);
138
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
139
+ { *name = container->constant_name(idx); })
140
+ }
141
+
142
+ AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN(
143
+ AOTInductorModelContainerHandle container_handle,
144
+ size_t idx,
145
+ const char** original_fqn) {
146
+ auto* container =
147
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
148
+ container_handle);
149
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
150
+ { *original_fqn = container->constant_original_fqn(idx); })
151
+ }
152
+
153
+ AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded(
154
+ AOTInductorModelContainerHandle container_handle,
155
+ size_t idx,
156
+ bool* from_folded) {
157
+ auto* container =
158
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(container_handle);
159
+ CONVERT_EXCEPTION_TO_ERROR_CODE({ *from_folded = container->constant_from_folded(idx); })
160
+ }
161
+
162
+ AOTIRuntimeError AOTInductorModelContainerGetConstantDtype(
163
+ AOTInductorModelContainerHandle container_handle,
164
+ size_t idx,
165
+ int32_t* dtype) {
166
+ auto* container =
167
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
168
+ container_handle);
169
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
170
+ { *dtype = container->constant_dtype(idx); })
171
+ }
172
+
173
+ AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer(
174
+ AOTInductorModelContainerHandle container_handle,
175
+ AOTInductorConstantMapHandle constant_map_handle,
176
+ bool use_inactive,
177
+ bool validate_full_update) {
178
+ auto* container =
179
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
180
+ container_handle);
181
+ auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
182
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
183
+ container->update_constant_buffer(
184
+ *input_map, use_inactive, validate_full_update);
185
+ })
186
+ }
187
+
188
+ AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer(
189
+ AOTInductorModelContainerHandle container_handle,
190
+ AOTInductorConstantMapHandle constant_map_handle) {
191
+ return AOTInductorModelContainerUpdateConstantBuffer(container_handle,
192
+ constant_map_handle,
193
+ /*use_inactive*/ true,
194
+ /*validate_full_update*/ true);
195
+ }
196
+
197
+ AOTIRuntimeError AOTInductorModelContainerRunConstantFolding(
198
+ AOTInductorModelContainerHandle container_handle,
199
+ bool use_inactive,
200
+ AOTInductorStreamHandle stream_handle,
201
+ AOTIProxyExecutorHandle proxy_executor_handle) {
202
+ auto* container =
203
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
204
+ container_handle);
205
+ auto stream =
206
+ reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
207
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
208
+ AOTINoGradGuard guard;
209
+ container->run_const_fold(use_inactive, stream, proxy_executor_handle);
210
+ })
211
+ }
212
+
213
+ AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer(
214
+ AOTInductorModelContainerHandle container_handle) {
215
+ auto* container =
216
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
217
+ container_handle);
218
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
219
+ container->swap_constant_buffer();
220
+ })
221
+ }
222
+
223
+ AOTIRuntimeError AOTInductorModelContainerGetNumInputs(
224
+ AOTInductorModelContainerHandle container_handle,
225
+ size_t* ret_num_inputs) {
226
+ auto* container =
227
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
228
+ container_handle);
229
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
230
+ { *ret_num_inputs = container->num_inputs(); })
231
+ }
232
+
233
+ AOTIRuntimeError AOTInductorModelContainerGetInputName(
234
+ AOTInductorModelContainerHandle container_handle,
235
+ size_t input_idx,
236
+ const char** ret_input_names) {
237
+ auto* container =
238
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
239
+ container_handle);
240
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
241
+ { *ret_input_names = container->input_name(input_idx); })
242
+ }
243
+
244
+ AOTIRuntimeError AOTInductorModelContainerGetNumOutputs(
245
+ AOTInductorModelContainerHandle container_handle,
246
+ size_t* ret_num_outputs) {
247
+ auto* container =
248
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
249
+ container_handle);
250
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
251
+ { *ret_num_outputs = container->num_outputs(); })
252
+ }
253
+
254
+ AOTIRuntimeError AOTInductorModelContainerGetOutputName(
255
+ AOTInductorModelContainerHandle container_handle,
256
+ size_t output_idx,
257
+ const char** ret_output_names) {
258
+ auto* container =
259
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
260
+ container_handle);
261
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
262
+ { *ret_output_names = container->output_name(output_idx); })
263
+ }
264
+
265
+ AOTIRuntimeError AOTInductorModelContainerGetCallSpec(
266
+ AOTInductorModelContainerHandle container_handle,
267
+ const char** in_spec,
268
+ const char** out_spec) {
269
+ auto* container =
270
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
271
+ container_handle);
272
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
273
+ *in_spec = container->get_in_spec();
274
+ *out_spec = container->get_out_spec();
275
+ })
276
+ }
277
+
278
+ AOTIRuntimeError AOTInductorModelCreate(
279
+ AOTInductorModelHandle* model_handle,
280
+ AOTInductorConstantMapHandle constant_map_handle){
281
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
282
+ auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
283
+ auto constant_array = std::make_shared<std::vector<torch::aot_inductor::ConstantHandle>>();
284
+ auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
285
+
286
+ auto model = new torch::aot_inductor::AOTInductorModel(
287
+ constant_map,
288
+ constant_array,
289
+ "cpu", // device_str is hardcoded, as AOTInductorModelCreate is only use for CPU models
290
+ ""
291
+ );
292
+
293
+ if (input_map) {
294
+ for (auto const& kv : *input_map) {
295
+ constant_map->emplace(kv.first, kv.second);
296
+ }
297
+ } else {
298
+ model->load_constants();
299
+ }
300
+
301
+ *model_handle = reinterpret_cast<AOTInductorModelHandle>(model);
302
+ })}
303
+
304
+ AOTIRuntimeError AOTInductorModelRun(
305
+ AOTInductorModelHandle model_handle,
306
+ AtenTensorHandle* input_handles,
307
+ AtenTensorHandle* output_handles) {
308
+ auto model =
309
+ reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
310
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
311
+ AOTINoGradGuard guard;
312
+ model->run_impl(
313
+ input_handles,
314
+ output_handles,
315
+ (torch::aot_inductor::DeviceStreamType) nullptr,
316
+ nullptr);
317
+ })
318
+ }
319
+
320
+ AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle){
321
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
322
+ auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(
323
+ model_handle);
324
+ delete model;
325
+ })}
326
+
327
+ AOTIRuntimeError AOTInductorModelGetNumOutputs(
328
+ AOTInductorModelHandle model_handle,
329
+ size_t* ret_num_outputs) {
330
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
331
+ auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
332
+ *ret_num_outputs = model->num_outputs();
333
+ })
334
+ }
335
+
336
+ AOTIRuntimeError AOTInductorModelUpdateConstantsMap(
337
+ AOTInductorModelHandle model_handle,
338
+ AOTInductorConstantMapHandle constant_map_handle) {
339
+ auto model =
340
+ reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
341
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
342
+ auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
343
+ auto input_map =
344
+ reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(
345
+ constant_map_handle);
346
+
347
+ for (auto const& kv : *input_map) {
348
+ constant_map->emplace(kv.first, kv.second);
349
+ }
350
+ model->update_constants_map(std::move(constant_map));
351
+ })
352
+ }
353
+
354
+ } // extern "C"
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (200 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__pycache__/device_op_overrides.cpython-311.pyc ADDED
Binary file (1.48 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/device_op_overrides.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from ..common import DeviceOpOverrides, register_device_op_overrides
3
+
4
+
5
+ class XPUDeviceOpOverrides(DeviceOpOverrides):
6
+ def import_get_raw_stream_as(self, name):
7
+ return f"from torch._C import _xpu_getCurrentRawStream as {name}"
8
+
9
+ def set_device(self, device_idx):
10
+ return f"torch.xpu.set_device({device_idx})"
11
+
12
+ def synchronize(self):
13
+ return "torch.xpu.synchronize()"
14
+
15
+ def device_guard(self, device_idx):
16
+ return f"torch.xpu._DeviceGuard({device_idx})"
17
+
18
+
19
+ register_device_op_overrides("xpu", XPUDeviceOpOverrides())
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/ddp_fusion.cpython-311.pyc ADDED
Binary file (30.4 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/b2b_gemm.py ADDED
@@ -0,0 +1,746 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+ from collections import deque
4
+ from typing import Dict, List, Set, Tuple
5
+
6
+ import torch
7
+ from torch.utils._pytree import tree_map
8
+
9
+ from ..._dynamo.utils import counters
10
+ from ..ir import (
11
+ ComputedBuffer,
12
+ FixedLayout,
13
+ FlexibleLayout,
14
+ InputBuffer,
15
+ StorageBox,
16
+ Subgraph,
17
+ TensorBox,
18
+ )
19
+ from ..lowering import lowerings
20
+ from ..pattern_matcher import (
21
+ Arg,
22
+ CallFunction,
23
+ Match,
24
+ PatternMatcherPass,
25
+ register_graph_pattern,
26
+ )
27
+ from ..select_algorithm import (
28
+ autotune_select_algorithm,
29
+ ExternKernelChoice,
30
+ TritonTemplate,
31
+ TritonTemplateCaller,
32
+ )
33
+ from ..utils import ceildiv
34
+
35
+
36
+ B2B_GEMM_PASS = PatternMatcherPass(
37
+ pass_name="b2b_gemm_pass",
38
+ )
39
+
40
+
41
+ def b2b_gemm_grid(M, P, meta):
42
+ return (ceildiv(M, meta["BLOCK_SIZE_M"]) * ceildiv(P, meta["BLOCK_SIZE_P"]), 1, 1)
43
+
44
+
45
+ b2b_gemm_left_template = TritonTemplate(
46
+ name="b2b_gemm_left",
47
+ grid=b2b_gemm_grid,
48
+ debug=False,
49
+ source=r"""
50
+ {{def_kernel("A", "B", "C")}}
51
+
52
+
53
+ # B2B_GEMM_LEFT_TRITON_ENTRANCE
54
+
55
+ # dynamic shapes
56
+ M = {{size("A", 0)}}
57
+ N = {{size("A", 1)}}
58
+ O = {{size("C", 0)}}
59
+ P = {{size("C", 1)}}
60
+
61
+ # dynamic strides
62
+ stride_am = {{stride("A", 0)}}
63
+ stride_an = {{stride("A", 1)}}
64
+ stride_bn = {{stride("B", 0)}}
65
+ stride_bo = {{stride("B", 1)}}
66
+ stride_co = {{stride("C", 0)}}
67
+ stride_cp = {{stride("C", 1)}}
68
+
69
+ # output block counts
70
+ num_m_block = tl.cdiv(M, BLOCK_SIZE_M)
71
+ num_p_block = tl.cdiv(P, BLOCK_SIZE_P)
72
+
73
+ # internal block counts
74
+ num_n_block = tl.cdiv(N, BLOCK_SIZE_N)
75
+ num_o_block = tl.cdiv(O, BLOCK_SIZE_O)
76
+
77
+ # output block ids
78
+ pid = tl.program_id(axis=0)
79
+ m_block_id = pid // num_p_block
80
+ p_block_id = pid % num_p_block
81
+
82
+ # accumulator
83
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P), dtype=tl.float32)
84
+
85
+ # main loop
86
+ offs_m = (m_block_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))
87
+ offs_p = (p_block_id * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P))
88
+ # (subgraph(A @ B) @ C)
89
+ offs_o = tl.arange(0, BLOCK_SIZE_O)
90
+ for _ in range(num_o_block):
91
+ c_mask = (offs_o[:, None] < O) & (offs_p[None, :] < P)
92
+ c_ptrs = C + (offs_o[:, None] * stride_co + offs_p[None, :] * stride_cp)
93
+ c = tl.load(c_ptrs, mask=c_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_O * BLOCK_SIZE_P
94
+ acc_ab = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_O), dtype=tl.float32)
95
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
96
+ for __ in range(num_n_block):
97
+ a_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
98
+ a_ptrs = A + (offs_m[:, None] * stride_am + offs_n[None, :] * stride_an)
99
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_M * BLOCK_SIZE_N
100
+ b_mask = (offs_n[:, None] < N) & (offs_o[None, :] < O)
101
+ b_ptrs = B + (offs_n[:, None] * stride_bn + offs_o[None, :] * stride_bo)
102
+ b = tl.load(b_ptrs, mask=b_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_N * BLOCK_SIZE_O
103
+ acc_ab += tl.dot(a, b, out_dtype=tl.float32)
104
+ offs_n += BLOCK_SIZE_N
105
+ # apply the subgraph
106
+ {{ modification(
107
+ subgraph_number=0,
108
+ output_name="post_subgraph_acc_ab",
109
+ inner_mm="acc_ab"
110
+ ) | indent_except_first(2) }}
111
+ acc += tl.dot(post_subgraph_acc_ab, c, out_dtype=tl.float32)
112
+ offs_o += BLOCK_SIZE_O
113
+
114
+ # type conversion
115
+ acc = acc.to(tl.float16)
116
+
117
+ # store preparation
118
+ idx_m = offs_m[:, None]
119
+ idx_p = offs_p[None, :]
120
+ out_mask = (idx_m < M) & (idx_p < P)
121
+
122
+ {{store_output(("idx_m", "idx_p"), "acc", "out_mask")}}
123
+ """,
124
+ )
125
+
126
+
127
+ b2b_gemm_right_template = TritonTemplate(
128
+ name="b2b_gemm_right",
129
+ grid=b2b_gemm_grid,
130
+ debug=False,
131
+ source=r"""
132
+ {{def_kernel("A", "B", "C")}}
133
+
134
+
135
+ # B2B_GEMM_RIGHT_TRITON_ENTRANCE
136
+
137
+ # dynamic shapes
138
+ M = {{size("A", 0)}}
139
+ N = {{size("A", 1)}}
140
+ O = {{size("C", 0)}}
141
+ P = {{size("C", 1)}}
142
+
143
+ # dynamic strides
144
+ stride_am = {{stride("A", 0)}}
145
+ stride_an = {{stride("A", 1)}}
146
+ stride_bn = {{stride("B", 0)}}
147
+ stride_bo = {{stride("B", 1)}}
148
+ stride_co = {{stride("C", 0)}}
149
+ stride_cp = {{stride("C", 1)}}
150
+
151
+ # output block counts
152
+ num_m_block = tl.cdiv(M, BLOCK_SIZE_M)
153
+ num_p_block = tl.cdiv(P, BLOCK_SIZE_P)
154
+
155
+ # internal block counts
156
+ num_n_block = tl.cdiv(N, BLOCK_SIZE_N)
157
+ num_o_block = tl.cdiv(O, BLOCK_SIZE_O)
158
+
159
+ # output block ids
160
+ pid = tl.program_id(axis=0)
161
+ m_block_id = pid // num_p_block
162
+ p_block_id = pid % num_p_block
163
+
164
+ # accumulator
165
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P), dtype=tl.float32)
166
+
167
+ # main loop (two cases)
168
+ offs_m = (m_block_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))
169
+ offs_p = (p_block_id * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P))
170
+ # (A @ subgraph(B @ C))
171
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
172
+ for _ in range(num_n_block):
173
+ a_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
174
+ a_ptrs = A + (offs_m[:, None] * stride_am + offs_n[None, :] * stride_an)
175
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_M * BLOCK_SIZE_N
176
+ acc_bc = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_P), dtype=tl.float32)
177
+ offs_o = tl.arange(0, BLOCK_SIZE_O)
178
+ for __ in range(num_o_block):
179
+ b_mask = (offs_n[:, None] < N) & (offs_o[None, :] < O)
180
+ b_ptrs = B + (offs_n[:, None] * stride_bn + offs_o[None, :] * stride_bo)
181
+ b = tl.load(b_ptrs, mask=b_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_N * BLOCK_SIZE_O
182
+ c_mask = (offs_o[:, None] < O) & (offs_p[None, :] < P)
183
+ c_ptrs = C + (offs_o[:, None] * stride_co + offs_p[None, :] * stride_cp)
184
+ c = tl.load(c_ptrs, mask=c_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_O * BLOCK_SIZE_P
185
+ acc_bc += tl.dot(b, c, out_dtype=tl.float32)
186
+ offs_o += BLOCK_SIZE_O
187
+ # apply the subgraph
188
+ {{ modification(
189
+ subgraph_number=0,
190
+ output_name="post_subgraph_acc_bc",
191
+ inner_mm="acc_bc"
192
+ ) | indent_except_first(2) }}
193
+ acc += tl.dot(a, post_subgraph_acc_bc, out_dtype=tl.float32)
194
+ offs_n += BLOCK_SIZE_N
195
+
196
+ # type conversion
197
+ acc = acc.to(tl.float16)
198
+
199
+ # store preparation
200
+ idx_m = offs_m[:, None]
201
+ idx_p = offs_p[None, :]
202
+ out_mask = (idx_m < M) & (idx_p < P)
203
+
204
+ {{store_output(("idx_m", "idx_p"), "acc", "out_mask")}}
205
+ """,
206
+ )
207
+
208
+
209
+ # Note: load_ratio_left and load_ratio_right are only calculating numbers
210
+ # in the trivial subgraph case; i.e. (A @ (B @ C)) or ((A @ B) @ C)
211
+
212
+
213
+ def load_ratio_left(
214
+ M: int, N: int, O: int, P: int, m: int, n: int, o: int, p: int
215
+ ) -> float:
216
+ """
217
+ compute the ratio of estimated numbers of loads in baseline and b2bgemm
218
+ M, N, O, P are matrix sizes
219
+ m, n, o, p are block sizes
220
+ | | baseline (lower bound) | b2bgemm
221
+ | load | M * N + N * O + M * O + O * P | M / m * P / p * O / o * (o * p + N / n * (m * n + n * o))
222
+ | store | M * O + M * P | M * P
223
+ b2bgemm is always better on stores, but for loads we need to find out beneficial cases using this function
224
+ """
225
+ base = M * N + N * O + M * O + O * P
226
+ gemm = (
227
+ ceildiv(M, m)
228
+ * ceildiv(P, p)
229
+ * ceildiv(O, o)
230
+ * (o * p + ceildiv(N, n) * (m * n + n * o))
231
+ )
232
+ return base / gemm
233
+
234
+
235
+ def load_ratio_right(
236
+ M: int, N: int, O: int, P: int, m: int, n: int, o: int, p: int
237
+ ) -> float:
238
+ """
239
+ compute the ratio of estimated numbers of loads in baseline and b2bgemm
240
+ M, N, O, P are matrix sizes
241
+ m, n, o, p are block sizes
242
+ | | baseline (lower bound) | b2bgemm
243
+ | load | N * O + O * P + M * N + N * P | M / m * P / p * N / n * (m * n + O / o * (n * o + o * p))
244
+ | store | N * P + M * P | M * P
245
+ b2bgemm is always better on stores, but for loads we need to find out beneficial cases using this function
246
+ """
247
+ base = N * O + O * P + M * N + N * P
248
+ gemm = (
249
+ ceildiv(M, m)
250
+ * ceildiv(P, p)
251
+ * ceildiv(N, n)
252
+ * (m * n + ceildiv(O, o) * (n * o + o * p))
253
+ )
254
+ return base / gemm
255
+
256
+
257
+ # the block sizes are limited by hardware (the shared memory)
258
+ # intuitively, the optimization works when the intermediate matrix is large
259
+ # and we assign large block sizes to large dimensions
260
+ b2b_gemm_configs = [
261
+ {
262
+ "BLOCK_SIZE_M": 128,
263
+ "BLOCK_SIZE_N": 16,
264
+ "BLOCK_SIZE_O": 16,
265
+ "BLOCK_SIZE_P": 16,
266
+ "num_stages": 4,
267
+ "num_warps": 8,
268
+ },
269
+ {
270
+ "BLOCK_SIZE_M": 128,
271
+ "BLOCK_SIZE_N": 32,
272
+ "BLOCK_SIZE_O": 32,
273
+ "BLOCK_SIZE_P": 32,
274
+ "num_stages": 2,
275
+ "num_warps": 4,
276
+ },
277
+ {
278
+ "BLOCK_SIZE_M": 128,
279
+ "BLOCK_SIZE_N": 64,
280
+ "BLOCK_SIZE_O": 64,
281
+ "BLOCK_SIZE_P": 64,
282
+ "num_stages": 2,
283
+ "num_warps": 4,
284
+ },
285
+ {
286
+ "BLOCK_SIZE_M": 128,
287
+ "BLOCK_SIZE_N": 16,
288
+ "BLOCK_SIZE_O": 128,
289
+ "BLOCK_SIZE_P": 16,
290
+ "num_stages": 4,
291
+ "num_warps": 8,
292
+ },
293
+ {
294
+ "BLOCK_SIZE_M": 128,
295
+ "BLOCK_SIZE_N": 32,
296
+ "BLOCK_SIZE_O": 128,
297
+ "BLOCK_SIZE_P": 32,
298
+ "num_stages": 2,
299
+ "num_warps": 4,
300
+ },
301
+ {
302
+ "BLOCK_SIZE_M": 128,
303
+ "BLOCK_SIZE_N": 64,
304
+ "BLOCK_SIZE_O": 128,
305
+ "BLOCK_SIZE_P": 64,
306
+ "num_stages": 2,
307
+ "num_warps": 4,
308
+ },
309
+ {
310
+ "BLOCK_SIZE_M": 16,
311
+ "BLOCK_SIZE_N": 16,
312
+ "BLOCK_SIZE_O": 16,
313
+ "BLOCK_SIZE_P": 128,
314
+ "num_stages": 4,
315
+ "num_warps": 8,
316
+ },
317
+ {
318
+ "BLOCK_SIZE_M": 32,
319
+ "BLOCK_SIZE_N": 32,
320
+ "BLOCK_SIZE_O": 32,
321
+ "BLOCK_SIZE_P": 128,
322
+ "num_stages": 2,
323
+ "num_warps": 4,
324
+ },
325
+ {
326
+ "BLOCK_SIZE_M": 64,
327
+ "BLOCK_SIZE_N": 64,
328
+ "BLOCK_SIZE_O": 64,
329
+ "BLOCK_SIZE_P": 128,
330
+ "num_stages": 2,
331
+ "num_warps": 4,
332
+ },
333
+ {
334
+ "BLOCK_SIZE_M": 16,
335
+ "BLOCK_SIZE_N": 128,
336
+ "BLOCK_SIZE_O": 16,
337
+ "BLOCK_SIZE_P": 128,
338
+ "num_stages": 4,
339
+ "num_warps": 8,
340
+ },
341
+ {
342
+ "BLOCK_SIZE_M": 32,
343
+ "BLOCK_SIZE_N": 128,
344
+ "BLOCK_SIZE_O": 32,
345
+ "BLOCK_SIZE_P": 128,
346
+ "num_stages": 2,
347
+ "num_warps": 4,
348
+ },
349
+ {
350
+ "BLOCK_SIZE_M": 64,
351
+ "BLOCK_SIZE_N": 128,
352
+ "BLOCK_SIZE_O": 64,
353
+ "BLOCK_SIZE_P": 128,
354
+ "num_stages": 2,
355
+ "num_warps": 4,
356
+ },
357
+ ]
358
+
359
+
360
+ def is_b2b_gemm_good_on(
361
+ is_left_assoc: bool,
362
+ A_node: torch.fx.Node,
363
+ B_node: torch.fx.Node,
364
+ C_node: torch.fx.Node,
365
+ ) -> bool:
366
+ """
367
+ checks whether the sizes are good for b2b_gemm
368
+ """
369
+ # basic checks
370
+ if not all(["val" in A_node.meta, "val" in B_node.meta, "val" in C_node.meta]):
371
+ return False
372
+ A, B, C = (
373
+ A_node.meta["val"],
374
+ B_node.meta["val"],
375
+ C_node.meta["val"],
376
+ ) # torch._subclasses.fake_tensor.FakeTensor
377
+ if not all([A.is_cuda, B.is_cuda, C.is_cuda]):
378
+ return False
379
+ if not all([len(A.shape) == 2, len(B.shape) == 2, len(C.shape) == 2]):
380
+ return False
381
+ if not ((A.shape[1] == B.shape[0]) and (B.shape[1] == C.shape[0])):
382
+ return False
383
+ # size checks: we only dispatch to B2B-GEMM when the average load ratio is > 1
384
+ M, N = A.shape
385
+ O, P = C.shape
386
+ ratios = []
387
+ if is_left_assoc:
388
+ for config in b2b_gemm_configs:
389
+ ratio = load_ratio_left(
390
+ M,
391
+ N,
392
+ O,
393
+ P,
394
+ config["BLOCK_SIZE_M"],
395
+ config["BLOCK_SIZE_N"],
396
+ config["BLOCK_SIZE_O"],
397
+ config["BLOCK_SIZE_P"],
398
+ )
399
+ ratios.append(ratio)
400
+ else:
401
+ for config in b2b_gemm_configs:
402
+ ratio = load_ratio_right(
403
+ M,
404
+ N,
405
+ O,
406
+ P,
407
+ config["BLOCK_SIZE_M"],
408
+ config["BLOCK_SIZE_N"],
409
+ config["BLOCK_SIZE_O"],
410
+ config["BLOCK_SIZE_P"],
411
+ )
412
+ ratios.append(ratio)
413
+ ratios.sort(reverse=True)
414
+ average_ratio = 1.0
415
+ for r in ratios[:3]: # top 3 choices
416
+ average_ratio *= r
417
+ average_ratio = average_ratio ** (1 / 3)
418
+ return (
419
+ average_ratio > 1
420
+ ) # even if average_ratio is close to 1, the number of stores is always better
421
+
422
+
423
+ def unoptimized_b2b_gemm(
424
+ is_left_assoc: bool,
425
+ subgraph: Subgraph,
426
+ A: torch.Tensor,
427
+ B: torch.Tensor,
428
+ C: torch.Tensor,
429
+ *,
430
+ out: torch.Tensor,
431
+ ) -> torch.Tensor:
432
+ """
433
+ The unoptimized version is used as a fallback when the b2b_gemm kernel is not beneficial.
434
+ """
435
+ if is_left_assoc:
436
+ torch.mm(subgraph.graph_module(torch.mm(A, B)), C, out=out)
437
+ else:
438
+ torch.mm(A, subgraph.graph_module(torch.mm(B, C)), out=out)
439
+ return out
440
+
441
+
442
+ unoptimized_choice = ExternKernelChoice(unoptimized_b2b_gemm)
443
+
444
+
445
+ def build_subgraph_buffer(
446
+ args: List[TensorBox],
447
+ subgraph: Subgraph,
448
+ ):
449
+ """
450
+ This function is adapted from ../kernel/flex_attention.py.
451
+ The goal is to take in the required args and produce the subgraph buffer
452
+ The subgraph buffer is a ComputedBuffer that will be inlined into the triton template
453
+
454
+ Args:
455
+ args: The args that are passed into the subgraph
456
+ subgraph: The Subgraph ir for which to produce the output node
457
+ """
458
+ cnt = 0
459
+ env = {}
460
+ for node in subgraph.graph_module.graph.nodes:
461
+ if node.op == "placeholder":
462
+ env[node] = args[cnt]
463
+ cnt += 1
464
+ elif node.op == "call_function":
465
+ # For call_function we use the default lowerings and pass in the
466
+ # already created TensorBoxes as args
467
+ args, kwargs = tree_map(
468
+ lambda x: env[x] if x in env else x, (node.args, node.kwargs)
469
+ )
470
+ env[node] = lowerings[node.target](*args, **kwargs)
471
+ elif node.op == "output":
472
+
473
+ def convert_output_node_to_buffer(output):
474
+ if output is None:
475
+ return None
476
+ output_node = output
477
+ output_buffer = env[output_node]
478
+ assert isinstance(output_buffer, TensorBox), (
479
+ "The output node for B2B-GEMM's subgraph must be a TensorBox, but got: ",
480
+ type(output_buffer),
481
+ )
482
+ assert isinstance(output_buffer.data, StorageBox), (
483
+ "The output node for B2B-GEMM's subgraph must be a StorageBox, but got: ",
484
+ type(output_buffer),
485
+ )
486
+ subgraph_buffer = ComputedBuffer(
487
+ name=None,
488
+ layout=FlexibleLayout(
489
+ device=output_buffer.data.get_device(),
490
+ dtype=output_buffer.data.get_dtype(),
491
+ size=output_buffer.data.get_size(),
492
+ ),
493
+ data=output_buffer.data.data, # type: ignore[arg-type]
494
+ )
495
+ return subgraph_buffer
496
+
497
+ # node.args[0] should be a single element representing the output of the subgraph
498
+ return tree_map(convert_output_node_to_buffer, node.args[0])
499
+
500
+ raise ValueError("B2B-GEMM was passed a subgraph with no output node!")
501
+
502
+
503
+ def create_placeholder(
504
+ name: str, dtype: torch.dtype, device: torch.device
505
+ ) -> TensorBox:
506
+ """
507
+ Creates a placeholder input buffers for producing subgraph_output
508
+ """
509
+ input_buffer = InputBuffer(name, FixedLayout(device, dtype, [], []))
510
+ return TensorBox.create(input_buffer)
511
+
512
+
513
+ def tuned_b2b_gemm(
514
+ is_left_assoc: bool,
515
+ subgraph: Subgraph,
516
+ A: torch._inductor.ir.TensorBox,
517
+ B: torch._inductor.ir.TensorBox,
518
+ C: torch._inductor.ir.TensorBox,
519
+ *,
520
+ layout=None,
521
+ ) -> torch._inductor.ir.TensorBox:
522
+ # call .realize() to get rid of Pointwise
523
+ A.realize()
524
+ B.realize()
525
+ C.realize()
526
+ layout = FixedLayout(A.get_device(), A.get_dtype(), [A.shape[0], C.shape[1]])
527
+ subgraph_buffer = build_subgraph_buffer(
528
+ [create_placeholder("inner_mm", A.get_dtype(), A.get_device())],
529
+ subgraph,
530
+ )
531
+ choices: list[TritonTemplateCaller] = []
532
+ for config in b2b_gemm_configs:
533
+ if is_left_assoc:
534
+ b2b_gemm_left_template.maybe_append_choice(
535
+ choices,
536
+ input_nodes=(A, B, C),
537
+ layout=layout,
538
+ subgraphs=[subgraph_buffer],
539
+ **config,
540
+ )
541
+ else:
542
+ b2b_gemm_right_template.maybe_append_choice(
543
+ choices,
544
+ input_nodes=(A, B, C),
545
+ layout=layout,
546
+ subgraphs=[subgraph_buffer],
547
+ **config,
548
+ )
549
+ # add the unoptimized choice to mitigate performance degradation
550
+ choices.append(
551
+ unoptimized_choice.bind(
552
+ (A, B, C), layout, is_left_assoc=is_left_assoc, subgraph=subgraph
553
+ )
554
+ )
555
+ # autotune
556
+ return autotune_select_algorithm("b2b_gemm", choices, [A, B, C], layout)
557
+
558
+
559
+ # match the inner mm of a potential b2b_gemm
560
+ @register_graph_pattern(
561
+ CallFunction(torch.ops.aten.mm, Arg(), Arg()),
562
+ pass_dict=B2B_GEMM_PASS,
563
+ )
564
+ def b2b_gemm_handler(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node) -> None:
565
+ # match.args: list[torch.fx.Node]
566
+
567
+ def is_pointwise_node(node: torch.fx.Node) -> bool:
568
+ return (
569
+ node.op == "call_function"
570
+ and isinstance(node.target, torch._ops.OpOverload)
571
+ and (torch.Tag.pointwise in node.target.tags)
572
+ )
573
+
574
+ def is_mm(node: torch.fx.Node) -> bool:
575
+ return node.target == torch.ops.aten.mm.default
576
+
577
+ # the inner MM
578
+ inner_mm = match.nodes[-1]
579
+
580
+ # find the (candidate) outer MM, which will be re-checked below to ensure every path reaches it
581
+ # In a real (A @ f(B @ C)), every path starting from (B @ C) must reach (A @ _).
582
+ outer_mm = None
583
+ node = inner_mm
584
+ while len(node.users) > 0:
585
+ node = next(iter(node.users))
586
+ if is_mm(node):
587
+ outer_mm = node
588
+ break
589
+ elif is_pointwise_node(node):
590
+ continue
591
+ else:
592
+ break
593
+ if not outer_mm:
594
+ return
595
+
596
+ # find the unique input node for outer_mm representing f(B @ C) in (A @ f(B @ C))
597
+ # we call it the "f_node"
598
+ # when the pattern is simply (A @ (B @ C)), f_node is just inner_mm
599
+ f_node = inner_mm
600
+ while next(iter(f_node.users)) is not outer_mm:
601
+ f_node = next(iter(f_node.users))
602
+
603
+ def all_reach_via_pointwise_with_no_other_inputs(
604
+ src: torch.fx.Node,
605
+ dst: torch.fx.Node,
606
+ ) -> Tuple[bool, Set[torch.fx.Node]]:
607
+ """
608
+ check whether every user path from src reaches dst via pointwise nodes,
609
+ with no other input nodes for the intermediates and dst;
610
+ return
611
+ (1) the Boolean value
612
+ (2) the subgraph node set including src and dst (which only makes sense when the Boolean value is True)
613
+ """
614
+ visited: Set[torch.fx.Node] = set()
615
+ input_counter: Dict[torch.fx.Node, int] = {}
616
+
617
+ all_reachable = True
618
+ queue = deque([src])
619
+ while queue:
620
+ node = queue.popleft()
621
+ if node not in visited:
622
+ if node is dst:
623
+ visited.add(node)
624
+ elif (node is src) or is_pointwise_node(node):
625
+ for user in node.users.keys():
626
+ # for nodes other than dst, bookkeep their users' input counts
627
+ if user not in input_counter:
628
+ input_counter[user] = len(user.all_input_nodes)
629
+ input_counter[user] -= 1
630
+ # continue BFS
631
+ queue.append(user)
632
+ visited.add(node)
633
+ else:
634
+ all_reachable = False
635
+ break
636
+
637
+ return (
638
+ all_reachable and all(count == 0 for count in input_counter.values()),
639
+ visited,
640
+ )
641
+
642
+ # check inner_mm reaches f_node on every user path via pointwise nodes with no outside input_nodes
643
+ ok, subgraph_node_set = all_reach_via_pointwise_with_no_other_inputs(
644
+ inner_mm, f_node
645
+ )
646
+ if not ok:
647
+ return
648
+
649
+ # check inner_mm's inputs and f_node's outputs
650
+ if not (len(inner_mm.all_input_nodes) == 2 and len(f_node.users) == 1):
651
+ return
652
+
653
+ # at this point, the nodes between inner_mm and f_node (both included)
654
+ # are all used internally inside (A @ subgraph(B @ C))
655
+ # i.e. they neither have other users nor have other inputs
656
+
657
+ # original graph and module
658
+ graph, module = inner_mm.graph, inner_mm.graph.owning_module
659
+
660
+ # construct the new (sub)graph
661
+ subgraph_node_list: List[
662
+ torch.fx.Node
663
+ ] = [] # ordered list of nodes used for node removal later
664
+ new_graph: torch.fx.Graph = torch.fx.Graph()
665
+ node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
666
+ new_input_anchor: torch.fx.Node # inner_mm, to be changed to an input node
667
+ new_output_anchor: torch.fx.Node # f_node, to be used to construct an output node
668
+ new_input_node: torch.fx.Node
669
+ new_output_node: torch.fx.Node
670
+ for node in graph.nodes: # preserve the order of nodes
671
+ if node in subgraph_node_set:
672
+ subgraph_node_list.append(node)
673
+ new_node = new_graph.node_copy(
674
+ node, lambda x: node_remapping[x] if x in node_remapping else x
675
+ )
676
+ node_remapping[node] = new_node
677
+ if node is inner_mm:
678
+ new_input_anchor = new_node
679
+ if node is f_node:
680
+ new_output_anchor = new_node
681
+ if new_input_anchor is not new_output_anchor: # subgraph is non-trivial
682
+ # update the input node
683
+ with new_graph.inserting_before(new_input_anchor):
684
+ new_input_node = new_graph.placeholder(name="subgraph_input")
685
+ new_input_node.meta.update(new_input_anchor.meta)
686
+ new_input_anchor.replace_all_uses_with(new_input_node)
687
+ new_graph.erase_node(new_input_anchor)
688
+ # add the output node
689
+ new_output_node = new_graph.output(new_output_anchor)
690
+ new_output_node.meta.update(new_output_anchor.meta)
691
+ else: # subgraph is trivial, e.g. (A @ (B @ C))
692
+ # update the input node
693
+ with new_graph.inserting_before(new_input_anchor):
694
+ new_input_node = new_graph.placeholder(name="subgraph_input")
695
+ new_input_node.meta.update(new_input_anchor.meta)
696
+ new_input_anchor.replace_all_uses_with(new_input_node)
697
+ new_graph.erase_node(new_input_anchor)
698
+ # update the output node (don't use new_output_anchor since it has been erased)
699
+ new_output_node = new_graph.output(new_input_node)
700
+ new_output_node.meta.update(new_input_node.meta)
701
+ new_graph.lint()
702
+
703
+ # construct the subgraph
704
+ subgraph = Subgraph(
705
+ name="subgraph", graph_module=torch.fx.GraphModule(module, new_graph)
706
+ )
707
+
708
+ # two cases
709
+ # (1) (subgraph(A @ B) @ C), called "left_assoc"
710
+ # (2) (A @ subgraph(B @ C)), called "right_assoc"
711
+ is_left_assoc = outer_mm.args[0] is f_node
712
+
713
+ # find the nodes A, B, C and check the sizes
714
+ A: torch.fx.Node
715
+ B: torch.fx.Node
716
+ C: torch.fx.Node
717
+ if is_left_assoc:
718
+ A = inner_mm.args[0] # type: ignore[assignment]
719
+ B = inner_mm.args[1] # type: ignore[assignment]
720
+ C = outer_mm.args[1] # type: ignore[assignment]
721
+ else:
722
+ A = outer_mm.args[0] # type: ignore[assignment]
723
+ B = inner_mm.args[0] # type: ignore[assignment]
724
+ C = inner_mm.args[1] # type: ignore[assignment]
725
+ if not is_b2b_gemm_good_on(is_left_assoc, A, B, C):
726
+ return
727
+
728
+ # finally update the original graph
729
+ counters["inductor"]["b2b_gemm"] += 1
730
+ graph = match.graph
731
+ with graph.inserting_before(outer_mm):
732
+ function = functools.partial(tuned_b2b_gemm, is_left_assoc, subgraph)
733
+ function.__name__ = tuned_b2b_gemm.__name__ # type: ignore[attr-defined]
734
+ function._inductor_lowering_function = True # type: ignore[attr-defined]
735
+ replacement: torch.fx.Node = graph.call_function(
736
+ function,
737
+ (A, B, C),
738
+ match.kwargs,
739
+ )
740
+ replacement.meta.update(outer_mm.meta)
741
+ outer_mm.replace_all_uses_with(replacement)
742
+ # erase unnecessary nodes
743
+ graph.erase_node(outer_mm)
744
+ for node in reversed(subgraph_node_list):
745
+ graph.erase_node(node)
746
+ graph.lint()
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/binary_folding.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+ import itertools
4
+
5
+ import torch
6
+
7
+ from ..._dynamo.utils import counters
8
+ from ..pattern_matcher import Arg, CallFunction, KeywordArg
9
+ from .freezing_patterns import register_binary_folding_pattern
10
+
11
+
12
+ aten = torch.ops.aten
13
+ prims = torch.ops.prims
14
+
15
+
16
+ def mark_mixed_dtype_conv(conv):
17
+ conv_dtype = conv.meta["val"].dtype
18
+ if conv_dtype not in (torch.float16, torch.bfloat16):
19
+ return
20
+
21
+ if not len(conv.users) == 1:
22
+ return
23
+
24
+ conv_user = next(iter(conv.users.keys()))
25
+ if not isinstance(conv_user.meta["val"], torch.Tensor):
26
+ return
27
+
28
+ if not conv_user.meta["val"].dtype == torch.float32:
29
+ return
30
+
31
+ while conv_user.target in _binary_ops:
32
+ if not len(conv_user.users) == 1:
33
+ return
34
+
35
+ conv_user = next(iter(conv_user.users.keys()))
36
+
37
+ if conv_user.target != prims.convert_element_type.default:
38
+ return
39
+
40
+ conv.meta["_allow_conv_mixed_dtype_folding"] = conv_dtype
41
+
42
+
43
+ def mark_mixed_dtype_allowed_convs(gm):
44
+ """
45
+ Mark convolutions which we will binary fold even with mixed precision constants. We constant fold in the higher precision
46
+ for better accuracy and then recover the original precision after.
47
+ """
48
+ for node in gm.graph.find_nodes(
49
+ op="call_function", target=aten.convolution.default
50
+ ):
51
+ mark_mixed_dtype_conv(node)
52
+
53
+
54
+ def recover_original_precision_folded_convs(gm):
55
+ """
56
+ After binary folding conv weights and biases to a higher dtype, recover the original precision they were in.
57
+ """
58
+ graph = gm.graph
59
+ for node in graph.find_nodes(op="call_function", target=aten.convolution.default):
60
+ orig_dtype = node.meta.get("_allow_conv_mixed_dtype_folding", None)
61
+ if orig_dtype is None:
62
+ continue
63
+
64
+ with graph.inserting_before(node):
65
+ for idx in [1, 2]:
66
+ old_input = node.args[idx]
67
+ if old_input is None:
68
+ continue
69
+
70
+ new_input = graph.create_node(
71
+ "call_function",
72
+ prims.convert_element_type.default,
73
+ (old_input, orig_dtype),
74
+ )
75
+ node.replace_input_with(old_input, new_input)
76
+
77
+
78
+ _binary_ops = [aten.add.Tensor, aten.sub.Tensor, aten.mul.Tensor, aten.div.Tensor]
79
+
80
+
81
+ @functools.lru_cache(None)
82
+ def binary_folding_init():
83
+ _conv_args = [Arg() for _ in range(9)]
84
+ _computation_ops = [aten.convolution.default]
85
+ _computation_calls = [CallFunction(aten.convolution.default, *_conv_args, _users=1)]
86
+
87
+ """
88
+ In order to fuse add/sub/mul/div with conv, the dimensions of its
89
+ constant tensor must satisfy the following:
90
+ - with resizing, broadcast to w/ weight/bias tensor shape
91
+ - broadcast to the conv output shape
92
+ It needs to have a shape that can resize to weight/bias
93
+ tensor shape because we need to run the op with the conv
94
+ weights/bias without changing their sizes.
95
+ It needs to broadcast to the conv output shape so that we do
96
+ accidentally change the shape of op output by pre-fusing it
97
+ compared to eager.
98
+ The only dimension value shared by weight/bias/conv output
99
+ is they all contain a dim with value = channels-out. In the
100
+ conv output tensor, this is in the second dimension,
101
+ so the pointwise op tensor may have a second dimension of
102
+ value == channels-out, but all the other dimensions have to be 1
103
+ """
104
+
105
+ def _op_not_broadcasting_with_conv(weight_tensor, other_tensor):
106
+ # According to opDoesNotBroadCastWithConv of frozen_conv_folding.cpp
107
+ weight_shape = weight_tensor.shape
108
+ other_shape = other_tensor.shape
109
+ if len(weight_shape) < len(other_shape):
110
+ return False
111
+ if len(weight_shape) == len(other_shape) + 1:
112
+ # weight shape is [o, i, *], other_shape is [o, 1...].
113
+ for i in reversed(range(len(other_shape))):
114
+ if i == 0 and weight_shape[0] == other_shape[i]:
115
+ continue
116
+ if other_shape[i] != 1:
117
+ return False
118
+ else:
119
+ # weight shape is [o, i, *], other_shape is [1, i, *]
120
+ for i in reversed(range(len(other_shape))):
121
+ if i == 1 and weight_shape[0] == other_shape[i]:
122
+ continue
123
+ if other_shape[i] != 1:
124
+ return False
125
+ return True
126
+
127
+ def _check_conv_and_broadcast_op(conv_node, other):
128
+ # According to checkConvAndBroadcastingOpPreConditions of frozen_conv_folding.cpp.
129
+ # conv.weight
130
+ if conv_node.args[1].op != "get_attr":
131
+ return False
132
+ # conv.bias
133
+ if conv_node.args[1] is not None and conv_node.args[1].op != "get_attr":
134
+ return False
135
+ if (
136
+ not isinstance(other, int)
137
+ and not isinstance(other, float)
138
+ and other.op != "get_attr"
139
+ ):
140
+ return False
141
+
142
+ if not len(conv_node.args[1].users) == 1:
143
+ return False
144
+
145
+ weight_meta_value = conv_node.args[1].meta.get("val")
146
+ if weight_meta_value is None:
147
+ return False
148
+ # Avoid fusing op that causes type promotion
149
+ # restricting to float avoids int/float difficulties with scalar overload
150
+ if not weight_meta_value.is_floating_point():
151
+ return False
152
+ if isinstance(other, torch.fx.Node) and other.op == "get_attr":
153
+ other_meta_value = other.meta.get("val")
154
+ if not other_meta_value.is_floating_point(): # type: ignore[union-attr]
155
+ return False
156
+ if (
157
+ torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype) # type: ignore[union-attr]
158
+ != weight_meta_value.dtype
159
+ ):
160
+ if not conv_node.meta.get("_allow_conv_mixed_dtype_folding", False):
161
+ return False
162
+
163
+ if (
164
+ other_meta_value.dtype != torch.float # type: ignore[union-attr]
165
+ and weight_meta_value.dtype not in (torch.float16, torch.bfloat16)
166
+ ):
167
+ return False
168
+
169
+ if not _op_not_broadcasting_with_conv(weight_meta_value, other_meta_value):
170
+ return False
171
+ else:
172
+ # TODO: support scalar case
173
+ return False
174
+
175
+ return True
176
+
177
+ def _is_foldable_pattern(match):
178
+ binary_node = match.output_node()
179
+ computation_node = binary_node.args[0]
180
+ other = binary_node.args[1]
181
+ if binary_node.args[0].target not in _computation_ops:
182
+ computation_node = binary_node.args[1]
183
+ other = binary_node.args[0]
184
+ if binary_node.args[0].target == aten.convolution.default:
185
+ return _check_conv_and_broadcast_op(computation_node, other)
186
+
187
+ return False
188
+
189
+ def resize_scalar_or_tensor_to_shape(graph, other, shape):
190
+ # TODO: support scalar case
191
+ if other.meta.get("val").numel() == 1:
192
+ # expand errors if the shape input has less # dims than the tensor input
193
+ res = graph.create_node(
194
+ "call_function",
195
+ aten.reshape.default,
196
+ (other, (1,)),
197
+ )
198
+ res = graph.create_node(
199
+ "call_function",
200
+ aten.expand.default,
201
+ (res, shape),
202
+ )
203
+ else:
204
+ res = graph.create_node(
205
+ "call_function",
206
+ aten.reshape.default,
207
+ (other, shape),
208
+ )
209
+ return res
210
+
211
+ def _create_new_conv_node(graph, conv_node, binary_node, other):
212
+ assert conv_node.target == aten.convolution.default
213
+ conv_args = list(conv_node.args)
214
+ weight_meta_value = conv_node.args[1].meta.get("val")
215
+ bias = conv_args[2]
216
+ if binary_node.target in [aten.add.Tensor, aten.sub.Tensor]:
217
+ other_reshape = resize_scalar_or_tensor_to_shape(
218
+ graph, other, (weight_meta_value.size(0),)
219
+ )
220
+ new_bias = graph.create_node(
221
+ "call_function",
222
+ binary_node.target,
223
+ (0 if bias is None else bias, other_reshape),
224
+ )
225
+ conv_args[2] = new_bias
226
+ else:
227
+ assert binary_node.target in [aten.mul.Tensor, aten.div.Tensor]
228
+ weight_broadcast_shape = [1 for _ in range(len(weight_meta_value.shape))]
229
+ weight_broadcast_shape[0] = weight_meta_value.size(0)
230
+ other_reshape1 = resize_scalar_or_tensor_to_shape(
231
+ graph, other, tuple(weight_broadcast_shape)
232
+ )
233
+ new_weight = graph.create_node(
234
+ "call_function", binary_node.target, (conv_args[1], other_reshape1)
235
+ )
236
+ new_weight.meta.update(conv_args[1].meta)
237
+ conv_args[1] = new_weight
238
+ if bias is not None:
239
+ other_reshape = resize_scalar_or_tensor_to_shape(
240
+ graph, other, (weight_meta_value.size(0),)
241
+ )
242
+ new_bias = graph.create_node(
243
+ "call_function", binary_node.target, (bias, other_reshape)
244
+ )
245
+ new_bias.meta.update(bias.meta)
246
+ conv_args[2] = new_bias
247
+ return graph.create_node("call_function", conv_node.target, tuple(conv_args))
248
+
249
+ for _computation_call, binary_op in itertools.product(
250
+ _computation_calls, _binary_ops
251
+ ):
252
+
253
+ @register_binary_folding_pattern(
254
+ CallFunction(binary_op, _computation_call, KeywordArg("other")),
255
+ extra_check=_is_foldable_pattern,
256
+ )
257
+ def folded_op(match, *args, **kwargs):
258
+ counters["inductor"]["binary_folding"] += 1
259
+ other = kwargs.get("other")
260
+ binary_node = match.output_node()
261
+ computation_node = (
262
+ binary_node.args[0]
263
+ if binary_node.args[0].target in _computation_ops
264
+ else binary_node.args[1]
265
+ )
266
+ graph = match.graph
267
+ with graph.inserting_before(binary_node):
268
+ # TODO: support linear?
269
+ assert computation_node.target == aten.convolution.default
270
+ new_computation_node = _create_new_conv_node(
271
+ graph, computation_node, binary_node, other
272
+ )
273
+ binary_node.replace_all_uses_with(new_computation_node)
274
+ new_computation_node.meta.update(computation_node.meta)
275
+ graph.erase_node(binary_node)
276
+ graph.erase_node(computation_node)
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/ddp_fusion.py ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Owner(s): ["oncall: distributed"]
2
+ import collections
3
+ import inspect
4
+ import logging
5
+ import math
6
+ import operator
7
+ from dataclasses import dataclass
8
+ from functools import partial
9
+ from typing import (
10
+ Any,
11
+ Callable,
12
+ cast,
13
+ Dict,
14
+ Generator,
15
+ List,
16
+ Optional,
17
+ Set,
18
+ Tuple,
19
+ Union,
20
+ )
21
+
22
+ import torch
23
+ import torch.fx as fx
24
+ from torch._dynamo.utils import counters
25
+ from torch.fx.passes.graph_transform_observer import GraphTransformObserver
26
+ from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
27
+ from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
28
+
29
+ from .. import config
30
+ from ..fx_utils import get_fake_args_kwargs
31
+ from ..virtualized import V
32
+
33
+
34
+ aten = torch.ops.aten
35
+ logger: logging.Logger = logging.getLogger("comm_fusion")
36
+
37
+
38
+ def move_block_after(block: List[fx.Node], target_node: fx.Node) -> None:
39
+ for node in block:
40
+ target_node.append(node)
41
+ target_node = node
42
+
43
+
44
+ def move_block_before(block: List[fx.Node], target_node: fx.Node) -> None:
45
+ for node in block:
46
+ target_node.prepend(node)
47
+ target_node = node
48
+
49
+
50
+ def call_function(
51
+ graph: fx.Graph,
52
+ target: Union[str, Callable[..., Any]],
53
+ args: Optional[Tuple[fx.node.Argument, ...]] = None,
54
+ kwargs: Optional[Dict[str, fx.node.Argument]] = None,
55
+ ) -> fx.Node:
56
+ # We accept target as a str to avoid typing error as the type of
57
+ # a node.target is Union[str, Callable[..., Any]].
58
+ # This also allows us to avoid writing check for every call.
59
+ if isinstance(target, str):
60
+ raise RuntimeError(f"Call function should not get a str target {target=}")
61
+ node = graph.call_function(target, args, kwargs)
62
+ _, args, kwargs = get_fake_args_kwargs(node)
63
+ with V.fake_mode:
64
+ node.meta["val"] = target(*args, **kwargs)
65
+ # node.meta["val"] may be a container. So we use tree_map here
66
+ # to recursively extract the tensor metadata.
67
+ node.meta["tensor_meta"] = tree_map(
68
+ _extract_tensor_metadata, (node.meta["val"],)
69
+ )[0]
70
+ return node
71
+
72
+
73
+ @dataclass(unsafe_hash=True)
74
+ class CommBlock:
75
+ shape: Union[torch.Size, List[torch.Size]]
76
+ node_list: List[fx.Node]
77
+ inputs: List[fx.Node]
78
+ wait_nodes: List[fx.Node]
79
+ comm_node: fx.Node
80
+ outputs: Set[fx.Node]
81
+
82
+
83
+ def get_comm_block(comm_node: fx.Node) -> Optional[CommBlock]:
84
+ """
85
+ Given a collective node (e.g., allreduce), find out all the nodes belong to
86
+ this communcation.
87
+
88
+ Args:
89
+ comm_node(fx.Node): The target communication/collective node.
90
+ Returns:
91
+ The CommBlock that encapsulates the related nodes (e.g., wait_node) of
92
+ the given comm_node.
93
+ """
94
+ node_list = []
95
+ wait_nodes = []
96
+ inputs, _ = tree_flatten((comm_node.args, comm_node.kwargs))
97
+ input_nodes = [inp for inp in inputs if isinstance(inp, fx.Node)]
98
+ wait_prefixes = "wait_tensor"
99
+ # If the users of the wait node are following items, we consinder them
100
+ # to be a part of the output.
101
+ intermediate_outputs = ("split", "reshape", "getitem", "detach", "alias")
102
+
103
+ first_user = next(iter(comm_node.users))
104
+ if (
105
+ len(comm_node.users) == 1
106
+ and first_user.target == torch.ops._c10d_functional.wait_tensor.default
107
+ ):
108
+ # Collective with only one output
109
+ node_list = [comm_node, first_user]
110
+ wait_nodes.append(first_user)
111
+ elif len(comm_node.users) > 1 and first_user.target == operator.getitem:
112
+ # Collective with only more than one output
113
+ node_list.append(comm_node)
114
+ for user in comm_node.users:
115
+ if user.target != operator.getitem:
116
+ return None
117
+ if len(user.users) != 1:
118
+ return None
119
+ wait_node = next(iter(user.users))
120
+ if wait_node.target != torch.ops._c10d_functional.wait_tensor.default:
121
+ return None
122
+ wait_nodes.append(wait_node)
123
+ node_list.append(user)
124
+ node_list.extend(wait_nodes)
125
+ else:
126
+ return None
127
+
128
+ # Identify all the outputs of this collective block.
129
+ outputs: Set[fx.Node] = set()
130
+ nodes = collections.deque(wait_nodes)
131
+ while nodes:
132
+ node = nodes.popleft()
133
+ for user in node.users:
134
+ if isinstance(user, fx.Node) and user.name.startswith(intermediate_outputs):
135
+ nodes.append(user)
136
+ node_list.append(user)
137
+ else:
138
+ outputs.add(node)
139
+ break
140
+
141
+ tensor_meta = input_nodes[0].meta["tensor_meta"]
142
+ shape: Union[torch.Size, List[torch.Size]]
143
+ if isinstance(tensor_meta, TensorMetadata):
144
+ shape = tensor_meta.shape
145
+ elif isinstance(tensor_meta, (list, tuple)):
146
+ shape = [tm.shape for tm in tensor_meta]
147
+ else:
148
+ logger.warning("Unexpected type of tensor_meta %s", type(tensor_meta))
149
+ return None
150
+
151
+ return CommBlock(
152
+ shape=shape,
153
+ node_list=node_list,
154
+ wait_nodes=wait_nodes,
155
+ comm_node=comm_node,
156
+ inputs=input_nodes,
157
+ outputs=outputs,
158
+ )
159
+
160
+
161
+ def get_all_comm_blocks(
162
+ graph: fx.Graph,
163
+ comm_ops: Tuple[torch._ops.OpOverload, ...],
164
+ comm_filter: Optional[Callable[..., bool]] = None,
165
+ ) -> List[CommBlock]:
166
+ if comm_filter is None:
167
+
168
+ def always_true(comm_block: CommBlock) -> bool:
169
+ return True
170
+
171
+ comm_filter = always_true
172
+
173
+ blocks = []
174
+ for node in graph.nodes:
175
+ if node.target not in comm_ops:
176
+ continue
177
+ comm_block = get_comm_block(node)
178
+ if comm_block is not None and comm_filter(comm_block):
179
+ blocks.append(comm_block)
180
+ return blocks
181
+
182
+
183
+ def _fuse_allreduce_by_concat(
184
+ graph: fx.Graph,
185
+ last_input_node: fx.Node,
186
+ all_input_nodes: List[fx.Node],
187
+ last_comm_block: CommBlock,
188
+ ) -> CommBlock:
189
+ """Given a list of inputs in order, create a fused allreduce using concat."""
190
+ # Flatten all the inputs to the all_reduce nodes.
191
+ with graph.inserting_after(last_input_node):
192
+ cat_inputs = []
193
+ for input_node in all_input_nodes:
194
+ assert isinstance(input_node.args[0], fx.Node)
195
+ input_node = input_node.args[0]
196
+ cat_inputs.append(
197
+ call_function(graph, aten.flatten.using_ints, (input_node,))
198
+ )
199
+
200
+ # Concat all the flattened nodes.
201
+ with graph.inserting_after(cat_inputs[0]):
202
+ cat_node = call_function(graph, aten.cat, (cat_inputs,))
203
+
204
+ # Insert the fused div node and remove the input div nodes.
205
+ # This is an optimization and is not mandatory for fusion.
206
+ divisors = [div.args[1] for div in all_input_nodes]
207
+ assert all(divisor == divisors[0] for divisor in divisors)
208
+ with graph.inserting_after(cat_node):
209
+ div_node = call_function(graph, last_input_node.target, (cat_node, divisors[0]))
210
+
211
+ # Create a new Comm/all_reduce node.
212
+ last_comm_node = last_comm_block.comm_node
213
+ last_wait_node = last_comm_block.wait_nodes[0]
214
+ with graph.inserting_after(div_node):
215
+ flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs))
216
+ flatten_args[0] = div_node
217
+ args, kwargs = tree_unflatten(flatten_args, spec)
218
+ fused_comm_node = call_function(graph, last_comm_node.target, args, kwargs)
219
+
220
+ # Create a new Wait node.
221
+ with graph.inserting_after(fused_comm_node):
222
+ flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs))
223
+ flatten_args[0] = fused_comm_node
224
+ args, kwargs = tree_unflatten(flatten_args, spec)
225
+ fused_wait_node = call_function(graph, last_wait_node.target, args, kwargs)
226
+
227
+ # Move the fused all_reduce and its args to right after the input node
228
+ nodes_to_move = cat_inputs + [cat_node, div_node, fused_comm_node, fused_wait_node]
229
+ move_block_after(nodes_to_move, last_input_node)
230
+
231
+ return CommBlock(
232
+ shape=cast(TensorMetadata, cat_node.meta.get("tensor_meta")).shape,
233
+ node_list=[fused_comm_node, fused_wait_node],
234
+ wait_nodes=[fused_wait_node],
235
+ comm_node=fused_comm_node,
236
+ inputs=[div_node],
237
+ outputs={fused_wait_node},
238
+ )
239
+
240
+
241
+ def _fuse_with_coalesced_op(
242
+ graph: fx.Graph,
243
+ last_input_node: fx.Node,
244
+ all_input_nodes: List[fx.Node],
245
+ last_comm_block: CommBlock,
246
+ ) -> CommBlock:
247
+ """Given a list of inputs in order, create a fused allreduce by coalesced."""
248
+ last_comm_node = last_comm_block.comm_node
249
+ last_wait_node = last_comm_block.wait_nodes[0]
250
+
251
+ # Insert the fused div node and remove the input div nodes.
252
+ # This is an optimization and is not mandatory for fusion.
253
+ dividends = [div.args[0] for div in all_input_nodes]
254
+ divisors = [div.args[1] for div in all_input_nodes]
255
+ assert all(divisor == divisors[0] for divisor in divisors)
256
+ with graph.inserting_before(last_input_node):
257
+ last_input_node = call_function(
258
+ graph, aten._foreach_div.Scalar, (dividends, divisors[0])
259
+ )
260
+ input_node = last_input_node
261
+
262
+ # Create a new Comm/all_reduce_coalesced node.
263
+ with graph.inserting_after(last_comm_node):
264
+ flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs))
265
+ flatten_args[0] = input_node
266
+ args, kwargs = tree_unflatten(flatten_args, spec)
267
+ fused_comm_node = call_function(
268
+ graph, torch.ops._c10d_functional.all_reduce_coalesced.default, args, kwargs
269
+ )
270
+
271
+ # Create a new wait node.
272
+ getitem_nodes = []
273
+ wait_nodes = []
274
+ flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs))
275
+ for idx in range(len(all_input_nodes)):
276
+ with graph.inserting_after(fused_comm_node):
277
+ gi_node = call_function(graph, operator.getitem, (fused_comm_node, idx))
278
+ getitem_nodes.append(gi_node)
279
+ flatten_args[0] = gi_node
280
+ args, kwargs = tree_unflatten(flatten_args, spec)
281
+ with graph.inserting_after(gi_node):
282
+ wait_nodes.append(call_function(graph, last_wait_node.target, args, kwargs))
283
+
284
+ # Move the new all_reduce_coalesced and its args to right after the input node
285
+ nodes_to_move = [fused_comm_node] + getitem_nodes + wait_nodes
286
+ move_block_after(nodes_to_move, last_input_node)
287
+
288
+ return CommBlock(
289
+ shape=[
290
+ tm.shape
291
+ for tm in cast(
292
+ List[TensorMetadata], fused_comm_node.meta.get("tensor_meta")
293
+ )
294
+ ],
295
+ node_list=[fused_comm_node] + getitem_nodes + wait_nodes,
296
+ wait_nodes=wait_nodes,
297
+ comm_node=fused_comm_node,
298
+ inputs=[input_node],
299
+ outputs=set(wait_nodes),
300
+ )
301
+
302
+
303
+ def _scatter_fused_allreduce_waits(
304
+ graph: fx.Graph,
305
+ fused_comm_block: CommBlock,
306
+ orig_comm_blocks: List[CommBlock],
307
+ node_indices: Dict[fx.Node, int],
308
+ split_and_reshape: bool = True,
309
+ ) -> None:
310
+ """
311
+ Scatters the result of the fused communication node to the original users.
312
+ If the fused method is concat splitting the output and reshape will be inserted,
313
+ before inserting getitem. Otherwise getitem will be used as the users of the
314
+ wait node.
315
+ """
316
+
317
+ # Before we mass up the order, we need to get the index of the last wait node
318
+ # in orig_comm_blocks. This index will be later used to determinee what users
319
+ # nodes need to be move to maintain a correct topological sort order.
320
+ last_wait_node_idx = 0
321
+ for node in graph.nodes:
322
+ last_wait_node_idx = max(
323
+ node_indices.get(node, last_wait_node_idx), last_wait_node_idx
324
+ )
325
+ if node == orig_comm_blocks[-1].wait_nodes[0]:
326
+ break
327
+
328
+ if split_and_reshape:
329
+ fused_wait_node = fused_comm_block.wait_nodes[0]
330
+ with graph.inserting_after(fused_wait_node):
331
+ split_node = call_function(
332
+ graph,
333
+ aten.split,
334
+ (
335
+ fused_wait_node,
336
+ [math.prod(cast(List[int], cb.shape)) for cb in orig_comm_blocks],
337
+ ),
338
+ )
339
+ with graph.inserting_after(split_node):
340
+ fused_outputs = []
341
+ for idx, comm_block in enumerate(orig_comm_blocks):
342
+ split_idx_node = call_function(
343
+ graph, operator.getitem, (split_node, idx)
344
+ )
345
+ with graph.inserting_after(split_idx_node):
346
+ fused_outputs.append(
347
+ call_function(
348
+ graph, aten.reshape, (split_idx_node, comm_block.shape)
349
+ )
350
+ )
351
+ else:
352
+ fused_outputs = fused_comm_block.wait_nodes
353
+
354
+ # Scatter the fused outputs.
355
+ incorrect_order_nodes = []
356
+ for comm_block, fused_output in zip(orig_comm_blocks, fused_outputs):
357
+ # Some descendant users of the orig_comm_blocks may be scheduled before
358
+ # the fused all_reduce. For example, the user nodes of the very first
359
+ # all_reduce may be scheduled before the second all_reduce. Since the
360
+ # fused all_reduce is inserted right after the last all_reudce, the
361
+ # order can be wrong.
362
+ # `incorrect_order_nodes` records these nodes.
363
+
364
+ orig_wait = comm_block.wait_nodes[0]
365
+ nodes = collections.deque(list(orig_wait.users))
366
+ while nodes:
367
+ user_node = nodes.popleft()
368
+ if not isinstance(user_node, fx.Node):
369
+ continue
370
+ if node_indices[user_node] < last_wait_node_idx:
371
+ incorrect_order_nodes.append(user_node)
372
+ nodes.extend(list(user_node.users))
373
+
374
+ orig_wait.replace_all_uses_with(fused_output)
375
+
376
+ last_fused_result = fused_outputs[0]
377
+ fused_outputs_set = set(fused_outputs)
378
+ for node in graph.nodes:
379
+ if node in fused_outputs_set:
380
+ last_fused_result = node
381
+
382
+ # Move the incorrect_order_nodes to right after the last fused_result.
383
+ incorrect_order_nodes = sorted(
384
+ incorrect_order_nodes, key=lambda node: node_indices[node]
385
+ )
386
+ move_block_after(incorrect_order_nodes, last_fused_result)
387
+
388
+
389
+ def _fuse_allreduce(
390
+ graph: fx.Graph,
391
+ comm_blocks: List[CommBlock],
392
+ node_indices: Dict[fx.Node, int],
393
+ use_concat: bool,
394
+ ) -> CommBlock:
395
+ """Given a list of allreduce CommBlock, fuse the CommBlocks into one CommBlock."""
396
+
397
+ if len(comm_blocks) == 1:
398
+ return comm_blocks[0]
399
+
400
+ # Find the last input node of all the CommBlocks. This node will be served
401
+ # as the inserting point of the new collective op.
402
+ last_input_node = comm_blocks[0].inputs[0]
403
+ last_input_index = -1
404
+ all_input_nodes = []
405
+ for comm_block in comm_blocks:
406
+ input_node = comm_block.inputs[0]
407
+ all_input_nodes.append(input_node)
408
+ index = node_indices[input_node]
409
+ if index >= last_input_index:
410
+ assert index != last_input_index
411
+ last_input_node = input_node
412
+ last_input_index = index
413
+
414
+ if use_concat:
415
+ fused_comm_block = _fuse_allreduce_by_concat(
416
+ graph, last_input_node, all_input_nodes, comm_blocks[-1]
417
+ )
418
+ else:
419
+ fused_comm_block = _fuse_with_coalesced_op(
420
+ graph, last_input_node, all_input_nodes, comm_blocks[-1]
421
+ )
422
+
423
+ _scatter_fused_allreduce_waits(
424
+ graph, fused_comm_block, comm_blocks, node_indices, split_and_reshape=use_concat
425
+ )
426
+
427
+ for comm_block in comm_blocks:
428
+ for wait in comm_block.wait_nodes:
429
+ graph.erase_node(wait)
430
+ graph.erase_node(comm_block.comm_node)
431
+ graph.eliminate_dead_code()
432
+
433
+ return fused_comm_block
434
+
435
+
436
+ def _bucket_size_fusion(
437
+ graph: fx.Graph, comm_blocks: List[CommBlock], bucket_size_mb: int
438
+ ) -> Generator[List[CommBlock], None, None]:
439
+ MB = 1024**2
440
+ bucket_size = 1 * MB
441
+ bucket_cap_size = bucket_size_mb * MB
442
+ curr_size = 0
443
+ curr_blocks = []
444
+
445
+ count = 0
446
+ fuse_count = 0
447
+ for i, block in enumerate(comm_blocks):
448
+ curr_blocks.append(block)
449
+ itemsize = block.comm_node.meta["tensor_meta"].dtype.itemsize
450
+ curr_size += cast(torch.Size, block.shape).numel() * itemsize
451
+ count += 1
452
+ if curr_size < bucket_size and i != len(comm_blocks) - 1:
453
+ continue
454
+
455
+ fuse_count += 1
456
+ if torch.distributed.get_rank() == 0:
457
+ logger.info(
458
+ "DDP bucketing: block%d, count=%d, curr_size=%d, bucket_size=%d",
459
+ fuse_count,
460
+ count,
461
+ curr_size,
462
+ bucket_size,
463
+ )
464
+
465
+ # Set the debug counters
466
+ counters["inductor"]["ddp_buckets"] = fuse_count
467
+ yield curr_blocks
468
+
469
+ bucket_size = bucket_cap_size
470
+ curr_blocks = []
471
+ curr_size = 0
472
+ count = 0
473
+
474
+
475
+ def _fuse_ddp_communication(
476
+ graph: fx.Graph, algorithm_fn: Callable[..., Any], fusion_fn: Callable[..., Any]
477
+ ) -> None:
478
+ for output in reversed(graph.nodes):
479
+ if output.op == "output":
480
+ break
481
+
482
+ def ddp_reducer_filter(block: CommBlock) -> bool:
483
+ if (
484
+ not isinstance(block.comm_node.args[0], fx.Node)
485
+ or block.comm_node.args[0].target != aten.div.Tensor
486
+ ):
487
+ return False
488
+
489
+ if len(block.wait_nodes[0].users) != 1:
490
+ # gradient/wait node should only be used by one user
491
+ return False
492
+
493
+ # Two cases:
494
+ # 1. gradient/wait node should be directly used by the output
495
+ # if gradient is None before bwd.
496
+ # 2. gradient/wait node should be directly used by copy_.
497
+ if (
498
+ output not in block.wait_nodes[0].users
499
+ and next(iter(block.wait_nodes[0].users)).target != aten.copy_.default
500
+ ):
501
+ return False
502
+
503
+ return True
504
+
505
+ ops = (
506
+ torch.ops._c10d_functional.all_reduce_.default,
507
+ torch.ops._c10d_functional.all_reduce.default,
508
+ )
509
+ comm_blocks = get_all_comm_blocks(graph, ops, comm_filter=ddp_reducer_filter)
510
+ node_indices = {node: i for i, node in enumerate(graph.nodes)}
511
+
512
+ for block in algorithm_fn(graph, comm_blocks):
513
+ fusion_fn(graph, block, node_indices)
514
+
515
+
516
+ def fuse_ddp_with_coalesced_op(graph: fx.Graph, bucket_size_mb: int) -> None:
517
+ _fuse_ddp_communication(
518
+ graph,
519
+ partial(_bucket_size_fusion, bucket_size_mb=bucket_size_mb),
520
+ partial(_fuse_allreduce, use_concat=False),
521
+ )
522
+
523
+
524
+ def fuse_ddp_with_concat_op(graph: fx.Graph, bucket_size_mb: int) -> None:
525
+ _fuse_ddp_communication(
526
+ graph,
527
+ partial(_bucket_size_fusion, bucket_size_mb=bucket_size_mb),
528
+ partial(_fuse_allreduce, use_concat=True),
529
+ )
530
+
531
+
532
+ def schedule_comm_wait(graph: fx.Graph) -> None:
533
+ """
534
+ Delay the execution of wait tensors of allreduce until its first user.
535
+
536
+ This algorithm considers the intermediate users, like split, getitem,
537
+ of the wait node and schedule those intermediate users as well.
538
+ This will result in a better overlapping result.
539
+ """
540
+ ops = (
541
+ torch.ops._c10d_functional.all_reduce_.default,
542
+ torch.ops._c10d_functional.all_reduce.default,
543
+ torch.ops._c10d_functional.all_reduce_coalesced.default,
544
+ torch.ops._c10d_functional.all_reduce_coalesced_.default,
545
+ )
546
+ comm_blocks = get_all_comm_blocks(graph, ops)
547
+ if not comm_blocks:
548
+ return
549
+
550
+ # Find all the end users.
551
+ allreduce_users: Set[fx.Node] = set()
552
+ for allreduce in comm_blocks:
553
+ for output in allreduce.outputs:
554
+ allreduce_users.update(output.users)
555
+
556
+ node_indices = {node: i for i, node in enumerate(graph.nodes)}
557
+ for allreduce in comm_blocks:
558
+ # Find the earliest/first user -- target_node.
559
+ assert (
560
+ len(allreduce.outputs) >= 1
561
+ ), f"Found a allreduce that has zero outputs/users -- {allreduce}."
562
+ # Initialize the target node to avoid typing issues.
563
+ target_node = next(iter(next(iter(allreduce.outputs)).users))
564
+ target_node_index = 2**31
565
+ for user in (user for output in allreduce.outputs for user in output.users):
566
+ index = node_indices[user]
567
+ if index < target_node_index:
568
+ target_node = user
569
+ target_node_index = index
570
+
571
+ # Move wait nodes and all the subsequent nodes in the comm_block to
572
+ # before the first user -- target_node.
573
+ wait_idx = -1
574
+ for wait_idx, node in enumerate(allreduce.node_list):
575
+ if node == allreduce.wait_nodes[0]:
576
+ break
577
+ assert wait_idx >= 0
578
+ move_block_before(allreduce.node_list[wait_idx:], target_node)
579
+
580
+
581
+ def fuse_ddp_communication(
582
+ graph: fx.Graph, passes: List[Union[Callable[..., None], str]], bucket_size_mb: int
583
+ ) -> None:
584
+ for i, pa in enumerate(passes):
585
+ with GraphTransformObserver(
586
+ graph.owning_module,
587
+ f"fuse_ddp_communication_pass_{i}",
588
+ config.trace.log_url_for_graph_xform,
589
+ ):
590
+ if isinstance(pa, str):
591
+ func = globals()[pa]
592
+ else:
593
+ func = pa
594
+ if "bucket_size_mb" in {
595
+ v.name for v in inspect.signature(func).parameters.values()
596
+ }:
597
+ func(graph, bucket_size_mb=bucket_size_mb)
598
+ else:
599
+ func(graph)
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import logging
3
+ from typing import List
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ from torch._dynamo.utils import counters
8
+
9
+ from .. import config
10
+ from ..pattern_matcher import Arg, CallFunction, Match, register_graph_pattern
11
+ from .split_cat import construct_pattern_matcher_pass
12
+
13
+
14
+ aten = torch.ops.aten
15
+ log = logging.getLogger(__name__)
16
+
17
+ # TODO: need a better strategy for decomposing mm
18
+ MIN_FIRST_DIMENSION_DECOMPOSITION = 10240
19
+ MAX_OTHER_DIMENSION_DECOMPOSITION = 32
20
+
21
+ min_first_dimension_decomposition = MIN_FIRST_DIMENSION_DECOMPOSITION
22
+ max_other_dimention_decomposition = MAX_OTHER_DIMENSION_DECOMPOSITION
23
+ if "decompose_mm_pass" in config.post_grad_fusion_options:
24
+ min_first_dimension_decomposition = config.post_grad_fusion_options[
25
+ "decompose_mm_pass"
26
+ ].get("min_first_dimension_decomposition", MIN_FIRST_DIMENSION_DECOMPOSITION)
27
+ max_other_dimention_decomposition = config.post_grad_fusion_options[
28
+ "decompose_mm_pass"
29
+ ].get("max_other_dimention_decomposition", MAX_OTHER_DIMENSION_DECOMPOSITION)
30
+
31
+
32
+ def check_device(a: Tensor, b: Tensor) -> bool:
33
+ return a.is_cuda and b.is_cuda
34
+
35
+
36
+ def realize_inputs(inputs: List[torch.fx.Node]):
37
+ for inp in inputs:
38
+ if isinstance(inp, torch.fx.node.Node):
39
+ inp.meta["inductor_realize_to_strides"] = True
40
+
41
+
42
+ def should_decompose_bmm(mat1, mat2) -> bool:
43
+ if is_node_meta_valid(mat1) and is_node_meta_valid(mat2):
44
+ mat1 = mat1.meta["val"]
45
+ mat2 = mat2.meta["val"]
46
+ else:
47
+ return False
48
+ if not check_device(mat1, mat2):
49
+ return False
50
+ else:
51
+ if len(mat1.shape) != 3 or len(mat2.shape) != 3:
52
+ return False
53
+ if mat1.shape[0] < min_first_dimension_decomposition:
54
+ return False
55
+ # 2 of m, n, k must be <= MAX_OTHER_DIMENSION_DECOMPOSITION
56
+ if (mat1.shape[1] < max_other_dimention_decomposition) + (
57
+ mat1.shape[2] < max_other_dimention_decomposition
58
+ ) + (mat2.shape[2] < max_other_dimention_decomposition) < 2:
59
+ return False
60
+ return True
61
+
62
+
63
+ def should_decompose_mm(mat1, mat2) -> bool:
64
+ if is_node_meta_valid(mat1) and is_node_meta_valid(mat2):
65
+ mat1 = mat1.meta["val"]
66
+ mat2 = mat2.meta["val"]
67
+ else:
68
+ return False
69
+ return (
70
+ check_device(mat1, mat2)
71
+ and len(mat1.shape) == 2
72
+ and len(mat2.shape) == 2
73
+ and mat1.shape[0] >= min_first_dimension_decomposition
74
+ and mat2.shape[0] < max_other_dimention_decomposition
75
+ and mat2.shape[1] < max_other_dimention_decomposition
76
+ )
77
+
78
+
79
+ def is_node_meta_valid(node: torch.fx.Node):
80
+ return "val" in node.meta
81
+
82
+
83
+ def print_decompose_pattern(match: Match, inputs: List[torch.fx.Node]):
84
+ node = match.nodes[-1]
85
+ log.debug(
86
+ "Decompose %s with input shape: %s",
87
+ node.target,
88
+ ", ".join(
89
+ str(input.meta["val"].shape) if "val" in input.meta else "None"
90
+ for input in inputs
91
+ ),
92
+ )
93
+
94
+
95
+ @register_graph_pattern(
96
+ CallFunction(aten.bmm, Arg(), Arg()),
97
+ pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"),
98
+ )
99
+ def decompose_bmm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node):
100
+ def repl(mat1, mat2):
101
+ return torch.sum(mat1[:, :, :, None] * mat2[:, None, :, :], dim=-2).to(
102
+ mat1.dtype
103
+ )
104
+
105
+ if should_decompose_bmm(mat1, mat2):
106
+ counters["inductor"]["decompose_bmm"] += 1
107
+ match.replace_by_example(repl, [mat1, mat2])
108
+ print_decompose_pattern(match, [mat1, mat2])
109
+ realize_inputs([mat1, mat2])
110
+ return
111
+
112
+
113
+ @register_graph_pattern(
114
+ CallFunction(aten.addmm, Arg(), Arg(), Arg()),
115
+ pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"),
116
+ )
117
+ def decompose_addmm(
118
+ match: Match,
119
+ mat1: torch.fx.Node,
120
+ mat2: torch.fx.Node,
121
+ mat3: torch.fx.Node,
122
+ ):
123
+ def repl(mat1, mat2, mat3):
124
+ return (
125
+ torch.sum(mat2[:, :, None] * mat3[None, :, :], dim=-2).to(mat2.dtype) + mat1
126
+ )
127
+
128
+ if should_decompose_mm(mat2, mat3):
129
+ counters["inductor"]["decompose_addmm"] += 1
130
+ match.replace_by_example(repl, [mat1, mat2, mat3])
131
+ print_decompose_pattern(match, [mat1, mat2, mat3])
132
+ realize_inputs([mat1, mat2, mat3])
133
+ return
134
+
135
+
136
+ @register_graph_pattern(
137
+ CallFunction(aten.mm, Arg(), Arg()),
138
+ pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"),
139
+ )
140
+ def decompose_mm(
141
+ match: Match,
142
+ mat1: torch.fx.Node,
143
+ mat2: torch.fx.Node,
144
+ ):
145
+ def repl(mat1, mat2):
146
+ return torch.sum(mat1[:, :, None] * mat2[None, :, :], dim=-2).to(mat1.dtype)
147
+
148
+ if should_decompose_mm(mat1, mat2):
149
+ counters["inductor"]["decompose_mm"] += 1
150
+ match.replace_by_example(repl, [mat1, mat2])
151
+ print_decompose_pattern(match, [mat1, mat2])
152
+ realize_inputs([mat1, mat2])
153
+ return
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from dataclasses import dataclass
3
+ from typing import Union
4
+
5
+ import torch
6
+ from torch import SymBool, SymFloat, SymInt
7
+ from torch.types import py_sym_types
8
+
9
+
10
+ @dataclass
11
+ class _SymExprHash:
12
+ """
13
+ Hash for a py_sym_types that will use the underlying sympy expression
14
+ """
15
+
16
+ sym_obj: Union[SymInt, SymFloat, SymBool]
17
+
18
+ def __hash__(self) -> int:
19
+ return hash((type(self.sym_obj), self.sym_obj.node.expr))
20
+
21
+ def __eq__(self, value) -> bool:
22
+ if not isinstance(value, _SymExprHash):
23
+ return False
24
+ return self.sym_obj.node.expr == value.sym_obj.node.expr
25
+
26
+
27
+ class _SymHashingDict:
28
+ """
29
+ Wrapper around a dictionary that will convert sym types to hash with _SymExprHash and reuse
30
+ existing sym proxies.
31
+
32
+ SymPy hash is not always reliable so optimistically hash sympy expression, and if those fail,
33
+ fallback to symnodes.
34
+ """
35
+
36
+ def __init__(self):
37
+ self.sym_hash_dict = {}
38
+
39
+ def __setitem__(self, key, value):
40
+ self.sym_hash_dict.__setitem__(self._wrap_to_sym_expr_hash(key), value)
41
+
42
+ def __getitem__(self, key):
43
+ return self.sym_hash_dict[self._wrap_to_sym_expr_hash(key)]
44
+
45
+ def __contains__(self, key):
46
+ return self._wrap_to_sym_expr_hash(key) in self.sym_hash_dict
47
+
48
+ def get(self, key, default=None):
49
+ return self.sym_hash_dict.get(self._wrap_to_sym_expr_hash(key), default)
50
+
51
+ def _wrap_to_sym_expr_hash(self, key):
52
+ return _SymExprHash(key) if isinstance(key, py_sym_types) else key
53
+
54
+
55
+ def dedupe_symints(graph: torch.fx.Graph):
56
+ """
57
+ Dedupes sym ints in the graph to nodes are resolvable to symint graph inputs.
58
+
59
+ We only dedupe from graph inputs to avoid adding a potential dependency in the forward
60
+ from the backward.
61
+
62
+ """
63
+
64
+ sym_dict = _SymHashingDict()
65
+ resolvable_from_input_symints = set()
66
+
67
+ for node in graph.nodes:
68
+ val = node.meta.get("val", None)
69
+ if val is None or not isinstance(val, py_sym_types):
70
+ continue
71
+
72
+ if node.op == "placeholder":
73
+ resolvable_from_input_symints.add(node)
74
+ sym_dict[val] = node
75
+ elif existing_node := sym_dict.get(val):
76
+ node.replace_all_uses_with(existing_node)
77
+ graph.erase_node(node)
78
+ elif all(n in resolvable_from_input_symints for n in node.all_input_nodes):
79
+ sym_dict[val] = node
80
+ resolvable_from_input_symints.add(node)
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch._dynamo.utils import counters
5
+ from torch._inductor import config as inductor_config
6
+ from torch.func import functional_call
7
+
8
+ from ..pattern_matcher import (
9
+ CallFunctionVarArgs,
10
+ CallModuleVarArgs,
11
+ Match,
12
+ register_graph_pattern,
13
+ )
14
+ from .pre_grad import efficient_conv_bn_eval_pass
15
+
16
+
17
+ def efficient_conv_bn_eval(
18
+ bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor
19
+ ):
20
+ """
21
+ Implementation based on https://arxiv.org/abs/2305.11624
22
+ "Efficient ConvBN Blocks for Transfer Learning and Beyond"
23
+ It leverages the associative law between convolution and affine transform,
24
+ i.e., normalize (weight conv feature) = (normalize weight) conv feature.
25
+ It works for Eval mode of ConvBN blocks during validation, and can be used
26
+ for **training** as well, but only if one sets `bn.training=False`. It
27
+ reduces memory footprint and computation cost, at the cost of slightly
28
+ reduced numerical stability.
29
+ Args:
30
+ bn (nn.modules.batchnorm._BatchNorm): a BatchNorm module.
31
+ conv (nn.modules.conv._ConvNd): a conv module
32
+ x (torch.Tensor): Input feature map.
33
+ """
34
+
35
+ assert bn.running_var is not None
36
+
37
+ # These lines of code are designed to deal with various cases
38
+ # like bn without affine transform, and conv without bias
39
+ weight_on_the_fly = conv.weight
40
+ if conv.bias is not None:
41
+ bias_on_the_fly = conv.bias
42
+ else:
43
+ bias_on_the_fly = torch.zeros_like(bn.running_var)
44
+
45
+ if bn.weight is not None:
46
+ bn_weight = bn.weight
47
+ else:
48
+ bn_weight = torch.ones_like(bn.running_var)
49
+
50
+ if bn.bias is not None:
51
+ bn_bias = bn.bias
52
+ else:
53
+ bn_bias = torch.zeros_like(bn.running_var)
54
+
55
+ # shape of [C_out, 1, 1, 1] in Conv2d
56
+ target_shape = [-1] + [1] * (conv.weight.ndim - 1)
57
+ if isinstance(conv, nn.modules.conv._ConvTransposeNd):
58
+ # for transposed conv, the C_out dimension should at index 1.
59
+ target_shape[:2] = [target_shape[1], target_shape[0]]
60
+ weight_coeff = torch.rsqrt(bn.running_var + bn.eps).reshape(target_shape)
61
+ # shape of [C_out, 1, 1, 1] in Conv2d
62
+ coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff
63
+
64
+ # shape of [C_out, C_in, k, k] in Conv2d
65
+ weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
66
+ # shape of [C_out] in Conv2d
67
+ bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * (
68
+ bias_on_the_fly - bn.running_mean
69
+ )
70
+
71
+ input = x
72
+ params = {"weight": weight_on_the_fly, "bias": bias_on_the_fly}
73
+ output = functional_call(conv, params, input)
74
+ return output
75
+
76
+
77
+ def efficient_conv_bn_eval_decomposed(
78
+ bn_weight,
79
+ bn_bias,
80
+ bn_running_mean,
81
+ bn_running_var,
82
+ bn_eps,
83
+ conv: torch._ops.OpOverload,
84
+ conv_weight,
85
+ conv_bias,
86
+ x,
87
+ conv_remainging_args,
88
+ ):
89
+ """
90
+ Implementation based on https://arxiv.org/abs/2305.11624
91
+ "Efficient ConvBN Blocks for Transfer Learning and Beyond"
92
+ It leverages the associative law between convolution and affine transform,
93
+ i.e., normalize (weight conv feature) = (normalize weight) conv feature.
94
+ It works for Eval mode of ConvBN blocks during validation, and can be used
95
+ for **training** as well, but only if one sets `bn.training=False`. It
96
+ reduces memory footprint and computation cost, at the cost of slightly
97
+ reduced numerical stability.
98
+ Args:
99
+ """
100
+ assert bn_running_var is not None
101
+
102
+ # These lines of code are designed to deal with various cases
103
+ # like bn without affine transform, and conv without bias
104
+ weight_on_the_fly = conv_weight
105
+ if conv_bias is not None:
106
+ bias_on_the_fly = conv_bias
107
+ else:
108
+ bias_on_the_fly = torch.zeros_like(bn_running_var)
109
+
110
+ if bn_weight is not None:
111
+ bn_weight = bn_weight
112
+ else:
113
+ bn_weight = torch.ones_like(bn_running_var)
114
+
115
+ if bn_bias is not None:
116
+ bn_bias = bn_bias
117
+ else:
118
+ bn_bias = torch.zeros_like(bn_running_var)
119
+
120
+ # shape of [C_out, 1, 1, 1] in Conv2d
121
+ target_shape = [-1] + [1] * (conv_weight.ndim - 1)
122
+ if "conv_transpose" in conv.__str__():
123
+ # for transposed conv, the C_out dimension should at index 1.
124
+ target_shape[:2] = [target_shape[1], target_shape[0]]
125
+ weight_coeff = torch.rsqrt(bn_running_var + bn_eps).reshape(target_shape)
126
+ # shape of [C_out, 1, 1, 1] in Conv2d
127
+ coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff
128
+
129
+ # shape of [C_out, C_in, k, k] in Conv2d
130
+ weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
131
+ # shape of [C_out] in Conv2d
132
+ bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * (
133
+ bias_on_the_fly - bn_running_mean
134
+ )
135
+
136
+ input = x
137
+ return conv(*((input, weight_on_the_fly, bias_on_the_fly) + conv_remainging_args))
138
+
139
+
140
+ @register_graph_pattern(
141
+ CallFunctionVarArgs(
142
+ [
143
+ torch.nn.functional.batch_norm,
144
+ ]
145
+ ),
146
+ pass_dict=efficient_conv_bn_eval_pass,
147
+ extra_check=lambda match: not inductor_config.freezing
148
+ and inductor_config.efficient_conv_bn_eval_fx_passes,
149
+ )
150
+ def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs):
151
+ bn_node = match.nodes[0]
152
+ graph = match.graph
153
+ assert len(bn_node.args) == 8
154
+
155
+ # We can only use efficient conv-bn for eval mode with track_running_stats
156
+ # bn_node.args is `training`
157
+ if bn_node.args[-3]:
158
+ return
159
+
160
+ # Check if the input is Conv
161
+ input_node = bn_node.args[0]
162
+
163
+ if input_node.op != "call_function": # type: ignore[union-attr]
164
+ return
165
+
166
+ input_fn = input_node.target # type: ignore[arg-type, union-attr]
167
+ supported_convs = [
168
+ torch._C._nn.linear,
169
+ torch.conv1d,
170
+ torch.conv2d,
171
+ torch.conv3d,
172
+ torch.conv_transpose1d,
173
+ torch.conv_transpose2d,
174
+ torch.conv_transpose3d,
175
+ ]
176
+
177
+ if not any(input_fn is cls for cls in supported_convs):
178
+ return
179
+
180
+ conv_node = input_node
181
+ # Output of conv is used by other nodes, cannot optimize
182
+ if len(conv_node.users) > 1: # type: ignore[union-attr]
183
+ return
184
+
185
+ counters["inductor"]["efficient_conv_bn_eval"] += 1
186
+
187
+ with graph.inserting_before(bn_node):
188
+ # prepare args for the fused function
189
+ bn_running_mean = bn_node.args[1]
190
+ bn_running_var = bn_node.args[2]
191
+ bn_weight = bn_node.args[3]
192
+ bn_bias = bn_node.args[4]
193
+ bn_eps = bn_node.args[7]
194
+ assert len(conv_node.args) >= 2 # type: ignore[union-attr]
195
+ conv_input = conv_node.args[0] # type: ignore[union-attr]
196
+ conv_weight = conv_node.args[1] # type: ignore[union-attr]
197
+ conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None # type: ignore[union-attr]
198
+ conv_remainging_args = conv_node.args[3:] # type: ignore[union-attr]
199
+ args = (
200
+ bn_weight,
201
+ bn_bias,
202
+ bn_running_mean,
203
+ bn_running_var,
204
+ bn_eps,
205
+ conv_node.target, # type: ignore[union-attr]
206
+ conv_weight,
207
+ conv_bias,
208
+ conv_input,
209
+ conv_remainging_args,
210
+ )
211
+
212
+ # create a new node
213
+ new_node = graph.create_node(
214
+ op="call_function",
215
+ target=efficient_conv_bn_eval_decomposed,
216
+ args=args, # type: ignore[arg-type]
217
+ name="efficient_conv_bn_eval",
218
+ )
219
+
220
+ # this node replaces the original conv + bn, and therefore
221
+ # should replace the uses of bn_node
222
+ bn_node.replace_all_uses_with(new_node)
223
+ # take care of the deletion order:
224
+ # delete bn_node first, and then conv_node
225
+ graph.erase_node(bn_node)
226
+ graph.erase_node(conv_node) # type: ignore[arg-type]
227
+
228
+ return
229
+
230
+
231
+ @register_graph_pattern(
232
+ CallFunctionVarArgs(
233
+ [
234
+ torch.ops.aten.batch_norm.default,
235
+ ]
236
+ ),
237
+ pass_dict=efficient_conv_bn_eval_pass,
238
+ extra_check=lambda match: not inductor_config.freezing
239
+ and inductor_config.efficient_conv_bn_eval_fx_passes,
240
+ )
241
+ def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwargs):
242
+ bn_node = match.nodes[0]
243
+ graph = match.graph
244
+ assert len(bn_node.args) == 9
245
+
246
+ # We can only use efficient conv-bn for eval mode with track_running_stats
247
+ # bn_node.args is `training`
248
+ if bn_node.args[-4]:
249
+ return
250
+
251
+ # Check if the input is Conv
252
+ input_node = bn_node.args[0]
253
+
254
+ if input_node.op != "call_function": # type: ignore[union-attr]
255
+ return
256
+
257
+ input_fn = input_node.target # type: ignore[arg-type, union-attr]
258
+ supported_convs = [
259
+ torch.ops.aten.linear.default,
260
+ torch.ops.aten.conv1d.default,
261
+ torch.ops.aten.conv2d.default,
262
+ torch.ops.aten.conv3d.default,
263
+ torch.ops.aten.conv_transpose1d.default,
264
+ torch.ops.aten.conv_transpose2d.input,
265
+ torch.ops.aten.conv_transpose3d.input,
266
+ ]
267
+
268
+ if not any(input_fn is cls for cls in supported_convs):
269
+ return
270
+
271
+ conv_node = input_node
272
+ # Output of conv is used by other nodes, cannot optimize
273
+ if len(conv_node.users) > 1: # type: ignore[union-attr]
274
+ return
275
+
276
+ counters["inductor"]["efficient_conv_bn_eval"] += 1
277
+
278
+ with graph.inserting_before(bn_node):
279
+ # prepare args for the fused function
280
+ bn_weight = bn_node.args[1]
281
+ bn_bias = bn_node.args[2]
282
+ bn_running_mean = bn_node.args[3]
283
+ bn_running_var = bn_node.args[4]
284
+ bn_eps = bn_node.args[7]
285
+ assert len(conv_node.args) >= 2 # type: ignore[union-attr]
286
+ conv_input = conv_node.args[0] # type: ignore[union-attr]
287
+ conv_weight = conv_node.args[1] # type: ignore[union-attr]
288
+ conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None # type: ignore[union-attr]
289
+ conv_remainging_args = conv_node.args[3:] # type: ignore[union-attr]
290
+ args = (
291
+ bn_weight,
292
+ bn_bias,
293
+ bn_running_mean,
294
+ bn_running_var,
295
+ bn_eps,
296
+ conv_node.target, # type: ignore[union-attr]
297
+ conv_weight,
298
+ conv_bias,
299
+ conv_input,
300
+ conv_remainging_args,
301
+ )
302
+
303
+ # create a new node
304
+ new_node = graph.create_node(
305
+ op="call_function",
306
+ target=efficient_conv_bn_eval_decomposed,
307
+ args=args, # type: ignore[arg-type]
308
+ name="efficient_conv_bn_eval",
309
+ )
310
+
311
+ # this node replaces the original conv + bn, and therefore
312
+ # should replace the uses of bn_node
313
+ bn_node.replace_all_uses_with(new_node)
314
+ # take care of the deletion order:
315
+ # delete bn_node first, and then conv_node
316
+ graph.erase_node(bn_node)
317
+ graph.erase_node(conv_node) # type: ignore[arg-type]
318
+
319
+ return
320
+
321
+
322
+ @register_graph_pattern(
323
+ CallModuleVarArgs(
324
+ [
325
+ nn.modules.batchnorm._BatchNorm,
326
+ nn.BatchNorm1d,
327
+ nn.BatchNorm2d,
328
+ nn.BatchNorm3d,
329
+ nn.SyncBatchNorm,
330
+ ],
331
+ ),
332
+ pass_dict=efficient_conv_bn_eval_pass,
333
+ extra_check=lambda match: not inductor_config.freezing
334
+ and inductor_config.efficient_conv_bn_eval_fx_passes,
335
+ )
336
+ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
337
+ # We matched a BN node
338
+ bn_node = match.nodes[0]
339
+ graph = match.graph
340
+ gm = graph.owning_module
341
+ bn_mod = getattr(gm, bn_node.target) # type: ignore[arg-type]
342
+
343
+ # We can only use efficient conv-bn for eval mode with track_running_stats
344
+ if not bn_mod.track_running_stats or bn_mod.training:
345
+ return
346
+
347
+ # Check if the input is Conv
348
+ if bn_node.args:
349
+ input_node = bn_node.args[0]
350
+ else:
351
+ input_node = bn_node.kwargs["input"]
352
+ if input_node.op != "call_module": # type: ignore[union-attr]
353
+ return
354
+ if not hasattr(gm, input_node.target): # type: ignore[arg-type, union-attr]
355
+ return
356
+ input_mod = getattr(gm, input_node.target) # type: ignore[arg-type, union-attr]
357
+ supported_convs = [
358
+ nn.Linear,
359
+ nn.Conv1d,
360
+ nn.Conv2d,
361
+ nn.Conv3d,
362
+ nn.ConvTranspose1d,
363
+ nn.ConvTranspose2d,
364
+ nn.ConvTranspose3d,
365
+ ]
366
+ if not any(isinstance(input_mod, cls) for cls in supported_convs):
367
+ return
368
+ conv_node = input_node
369
+ # Output of conv is used by other nodes, cannot optimize
370
+ if len(conv_node.users) > 1: # type: ignore[union-attr]
371
+ return
372
+
373
+ # Find a pair of conv and bn computation nodes to optimize.
374
+ counters["inductor"]["efficient_conv_bn_eval"] += 1
375
+
376
+ with graph.inserting_before(conv_node): # type: ignore[arg-type]
377
+ # create `get_attr` node to access modules
378
+ # note that we directly call `create_node` to fill the `name`
379
+ # argument. `graph.get_attr` and
380
+ # `graph.call_function` does not allow the `name` argument.
381
+ conv_get_node = graph.create_node(
382
+ op="get_attr", target=conv_node.target, name="get_conv" # type: ignore[union-attr]
383
+ )
384
+ bn_get_node = graph.create_node(
385
+ op="get_attr", target=bn_node.target, name="get_bn"
386
+ )
387
+ if conv_node.args: # type: ignore[union-attr]
388
+ conv_input = conv_node.args[0] # type: ignore[union-attr]
389
+ else:
390
+ conv_input = conv_node.kwargs["input"] # type: ignore[union-attr]
391
+ # prepare args for the fused function
392
+ args = (bn_get_node, conv_get_node, conv_input)
393
+ # create a new node
394
+ new_node = graph.create_node(
395
+ op="call_function",
396
+ target=efficient_conv_bn_eval,
397
+ args=args,
398
+ name="efficient_conv_bn_eval",
399
+ )
400
+ # this node replaces the original conv + bn, and therefore
401
+ # should replace the uses of bn_node
402
+ bn_node.replace_all_uses_with(new_node)
403
+ # take care of the deletion order:
404
+ # delete bn_node first, and then conv_node
405
+ graph.erase_node(bn_node)
406
+ graph.erase_node(conv_node) # type: ignore[arg-type]
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/freezing_patterns.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+
4
+ import torch
5
+ from torch._inductor.compile_fx import fake_tensor_prop
6
+
7
+ from ..._dynamo.utils import counters
8
+ from .. import config
9
+ from ..pattern_matcher import (
10
+ _return_true,
11
+ CallFunction,
12
+ fwd_only,
13
+ Ignored,
14
+ init_once_fakemode,
15
+ KeywordArg,
16
+ Match,
17
+ PatternMatcherPass,
18
+ register_graph_pattern,
19
+ register_replacement,
20
+ stable_topological_sort,
21
+ )
22
+
23
+
24
+ aten = torch.ops.aten
25
+
26
+ # First pass_patterns[0] are applied, then [1], then [2]
27
+ pass_patterns = [
28
+ PatternMatcherPass(),
29
+ PatternMatcherPass(),
30
+ PatternMatcherPass(),
31
+ ]
32
+
33
+ binary_folding_pass = PatternMatcherPass()
34
+
35
+
36
+ def freezing_passes(gm: torch.fx.GraphModule, aot_example_inputs):
37
+ """
38
+ Passes that are applied to the graph to freeze pass.
39
+ """
40
+
41
+ from ..freezing import constant_fold
42
+
43
+ lazy_init()
44
+ # We need a few rounds of binary folding to get rid of all the
45
+ # unnecessary nodes, but may need a good method to chose the rounds number.
46
+ # works like: conv+binary+binary.
47
+ binary_folding = counters["inductor"]["binary_folding"]
48
+ fake_tensor_prop(gm, aot_example_inputs, True)
49
+
50
+ torch._inductor.fx_passes.binary_folding.mark_mixed_dtype_allowed_convs(gm)
51
+ for _ in range(4):
52
+ constant_fold(gm)
53
+ # Make sure meta['val'] is properly set for all nodes
54
+ fake_tensor_prop(gm, aot_example_inputs, True)
55
+ binary_folding_pass.apply(gm.graph) # type: ignore[arg-type]
56
+ # If we don't have binary folding, we don't need to run the pass again.
57
+ # TODO: remove the need to run fake_tensor_prop on the whole model.
58
+ if counters["inductor"]["binary_folding"] == binary_folding:
59
+ break
60
+ binary_folding = counters["inductor"]["binary_folding"]
61
+
62
+ torch._inductor.fx_passes.binary_folding.recover_original_precision_folded_convs(gm)
63
+
64
+ constant_fold(gm)
65
+ fake_tensor_prop(gm, aot_example_inputs, True)
66
+
67
+ for pattern in pass_patterns:
68
+ pattern.apply(gm.graph) # type: ignore[arg-type]
69
+
70
+ # The CPU weight packing always assume the conv's weight is channels last,
71
+ # So make sure the layout_optimization is on when doing it.
72
+ if (
73
+ torch._C._has_mkldnn
74
+ and config.cpp.weight_prepack
75
+ and config.layout_optimization
76
+ ):
77
+ from .mkldnn_fusion import _eliminate_duplicate_packed_nodes
78
+
79
+ _eliminate_duplicate_packed_nodes(gm)
80
+
81
+ stable_topological_sort(gm.graph)
82
+ gm.recompile()
83
+ gm.graph.lint()
84
+
85
+
86
+ @init_once_fakemode
87
+ def lazy_init():
88
+ if torch._C._has_mkldnn and config.cpp.weight_prepack:
89
+ from .mkldnn_fusion import _mkldnn_weight_pack_init
90
+
91
+ _mkldnn_weight_pack_init()
92
+
93
+ from .binary_folding import binary_folding_init
94
+
95
+ addmm_patterns_init()
96
+ binary_folding_init()
97
+
98
+
99
+ def register_freezing_graph_pattern(pattern, extra_check=_return_true, pass_number=0):
100
+ return register_graph_pattern(
101
+ pattern,
102
+ extra_check=extra_check,
103
+ pass_dict=pass_patterns[pass_number],
104
+ )
105
+
106
+
107
+ def register_binary_folding_pattern(pattern, extra_check=_return_true):
108
+ return register_graph_pattern(
109
+ pattern,
110
+ extra_check=extra_check,
111
+ pass_dict=binary_folding_pass,
112
+ )
113
+
114
+
115
+ @functools.lru_cache(None)
116
+ def addmm_patterns_init():
117
+ if torch.cuda.is_available():
118
+ # workaround https://github.com/pytorch/pytorch/issues/97894
119
+ device = "cuda"
120
+ else:
121
+ device = "cpu"
122
+ val = functools.partial(torch.empty, (10, 10), device=device, requires_grad=False)
123
+
124
+ def check_concat_weights(match):
125
+ weight_inputs = ["w1", "w2"]
126
+ if "w3" in match.kwargs:
127
+ weight_inputs.append("w3")
128
+
129
+ equal_shape_inputs = [weight_inputs]
130
+
131
+ if "b1" in match.kwargs:
132
+ bias_inputs = ["b1", "b2"]
133
+ if "b3" in match.kwargs:
134
+ bias_inputs.append("b3")
135
+
136
+ equal_shape_inputs.append(bias_inputs)
137
+
138
+ for equal_shape_group in equal_shape_inputs:
139
+ inps = [match.kwargs[name] for name in equal_shape_group]
140
+
141
+ if not all(
142
+ inp.op == "get_attr"
143
+ and inp.meta["val"].shape == inps[0].meta["val"].shape
144
+ for inp in inps
145
+ ):
146
+ return False
147
+
148
+ return True
149
+
150
+ def matmul_fuse_pattern(inp, w1, w2, w3):
151
+ return (inp @ w1, inp @ w2, inp @ w3)
152
+
153
+ def matmul_replacement(inp, w1, w2, w3):
154
+ cat_t = torch.cat((w1, w2, w3), dim=1)
155
+ mm = inp @ cat_t
156
+ return mm.chunk(3, dim=1)
157
+
158
+ register_replacement(
159
+ matmul_fuse_pattern,
160
+ matmul_replacement,
161
+ [val(), val(), val(), val()],
162
+ fwd_only,
163
+ pass_patterns[0],
164
+ extra_check=check_concat_weights,
165
+ exclusive_arg_names=("w1", "w2", "w3"),
166
+ )
167
+
168
+ def matmul_fuse_pattern_two(inp, w1, w2):
169
+ return (inp @ w1, inp @ w2)
170
+
171
+ def matmul_replacement_two(inp, w1, w2):
172
+ cat_t = torch.cat((w1, w2), dim=1)
173
+ mm = inp @ cat_t
174
+ return mm.chunk(2, dim=1)
175
+
176
+ register_replacement(
177
+ matmul_fuse_pattern_two,
178
+ matmul_replacement_two,
179
+ [val(), val(), val()],
180
+ fwd_only,
181
+ pass_patterns[0],
182
+ extra_check=check_concat_weights,
183
+ exclusive_arg_names=("w1", "w2"),
184
+ )
185
+
186
+ def addmm_fuse_pattern_second(inp, w1, w2, w3, b1, b2, b3):
187
+ return (
188
+ aten.addmm(b1, inp, w1),
189
+ aten.addmm(b2, inp, w2),
190
+ aten.addmm(b3, inp, w3),
191
+ )
192
+
193
+ def addmm_fuse_replacement_second(inp, w1, w2, w3, b1, b2, b3):
194
+ cat_w = torch.cat((w1, w2, w3), dim=1)
195
+ cat_b = torch.cat((b1, b2, b3))
196
+ return aten.addmm(cat_b, inp, cat_w).chunk(3, dim=1)
197
+
198
+ register_replacement(
199
+ addmm_fuse_pattern_second,
200
+ addmm_fuse_replacement_second,
201
+ [val() for _ in range(7)],
202
+ fwd_only,
203
+ pass_patterns[0],
204
+ extra_check=check_concat_weights,
205
+ exclusive_arg_names=("w1", "w2", "w3", "b1", "b2", "b3"),
206
+ )
207
+
208
+
209
+ def same_dtype(match):
210
+ return match.output_node().args[0].meta["val"].dtype == match.kwargs["dtype"]
211
+
212
+
213
+ @register_graph_pattern(
214
+ CallFunction(
215
+ torch.ops.prims.convert_element_type.default,
216
+ Ignored(),
217
+ KeywordArg("dtype"),
218
+ ),
219
+ pass_dict=pass_patterns[0],
220
+ extra_check=same_dtype,
221
+ )
222
+ def unnecessary_dtype_convert(match: Match, **kwargs):
223
+ """Remove unnecessary dtype conversion op, probably left as a result of Conv-Bn folding"""
224
+ graph = match.graph
225
+ node = match.output_node()
226
+ node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type]
227
+ graph.erase_node(node)
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/fuse_attention.py ADDED
@@ -0,0 +1,909 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+ import inspect
4
+ import logging
5
+ import math
6
+
7
+ import torch
8
+ from torch.nn.attention import sdpa_kernel, SDPBackend
9
+
10
+ from ..._dynamo.utils import counters
11
+ from ..pattern_matcher import (
12
+ filter_nodes,
13
+ fwd_only,
14
+ gen_register_replacement,
15
+ joint_fwd_bwd,
16
+ )
17
+
18
+
19
+ log = logging.getLogger(__name__)
20
+ aten = torch.ops.aten
21
+
22
+
23
+ if torch.version.hip:
24
+
25
+ def _scaled_dot_product_attention(*args, **kwargs):
26
+ with sdpa_kernel(backends=[SDPBackend.MATH, SDPBackend.FLASH_ATTENTION]):
27
+ return aten.scaled_dot_product_attention(*args, **kwargs)
28
+
29
+ else:
30
+ _scaled_dot_product_attention = aten.scaled_dot_product_attention
31
+
32
+
33
+ def _sfdp_pattern_1(query, key, value, inv_scale):
34
+ return (
35
+ torch.matmul(query, key.transpose(-2, -1))
36
+ .div(inv_scale)
37
+ .softmax(dim=-1)
38
+ .matmul(value)
39
+ )
40
+
41
+
42
+ def _sfdp_replacement_1(query, key, value, inv_scale):
43
+ counters["inductor"]["fuse_attention"] += 1
44
+ return _scaled_dot_product_attention(
45
+ query.contiguous(),
46
+ key.contiguous(),
47
+ value.contiguous(),
48
+ attn_mask=None,
49
+ dropout_p=0.0,
50
+ is_causal=False,
51
+ scale=1.0 / inv_scale,
52
+ )
53
+
54
+
55
+ def _sfdp_pattern_2(query, key, value, scale_factor):
56
+ return (
57
+ torch.matmul(query, key.transpose(-2, -1))
58
+ .mul(scale_factor)
59
+ .softmax(dim=-1)
60
+ .matmul(value)
61
+ )
62
+
63
+
64
+ def _sfdp_replacement_2(query, key, value, scale_factor):
65
+ counters["inductor"]["fuse_attention"] += 1
66
+ return _scaled_dot_product_attention(
67
+ query.contiguous(),
68
+ key.contiguous(),
69
+ value.contiguous(),
70
+ attn_mask=None,
71
+ dropout_p=0.0,
72
+ is_causal=False,
73
+ scale=scale_factor,
74
+ )
75
+
76
+
77
+ def _sfdp_pattern_3(query, key, value, inv_scale_factor, dropout_p):
78
+ return torch.nn.functional.dropout(
79
+ torch.matmul(query, key.transpose(-2, -1))
80
+ .div(inv_scale_factor)
81
+ .softmax(dim=-1),
82
+ p=dropout_p,
83
+ ).matmul(value)
84
+
85
+
86
+ def _sfdp_replacement_3(query, key, value, inv_scale_factor, dropout_p):
87
+ counters["inductor"]["fuse_attention"] += 1
88
+ return _scaled_dot_product_attention(
89
+ query.contiguous(),
90
+ key.contiguous(),
91
+ value.contiguous(),
92
+ attn_mask=None,
93
+ dropout_p=dropout_p,
94
+ is_causal=False,
95
+ scale=1.0 / inv_scale_factor,
96
+ )
97
+
98
+
99
+ def _sfdp_pattern_4(query, key, value, scale_factor, dropout_p):
100
+ return torch.nn.functional.dropout(
101
+ torch.matmul(query, key.transpose(-2, -1)).mul(scale_factor).softmax(dim=-1),
102
+ p=dropout_p,
103
+ ).matmul(value)
104
+
105
+
106
+ def _sfdp_replacement_4(query, key, value, scale_factor, dropout_p):
107
+ counters["inductor"]["fuse_attention"] += 1
108
+ return _scaled_dot_product_attention(
109
+ query.contiguous(),
110
+ key.contiguous(),
111
+ value.contiguous(),
112
+ attn_mask=None,
113
+ dropout_p=dropout_p,
114
+ is_causal=False,
115
+ scale=scale_factor,
116
+ )
117
+
118
+
119
+ def _sfdp_pattern_5(query, key, value, attn_mask):
120
+ attn_weight = torch.softmax(
121
+ (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1
122
+ )
123
+ # attn_weight = torch.dropout(attn_weight, dropout_p)
124
+ return attn_weight @ value
125
+
126
+
127
+ def _sfdp_replacement_5(query, key, value, attn_mask):
128
+ counters["inductor"]["fuse_attention"] += 1
129
+ return _scaled_dot_product_attention(
130
+ query.contiguous(),
131
+ key.contiguous(),
132
+ value.contiguous(),
133
+ attn_mask=attn_mask.to(dtype=query.dtype),
134
+ dropout_p=0.0,
135
+ is_causal=False,
136
+ )
137
+
138
+
139
+ def _sfdp_pattern_6(query, key, value, attn_mask, dropout_p):
140
+ attn_weight = torch.softmax(
141
+ (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1
142
+ )
143
+ attn_weight = torch.dropout(attn_weight, dropout_p, True)
144
+ return attn_weight @ value
145
+
146
+
147
+ def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p):
148
+ counters["inductor"]["fuse_attention"] += 1
149
+ return _scaled_dot_product_attention(
150
+ query.contiguous(),
151
+ key.contiguous(),
152
+ value.contiguous(),
153
+ attn_mask=attn_mask.to(dtype=query.dtype),
154
+ dropout_p=dropout_p,
155
+ is_causal=False,
156
+ )
157
+
158
+
159
+ def _sfdp_pattern_7(query, key, value, dropout_p):
160
+ # in real workloads inputs to matmul are permuted
161
+ # causing matmul to expand to a series of expand and clone calls
162
+ # we want the same to happen during pattern tracing
163
+ q = query.permute(0, 2, 1, 3)
164
+ k = key.permute(0, 2, 1, 3)
165
+ v = value.permute(0, 2, 1, 3)
166
+ div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
167
+ div = div.to(torch.float32)
168
+ attn_weight = torch.softmax(div, dim=-1)
169
+ attn_weight = torch.dropout(attn_weight, dropout_p, True)
170
+ attn_weight = attn_weight.to(torch.float16)
171
+ return attn_weight @ v
172
+
173
+
174
+ def _sfdp_replacement_7(query, key, value, dropout_p):
175
+ # sdpa prefers inputs in permuted format
176
+ # it makes a copy to put them in this format
177
+ # if they aren't already
178
+ # to make replacement efficient ensure that inputs to sdpa
179
+ # are in required order
180
+ counters["inductor"]["fuse_attention"] += 1
181
+ q = query.permute(0, 2, 1, 3)
182
+ k = key.permute(0, 2, 1, 3)
183
+ v = value.permute(0, 2, 1, 3)
184
+ return _scaled_dot_product_attention(
185
+ q,
186
+ k,
187
+ v,
188
+ attn_mask=None, # attn_mask,
189
+ dropout_p=dropout_p,
190
+ is_causal=False,
191
+ )
192
+
193
+
194
+ def _sfdp_pattern_8(query, key, value):
195
+ # no dropout version of pattern 7
196
+ q = query.permute(0, 2, 1, 3)
197
+ k = key.permute(0, 2, 1, 3)
198
+ v = value.permute(0, 2, 1, 3)
199
+ div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
200
+ div = div.to(torch.float32)
201
+ attn_weight = torch.softmax(div, dim=-1)
202
+ attn_weight = attn_weight.to(torch.float16)
203
+ return attn_weight @ v
204
+
205
+
206
+ def _sfdp_replacement_8(query, key, value):
207
+ counters["inductor"]["fuse_attention"] += 1
208
+ q = query.permute(0, 2, 1, 3)
209
+ k = key.permute(0, 2, 1, 3)
210
+ v = value.permute(0, 2, 1, 3)
211
+ return _scaled_dot_product_attention(
212
+ q,
213
+ k,
214
+ v,
215
+ attn_mask=None, # attn_mask,
216
+ dropout_p=0.0,
217
+ is_causal=False,
218
+ )
219
+
220
+
221
+ def _sfdp_pattern_9(query, key, value, dropout_p):
222
+ q = query.permute(0, 2, 1, 3)
223
+ k = key.permute(0, 2, 1, 3)
224
+ v = value.permute(0, 2, 1, 3)
225
+ q = q / math.sqrt(q.size(-1))
226
+ div = q @ k.transpose(-2, -1)
227
+ div = div.to(torch.float32)
228
+ attn_weight = torch.softmax(div, dim=-1)
229
+ attn_weight = torch.dropout(attn_weight, dropout_p, True)
230
+ attn_weight = attn_weight.to(torch.float16)
231
+ return attn_weight @ v
232
+
233
+
234
+ def _sfdp_replacement_9(query, key, value, dropout_p):
235
+ counters["inductor"]["fuse_attention"] += 1
236
+ q = query.permute(0, 2, 1, 3)
237
+ k = key.permute(0, 2, 1, 3)
238
+ v = value.permute(0, 2, 1, 3)
239
+ return _scaled_dot_product_attention(
240
+ q,
241
+ k,
242
+ v,
243
+ attn_mask=None, # attn_mask,
244
+ dropout_p=dropout_p,
245
+ is_causal=False,
246
+ )
247
+
248
+
249
+ def _sfdp_pattern_10(query, key, value):
250
+ # no dropout version of 9
251
+ q = query.permute(0, 2, 1, 3)
252
+ k = key.permute(0, 2, 1, 3)
253
+ v = value.permute(0, 2, 1, 3)
254
+ q = q / math.sqrt(q.size(-1))
255
+ div = q @ k.transpose(-2, -1)
256
+ div = div.to(torch.float32)
257
+ attn_weight = torch.softmax(div, dim=-1)
258
+ attn_weight = attn_weight.to(torch.float16)
259
+ return attn_weight @ v
260
+
261
+
262
+ def _sfdp_replacement_10(query, key, value):
263
+ counters["inductor"]["fuse_attention"] += 1
264
+ q = query.permute(0, 2, 1, 3)
265
+ k = key.permute(0, 2, 1, 3)
266
+ v = value.permute(0, 2, 1, 3)
267
+ return _scaled_dot_product_attention(
268
+ q,
269
+ k,
270
+ v,
271
+ attn_mask=None, # attn_mask,
272
+ dropout_p=0.0,
273
+ is_causal=False,
274
+ )
275
+
276
+
277
+ def _sfdp_pattern_11(query, key, value, inv_scale):
278
+ # Mainly for huggingface models
279
+ q = query.permute(0, 2, 1, 3)
280
+ k = key.permute(0, 2, 1, 3)
281
+ v = value.permute(0, 2, 1, 3)
282
+ return torch.matmul(q, k.transpose(-2, -1)).div(inv_scale).softmax(dim=-1).matmul(v)
283
+
284
+
285
+ def _sfdp_replacement_11(query, key, value, inv_scale):
286
+ counters["inductor"]["fuse_attention"] += 1
287
+ return _scaled_dot_product_attention(
288
+ query.transpose(1, 2),
289
+ key.transpose(1, 2),
290
+ value.transpose(1, 2),
291
+ attn_mask=None,
292
+ dropout_p=0.0,
293
+ is_causal=False,
294
+ scale=1.0 / inv_scale,
295
+ )
296
+
297
+
298
+ def _sfdp_pattern_12(query, key, value, inv_scale_factor, dropout_p):
299
+ q = query.permute(0, 2, 1, 3)
300
+ k = key.permute(0, 2, 1, 3)
301
+ v = value.permute(0, 2, 1, 3)
302
+ return torch.nn.functional.dropout(
303
+ torch.matmul(q, k.transpose(-2, -1)).div(inv_scale_factor).softmax(dim=-1),
304
+ p=dropout_p,
305
+ ).matmul(v)
306
+
307
+
308
+ def _sfdp_replacement_12(query, key, value, inv_scale_factor, dropout_p):
309
+ counters["inductor"]["fuse_attention"] += 1
310
+ return _scaled_dot_product_attention(
311
+ query.transpose(1, 2),
312
+ key.transpose(1, 2),
313
+ value.transpose(1, 2),
314
+ attn_mask=None,
315
+ dropout_p=dropout_p,
316
+ is_causal=False,
317
+ scale=1.0 / inv_scale_factor,
318
+ )
319
+
320
+
321
+ def _sfdp_pattern_13(query, key, value, dropout_p):
322
+ attn_weight = torch.bmm(query, key.transpose(1, 2)).softmax(dim=-1)
323
+ attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p)
324
+ return torch.bmm(attn_weight, value)
325
+
326
+
327
+ def _sfdp_replacement_13(query, key, value, dropout_p):
328
+ counters["inductor"]["fuse_attention"] += 1
329
+ return _scaled_dot_product_attention(
330
+ query.unsqueeze(0),
331
+ key.unsqueeze(0),
332
+ value.unsqueeze(0),
333
+ dropout_p=dropout_p,
334
+ scale=1.0,
335
+ ).squeeze(0)
336
+
337
+
338
+ def _sfdp_pattern_14(query, key, value, attn_mask, inv_scale):
339
+ # for BertLarge
340
+ # Permutations are needed to create clones in graph.
341
+ q = query.permute([0, 2, 1, 3])
342
+ k = key.permute([0, 2, 1, 3])
343
+ v = value.permute([0, 2, 1, 3])
344
+ return (
345
+ (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask)
346
+ .softmax(dim=-1)
347
+ .matmul(v)
348
+ )
349
+
350
+
351
+ def _sfdp_replacement_14(query, key, value, attn_mask, inv_scale):
352
+ counters["inductor"]["fuse_attention"] += 1
353
+ return _scaled_dot_product_attention(
354
+ query.transpose(1, 2),
355
+ key.transpose(1, 2),
356
+ value.transpose(1, 2),
357
+ attn_mask=attn_mask.to(dtype=query.dtype),
358
+ dropout_p=0.0,
359
+ is_causal=False,
360
+ scale=1.0 / inv_scale,
361
+ )
362
+
363
+
364
+ def _sfdp_pattern_15(query, key, value, attn_mask, inv_scale):
365
+ # for DistilBert
366
+ # Permutations are needed to create clones in graph.
367
+ # Ref: https://github.com/pytorch/pytorch/issues/119911
368
+ q = query.permute([0, 2, 1, 3])
369
+ k = key.permute([0, 2, 1, 3])
370
+ v = value.permute([0, 2, 1, 3])
371
+ bs = q.size(0)
372
+ k_len = k.size(-2)
373
+ scores = q @ k.transpose(-2, -1)
374
+ scores = scores.div(inv_scale)
375
+ fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
376
+ attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores)
377
+ return torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1) @ v
378
+
379
+
380
+ def _sfdp_replacement_15(query, key, value, attn_mask, inv_scale):
381
+ counters["inductor"]["fuse_attention"] += 1
382
+ bs = query.size(0)
383
+ n_head = query.size(2)
384
+ q_len = query.size(1)
385
+ k_len = key.size(1)
386
+ # do attn_mask->logical_not() in _scaled_dot_product_attention
387
+ attn_mask = (
388
+ (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len))
389
+ )
390
+ return _scaled_dot_product_attention(
391
+ query.transpose(1, 2),
392
+ key.transpose(1, 2),
393
+ value.transpose(1, 2),
394
+ attn_mask=attn_mask.to(dtype=torch.bool),
395
+ dropout_p=0.0,
396
+ is_causal=False,
397
+ scale=1.0 / inv_scale,
398
+ )
399
+
400
+
401
+ def _sfdp_pattern_16(query, key, value, attn_mask, inv_scale, dropout_p):
402
+ # for BertLarge with dropout
403
+ q = query.permute([0, 2, 1, 3])
404
+ k = key.permute([0, 2, 1, 3])
405
+ v = value.permute([0, 2, 1, 3])
406
+ return (
407
+ torch.nn.functional.dropout(
408
+ (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask).softmax(
409
+ dim=-1
410
+ ),
411
+ dropout_p,
412
+ )
413
+ .to(dtype=query.dtype)
414
+ .matmul(v)
415
+ )
416
+
417
+
418
+ def _sfdp_replacement_16(query, key, value, attn_mask, inv_scale, dropout_p):
419
+ counters["inductor"]["fuse_attention"] += 1
420
+ return _scaled_dot_product_attention(
421
+ query.transpose(1, 2),
422
+ key.transpose(1, 2),
423
+ value.transpose(1, 2),
424
+ attn_mask=attn_mask.to(dtype=query.dtype),
425
+ dropout_p=dropout_p,
426
+ is_causal=False,
427
+ scale=1.0 / inv_scale,
428
+ )
429
+
430
+
431
+ def _sfdp_pattern_17(query, key, value, attn_mask, inv_scale, dropout_p):
432
+ # for DistilBert with dropout
433
+ q = query.permute([0, 2, 1, 3])
434
+ k = key.permute([0, 2, 1, 3])
435
+ v = value.permute([0, 2, 1, 3])
436
+ bs = q.size(0)
437
+ k_len = k.size(-2)
438
+ scores = q @ k.transpose(-2, -1)
439
+ scores = scores.div(inv_scale)
440
+ fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
441
+ attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores)
442
+ return (
443
+ torch.nn.functional.dropout(
444
+ torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1), dropout_p
445
+ )
446
+ @ v
447
+ )
448
+
449
+
450
+ def _sfdp_replacement_17(query, key, value, attn_mask, inv_scale, dropout_p):
451
+ counters["inductor"]["fuse_attention"] += 1
452
+ bs = query.size(0)
453
+ n_head = query.size(2)
454
+ q_len = query.size(1)
455
+ k_len = key.size(1)
456
+ # do attn_mask->logical_not() in _scaled_dot_product_attention
457
+ attn_mask = (
458
+ (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len))
459
+ )
460
+ return _scaled_dot_product_attention(
461
+ query.transpose(1, 2),
462
+ key.transpose(1, 2),
463
+ value.transpose(1, 2),
464
+ attn_mask=attn_mask.to(dtype=torch.bool),
465
+ dropout_p=dropout_p,
466
+ is_causal=False,
467
+ scale=1.0 / inv_scale,
468
+ )
469
+
470
+
471
+ def _sfdp_pattern_18(query, key, value, causal_mask, dropout_p):
472
+ # for hf_GPT2 with dropout (introduces clone node) for inference
473
+ # it also returns permuted key & value
474
+ query = query.permute([0, 2, 1, 3])
475
+ key = key.permute([0, 2, 1, 3])
476
+ value = value.permute([0, 2, 1, 3])
477
+ attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2))
478
+ inv_scale = torch.full(
479
+ [],
480
+ value.size(-1) ** 0.5,
481
+ dtype=attn_weights.dtype,
482
+ device=attn_weights.device,
483
+ )
484
+ attn_weights = attn_weights.div(inv_scale)
485
+ causal_mask_value = torch.full(
486
+ (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device
487
+ )
488
+ attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value)
489
+ return (
490
+ (
491
+ torch.nn.functional.dropout(attn_weights.softmax(dim=-1), dropout_p).matmul(
492
+ value
493
+ )
494
+ ),
495
+ key,
496
+ value,
497
+ )
498
+
499
+
500
+ def _sfdp_replacement_18(query, key, value, causal_mask, dropout_p):
501
+ counters["inductor"]["fuse_attention"] += 1
502
+ permuted_key = key.transpose(1, 2)
503
+ permuted_value = value.transpose(1, 2)
504
+ return (
505
+ _scaled_dot_product_attention(
506
+ query.transpose(1, 2),
507
+ permuted_key,
508
+ permuted_value,
509
+ attn_mask=causal_mask,
510
+ dropout_p=dropout_p,
511
+ is_causal=False,
512
+ scale=1.0 / math.sqrt(value.size(-1)),
513
+ ),
514
+ permuted_key,
515
+ permuted_value,
516
+ )
517
+
518
+
519
+ def _sfdp_pattern_19(query, key, value, causal_mask, attn_mask, dropout_p):
520
+ # for token-classification+gpt2 / text-generation+gpt2
521
+ attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2))
522
+ inv_scale = torch.full(
523
+ [],
524
+ value.size(-1) ** 0.5,
525
+ dtype=attn_weights.dtype,
526
+ device=attn_weights.device,
527
+ )
528
+ attn_weights = attn_weights.div(inv_scale)
529
+ causal_mask_value = torch.full(
530
+ (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device
531
+ )
532
+ attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value)
533
+ attn_weights = attn_weights + attn_mask
534
+ attn_weights = attn_weights.softmax(dim=-1).type(value.dtype)
535
+ return torch.nn.functional.dropout(attn_weights, dropout_p).matmul(value)
536
+
537
+
538
+ def _sfdp_replacement_19(query, key, value, causal_mask, attn_mask, dropout_p):
539
+ counters["inductor"]["fuse_attention"] += 1
540
+ fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
541
+ attn_mask = torch.where(causal_mask, attn_mask, fill_value)
542
+ return _scaled_dot_product_attention(
543
+ query,
544
+ key,
545
+ value,
546
+ attn_mask=attn_mask,
547
+ dropout_p=dropout_p,
548
+ is_causal=False,
549
+ scale=1.0 / math.sqrt(value.size(-1)),
550
+ )
551
+
552
+
553
+ def _sfdp_params_check(match):
554
+ assert all(k in match.kwargs for k in ("query", "key", "value"))
555
+ query = match.kwargs["query"].meta["val"]
556
+ key = match.kwargs["key"].meta["val"]
557
+ value = match.kwargs["value"].meta["val"]
558
+ if not (query.dtype == key.dtype == value.dtype) or not (
559
+ query.device == key.device == value.device
560
+ ):
561
+ return False
562
+ add_mask_node = filter_nodes(match.nodes, aten.add.Tensor)
563
+ # Has attn_mask add.
564
+ if len(add_mask_node) > 0:
565
+ attn_mask_node = add_mask_node[0].args[1]
566
+ # attn_mask_node may be a float/int number.
567
+ if not hasattr(attn_mask_node, "meta"):
568
+ return False
569
+ attn_mask = attn_mask_node.meta["val"] # type: ignore[union-attr]
570
+ # Make sure attn_mask.dtype == query.dtype or attn_mask.dtype == torch.bool
571
+ # attn_mask.dtype == torch.float for models like albert.
572
+ if (
573
+ not isinstance(attn_mask, torch.Tensor)
574
+ or not (
575
+ attn_mask.dtype == query.dtype
576
+ or attn_mask.dtype == torch.bool
577
+ or attn_mask.dtype == torch.float
578
+ )
579
+ or query.device != attn_mask.device
580
+ ):
581
+ return False
582
+ return True
583
+
584
+
585
+ def _sfdp_extra_check(scale_factor_op=None, disable_cuda=False):
586
+ def fn(match):
587
+ if (
588
+ disable_cuda
589
+ and "query" in match.kwargs
590
+ and "cuda" in str(match.kwargs["query"].meta["val"].device)
591
+ ):
592
+ return False
593
+ if scale_factor_op is not None:
594
+ scale_factor_node = filter_nodes(match.nodes, scale_factor_op)[0]
595
+ # Note: args[1] of the scale_factor_node is always the scale_factor for the current patterns.
596
+ scale_factor = scale_factor_node.args[1]
597
+ # make sure the scale_factor a float/int. SymInt?
598
+ if not isinstance(scale_factor, (float, int)):
599
+ return False
600
+ return _sfdp_params_check(match)
601
+
602
+ return fn
603
+
604
+
605
+ def partialize_and_update_signature(func, **kwargs):
606
+ """
607
+ Equivalent to functools.partial but also updates the signature on returned function
608
+ """
609
+ original_sig = inspect.signature(func)
610
+ parameters = original_sig.parameters
611
+
612
+ new_parameters = {
613
+ key: value for key, value in parameters.items() if key not in kwargs
614
+ }
615
+ new_sig = inspect.Signature(parameters=list(new_parameters.values()))
616
+
617
+ partial_func = functools.partial(func, **kwargs)
618
+
619
+ def wrapper(*args, **kwargs):
620
+ return partial_func(*args, **kwargs)
621
+
622
+ wrapper.__signature__ = new_sig # type: ignore[attr-defined]
623
+ wrapper.__name__ = func.__name__
624
+
625
+ return wrapper
626
+
627
+
628
+ def _get_sfdp_patterns():
629
+ from .joint_graph import patterns
630
+
631
+ if torch.cuda.is_available():
632
+ # workaround https://github.com/pytorch/pytorch/issues/97894
633
+ device = "cuda"
634
+ else:
635
+ device = "cpu"
636
+
637
+ # sizes/values don't actually matter for initial trace
638
+ # once we get a possible match we re-trace with the actual values and verify the match still holds
639
+ g_inp = functools.partial(
640
+ torch.empty, (2, 4, 8, 16), device=device, requires_grad=True
641
+ )
642
+ # attn_mask
643
+ b_inp = functools.partial(torch.empty, (1, 1, 8, 8), device=device)
644
+ m_inp = functools.partial(torch.empty, (2, 1, 1, 4), device=device)
645
+ # inv_scale
646
+ c_inp = functools.partial(torch.tensor, 2.0, device=device)
647
+ # workaround https://github.com/pytorch/pytorch/issues/97894
648
+ # 0.113377 is a "magic" value that lets us recover the lost input arg relationship
649
+ d = {"dropout_p": 0.113377}
650
+
651
+ # we could also generate all these patterns in 3d.. TODO
652
+ g_3d_inp = functools.partial(
653
+ torch.empty, (1024, 128, 128), device=device, requires_grad=True
654
+ )
655
+
656
+ # reshape in matmul decomposition generates a clone when batch_size>1 due to the memory layout change.
657
+ # however when batch_size=1, reshape does not change the memory layout, so clone would not be generated.
658
+ # here we need to trace with input of batch_size=1 to generate a pattern graph without clone.
659
+ g_bs1_inp = functools.partial(
660
+ torch.empty, (1, 4, 8, 16), device=device, requires_grad=True
661
+ )
662
+ m_bs1_inp = functools.partial(torch.empty, (1, 1, 1, 4), device=device)
663
+
664
+ # softmax will generate a dtype conversion on inputs if they are in half,
665
+ # but will not in float, so we generate a pattern for both
666
+ for dtype in [torch.float, torch.half]:
667
+ g = functools.partial(g_inp, dtype=dtype)
668
+ b = functools.partial(b_inp, dtype=dtype)
669
+ b_float = functools.partial(b_inp, dtype=torch.float)
670
+ b_bool = functools.partial(b_inp, dtype=torch.bool)
671
+ m = functools.partial(m_inp, dtype=dtype)
672
+ m_float = functools.partial(m_inp, dtype=torch.float)
673
+ m_bool = functools.partial(m_inp, dtype=torch.bool)
674
+ c = functools.partial(c_inp, dtype=dtype)
675
+ g_3d = functools.partial(g_3d_inp, dtype=dtype)
676
+ g_bs1 = functools.partial(g_bs1_inp, dtype=dtype)
677
+ m_bs1 = functools.partial(m_bs1_inp, dtype=dtype)
678
+ m_bs1_float = functools.partial(m_bs1_inp, dtype=torch.float)
679
+ m_bs1_bool = functools.partial(m_bs1_inp, dtype=torch.bool)
680
+
681
+ candidates = [
682
+ (
683
+ _sfdp_pattern_1,
684
+ _sfdp_replacement_1,
685
+ [g(), g(), g(), c()],
686
+ {},
687
+ _sfdp_extra_check(aten.div.Tensor),
688
+ ),
689
+ (
690
+ _sfdp_pattern_2,
691
+ _sfdp_replacement_2,
692
+ [g(), g(), g(), c()],
693
+ {},
694
+ _sfdp_extra_check(aten.mul.Tensor),
695
+ ),
696
+ (
697
+ _sfdp_pattern_3,
698
+ _sfdp_replacement_3,
699
+ [g(), g(), g(), c()],
700
+ d,
701
+ _sfdp_extra_check(aten.div.Tensor),
702
+ ),
703
+ (
704
+ _sfdp_pattern_4,
705
+ _sfdp_replacement_4,
706
+ [g(), g(), g(), c()],
707
+ d,
708
+ _sfdp_extra_check(aten.mul.Tensor),
709
+ ),
710
+ (
711
+ _sfdp_pattern_5,
712
+ _sfdp_replacement_5,
713
+ [g(), g(), g(), b()],
714
+ {},
715
+ _sfdp_params_check,
716
+ ),
717
+ (
718
+ _sfdp_pattern_6,
719
+ _sfdp_replacement_6,
720
+ [g(), g(), g(), b()],
721
+ d,
722
+ _sfdp_params_check,
723
+ ),
724
+ (
725
+ _sfdp_pattern_7,
726
+ _sfdp_replacement_7,
727
+ [g(), g(), g()],
728
+ d,
729
+ _sfdp_params_check,
730
+ ),
731
+ (
732
+ _sfdp_pattern_8,
733
+ _sfdp_replacement_8,
734
+ [g(), g(), g()],
735
+ {},
736
+ _sfdp_params_check,
737
+ ),
738
+ (
739
+ _sfdp_pattern_9,
740
+ _sfdp_replacement_9,
741
+ [g(), g(), g()],
742
+ d,
743
+ _sfdp_params_check,
744
+ ),
745
+ (
746
+ _sfdp_pattern_10,
747
+ _sfdp_replacement_10,
748
+ [g(), g(), g()],
749
+ {},
750
+ _sfdp_params_check,
751
+ ),
752
+ (
753
+ _sfdp_pattern_11,
754
+ _sfdp_replacement_11,
755
+ [g(), g(), g(), c()],
756
+ {},
757
+ _sfdp_extra_check(aten.div.Tensor),
758
+ ),
759
+ (
760
+ _sfdp_pattern_12,
761
+ _sfdp_replacement_12,
762
+ [g(), g(), g(), c()],
763
+ d,
764
+ _sfdp_extra_check(aten.div.Tensor),
765
+ ),
766
+ (
767
+ _sfdp_pattern_13,
768
+ _sfdp_replacement_13,
769
+ [g_3d(), g_3d(), g_3d()],
770
+ d,
771
+ _sfdp_params_check,
772
+ ),
773
+ (
774
+ _sfdp_pattern_14,
775
+ _sfdp_replacement_14,
776
+ [g(), g(), g(), m(), c()],
777
+ {},
778
+ _sfdp_extra_check(aten.div.Tensor),
779
+ ),
780
+ (
781
+ _sfdp_pattern_15,
782
+ _sfdp_replacement_15,
783
+ [g(), g(), g(), m(), c()],
784
+ {},
785
+ _sfdp_extra_check(aten.div.Tensor),
786
+ ),
787
+ # TODO: Enable CUDA after solving Bert accuracy issue of calling efficient attention
788
+ (
789
+ _sfdp_pattern_16,
790
+ _sfdp_replacement_16,
791
+ [g(), g(), g(), m(), c()],
792
+ d,
793
+ _sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
794
+ ),
795
+ (
796
+ _sfdp_pattern_16,
797
+ _sfdp_replacement_16,
798
+ [g_bs1(), g_bs1(), g_bs1(), m_bs1(), c()],
799
+ d,
800
+ _sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
801
+ ),
802
+ (
803
+ _sfdp_pattern_17,
804
+ _sfdp_replacement_17,
805
+ [g(), g(), g(), m(), c()],
806
+ d,
807
+ _sfdp_extra_check(aten.div.Tensor),
808
+ ),
809
+ (
810
+ _sfdp_pattern_18,
811
+ _sfdp_replacement_18,
812
+ [g(), g(), g(), m_bool()],
813
+ d,
814
+ # CUDA AOT Inductor CI job's GPT2ForSequenceClassification accuracy test failed
815
+ _sfdp_extra_check(disable_cuda=True),
816
+ ),
817
+ (
818
+ _sfdp_pattern_18,
819
+ _sfdp_replacement_18,
820
+ [g_bs1(), g_bs1(), g_bs1(), m_bs1_bool()],
821
+ d,
822
+ # CUDA AOT Inductor CI job's GPT2ForSequenceClassification accuracy test failed
823
+ _sfdp_extra_check(disable_cuda=True),
824
+ ),
825
+ (
826
+ _sfdp_pattern_19,
827
+ _sfdp_replacement_19,
828
+ [g(), g(), g(), b_bool(), b_float()],
829
+ d,
830
+ _sfdp_params_check,
831
+ ),
832
+ ]
833
+ mask_fp32_patterns = ["pattern_16"]
834
+ if dtype == torch.half:
835
+ # Add inputs of bf16 q/k/v and fp32 mask, for models like albert.
836
+ candidates.append(
837
+ (
838
+ _sfdp_pattern_16,
839
+ _sfdp_replacement_16,
840
+ [g(), g(), g(), m_float(), c()],
841
+ d,
842
+ _sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
843
+ )
844
+ )
845
+ candidates.append(
846
+ (
847
+ _sfdp_pattern_16,
848
+ _sfdp_replacement_16,
849
+ [g_bs1(), g_bs1(), g_bs1(), m_bs1_float(), c()],
850
+ d,
851
+ _sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
852
+ )
853
+ )
854
+
855
+ for pattern, replacement, args, workaround, extra_check in candidates:
856
+ # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern
857
+ # gets serialized to a python file and does not require tracing at runtime.
858
+ assert isinstance(workaround, dict)
859
+ name = pattern.__name__
860
+
861
+ if dtype != torch.float:
862
+ name += "_half"
863
+ if (
864
+ any(p in name for p in mask_fp32_patterns)
865
+ and args[3].dtype == torch.float32
866
+ ):
867
+ name += "_mask_fp32"
868
+ if args[0].size(0) == 1:
869
+ name += "_bs1"
870
+
871
+ training_name = name + "_training"
872
+ yield training_name, {
873
+ "search_fn": pattern,
874
+ "replace_fn": replacement,
875
+ "example_inputs": args,
876
+ "trace_fn": joint_fwd_bwd,
877
+ "pass_dicts": patterns,
878
+ "extra_check": extra_check,
879
+ "scalar_workaround": workaround,
880
+ }
881
+
882
+ if workaround:
883
+ assert len(workaround) == 1 and "dropout_p" in workaround
884
+ # functools.partial insufficient because we look at signature downstream
885
+ pattern = partialize_and_update_signature(pattern, dropout_p=0.0)
886
+ replacement = partialize_and_update_signature(
887
+ replacement, dropout_p=0.0
888
+ )
889
+ workaround = {}
890
+
891
+ inference_name = name + "_inference"
892
+ yield inference_name, {
893
+ "search_fn": pattern,
894
+ "replace_fn": replacement,
895
+ "example_inputs": args,
896
+ "trace_fn": fwd_only,
897
+ "pass_dicts": patterns,
898
+ "extra_check": extra_check,
899
+ "scalar_workaround": workaround,
900
+ # with dropout turned into clone, we end up with a number of
901
+ # semantically identical graphs
902
+ "skip_duplicates": True,
903
+ }
904
+
905
+
906
+ @functools.lru_cache(None)
907
+ def _sfdp_init():
908
+ for key, register_replacement_kwargs in _get_sfdp_patterns():
909
+ gen_register_replacement(key, **register_replacement_kwargs)
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py ADDED
@@ -0,0 +1,1317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import collections
3
+ import logging
4
+ import operator
5
+ from collections import OrderedDict
6
+ from typing import (
7
+ Any,
8
+ DefaultDict,
9
+ Deque,
10
+ Dict,
11
+ Iterable,
12
+ Iterator,
13
+ List,
14
+ Optional,
15
+ Set,
16
+ Tuple,
17
+ )
18
+
19
+ import torch
20
+ from torch._dynamo.utils import counters, optimus_scuba_log
21
+ from torch._utils_internal import upload_graph
22
+ from torch.fx.passes.graph_transform_observer import GraphTransformObserver
23
+
24
+ from .. import config
25
+ from ..pattern_matcher import (
26
+ CallFunctionVarArgs,
27
+ get_arg_value,
28
+ stable_topological_sort,
29
+ )
30
+
31
+
32
+ try:
33
+ # importing this will register fbgemm lowerings for inductor
34
+ import deeplearning.fbgemm.fbgemm_gpu.fb.inductor_lowerings # noqa: F401
35
+
36
+ has_fbgemm = True
37
+ except Exception:
38
+ has_fbgemm = False
39
+
40
+ aten = torch.ops.aten
41
+
42
+ log = logging.getLogger(__name__)
43
+
44
+ MIN_FUSE_SET_SIZE = 5
45
+ MAX_FUSE_SET_SIZE = 300
46
+ MAX_FUSE_SEARCH_DEPTH = 5
47
+ # The maximum tensor size that can go into the fusion group
48
+ MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR = 4096
49
+ # Whether we only fuse nodes with same parent node
50
+ FUSE_NODES_WITH_SAME_PARENT = False
51
+ # Whether we enable the add broadcast in batch linear
52
+ SHAPE_BROADCAST_BATCH_LINEAR = False
53
+ # Whether we enable the fuse nodes with same users
54
+ Fuse_NODES_WITH_SAME_USERS = False
55
+
56
+ # exclude these nodes from BFS
57
+ # excluding get item improves optimizer compilation time by 60s
58
+ SEARCH_EXCLUSIONS = {operator.getitem}
59
+
60
+
61
+ default_graph_search_options = {
62
+ "min_fuse_set_size": MIN_FUSE_SET_SIZE,
63
+ "max_fuse_set_size": MAX_FUSE_SET_SIZE,
64
+ "max_fuse_search_depth": MAX_FUSE_SEARCH_DEPTH,
65
+ "max_fuse_tensor_size_group_linear": MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR,
66
+ "fuse_nodes_with_same_parent": FUSE_NODES_WITH_SAME_PARENT,
67
+ "shape_broadcast_batch_linear": SHAPE_BROADCAST_BATCH_LINEAR,
68
+ "fuse_nodes_with_same_users": Fuse_NODES_WITH_SAME_USERS,
69
+ }
70
+
71
+ graph_search_options = default_graph_search_options
72
+
73
+
74
+ def update_stack_example_value(node, metadata, dim=0, op=torch.stack):
75
+ """
76
+ Update the example value of the node in the graph to enable followup split cat opt.
77
+ """
78
+ if node is not None and hasattr(node, "meta"):
79
+ if op == torch.stack:
80
+ example_value = torch.stack(metadata, dim=dim)
81
+ elif op == torch.unbind:
82
+ example_value = torch.unbind(metadata, dim=dim) # type: ignore[assignment]
83
+ else:
84
+ return
85
+ node.meta["example_value"] = example_value
86
+
87
+
88
+ def update_pointwise_example_value(pointwise_node, input, other, op):
89
+ """
90
+ Update the example value of the add node in the graph to enable followup split cat opt.
91
+ """
92
+ if pointwise_node is not None and hasattr(pointwise_node, "meta"):
93
+ if op == torch.add:
94
+ example_value = torch.add(input, other)
95
+ elif op == torch.mul:
96
+ example_value = torch.mul(input, other)
97
+ else:
98
+ return
99
+ pointwise_node.meta["example_value"] = example_value
100
+
101
+
102
+ class GroupBatchFusionBase:
103
+ def __init__(self, **kwargs) -> None:
104
+ self.graph_search_options = kwargs.pop(
105
+ "graph_search_options", default_graph_search_options
106
+ )
107
+
108
+ def match(self, node):
109
+ raise NotImplementedError("match called on base")
110
+
111
+ def fuse(self, graph, subset):
112
+ raise NotImplementedError("fuse called on base")
113
+
114
+
115
+ PRE_GRAD_FUSIONS: Dict[str, GroupBatchFusionBase] = {}
116
+ POST_GRAD_FUSIONS: Dict[str, GroupBatchFusionBase] = {}
117
+
118
+
119
+ def register_fusion(name: str, pre_grad=True):
120
+ def decorator(fusion_cls: GroupBatchFusionBase):
121
+ if pre_grad:
122
+ PRE_GRAD_FUSIONS[name] = fusion_cls
123
+ else:
124
+ POST_GRAD_FUSIONS[name] = fusion_cls
125
+ return fusion_cls
126
+
127
+ return decorator
128
+
129
+
130
+ def list_group_batch_fusions(pre_grad=True) -> List[str]:
131
+ if pre_grad:
132
+ return list(PRE_GRAD_FUSIONS.keys())
133
+ else:
134
+ return list(POST_GRAD_FUSIONS.keys())
135
+
136
+
137
+ def decompose_stack(graph: torch.fx.GraphModule, input_tensors: List[Any]) -> Any:
138
+ unsqueezed_inputs = []
139
+ unsqueezed_inputs_meta = []
140
+ for input_tensor in input_tensors:
141
+ unsqueezed_input = graph.call_function(
142
+ aten.unsqueeze, args=(input_tensor,), kwargs={"dim": 0}
143
+ )
144
+ unsqueezed_inputs.append(unsqueezed_input)
145
+ unsqueezed_input.meta["val"] = aten.unsqueeze(input_tensor.meta["val"], dim=0) # type: ignore[assignment]
146
+ unsqueezed_inputs_meta.append(unsqueezed_input.meta["val"])
147
+ stacked_inputs = graph.call_function(
148
+ aten.cat, args=(unsqueezed_inputs,), kwargs={"dim": 0}
149
+ )
150
+ stacked_inputs.meta["val"] = aten.cat(unsqueezed_inputs_meta, dim=0) # type: ignore[assignment]
151
+ return stacked_inputs
152
+
153
+
154
+ class GroupFusion(GroupBatchFusionBase):
155
+ """
156
+ Fuse ops in a group way, e.g, fuse mm/addmm of arbitrary input shapes with fbgemm.gmm.
157
+ """
158
+
159
+
160
+ class BatchFusion(GroupBatchFusionBase):
161
+ """
162
+ Fuse ops in a batch way, e.g, fuse mm/addmm of same input shapes with bmm.
163
+ """
164
+
165
+
166
+ class BatchPointwiseOpsFusionFactory(BatchFusion):
167
+ def __init__(self, op, **kwargs) -> None:
168
+ super().__init__(**kwargs)
169
+ self.op = op
170
+
171
+
172
+ @register_fusion("batch_linear_post_grad", pre_grad=False)
173
+ class PostGradBatchLinearFusion(BatchFusion):
174
+ """
175
+ Fuse ops in a batch way in post grad (aten level).
176
+ """
177
+
178
+ def _addmm_node_can_be_fused(self, node: torch.fx.Node) -> bool:
179
+ # pyre-fixme[7]: Incompatible return type
180
+ return (
181
+ node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0 # type: ignore[return-value]
182
+ )
183
+
184
+ def _is_input_2d(self, input: torch.fx.Node) -> bool:
185
+ input_shapes = input.meta["val"].shape
186
+ return (
187
+ len(input_shapes) == 2
188
+ and isinstance(input_shapes[0], int)
189
+ and isinstance(input_shapes[1], int)
190
+ )
191
+
192
+ def match(
193
+ self, node: torch.fx.Node
194
+ ) -> Optional[Tuple[str, int, int, int, bool, str]]:
195
+ if CallFunctionVarArgs(aten.mm).match(node):
196
+ input_m, weight_m = node.args
197
+ bias_m = None
198
+
199
+ elif CallFunctionVarArgs(aten.addmm.default).match(
200
+ node
201
+ ) and self._addmm_node_can_be_fused(node):
202
+ bias_m, input_m, weight_m = node.args
203
+ else:
204
+ return None
205
+ # get the user of the node
206
+ if self.graph_search_options.get("fuse_nodes_with_same_users", False):
207
+ users = [user.target for user in node.users.keys()]
208
+ else:
209
+ users = "" # type: ignore[assignment]
210
+ # only handle the cases where inputs are 2D tensors
211
+ if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m): # type: ignore[arg-type]
212
+ return None
213
+ m, k = input_m.meta["val"].shape # type: ignore[union-attr]
214
+ n = weight_m.meta["val"].shape[1] # type: ignore[union-attr]
215
+ batch_key = ("batch_linear_post_grad", m, k, n, bias_m is not None, str(users))
216
+ return batch_key
217
+
218
+ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
219
+ batch_inputs = []
220
+ batch_weights = []
221
+ batch_biases = []
222
+ batch_nodes = []
223
+ batch_inputs_meta = []
224
+ batch_weights_meta = []
225
+ batch_biases_meta = []
226
+
227
+ for node in subset:
228
+ if CallFunctionVarArgs(aten.addmm.default).match(node):
229
+ bias, input, weight = node.args
230
+ elif CallFunctionVarArgs(aten.mm.default).match(node):
231
+ input, weight = node.args
232
+ bias = None
233
+ batch_nodes.append(node)
234
+ batch_inputs.append(input) # type: ignore[possibly-undefined]
235
+ batch_weights.append(weight) # type: ignore[possibly-undefined]
236
+ batch_biases.append(bias) # type: ignore[possibly-undefined]
237
+ batch_inputs_meta.append(input.meta) # type: ignore[possibly-undefined, union-attr]
238
+ batch_weights_meta.append(weight.meta) # type: ignore[possibly-undefined, union-attr]
239
+ if bias is not None: # type: ignore[possibly-undefined]
240
+ batch_biases_meta.append(bias.meta) # type: ignore[possibly-undefined, union-attr]
241
+ else:
242
+ batch_biases_meta.append(None)
243
+
244
+ with graph.inserting_before(subset[-1]):
245
+ fused_inputs = decompose_stack(graph, batch_inputs)
246
+ fused_weights = decompose_stack(graph, batch_weights)
247
+ fused_inputs_meta_val = torch.stack(
248
+ [input["val"] for input in batch_inputs_meta]
249
+ )
250
+ fused_weights_meta_val = torch.stack(
251
+ [weight["val"] for weight in batch_weights_meta]
252
+ )
253
+ fused_bmm = graph.call_function(
254
+ aten.bmm,
255
+ args=(fused_inputs, fused_weights),
256
+ )
257
+ fused_bmm.meta["val"] = aten.bmm(
258
+ fused_inputs_meta_val, fused_weights_meta_val
259
+ )
260
+ for i, original_mm in enumerate(batch_nodes):
261
+ has_bias = False
262
+ with graph.inserting_after(fused_bmm):
263
+ new_mm = graph.call_function(aten.select, args=((fused_bmm, 0, i)))
264
+ new_mm.meta["val"] = aten.select(fused_bmm.meta["val"], 0, i)
265
+ if batch_biases[i]:
266
+ has_bias = True
267
+ # broadcast the bias to the same shape as the mm output
268
+ if self.graph_search_options.get(
269
+ "shape_broadcast_batch_linear", False
270
+ ):
271
+ broadcast_shape = torch.broadcast_shapes(
272
+ batch_biases_meta[i]["val"].shape, new_mm.meta["val"].shape
273
+ )
274
+ broadcast_bias = graph.call_function(
275
+ aten.broadcast_to.default,
276
+ args=(batch_biases[i],),
277
+ kwargs={"size": broadcast_shape},
278
+ )
279
+ broadcast_bias.meta["val"] = aten.broadcast_to(batch_biases_meta[i]["val"], broadcast_shape) # type: ignore[assignment]
280
+ new_bias_add = graph.call_function(
281
+ aten.add.Tensor, args=((broadcast_bias, new_mm))
282
+ )
283
+ new_bias_add.meta["val"] = aten.add.Tensor(
284
+ broadcast_bias.meta["val"], new_mm.meta["val"]
285
+ )
286
+ else:
287
+ new_bias_add = graph.call_function(
288
+ aten.add, args=((batch_biases[i], new_mm))
289
+ )
290
+ new_bias_add.meta["val"] = aten.add.Tensor(
291
+ batch_biases_meta[i]["val"], new_mm.meta["val"]
292
+ )
293
+ new_mm_cont = new_bias_add if has_bias else new_mm # type: ignore[possibly-undefined]
294
+ original_mm.replace_all_uses_with(new_mm_cont)
295
+ new_mm_cont.meta.update(original_mm.meta)
296
+ graph.erase_node(original_mm)
297
+ counters["inductor"]["batch_linear_post_grad"] += 1
298
+
299
+
300
+ @register_fusion("group_linear", pre_grad=False)
301
+ class GroupLinearFusion(GroupFusion):
302
+ def _addmm_node_can_be_fused(self, node: torch.fx.Node):
303
+ input_shape = node.args[1].meta["val"].shape # type: ignore[union-attr]
304
+ weight_shape = node.args[2].meta["val"].shape # type: ignore[union-attr]
305
+ return (
306
+ node.kwargs.get("beta", 1.0) == 1.0
307
+ and node.kwargs.get("alpha", 1.0) == 1.0
308
+ and len(input_shape) == 2
309
+ and len(weight_shape) == 2
310
+ and all(x % 2 == 0 for x in input_shape + weight_shape)
311
+ and all(
312
+ shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"]
313
+ for shape in input_shape + weight_shape
314
+ )
315
+ )
316
+
317
+ def _mm_node_can_be_fused(self, node: torch.fx.Node):
318
+ input_shape = node.args[0].meta["val"].shape # type: ignore[union-attr]
319
+ weight_shape = node.args[1].meta["val"].shape # type: ignore[union-attr]
320
+ return (
321
+ len(input_shape) == 2
322
+ and len(weight_shape) == 2
323
+ and all(x % 2 == 0 for x in input_shape + weight_shape)
324
+ and all(
325
+ shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"]
326
+ for shape in input_shape + weight_shape
327
+ )
328
+ )
329
+
330
+ def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool]]:
331
+ if CallFunctionVarArgs(aten.mm.default).match(
332
+ node
333
+ ) and self._mm_node_can_be_fused(node):
334
+ group_key = ("group_linear", True)
335
+ elif CallFunctionVarArgs(aten.addmm.default).match(
336
+ node
337
+ ) and self._addmm_node_can_be_fused(node):
338
+ bias = node.args[0]
339
+ group_key = ("group_linear", bias is None)
340
+ else:
341
+ group_key = None
342
+ return group_key
343
+
344
+ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
345
+ group_inputs = []
346
+ group_weights = []
347
+ group_biases = []
348
+ group_nodes = []
349
+ for node in subset:
350
+ if CallFunctionVarArgs(aten.addmm.default).match(node):
351
+ bias, input, weight = node.args
352
+ else:
353
+ assert CallFunctionVarArgs(aten.mm.default).match(node)
354
+ input, weight = node.args
355
+ bias = None
356
+
357
+ group_nodes.append(node)
358
+ group_inputs.append(input)
359
+ group_weights.append(weight)
360
+ group_biases.append(bias)
361
+
362
+ if all(bias is None for bias in group_biases):
363
+ group_biases = None # type: ignore[assignment]
364
+
365
+ with graph.inserting_before(subset[0]):
366
+ fused_mm = graph.call_function(
367
+ torch.ops.fbgemm.gmm.default,
368
+ args=(group_inputs, group_weights, group_biases),
369
+ kwargs={"smart_fused": True},
370
+ )
371
+
372
+ for i, original_mm in enumerate(group_nodes):
373
+ with graph.inserting_after(fused_mm):
374
+ new_mm = graph.call_function(operator.getitem, args=(fused_mm, i))
375
+ original_mm.replace_all_uses_with(new_mm)
376
+ new_mm.meta.update(original_mm.meta)
377
+ graph.erase_node(original_mm)
378
+ counters["inductor"]["group_linear"] += 1
379
+
380
+
381
+ class BatchPointwiseMathOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
382
+ """
383
+ Batch pointwise math operator (e.g., add, mul) in post grad pass.
384
+ """
385
+
386
+ def __init__(self, op, **kwargs) -> None:
387
+ super().__init__(op, **kwargs)
388
+ self.op = op
389
+
390
+ def _pointwise_node_can_be_fused(self, node: torch.fx.Node):
391
+ # note: we only consider the case where the inputs are tensors
392
+ # for mixed precision training, we need to make sure the inputs
393
+ # of the aten.cat when do the stack should be the same dtype
394
+ # otherwise, the output of the aten.cat may be not the same as
395
+ # its inputs, and cause dtype not same error in mm or addmm
396
+ input, other = node.args
397
+ return (
398
+ input.meta["val"].shape == other.meta["val"].shape # type: ignore[union-attr]
399
+ if hasattr(input, "meta")
400
+ and hasattr(other, "meta")
401
+ and "val" in input.meta # type: ignore[union-attr]
402
+ and "val" in other.meta # type: ignore[union-attr]
403
+ else False
404
+ )
405
+
406
+ def match(self, node: torch.fx.Node):
407
+ if CallFunctionVarArgs(self.op).match(
408
+ node
409
+ ) and self._pointwise_node_can_be_fused(node):
410
+ alpha = node.kwargs.get("alpha", 1.0)
411
+ rounding_mode = node.kwargs.get("rounding_mode", None)
412
+ input, other = node.args
413
+ shape = list(input.meta["val"].shape) # type: ignore[union-attr]
414
+ if self.graph_search_options.get("fuse_nodes_with_same_parent", False):
415
+ # only consider the linear case so far
416
+ # pyre-fixme[16]
417
+ if input.target == aten.select or other.target == aten.select: # type: ignore[union-attr]
418
+ parent = (
419
+ # pyre-fixme[16]
420
+ input.args[0] # type: ignore[union-attr]
421
+ # pyre-fixme[16]
422
+ if input.target == aten.select # type: ignore[union-attr]
423
+ else other.args[0] # type: ignore[union-attr]
424
+ )
425
+ else:
426
+ parent = ""
427
+ else:
428
+ parent = ""
429
+ group_key = (
430
+ "batch_aten_" + self.op.__name__.lower().split(".")[0],
431
+ str(shape),
432
+ str(input.meta["val"].dtype), # type: ignore[union-attr]
433
+ str(other.meta["val"].dtype), # type: ignore[union-attr]
434
+ str(alpha),
435
+ str(rounding_mode),
436
+ str(parent),
437
+ )
438
+ else:
439
+ group_key = None
440
+ return group_key
441
+
442
+ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
443
+ batch_inputs, batch_others = [], []
444
+ alpha = subset[0].kwargs.get("alpha", 1.0)
445
+ batch_inputs_meta, batch_others_meta = [], []
446
+
447
+ for node in subset:
448
+ input, other = node.args
449
+ batch_inputs.append(input)
450
+ batch_others.append(other)
451
+ batch_inputs_meta.append(input.meta) # type: ignore[possibly-undefined, union-attr]
452
+ batch_others_meta.append(other.meta) # type: ignore[possibly-undefined, union-attr]
453
+
454
+ with graph.inserting_before(subset[0]):
455
+ stack_inputs = decompose_stack(graph, batch_inputs)
456
+ stack_others = decompose_stack(graph, batch_others)
457
+ stack_inputs_meta = torch.stack(
458
+ [input["val"] for input in batch_inputs_meta]
459
+ )
460
+ stack_others_meta = torch.stack(
461
+ [other["val"] for other in batch_others_meta]
462
+ )
463
+
464
+ batch_op = graph.call_function(
465
+ self.op,
466
+ args=(stack_inputs, stack_others),
467
+ kwargs={"alpha": alpha} if self.op == aten.add.Tensor else {},
468
+ )
469
+ batch_op.meta["val"] = self.op(stack_inputs_meta, stack_others_meta)
470
+ for i, original_add in enumerate(subset):
471
+ with graph.inserting_after(batch_op):
472
+ new_add = graph.call_function(
473
+ torch.ops.aten.select, args=((batch_op, 0, i))
474
+ )
475
+ original_add.replace_all_uses_with(new_add)
476
+ new_add.meta.update(original_add.meta)
477
+ graph.erase_node(original_add)
478
+ counters["inductor"][
479
+ "batch_aten_" + self.op.__name__.lower().split(".")[0]
480
+ ] += 1
481
+
482
+
483
+ @register_fusion("batch_linear_lhs")
484
+ class BatchLinearLHSFusion(BatchFusion):
485
+ """
486
+ Batch linear left-hand side fusion. This pass tries to fuse the following patterns:
487
+
488
+ torch.nn.functional.linear(x, w1), linear(x, w2),... * linear(x, wn)
489
+ -> torch.mm(x, torch.cat([w1, w2,... * wn]).transpose(0, 1))
490
+
491
+ We have a separate pass to eliminate contiguous transpose in a generic way.
492
+ """
493
+
494
+ def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool, Any]]:
495
+ if CallFunctionVarArgs(torch.nn.functional.linear).match(
496
+ node
497
+ ) and is_linear_node_can_be_fused(node):
498
+ input = get_arg_value(node, 0, "input")
499
+ bias = get_arg_value(node, 2, "bias")
500
+ group_key = ("batch_linear_lhs", bias is None, input)
501
+ else:
502
+ group_key = None
503
+ return group_key
504
+
505
+ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
506
+ batch_nodes = []
507
+ batch_input = None
508
+ batch_weights, batch_weights_meta = [], []
509
+ batch_biases, batch_biases_meta = [], []
510
+ split_sections = []
511
+ for node in subset:
512
+ input = get_arg_value(node, 0, "input")
513
+ weight = get_arg_value(node, 1, "weight")
514
+ bias = get_arg_value(node, 2, "bias")
515
+ batch_nodes.append(node)
516
+ if batch_input is None:
517
+ batch_input = input
518
+ else:
519
+ assert batch_input is input
520
+ batch_weights.append(weight)
521
+ batch_weights_meta.append(weight.meta["example_value"])
522
+ if bias:
523
+ batch_biases.append(bias)
524
+ batch_biases_meta.append(bias.meta["example_value"])
525
+ split_sections.append(weight.meta["example_value"].shape[0])
526
+
527
+ with graph.inserting_before(subset[0]):
528
+ cat_weights = graph.call_function(
529
+ torch.cat, args=(batch_weights,), kwargs={"dim": 0}
530
+ )
531
+ cat_weights.meta["example_value"] = torch.cat(batch_weights_meta, dim=0)
532
+ transposed_weights = graph.call_function(
533
+ torch.transpose, args=(cat_weights, 0, 1)
534
+ )
535
+ transposed_weights.meta["example_value"] = torch.transpose(
536
+ cat_weights.meta["example_value"], 0, 1
537
+ )
538
+ if len(batch_biases) > 0:
539
+ cat_biases = graph.call_function(
540
+ torch.cat, args=(batch_biases,), kwargs={"dim": 0}
541
+ )
542
+ cat_biases.meta["example_value"] = torch.cat(batch_biases_meta, dim=0)
543
+ fused_lhs = graph.call_function(
544
+ torch.addmm,
545
+ args=(cat_biases, batch_input, transposed_weights),
546
+ )
547
+ fused_lhs.meta["example_value"] = torch.addmm(
548
+ cat_biases.meta["example_value"],
549
+ batch_input.meta["example_value"], # type: ignore[union-attr]
550
+ transposed_weights.meta["example_value"],
551
+ )
552
+ else:
553
+ fused_lhs = graph.call_function(
554
+ torch.mm,
555
+ args=(batch_input, transposed_weights),
556
+ )
557
+ fused_lhs.meta["example_value"] = torch.mm(
558
+ batch_input.meta["example_value"], # type: ignore[union-attr]
559
+ transposed_weights.meta["example_value"],
560
+ )
561
+ fused_lhs_list = graph.call_function(
562
+ torch.split, args=(fused_lhs, split_sections), kwargs={"dim": 1}
563
+ )
564
+
565
+ for i, node in enumerate(batch_nodes):
566
+ with graph.inserting_after(fused_lhs_list):
567
+ new_node = graph.call_function(
568
+ operator.getitem, args=(fused_lhs_list, i)
569
+ )
570
+ node.replace_all_uses_with(new_node)
571
+ new_node.meta.update(node.meta)
572
+ graph.erase_node(node)
573
+ counters["inductor"]["batch_linear_lhs"] += 1
574
+
575
+
576
+ def is_node_meta_valid(node: Optional[torch.fx.Node]):
577
+ return node is None or "example_value" in node.meta or "val" in node.meta
578
+
579
+
580
+ # Poor person's check for if a node in the graph mutates its input.
581
+ # (the graph is torch IR, so we will see torch fns and python operators)
582
+ def _is_mutable_node(tgt):
583
+ if str(tgt).endswith("_"):
584
+ # e.g. torch.mul_, torch.Tensor.mul_
585
+ return True
586
+ if (
587
+ hasattr(tgt, "__module__")
588
+ and tgt.__module__ == "_operator"
589
+ and tgt.__name__.startswith("i")
590
+ ):
591
+ # e.g. operator.iand, operator.imul
592
+ return True
593
+ return False
594
+
595
+
596
+ def is_linear_node_can_be_fused(node: torch.fx.Node):
597
+ input = get_arg_value(node, 0, "input")
598
+ weight = get_arg_value(node, 1, "weight")
599
+ return (
600
+ is_node_meta_valid(node)
601
+ and is_node_meta_valid(input)
602
+ and is_node_meta_valid(weight)
603
+ and len(input.meta["example_value"].shape) == 2
604
+ and len(weight.meta["example_value"].shape) == 2
605
+ # the mm -> bmm transform adds an unbind() op,
606
+ # which is not safe for autograd when the output of the mm is mutated.
607
+ # don't pattern match if any users of the mm mutate the input.
608
+ and not any(_is_mutable_node(user.target) for user in node.users)
609
+ )
610
+
611
+
612
+ @register_fusion("batch_linear")
613
+ class PreGradBatchLinearFusion(BatchFusion):
614
+ """
615
+ Batch linear fusion in pre grad pass.
616
+ Fuse linear with same size with torch.baddmm
617
+ """
618
+
619
+ def _getitem_args(self, getitem_node: torch.fx.Node):
620
+ if getitem_node.target != operator.__getitem__ or (
621
+ getitem_node.op != "call_function"
622
+ ):
623
+ return None
624
+ return getitem_node.args[0]
625
+
626
+ def match(self, node: torch.fx.Node):
627
+ if CallFunctionVarArgs(torch.nn.functional.linear).match(
628
+ node
629
+ ) and is_linear_node_can_be_fused(node):
630
+ input = get_arg_value(node, 0, "input")
631
+ weight = get_arg_value(node, 1, "weight")
632
+ bias = get_arg_value(node, 2, "bias")
633
+ if self.graph_search_options.get("fuse_nodes_with_same_users", False):
634
+ users = [user.target for user in node.users.keys()]
635
+ else:
636
+ users = "" # type: ignore[assignment]
637
+ group_key = (
638
+ "batch_linear",
639
+ self._getitem_args(input),
640
+ str(input.meta["example_value"].shape),
641
+ str(weight.meta["example_value"].shape),
642
+ bias is None,
643
+ str(users),
644
+ )
645
+ else:
646
+ group_key = None
647
+ return group_key
648
+
649
+ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
650
+ batch_nodes = []
651
+ batch_inputs = []
652
+ batch_weights = []
653
+ batch_biases = []
654
+ batch_inputs_metadata = []
655
+ batch_weights_metadata = []
656
+ batch_biases_metadata = []
657
+ for node in subset:
658
+ batch_nodes.append(node)
659
+ input = get_arg_value(node, 0, "input")
660
+ batch_inputs.append(input)
661
+ batch_inputs_metadata.append(input.meta["example_value"])
662
+ weight = get_arg_value(node, 1, "weight")
663
+ batch_weights.append(weight)
664
+ batch_weights_metadata.append(weight.meta["example_value"])
665
+ bias = get_arg_value(node, 2, "bias")
666
+ batch_biases.append(bias)
667
+ if bias is not None and hasattr(bias, "meta"):
668
+ batch_biases_metadata.append(bias.meta["example_value"])
669
+
670
+ with graph.inserting_before(subset[0]):
671
+ stack_inputs = graph.call_function(
672
+ torch.stack, args=(batch_inputs,), kwargs={"dim": 0}
673
+ )
674
+ update_stack_example_value(stack_inputs, batch_inputs_metadata)
675
+ stack_weights = graph.call_function(
676
+ torch.stack, args=(batch_weights,), kwargs={"dim": 0}
677
+ )
678
+ update_stack_example_value(stack_weights, batch_weights_metadata)
679
+ transpose_weight = graph.call_function(
680
+ torch.transpose, args=(stack_weights, 1, 2)
681
+ )
682
+ transpose_weight.meta["example_value"] = torch.transpose(
683
+ stack_weights.meta["example_value"], 1, 2
684
+ )
685
+ if all(bias is None for bias in batch_biases):
686
+ bmm = graph.call_function(
687
+ torch.bmm,
688
+ args=(stack_inputs, transpose_weight),
689
+ )
690
+ bmm.meta["example_value"] = torch.bmm(
691
+ stack_inputs.meta["example_value"],
692
+ transpose_weight.meta["example_value"],
693
+ )
694
+ bmm_meta = bmm.meta["example_value"]
695
+ else:
696
+ stack_biases = graph.call_function(
697
+ torch.stack, args=(batch_biases,), kwargs={"dim": 0}
698
+ )
699
+ update_stack_example_value(stack_biases, batch_biases_metadata)
700
+ unsqueeze_biases = graph.call_function(
701
+ torch.unsqueeze, args=(stack_biases, 1)
702
+ )
703
+ unsqueeze_biases.meta["example_value"] = torch.unsqueeze(
704
+ stack_biases.meta["example_value"], 1
705
+ )
706
+ bmm = graph.call_function(
707
+ torch.baddbmm,
708
+ args=(unsqueeze_biases, stack_inputs, transpose_weight),
709
+ )
710
+ try:
711
+ # it will have runtime error to broadcast when it has dynamic shape included
712
+ # in the meta data, so we need to skip the update meta data
713
+ bmm.meta["example_value"] = torch.baddbmm(
714
+ unsqueeze_biases.meta["example_value"],
715
+ stack_inputs.meta["example_value"],
716
+ transpose_weight.meta["example_value"],
717
+ )
718
+ bmm_meta = bmm.meta["example_value"]
719
+ except Exception as e:
720
+ log.debug(
721
+ f" exception when update bmm meta data with stack error tracekey {e}" # noqa: G004
722
+ )
723
+ bmm_meta = None
724
+
725
+ bmm = graph.call_function(torch.unbind, args=(bmm,), kwargs={"dim": 0})
726
+ if bmm_meta is not None:
727
+ bmm.meta["example_value"] = torch.unbind(bmm_meta, dim=0)
728
+ for i, linear in enumerate(batch_nodes):
729
+ with graph.inserting_after(bmm):
730
+ getitem = graph.call_function(operator.getitem, args=(bmm, i))
731
+ linear.replace_all_uses_with(getitem)
732
+ getitem.meta.update(linear.meta)
733
+ graph.erase_node(linear)
734
+ counters["inductor"]["batch_linear"] += 1
735
+
736
+
737
+ @register_fusion("batch_layernorm")
738
+ class BatchLayernormFusion(BatchFusion):
739
+ """
740
+ Batch layer norm fusion in pre grad pass
741
+ """
742
+
743
+ def match(self, node: torch.fx.Node):
744
+ if CallFunctionVarArgs(torch.nn.functional.layer_norm).match(node):
745
+ input = get_arg_value(node, 0, "input")
746
+ weight = get_arg_value(node, 2, "weight")
747
+ bias = get_arg_value(node, 3, "bias")
748
+ if self.graph_search_options.get("fuse_nodes_with_same_users", False):
749
+ users = [user.target for user in node.users.keys()]
750
+ else:
751
+ users = "" # type: ignore[assignment]
752
+ group_key = (
753
+ (
754
+ "batch_layernorm",
755
+ str(input.meta["example_value"].shape),
756
+ str(weight.meta["example_value"].shape)
757
+ if weight is not None
758
+ else "",
759
+ str(bias.meta["example_value"].shape) if bias is not None else "",
760
+ str(get_arg_value(node, 1, "normalized_shape")),
761
+ str(get_arg_value(node, 4, "eps")),
762
+ str(users),
763
+ )
764
+ if "example_value" in input.meta
765
+ and is_node_meta_valid(weight)
766
+ and is_node_meta_valid(bias)
767
+ else None
768
+ )
769
+ else:
770
+ group_key = None
771
+ return group_key
772
+
773
+ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
774
+ group_inputs = []
775
+ group_shapes = []
776
+ group_weights = []
777
+ group_biases = []
778
+ group_epss = []
779
+ group_nodes = []
780
+ group_inputs_metadata = []
781
+ group_biases_metadata = []
782
+ group_weights_metadata = []
783
+ for node in subset:
784
+ group_nodes.append(node)
785
+ input = get_arg_value(node, 0, "input")
786
+ group_inputs.append(input)
787
+ group_inputs_metadata.append(input.meta["example_value"])
788
+ group_shapes.append(get_arg_value(node, 1, "normalized_shape"))
789
+ weight = get_arg_value(node, 2, "weight")
790
+ group_weights.append(weight)
791
+ if weight is not None and hasattr(weight, "meta"):
792
+ group_weights_metadata.append(weight.meta["example_value"])
793
+ bias = get_arg_value(node, 3, "bias")
794
+ group_biases.append(bias)
795
+ if bias is not None and hasattr(bias, "meta"):
796
+ group_biases_metadata.append(bias.meta["example_value"])
797
+ eps = get_arg_value(node, 4, "eps")
798
+ if eps is None:
799
+ eps = 1e-5
800
+ group_epss.append(eps)
801
+ stack_dim = -1 - len(group_shapes[-1])
802
+
803
+ if all(bias is None for bias in group_biases):
804
+ group_biases = None # type: ignore[assignment]
805
+ if all(weight is None for weight in group_weights):
806
+ group_weights = None # type: ignore[assignment]
807
+ assert all(
808
+ eps == group_epss[0] for eps in group_epss
809
+ ), "all epsilon values must be equal"
810
+
811
+ with graph.inserting_before(subset[0]):
812
+ stack_input = graph.call_function(
813
+ torch.stack, args=(group_inputs,), kwargs={"dim": stack_dim}
814
+ )
815
+ update_stack_example_value(stack_input, group_inputs_metadata, stack_dim)
816
+ if group_weights is not None:
817
+ stack_weight = graph.call_function(
818
+ torch.stack, args=(group_weights,), kwargs={"dim": 0}
819
+ )
820
+ update_stack_example_value(stack_weight, group_weights_metadata)
821
+ else:
822
+ stack_weight = None
823
+ if group_biases is not None:
824
+ stack_bias = graph.call_function(
825
+ torch.stack, args=(group_biases,), kwargs={"dim": 0}
826
+ )
827
+ update_stack_example_value(stack_bias, group_biases_metadata)
828
+ else:
829
+ stack_bias = None
830
+
831
+ batch_layer_norm = graph.call_function(
832
+ torch.nn.functional.layer_norm,
833
+ args=(stack_input, group_shapes[-1]),
834
+ kwargs={"eps": group_epss[-1]},
835
+ )
836
+ batch_layer_norm.meta["example_value"] = stack_input.meta["example_value"]
837
+
838
+ if group_weights is not None and group_biases is not None:
839
+ previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
840
+ batch_layer_norm = graph.call_function(
841
+ torch.mul, args=(stack_weight, batch_layer_norm)
842
+ )
843
+ update_pointwise_example_value(
844
+ batch_layer_norm,
845
+ stack_weight.meta["example_value"],
846
+ previous_batch_layer_norm_meta,
847
+ torch.mul,
848
+ )
849
+ previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
850
+ batch_layer_norm = graph.call_function(
851
+ torch.add, args=(stack_bias, batch_layer_norm)
852
+ )
853
+ update_pointwise_example_value(
854
+ batch_layer_norm,
855
+ stack_bias.meta["example_value"],
856
+ previous_batch_layer_norm_meta,
857
+ torch.add,
858
+ )
859
+ elif group_weights is not None and group_biases is None:
860
+ previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
861
+ batch_layer_norm = graph.call_function(
862
+ torch.mul, args=(stack_weight, batch_layer_norm)
863
+ )
864
+ update_pointwise_example_value(
865
+ batch_layer_norm,
866
+ stack_weight.meta["example_value"],
867
+ previous_batch_layer_norm_meta,
868
+ torch.mul,
869
+ )
870
+ elif group_weights is None and group_biases is not None:
871
+ previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
872
+ batch_layer_norm = graph.call_function(
873
+ torch.add, args=(stack_bias, batch_layer_norm)
874
+ )
875
+ update_pointwise_example_value(
876
+ batch_layer_norm,
877
+ stack_bias.meta["example_value"],
878
+ previous_batch_layer_norm_meta,
879
+ torch.add,
880
+ )
881
+
882
+ batch_layer_norm_unbind = graph.call_function(
883
+ torch.unbind,
884
+ args=(batch_layer_norm,),
885
+ kwargs={"dim": stack_dim},
886
+ )
887
+ update_stack_example_value(
888
+ batch_layer_norm_unbind,
889
+ batch_layer_norm.meta["example_value"],
890
+ op=torch.unbind,
891
+ dim=stack_dim,
892
+ )
893
+
894
+ for i, node in enumerate(group_nodes):
895
+ with graph.inserting_after(batch_layer_norm_unbind):
896
+ new_node = graph.call_function(
897
+ operator.getitem, args=(batch_layer_norm_unbind, i)
898
+ )
899
+ node.replace_all_uses_with(new_node)
900
+ new_node.meta.update(node.meta)
901
+ graph.erase_node(node)
902
+ counters["inductor"]["batch_layernorm"] += 1
903
+
904
+
905
+ class BatchPointwiseOpsPreGradFusion(BatchPointwiseOpsFusionFactory):
906
+ """
907
+ Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in pre grad pass.
908
+ We fuse it in random place, and the introduced stack node may be merged in split cat.
909
+ """
910
+
911
+ def __init__(self, op, **kwargs) -> None:
912
+ super().__init__(op, **kwargs)
913
+ self.op = op
914
+
915
+ def match(self, node: torch.fx.Node):
916
+ input = get_arg_value(node, 0, "input")
917
+ if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node):
918
+ if self.graph_search_options.get("fuse_nodes_with_same_parent", False):
919
+ # pyre-fixme[16]
920
+ parent = node.args[0]
921
+ parent = parent.target if parent is not None else "" # type: ignore[union-attr]
922
+ else:
923
+ parent = ""
924
+ # for relu op, we also use the inplace to construct the key
925
+ group_key = (
926
+ "batch_" + self.op.__name__.lower().split(".")[0],
927
+ str(input.meta["example_value"].shape),
928
+ str(node.kwargs.get("inplace", False)),
929
+ str(parent),
930
+ )
931
+ else:
932
+ group_key = None
933
+ return group_key
934
+
935
+ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
936
+ batch_nodes = []
937
+ batch_inputs = []
938
+ batch_inputs_metadata = []
939
+
940
+ for node in subset:
941
+ batch_nodes.append(node)
942
+ input = get_arg_value(node, 0, "input")
943
+ batch_inputs.append(input)
944
+ batch_inputs_metadata.append(input.meta["example_value"])
945
+
946
+ with graph.inserting_before(subset[0]):
947
+ stack_inputs = graph.call_function(
948
+ torch.stack, args=(batch_inputs,), kwargs={"dim": 0}
949
+ )
950
+ update_stack_example_value(stack_inputs, batch_inputs_metadata)
951
+ if self.op == torch.nn.functional.relu:
952
+ batch_op = graph.call_function(
953
+ self.op,
954
+ args=(stack_inputs,),
955
+ kwargs={"inplace": subset[0].kwargs.get("inplace", False)},
956
+ )
957
+ batch_op.meta["example_value"] = self.op(
958
+ stack_inputs.meta["example_value"],
959
+ inplace=subset[0].kwargs.get("inplace", False),
960
+ )
961
+ else:
962
+ batch_op = graph.call_function(
963
+ self.op,
964
+ args=(stack_inputs,),
965
+ )
966
+ batch_op.meta["example_value"] = self.op(
967
+ stack_inputs.meta["example_value"]
968
+ )
969
+ unbind_op = graph.call_function(
970
+ torch.unbind, args=(batch_op,), kwargs={"dim": 0}
971
+ )
972
+ unbind_op.meta["example_value"] = torch.unbind(
973
+ batch_op.meta["example_value"], dim=0
974
+ )
975
+ for i, node in enumerate(batch_nodes):
976
+ with graph.inserting_after(unbind_op):
977
+ getitem = graph.call_function(operator.getitem, args=(unbind_op, i))
978
+ node.replace_all_uses_with(getitem)
979
+ getitem.meta.update(node.meta)
980
+ graph.erase_node(node)
981
+ counters["inductor"]["batch_" + self.op.__name__.lower().split(".")[0]] += 1
982
+
983
+
984
+ class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
985
+ """
986
+ Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in post grad pass.
987
+ The introduced stack node may be merged in split cat.
988
+ """
989
+
990
+ def __init__(self, op, **kwargs) -> None:
991
+ super().__init__(op, **kwargs)
992
+ self.op = op
993
+
994
+ def match(self, node: torch.fx.Node):
995
+ input = get_arg_value(node, 0, "input")
996
+ if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node):
997
+ # for relu op, we also use the inplace to construct the key
998
+ # we batch the ops with same parent to enable followup split cat
999
+ parent = node.args[0]
1000
+ parent = parent.target if self.graph_search_options.get("fuse_nodes_with_same_parent", False) else "" # type: ignore[union-attr]
1001
+ group_key = (
1002
+ "batch_aten_" + self.op.__name__.lower().split(".")[0],
1003
+ str(input.meta["val"].shape),
1004
+ str(node.kwargs.get("inplace", False)),
1005
+ # pyre-fixme[16]
1006
+ str(parent),
1007
+ )
1008
+ else:
1009
+ group_key = None
1010
+ return group_key
1011
+
1012
+ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
1013
+ batch_nodes = []
1014
+ batch_inputs = []
1015
+ batch_inputs_metadata = []
1016
+
1017
+ for node in subset:
1018
+ batch_nodes.append(node)
1019
+ input = get_arg_value(node, 0, "input")
1020
+ batch_inputs.append(input)
1021
+ batch_inputs_metadata.append(input.meta["val"])
1022
+
1023
+ with graph.inserting_before(subset[0]):
1024
+ stack_inputs = decompose_stack(graph, batch_inputs)
1025
+ update_stack_example_value(stack_inputs, batch_inputs_metadata)
1026
+ batch_op = graph.call_function(
1027
+ self.op,
1028
+ args=(stack_inputs,),
1029
+ )
1030
+ for i, node in enumerate(batch_nodes):
1031
+ with graph.inserting_after(batch_op):
1032
+ getitem = graph.call_function(aten.select, args=(batch_op, 0, i))
1033
+ node.replace_all_uses_with(getitem)
1034
+ getitem.meta.update(node.meta)
1035
+ graph.erase_node(node)
1036
+ counters["inductor"][
1037
+ "batch_aten_" + self.op.__name__.lower().split(".")[0]
1038
+ ] += 1
1039
+
1040
+
1041
+ @register_fusion("batch_tanh")
1042
+ class BatchTanhPreGradFusion(BatchPointwiseOpsPreGradFusion):
1043
+ def __init__(self, **kwargs) -> None:
1044
+ super().__init__(torch.tanh, **kwargs)
1045
+
1046
+
1047
+ @register_fusion("batch_sigmoid")
1048
+ class BatchSigmoidPreGradFusion(BatchPointwiseOpsPreGradFusion):
1049
+ def __init__(self, **kwargs) -> None:
1050
+ super().__init__(torch.sigmoid, **kwargs)
1051
+
1052
+
1053
+ @register_fusion("batch_relu")
1054
+ class BatchReLuPreGradFusion(BatchPointwiseOpsPreGradFusion):
1055
+ def __init__(self, **kwargs) -> None:
1056
+ super().__init__(torch.nn.functional.relu, **kwargs)
1057
+
1058
+
1059
+ @register_fusion("batch_aten_tanh", pre_grad=False)
1060
+ class BatchTanhPostGradFusion(BatchPointwiseOpsPostGradFusion):
1061
+ def __init__(self, **kwargs) -> None:
1062
+ super().__init__(aten.tanh.default, **kwargs)
1063
+
1064
+
1065
+ @register_fusion("batch_aten_sigmoid", pre_grad=False)
1066
+ class BatchSigmoidPostGradFusion(BatchPointwiseOpsPostGradFusion):
1067
+ def __init__(self, **kwargs) -> None:
1068
+ super().__init__(aten.sigmoid.default, **kwargs)
1069
+
1070
+
1071
+ @register_fusion("batch_aten_relu", pre_grad=False)
1072
+ class BatchReLuPostGradFusion(BatchPointwiseOpsPostGradFusion):
1073
+ def __init__(self, **kwargs) -> None:
1074
+ super().__init__(aten.relu.default, **kwargs)
1075
+
1076
+
1077
+ @register_fusion("batch_aten_add", pre_grad=False)
1078
+ class BatchAddPostGradFusion(BatchPointwiseMathOpsPostGradFusion):
1079
+ def __init__(self, **kwargs) -> None:
1080
+ super().__init__(aten.add.Tensor, **kwargs)
1081
+
1082
+
1083
+ @register_fusion("batch_aten_sub", pre_grad=False)
1084
+ class BatchSubPostGradFusion(BatchPointwiseMathOpsPostGradFusion):
1085
+ def __init__(self, **kwargs) -> None:
1086
+ super().__init__(aten.sub.Tensor, **kwargs)
1087
+
1088
+
1089
+ @register_fusion("batch_aten_div", pre_grad=False)
1090
+ class BatchDivPostGradFusion(BatchPointwiseMathOpsPostGradFusion):
1091
+ def __init__(self, **kwargs) -> None:
1092
+ super().__init__(aten.div.Tensor, **kwargs)
1093
+
1094
+
1095
+ @register_fusion("batch_aten_mul", pre_grad=False)
1096
+ class BatchMulPostGradFusion(BatchPointwiseMathOpsPostGradFusion):
1097
+ def __init__(self, **kwargs) -> None:
1098
+ super().__init__(aten.mul.Tensor, **kwargs)
1099
+
1100
+
1101
+ class _OrderedSet:
1102
+ def __init__(self, param=None) -> None:
1103
+ if param:
1104
+ self.rep = OrderedDict(dict.fromkeys(param))
1105
+ else:
1106
+ self.rep = OrderedDict()
1107
+
1108
+ def __contains__(self, o) -> bool:
1109
+ return o in self.rep
1110
+
1111
+ def __len__(self) -> int:
1112
+ return self.rep.__len__()
1113
+
1114
+ def append(self, o):
1115
+ self.rep[o] = None
1116
+
1117
+ def __iter__(self):
1118
+ return self.rep.keys().__iter__()
1119
+
1120
+
1121
+ def find_independent_subset_greedy(
1122
+ node_list: Iterable[torch.fx.Node],
1123
+ graph_search_options: Dict[str, Any],
1124
+ ) -> Iterator[Iterable[torch.fx.Node]]:
1125
+ """
1126
+ Yields a list of subsets of `node_list` where no element in the subset
1127
+ depends on any other element in the subset. This results in a set of
1128
+ independent nodes which can be fused together.
1129
+
1130
+ The order of `node_list` is preserved within each subset so we can benefit
1131
+ from split-cat elimination in later passes.
1132
+
1133
+ During iteration it is only safe to mutate the graph by changing the nodes
1134
+ that have been returned.
1135
+
1136
+ graph_search_options:
1137
+ - min_fuse_set_size: Minimum size of the subset to consider. Subsets below
1138
+ this size will be ignored.
1139
+ - max_fuse_set_size: Maximum size of the subset to consider. Subsets will
1140
+ be broken to be at most this size.
1141
+ """
1142
+
1143
+ # Compute all the children of `node` which are members of
1144
+ # `interesting_nodes`.
1145
+ def find_dependent_nodes(node, interesting_nodes):
1146
+ visited_node_set: Set[torch.fx.Node] = {node}
1147
+ dep_set: Set[torch.fx.Node] = set()
1148
+
1149
+ work = [node]
1150
+ while work:
1151
+ node = work.pop()
1152
+ for input_node in node.all_input_nodes:
1153
+ if input_node in interesting_nodes:
1154
+ dep_set.add(input_node)
1155
+
1156
+ if input_node not in visited_node_set:
1157
+ visited_node_set.add(input_node)
1158
+ work.append(input_node)
1159
+
1160
+ return dep_set
1161
+
1162
+ min_fuse_set_size = graph_search_options["min_fuse_set_size"]
1163
+ max_fuse_set_size = graph_search_options["max_fuse_set_size"]
1164
+
1165
+ # node_list needs to be a set because we only track the nodes that are left
1166
+ # in it (and we want to do the `in` on a set, not a list). But we want to
1167
+ # keep the correct order.
1168
+ node_list = _OrderedSet(node_list)
1169
+
1170
+ cache: Dict[torch.fx.Node, Set[torch.fx.Node]] = {}
1171
+ while node_list:
1172
+ subset: List[torch.fx.Node] = []
1173
+ subset_deps: Set[torch.fx.Node] = set()
1174
+
1175
+ next_round_node_list = _OrderedSet()
1176
+ for node in node_list:
1177
+ if len(subset) >= max_fuse_set_size or node in subset_deps:
1178
+ next_round_node_list.append(node)
1179
+ continue
1180
+
1181
+ dep_set = cache.pop(node, None)
1182
+ if dep_set is None:
1183
+ dep_set = find_dependent_nodes(node, node_list)
1184
+
1185
+ if not dep_set.intersection(subset):
1186
+ subset.append(node)
1187
+ subset_deps.update(dep_set)
1188
+ else:
1189
+ next_round_node_list.append(node)
1190
+ cache[node] = dep_set
1191
+
1192
+ if len(subset) >= min_fuse_set_size:
1193
+ # Careful here - the caller uses the subsets to fuse nodes together
1194
+ # so we need to clear any cache entry that contains one of the
1195
+ # returned nodes because the dependency list could be different
1196
+ # (larger) after the merge.
1197
+ cache = {k: v for k, v in cache.items() if v.isdisjoint(subset)}
1198
+ yield subset
1199
+
1200
+ node_list = next_round_node_list
1201
+
1202
+
1203
+ def get_fusion_candidates(
1204
+ rule: GroupBatchFusionBase, root_node: torch.fx.Node, fused_set: Set[torch.fx.Node]
1205
+ ) -> DefaultDict[Any, List[torch.fx.Node]]:
1206
+ """
1207
+ Search fusion candidates for a specific rule using BFS starting from the root node.
1208
+ We only search the subgraph within graph_search_options["max_fuse_search_depth"].
1209
+ """
1210
+ q: Deque[Tuple[int, torch.fx.Node]] = collections.deque()
1211
+
1212
+ candidate_dict: DefaultDict[Any, List[torch.fx.Node]] = collections.defaultdict(
1213
+ list
1214
+ )
1215
+
1216
+ if root_node.target in SEARCH_EXCLUSIONS:
1217
+ return candidate_dict
1218
+
1219
+ visited_set: Set[torch.fx.Node] = set()
1220
+
1221
+ for next_node in root_node.all_input_nodes:
1222
+ q.append((1, next_node))
1223
+ visited_set.add(next_node)
1224
+
1225
+ while len(q) > 0:
1226
+ depth, node = q.popleft()
1227
+
1228
+ if node in fused_set:
1229
+ continue
1230
+
1231
+ key = rule.match(node)
1232
+ if key is not None:
1233
+ candidate_nodes = candidate_dict[key]
1234
+ if node not in candidate_nodes:
1235
+ candidate_nodes.append(node)
1236
+ else:
1237
+ if depth < rule.graph_search_options["max_fuse_search_depth"]:
1238
+ for next_node in node.all_input_nodes:
1239
+ if next_node not in visited_set:
1240
+ visited_set.add(next_node)
1241
+ q.append((depth + 1, next_node))
1242
+
1243
+ return candidate_dict
1244
+
1245
+
1246
+ def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusionBase):
1247
+ stable_topological_sort(graph) # type: ignore[arg-type]
1248
+ fused_set: Set[torch.fx.Node] = set()
1249
+ log_to_scuba = False
1250
+
1251
+ for node in reversed(graph.nodes):
1252
+ candidates = get_fusion_candidates(rule, node, fused_set)
1253
+
1254
+ for key, candidate_nodes in candidates.items():
1255
+ if len(candidate_nodes) < rule.graph_search_options["min_fuse_set_size"]:
1256
+ continue
1257
+
1258
+ for subset in find_independent_subset_greedy(
1259
+ candidate_nodes, rule.graph_search_options
1260
+ ):
1261
+ rule.fuse(graph, subset)
1262
+ fused_set.update(subset)
1263
+ log.debug(
1264
+ f"{rule.__class__.__name__}: key = {key}; subset size = {len(list(subset))}" # noqa: G004
1265
+ )
1266
+ log_to_scuba = True
1267
+ if log_to_scuba:
1268
+ optimus_scuba_log[rule.__class__.__name__] = upload_graph(graph)
1269
+
1270
+
1271
+ def generate_fusion_from_config(config_options: Dict[str, Any], pre_grad=True):
1272
+ fusions: List[GroupBatchFusionBase] = []
1273
+ for name, options in config_options.items():
1274
+ # we skip all patterns from pattern_matcher passes (e.g., split_cat)
1275
+ if name not in PRE_GRAD_FUSIONS and name not in POST_GRAD_FUSIONS:
1276
+ continue
1277
+ fusion_cls = PRE_GRAD_FUSIONS[name] if pre_grad else POST_GRAD_FUSIONS[name]
1278
+ _options = graph_search_options.copy()
1279
+ _options.update(options)
1280
+ fusions.append(fusion_cls(graph_search_options=_options)) # type: ignore[operator]
1281
+ return fusions
1282
+
1283
+
1284
+ def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
1285
+ fusions: List[GroupBatchFusionBase] = []
1286
+ # we keep all current pre grad fusions to keep
1287
+ # current implementation, will remove this later
1288
+ if pre_grad:
1289
+ fusions += generate_fusion_from_config(
1290
+ config.pre_grad_fusion_options, pre_grad=True
1291
+ )
1292
+ else:
1293
+ fbgemm_fusion_keys = [
1294
+ x
1295
+ for x in config.post_grad_fusion_options
1296
+ if config.post_grad_fusion_options[x].get("require_fbgemm", False)
1297
+ ]
1298
+ fbgemm_fusions = {
1299
+ fusion: config.post_grad_fusion_options[fusion]
1300
+ for fusion in fbgemm_fusion_keys
1301
+ }
1302
+ non_fbgemm_fusions = {
1303
+ fusion: config.post_grad_fusion_options[fusion]
1304
+ for fusion in config.post_grad_fusion_options.keys()
1305
+ if fusion not in fbgemm_fusion_keys
1306
+ }
1307
+ fusions += generate_fusion_from_config(non_fbgemm_fusions, pre_grad=False)
1308
+ if has_fbgemm:
1309
+ fusions += generate_fusion_from_config(fbgemm_fusions, pre_grad=False)
1310
+
1311
+ for i, rule in enumerate(fusions):
1312
+ with GraphTransformObserver(
1313
+ graph.owning_module,
1314
+ f"group_batch_fusion_{i}",
1315
+ config.trace.log_url_for_graph_xform,
1316
+ ):
1317
+ apply_group_batch_fusion(graph, rule) # type: ignore[arg-type]
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/joint_graph.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import itertools
3
+ import logging
4
+ import typing
5
+ from collections import Counter
6
+ from typing import Any, Dict, List, Set, Union
7
+
8
+ import torch
9
+ import torch._guards
10
+ import torch.utils._pytree as pytree
11
+ from torch._inductor.constant_folding import ConstantFolder
12
+ from torch._inductor.fx_passes.dedupe_symint_uses import _SymHashingDict
13
+ from torch.fx.experimental.symbolic_shapes import statically_known_true
14
+ from torch.fx.passes.graph_transform_observer import GraphTransformObserver
15
+ from torch.multiprocessing.reductions import StorageWeakRef
16
+
17
+ from ...utils._ordered_set import OrderedSet
18
+ from .. import config
19
+ from ..pattern_matcher import (
20
+ CallFunction,
21
+ init_once_fakemode,
22
+ KeywordArg,
23
+ Match,
24
+ MULTIPLE,
25
+ PatternMatcherPass,
26
+ register_graph_pattern,
27
+ stable_topological_sort,
28
+ )
29
+ from .replace_random import replace_random_passes
30
+
31
+
32
+ log = logging.getLogger(__name__)
33
+ patterns = PatternMatcherPass()
34
+ aten = torch.ops.aten
35
+ prims = torch.ops.prims
36
+
37
+ pass_patterns = [
38
+ patterns,
39
+ PatternMatcherPass(),
40
+ ]
41
+
42
+
43
+ @init_once_fakemode
44
+ def lazy_init():
45
+ from .fuse_attention import _sfdp_init
46
+ from .misc_patterns import _misc_patterns_init
47
+ from .pad_mm import _pad_mm_init
48
+
49
+ _pad_mm_init()
50
+ _sfdp_init()
51
+ _misc_patterns_init()
52
+
53
+
54
+ def remove_no_ops(
55
+ gm: torch.fx.GraphModule, zeros: Set[torch.fx.Node], ones: Set[torch.fx.Node]
56
+ ):
57
+ with torch.utils._python_dispatch._disable_current_modes():
58
+ "Removes no-ops: (+ 0, - 0, * 1, / 1)"
59
+ graph = gm.graph
60
+
61
+ def fake_tensors_eq(t1, t2, fields=("shape", "dtype", "device")):
62
+ if any(not isinstance(t, torch.Tensor) for t in (t1, t2)):
63
+ return False
64
+ for field in fields:
65
+ if getattr(t1, field) != getattr(t2, field):
66
+ return False
67
+ return True
68
+
69
+ def replace_no_op(node, replace_input_index):
70
+ replacement = node.args[replace_input_index]
71
+
72
+ # https://github.com/pytorch/pytorch/issues/86128 causes
73
+ # non-Tensor inputs even for ops with only Tensor inputs.
74
+ # TODO - decompose/type promote to avoid this
75
+ if not all(isinstance(arg, torch.fx.Node) for arg in node.args):
76
+ return
77
+
78
+ if not fake_tensors_eq(node.meta["val"], replacement.meta["val"]):
79
+ if fake_tensors_eq(
80
+ node.meta["val"],
81
+ replacement.meta["val"],
82
+ ("shape", "device"),
83
+ ):
84
+ with graph.inserting_after(node):
85
+ replacement = graph.call_function(
86
+ torch.ops.prims.convert_element_type.default,
87
+ args=(replacement, node.meta["val"].dtype),
88
+ )
89
+ else:
90
+ return
91
+
92
+ node.replace_all_uses_with(replacement)
93
+ replacement.meta.update(node.meta)
94
+ graph.erase_node(node)
95
+
96
+ for node in graph.find_nodes(op="call_function", target=aten.add.Tensor):
97
+ # TODO handle Tensor-Scalar adds, it's a different schema
98
+ if len(node.args) == 2:
99
+ if (
100
+ not any(e in zeros for e in node.args)
101
+ or node.kwargs.get("alpha", 1) != 1
102
+ ):
103
+ continue
104
+
105
+ replace_index = 1 if node.args[0] in zeros else 0
106
+ replace_no_op(node, replace_index)
107
+
108
+ for node in graph.find_nodes(op="call_function", target=aten.sub.Tensor):
109
+ if len(node.args) == 2:
110
+ if node.args[1] not in zeros or node.kwargs.get("alpha", 1) != 1:
111
+ continue
112
+
113
+ replace_no_op(node, 0)
114
+
115
+ for node in graph.find_nodes(op="call_function", target=aten.mul.Tensor):
116
+ if len(node.args) == 2:
117
+ if not any(e in ones for e in node.args):
118
+ continue
119
+
120
+ replace_input_index = 1 if node.args[0] in ones else 0
121
+ replace_no_op(node, replace_input_index)
122
+
123
+ for node in graph.find_nodes(op="call_function", target=aten.div.Tensor):
124
+ if len(node.args) == 2 and node.args[1] in ones:
125
+ replace_no_op(node, 0)
126
+
127
+ # meta tensors returned from the graph have no data and can be replaced with empty_strided
128
+ for output_node in graph.find_nodes(op="output"):
129
+ had_meta_return = False
130
+
131
+ def visit(n):
132
+ nonlocal had_meta_return
133
+ val = n.meta.get("val")
134
+ if isinstance(val, torch.Tensor) and val.device.type == "meta":
135
+ with graph.inserting_before(output_node):
136
+ n.replace_all_uses_with(
137
+ graph.call_function(
138
+ torch.ops.aten.empty_strided.default,
139
+ args=(val.size(), val.stride()),
140
+ kwargs={"dtype": val.dtype, "device": val.device},
141
+ )
142
+ )
143
+ had_meta_return = True
144
+
145
+ torch.fx.map_arg(output_node.args, visit)
146
+ if had_meta_return:
147
+ graph.eliminate_dead_code()
148
+
149
+
150
+ def remove_redundant_views(gm: torch.fx.GraphModule):
151
+ """
152
+ Removes redundant views by reusing existing ones.
153
+ """
154
+ with torch.utils._python_dispatch._disable_current_modes():
155
+ # A dictionary mapping a tensor to all aliased views.
156
+ views: Dict[torch.fx.Node, Dict[torch.dtype, torch.fx.Node]] = {}
157
+ graph = gm.graph
158
+
159
+ for node in graph.find_nodes(
160
+ op="call_function", target=torch.ops.aten.view.dtype
161
+ ):
162
+ src = node.args[0]
163
+ to_type = node.args[1]
164
+ existing_views = views.get(src)
165
+ is_needed = True
166
+
167
+ if existing_views:
168
+ # Replace the view with the an existing view if available.
169
+ alias = existing_views.get(to_type)
170
+ if alias:
171
+ is_needed = False
172
+ node.replace_all_uses_with(alias)
173
+ alias.meta.update(node.meta)
174
+ graph.erase_node(node)
175
+ else:
176
+ from_type = src.meta["val"].dtype
177
+ existing_views = {from_type: src}
178
+ views[src] = existing_views
179
+
180
+ if is_needed:
181
+ # Save the new alias but do not replace existing one.
182
+ existing_views.setdefault(to_type, node)
183
+ views[node] = existing_views
184
+
185
+ # Clean up unused views.
186
+ while True:
187
+ unused_views = [alias for alias in views if not alias.users]
188
+ if len(unused_views) == 0:
189
+ break
190
+ for unused in unused_views:
191
+ views.pop(unused)
192
+ graph.erase_node(unused)
193
+
194
+
195
+ class UniformValueConstantFolder(ConstantFolder):
196
+ """
197
+ Runs constant folding and replaces tensors that have a unifrom value
198
+ with a tensor constructor call: aten.full([shape], value, ...)
199
+ """
200
+
201
+ def __init__(self, gm, skip_constructors=False) -> None:
202
+ super().__init__(gm, skip_constructors)
203
+ self.node_storages_ptrs: Dict[torch.fx.Node, int] = {}
204
+ self.constant_data_ptrs: Dict[torch.fx.Node, StorageWeakRef] = {}
205
+ # we may constant fold a tensor which in the graph has a sym size
206
+ # see: [constant folding refining of symints]
207
+ self.node_replacements_shapes: Dict[torch.fx.Node, List[int]] = {}
208
+
209
+ # initialize symint -> node mapping so that we can
210
+ # use symint nodes in full constructors
211
+ self.symint_nodes = _SymHashingDict()
212
+ for n in self.module.graph.nodes:
213
+ if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt):
214
+ self.symint_nodes[n.meta["val"]] = n
215
+
216
+ # reference from torch/_funtorch/partitioners.py:get_default_op_list
217
+ self.view_op_packets = [
218
+ aten.squeeze,
219
+ aten.unsqueeze,
220
+ aten.alias,
221
+ aten.view,
222
+ aten.slice,
223
+ aten.t,
224
+ prims.broadcast_in_dim,
225
+ aten.expand,
226
+ aten.as_strided,
227
+ aten.permute,
228
+ ]
229
+
230
+ self.indexing_op_packets = {
231
+ aten.slice,
232
+ }
233
+
234
+ def _support_dynamic_shape(self):
235
+ return True
236
+
237
+ def insertable_tensor_check(self, t: torch.Tensor) -> bool:
238
+ return True
239
+
240
+ def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
241
+ self.node_replacements[node] = tensor.flatten()[0].item()
242
+ self.node_replacements_shapes[node] = node.meta["val"].shape
243
+ self.constant_data_ptrs[node] = StorageWeakRef(tensor.untyped_storage())
244
+
245
+ def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
246
+ for n in self.module.graph.find_nodes(op="placeholder"):
247
+ if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt):
248
+ env[n] = n.meta["val"]
249
+ else:
250
+ env[n] = self.unknown_value
251
+
252
+ def _deduce_value(self, node: torch.fx.Node):
253
+ # deduce value for full-like nodes
254
+ # 1. for constructors, substitute value is a tensor of size [1]
255
+ # 2. for view ops/indexing, substitute value is the same as the input
256
+ # 3. for pointwise ops, run node to get the substitute value
257
+ # 4. deal with some special ops
258
+ # otherwise, stop deduce value and return unknown value
259
+
260
+ # TODO: cat, more indexing
261
+ # TODO - do on cpu to avoid syncs
262
+
263
+ # single-elem attrs
264
+ if node.op == "get_attr" or (
265
+ node.op == "call_function"
266
+ and node.target == torch.ops.aten.lift_fresh_copy.default
267
+ ):
268
+ out = super(ConstantFolder, self).run_node(node)
269
+ if isinstance(out, torch.Tensor) and out.numel() == 1:
270
+ return out
271
+
272
+ # handle device_put op
273
+ if node.target == prims.device_put.default:
274
+ return super(ConstantFolder, self).run_node(node)
275
+
276
+ # constructors ops
277
+ if (
278
+ node.op == "call_function"
279
+ and node.target == aten.full.default
280
+ and len(node.args) == 2
281
+ ):
282
+ args, kwargs = self.fetch_args_kwargs_from_env(node)
283
+ new_args = [[1], args[1]]
284
+ return aten.full.default(*new_args, **node.kwargs)
285
+
286
+ # handle before view ops because this changes value
287
+ if node.target == aten.view.dtype:
288
+ return super(ConstantFolder, self).run_node(node)
289
+
290
+ # view ops, return input tensor, the first argument
291
+ if hasattr(node.target, "overloadpacket") and (
292
+ node.target.overloadpacket in self.view_op_packets
293
+ or node.target.overloadpacket in self.indexing_op_packets
294
+ ):
295
+ assert isinstance(node.args[0], torch.fx.Node)
296
+ return self.env[node.args[0]]
297
+
298
+ # we don't want to return unknown value for symints so that we can
299
+ # still constant fold through their use in constructors or views
300
+ # if we see them in a pointwise node (e.g., tensor * symint)
301
+ # we will bail
302
+ if "val" in node.meta and isinstance(node.meta["val"], torch.SymInt):
303
+ return node.meta["val"]
304
+
305
+ # pointwise ops
306
+ if isinstance(node.target, torch._ops.OpOverload) and (
307
+ torch.Tag.pointwise in node.target.tags
308
+ or node.target is torch.ops.aten.scalar_tensor.default
309
+ ):
310
+ args, kwargs = self.fetch_args_kwargs_from_env(node)
311
+ flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
312
+
313
+ if any(isinstance(inp, torch.SymInt) for inp in flattened_inputs):
314
+ return self.unknown_value
315
+
316
+ # we run the ops with dim 1, so remove memory_format to avoid error
317
+ kwargs = dict(kwargs)
318
+ kwargs.pop("memory_format", None)
319
+
320
+ return node.target(*args, **kwargs)
321
+
322
+ return self.unknown_value
323
+
324
+
325
+ def constant_fold_uniform_value(gm: torch.fx.GraphModule):
326
+ with torch.utils._python_dispatch._disable_current_modes():
327
+ "Runs constant folding and replaces constants which can be constructed with a single `full` call. Calls into remove_no_ops."
328
+ aten = torch.ops.aten
329
+
330
+ # Constant folding can leak memory, especially with repeated compilation, so we are only going to
331
+ # remove constants which can be replaced with a constructor.
332
+ cf = UniformValueConstantFolder(gm)
333
+ cf.run()
334
+
335
+ node_replacements = cf.node_replacements
336
+
337
+ # note: [constant folding refining of symints]
338
+ # constant folding will partially evaluate a graph such that values which have dependencies which
339
+ # are entirely known at compile time may also become compile time constants. in some cases,
340
+ # this will include symints which we had not yet previously deduced are guaranteed a
341
+ # constant value and is then deduced in constant folding. an example is:
342
+ # unbacked_symint_eq_11 = torch.full((), 11).item()
343
+ # torch.full((unbacked_symint_eq_11,), 0)
344
+ node_replacements_shapes = cf.node_replacements_shapes
345
+
346
+ graph = gm.graph
347
+
348
+ zeros = set()
349
+ ones = set()
350
+
351
+ # Got failures in `test_is_set_to_cuda` if we change aliasing on constants,
352
+ # so just constant-ify if a Tensor is unaliased
353
+ constant_data_ptr_count: typing.Counter[StorageWeakRef] = Counter()
354
+
355
+ for node in cf.node_replacements:
356
+ constant_data_ptr_count[cf.constant_data_ptrs[node]] += 1
357
+
358
+ for node, value in node_replacements.items():
359
+ # we dont have a functional way right now of instantiating a non-contiguous tensor with full/zeros/ones right now
360
+ # hasn't shown up to be important yet
361
+ if "val" not in node.meta:
362
+ # This can only happen in AOTI
363
+ continue
364
+
365
+ fake_tensor = node.meta["val"]
366
+ if not fake_tensor.is_contiguous(memory_format=torch.contiguous_format):
367
+ continue
368
+
369
+ # TODO - not sure about lossy uint->python value->uint conversions
370
+ if fake_tensor.dtype in (
371
+ torch.uint8,
372
+ torch.uint16,
373
+ torch.uint32,
374
+ torch.uint64,
375
+ ):
376
+ continue
377
+
378
+ if constant_data_ptr_count[cf.constant_data_ptrs[node]] > 1:
379
+ continue
380
+
381
+ with graph.inserting_after(node):
382
+ # the conversion from tensor and back to value can be lossy, just use the original full ctor value
383
+ if (
384
+ node.op == "call_function"
385
+ and node.target == aten.full.default
386
+ and len(node.args) == 2
387
+ ):
388
+ value = node.args[1]
389
+
390
+ # refines symints, see [constant folding refining of symints] above
391
+ for runtime_size, compile_time_size in zip(
392
+ node_replacements_shapes[node], fake_tensor.shape
393
+ ):
394
+ torch._check(runtime_size == compile_time_size)
395
+
396
+ # replace SymInt as Node before creating a new full node
397
+ # e.g. (1, s0) -> (1, arg0_1)
398
+ node_shape = node_replacements_shapes[node]
399
+ if not all(
400
+ not isinstance(s, torch.SymInt) or s in cf.symint_nodes
401
+ for s in node_shape
402
+ ):
403
+ continue
404
+
405
+ shapes = [
406
+ cf.symint_nodes[s] if isinstance(s, torch.SymInt) else s
407
+ for s in node_replacements_shapes[node]
408
+ ]
409
+
410
+ # zeros and ones just get traced into full, so we insert those
411
+ new_node = graph.call_function(
412
+ aten.full.default,
413
+ args=(shapes, value),
414
+ kwargs={
415
+ "dtype": fake_tensor.dtype,
416
+ "layout": torch.strided,
417
+ "device": fake_tensor.device,
418
+ "pin_memory": False,
419
+ },
420
+ )
421
+
422
+ new_node.meta.update(node.meta)
423
+ node.replace_all_uses_with(new_node)
424
+ graph.erase_node(node)
425
+
426
+ if value == 0:
427
+ zeros.add(new_node)
428
+ elif value == 1:
429
+ ones.add(new_node)
430
+
431
+ remove_no_ops(gm, zeros, ones)
432
+ remove_redundant_views(gm)
433
+
434
+
435
+ def joint_graph_passes(graph: torch.fx.GraphModule):
436
+ """
437
+ Run FX transformations on the joint forwards+backwards graph.
438
+ """
439
+ lazy_init()
440
+ count = 0
441
+ if config.joint_custom_pre_pass is not None:
442
+ with GraphTransformObserver(
443
+ graph, "joint_custom_pre_pass", config.trace.log_url_for_graph_xform
444
+ ):
445
+ config.joint_custom_pre_pass(graph.graph)
446
+ count += 1
447
+
448
+ from .post_grad import remove_noop_ops
449
+
450
+ remove_noop_ops(graph.graph)
451
+
452
+ if config.joint_graph_constant_folding:
453
+ with GraphTransformObserver(
454
+ graph, "constant_fold_uniform_value", config.trace.log_url_for_graph_xform
455
+ ):
456
+ constant_fold_uniform_value(graph)
457
+
458
+ if config.pattern_matcher:
459
+ for patterns in pass_patterns:
460
+ count += patterns.apply(graph.graph) # type: ignore[arg-type]
461
+
462
+ if not config.fallback_random:
463
+ count += replace_random_passes(graph)
464
+
465
+ if config.joint_custom_post_pass is not None:
466
+ with GraphTransformObserver(
467
+ graph, "joint_custom_post_pass", config.trace.log_url_for_graph_xform
468
+ ):
469
+ config.joint_custom_post_pass(graph.graph)
470
+ count += 1
471
+
472
+ if count:
473
+ stable_topological_sort(graph.graph)
474
+ graph.graph.lint()
475
+ graph.recompile()
476
+ return graph
477
+
478
+
479
+ @register_graph_pattern(
480
+ CallFunction(
481
+ torch.ops.prims.iota.default,
482
+ KeywordArg("length"),
483
+ start=KeywordArg("start"),
484
+ step=KeywordArg("step"),
485
+ dtype=KeywordArg("dtype"),
486
+ device=KeywordArg("device"),
487
+ requires_grad=KeywordArg("requires_grad"),
488
+ ),
489
+ pass_dict=patterns,
490
+ )
491
+ def fix_iota_device(match: Match, length, start, step, dtype, device, requires_grad):
492
+ """
493
+ Eager supports:
494
+
495
+ aten.index(cuda_tensor, torch.arange(..., device="cpu"))
496
+
497
+ But this results in an implicit host-device-copy and breaks cudagraphs.
498
+ Rewrite the arange to use CUDA.
499
+ """
500
+ (node,) = match.nodes
501
+ user_devices: OrderedSet[torch.device] = OrderedSet()
502
+ for user in node.users:
503
+ if (
504
+ user.op == "call_function"
505
+ and user.target in (aten.index.Tensor, aten.index_put.default)
506
+ and hasattr(user.meta.get("val"), "device")
507
+ ):
508
+ user_devices.add(user.meta["val"].device) # type: ignore[union-attr]
509
+ else:
510
+ return # bail out
511
+
512
+ if len(user_devices) == 1 and "val" in node.meta:
513
+ (user_device,) = user_devices
514
+ if device.type != user_device.type:
515
+ repl = match.graph.call_function(
516
+ torch.ops.prims.iota.default,
517
+ (length,),
518
+ {
519
+ "start": start,
520
+ "step": step,
521
+ "dtype": dtype,
522
+ "device": user_device,
523
+ "requires_grad": requires_grad,
524
+ },
525
+ )
526
+ repl.meta.update(node.meta)
527
+ repl.meta["val"] = repl.meta["val"].to(user_device)
528
+ node.replace_all_uses_with(repl)
529
+ match.erase_nodes()
530
+
531
+
532
+ @register_graph_pattern(
533
+ CallFunction(
534
+ torch.ops.prims.convert_element_type.default,
535
+ CallFunction(
536
+ torch.ops.prims.convert_element_type.default,
537
+ KeywordArg("arg"),
538
+ KeywordArg("dtype1"),
539
+ ),
540
+ KeywordArg("dtype2"),
541
+ ),
542
+ pass_dict=patterns,
543
+ )
544
+ def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtype):
545
+ """Remove chain of dtype conversions often created by AMP"""
546
+ graph = match.graph
547
+ node = match.output_node()
548
+ allowed = {torch.float16, torch.bfloat16, torch.float32, torch.float64}
549
+ if dtype1 in allowed and dtype2 in allowed:
550
+ repl = graph.call_function(
551
+ torch.ops.prims.convert_element_type.default, (arg, dtype2)
552
+ )
553
+ repl.meta.update(node.meta)
554
+ node.replace_all_uses_with(repl)
555
+ match.erase_nodes()
556
+
557
+
558
+ @register_graph_pattern(
559
+ CallFunction(torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")),
560
+ pass_dict=patterns,
561
+ )
562
+ def pointless_view(match: Match, arg, size):
563
+ """Remove no-op view"""
564
+ node = match.output_node()
565
+ arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr]
566
+ if size == arg_size:
567
+ node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type]
568
+ match.erase_nodes()
569
+
570
+
571
+ # When softmax is used with temperature or other scaling, we get the pattern
572
+ #
573
+ # scale(x) - scale(x).amax(dim, keepdim=True)
574
+ #
575
+ # which is expected to be at most zero, but we may end up with numerical
576
+ # discrepancies # between the recomputed values of scale(x) inside and out
577
+ # of the reduction, # depending on compiler optimizations, e.g. use of fma
578
+ # instructions.
579
+ #
580
+ # Here we replace it with the mathematically equivalent,
581
+ #
582
+ # scale(x - x.amax(dim, keepdim=True))
583
+ #
584
+ # which is more stable as we only compute the scaling once.
585
+ #
586
+ # NOTE: This pattern must come after fused attention matching!
587
+
588
+
589
+ def _partial_softmax_pattern(linear_func, reverse=False, to_dtype=False):
590
+ # Allow matching inp * other and other * input
591
+ if reverse:
592
+ scaled = CallFunction(
593
+ linear_func, KeywordArg("other"), KeywordArg("inp"), _users=MULTIPLE
594
+ )
595
+ else:
596
+ scaled = CallFunction(
597
+ linear_func, KeywordArg("inp"), KeywordArg("other"), _users=MULTIPLE
598
+ )
599
+ if to_dtype:
600
+ scaled = CallFunction(
601
+ prims.convert_element_type, scaled, KeywordArg("dtype"), _users=MULTIPLE
602
+ )
603
+ amax = CallFunction(
604
+ aten.amax.default, scaled, KeywordArg("dim"), KeywordArg("keepdim")
605
+ )
606
+ return CallFunction(aten.sub.Tensor, scaled, amax)
607
+
608
+
609
+ def _other_is_broadcasted_in_dim(match):
610
+ # Check that the scaling factor is constant across the reduction dim,
611
+ # so scaling doesn't change which index corresponds to the maximum value
612
+ other = match.kwargs["other"]
613
+ if isinstance(other, (int, float)):
614
+ return True
615
+
616
+ inp = match.kwargs["inp"]
617
+ if not all(isinstance(x, torch.fx.Node) for x in (inp, other)):
618
+ return False
619
+
620
+ inp_example = inp.meta["val"]
621
+ other_example = other.meta["val"]
622
+ if isinstance(other_example, (torch.SymInt, torch.SymFloat)):
623
+ return True
624
+
625
+ if not all(isinstance(x, torch.Tensor) for x in (inp_example, other_example)):
626
+ return False
627
+
628
+ inp_ndim = inp_example.ndim
629
+ other_shape = other_example.shape
630
+ if inp_ndim < len(other_shape):
631
+ return False
632
+
633
+ # Pad other_shape to the same ndim as inp
634
+ other_shape = [1] * (inp_ndim - len(other_shape)) + list(other_shape)
635
+
636
+ dim = match.kwargs["dim"]
637
+ if isinstance(dim, int):
638
+ dim = (dim,)
639
+
640
+ return all(statically_known_true(other_shape[d] == 1) for d in dim)
641
+
642
+
643
+ def mul_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None):
644
+ def repl(inp, other):
645
+ if dtype is not None:
646
+ inp = inp.to(dtype)
647
+
648
+ sign: Union[int, float, torch.Tensor]
649
+ if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)):
650
+ sign = 1 if other >= 0 else -1
651
+ else:
652
+ one = torch.scalar_tensor(1, dtype=inp.dtype, device=inp.device)
653
+ sign = torch.where(other >= 0, one, -one)
654
+
655
+ inp = inp * sign
656
+ max_ = torch.amax(inp, dim=dim, keepdim=keepdim)
657
+ return (inp - max_) * (sign * other)
658
+
659
+ match.replace_by_example(repl, [inp, other])
660
+
661
+
662
+ for reverse, to_dtype in itertools.product((False, True), repeat=2):
663
+ register_graph_pattern(
664
+ _partial_softmax_pattern(aten.mul.Tensor, reverse=reverse, to_dtype=to_dtype),
665
+ pass_dict=pass_patterns[1],
666
+ extra_check=_other_is_broadcasted_in_dim,
667
+ )(mul_softmax_pattern)
668
+
669
+
670
+ def div_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None):
671
+ def repl(inp, other):
672
+ if dtype is not None:
673
+ inp = inp.to(dtype)
674
+
675
+ sign: Union[int, float, torch.Tensor]
676
+ if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)):
677
+ sign = 1 if other >= 0 else -1
678
+ else:
679
+ one = torch.scalar_tensor(1, dtype=inp.dtype, device=inp.device)
680
+ sign = torch.where(other >= 0, one, -one)
681
+
682
+ inp = inp * sign
683
+ max_ = torch.amax(inp, dim=dim, keepdim=keepdim)
684
+ return (inp - max_) / (sign * other)
685
+
686
+ match.replace_by_example(repl, [inp, other])
687
+
688
+
689
+ for to_dtype in (False, True):
690
+ register_graph_pattern(
691
+ _partial_softmax_pattern(aten.div.Tensor, to_dtype=to_dtype),
692
+ pass_dict=pass_patterns[1],
693
+ extra_check=_other_is_broadcasted_in_dim,
694
+ )(div_softmax_pattern)
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/micro_pipeline_tp.py ADDED
@@ -0,0 +1,854 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import operator
3
+ from collections import defaultdict
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, cast, Dict, List, Optional, Set
6
+
7
+ import torch
8
+
9
+ from .. import config, inductor_prims
10
+ from ..pattern_matcher import (
11
+ CallFunction,
12
+ Ignored,
13
+ KeywordArg,
14
+ ListOf,
15
+ Match,
16
+ MULTIPLE,
17
+ PatternExpr,
18
+ PatternMatcherPass,
19
+ )
20
+
21
+
22
+ aten = torch.ops.aten
23
+ patterns = PatternMatcherPass()
24
+
25
+
26
+ def _is_backward(graph: torch.fx.Graph) -> bool:
27
+ placeholders = []
28
+ for node in graph.nodes:
29
+ if node.op != "placeholder":
30
+ break
31
+ placeholders.append(node)
32
+ return not all(node.name.startswith("primal") for node in placeholders)
33
+
34
+
35
+ def _compute_mm_arithmetic_intensity(M: int, N: int, K: int) -> float:
36
+ return M * N * K / (M * K + N * K + M * N)
37
+
38
+
39
+ def _filter_nodes_by_target(nodes: List[torch.fx.Node], target) -> List[torch.fx.Node]:
40
+ return [x for x in nodes if x.target == target]
41
+
42
+
43
+ def _find_ancestors(node: torch.fx.Node) -> Set[torch.fx.Node]:
44
+ ancestors = set()
45
+ ancestors.add(node)
46
+ cur_nodes = [node]
47
+ while len(cur_nodes) > 0:
48
+ new_nodes = []
49
+ for node in cur_nodes:
50
+ for inp in node.all_input_nodes:
51
+ if inp not in ancestors:
52
+ ancestors.add(inp)
53
+ new_nodes.append(inp)
54
+ cur_nodes = new_nodes
55
+ return {node for node in ancestors if node.op != "placeholder"}
56
+
57
+
58
+ def _get_tensor(node: torch.fx.Node) -> torch.Tensor:
59
+ val = node.meta["val"]
60
+ assert isinstance(val, torch.Tensor)
61
+ return val
62
+
63
+
64
+ @dataclass
65
+ class _AllGatherMatch:
66
+ match: Match
67
+ shard_node: torch.fx.Node
68
+ ag_node: torch.fx.Node
69
+ res_node: torch.fx.Node
70
+ gather_dim: int
71
+ group_name: str
72
+
73
+ def replace_with(self, new_node: torch.fx.Node) -> None:
74
+ self.res_node.replace_all_uses_with(new_node)
75
+
76
+ def erase(self) -> None:
77
+ for node in reversed(self.match.nodes):
78
+ if len(node.users) == 0:
79
+ node.graph.erase_node(node)
80
+
81
+
82
+ def find_all_gather_patterns(graph: torch.fx.Graph):
83
+ c10d = torch.ops._c10d_functional
84
+
85
+ def make_zero_dim_all_gather_pattern(shard):
86
+ return CallFunction(
87
+ c10d.wait_tensor.default,
88
+ CallFunction(
89
+ c10d.all_gather_into_tensor.default,
90
+ shard,
91
+ Ignored(),
92
+ KeywordArg("group_name"),
93
+ ),
94
+ )
95
+
96
+ # Matches funcol.all_gather_tensor with gather_dim == 0
97
+ zero_dim_all_gather_pattern = make_zero_dim_all_gather_pattern(KeywordArg("shard"))
98
+
99
+ def make_all_gather_split_pattern(shard):
100
+ return CallFunction(
101
+ operator.getitem,
102
+ CallFunction(
103
+ aten.split.Tensor,
104
+ make_zero_dim_all_gather_pattern(shard),
105
+ Ignored(),
106
+ _users=MULTIPLE,
107
+ ),
108
+ Ignored(),
109
+ )
110
+
111
+ def make_cat_pattern(splits):
112
+ return CallFunction(
113
+ aten.cat.default,
114
+ ListOf(splits),
115
+ KeywordArg("gather_dim"),
116
+ )
117
+
118
+ # Matches funcol.all_gather_tensor with gather_dim > 0
119
+ non_zero_dim_all_gather_pattern = make_cat_pattern(
120
+ make_all_gather_split_pattern(KeywordArg("shard")),
121
+ )
122
+
123
+ # Match a zero-dim all-gather in which the data is transferred as uint8 and
124
+ # viewed back as the original dtype.
125
+ zero_dim_type_erased_all_gather_pattern = CallFunction(
126
+ aten.view.dtype,
127
+ make_zero_dim_all_gather_pattern(
128
+ KeywordArg("shard"),
129
+ ),
130
+ Ignored(),
131
+ )
132
+
133
+ # Match a non-zero dim all-gather in which the data is transferred as uint8
134
+ # and viewed back as the original dtype.
135
+ non_zero_dim_type_erased_all_gather_pattern = CallFunction(
136
+ aten.view.dtype,
137
+ make_cat_pattern(
138
+ CallFunction(
139
+ aten.view.dtype,
140
+ make_all_gather_split_pattern(
141
+ KeywordArg("shard"),
142
+ ),
143
+ Ignored(),
144
+ ),
145
+ ),
146
+ Ignored(),
147
+ )
148
+
149
+ # If two patterns with the same res_node_target have the same suffix, the
150
+ # longer pattern should appear first in the list.
151
+ # e.g. supposed we have (1) A -> B -> C -> D and (2) B -> C -> D, (1)
152
+ # should appear before (2) in the list.
153
+ res_node_target_to_patterns = {
154
+ aten.cat.default: [
155
+ (non_zero_dim_all_gather_pattern, 0),
156
+ ],
157
+ aten.view.dtype: [
158
+ (non_zero_dim_type_erased_all_gather_pattern, 0),
159
+ (zero_dim_type_erased_all_gather_pattern, 0),
160
+ ],
161
+ c10d.wait_tensor.default: [
162
+ (zero_dim_all_gather_pattern, 0),
163
+ ],
164
+ }
165
+
166
+ # Match in reverse to ensure longer patterns is prioritized
167
+ all_gathers = []
168
+ visited_ag_nodes = set()
169
+ for node in reversed(graph.nodes):
170
+ for target, patterns in res_node_target_to_patterns.items():
171
+ if node.target != target:
172
+ continue
173
+ for pattern, ag_node_idx in patterns:
174
+ match = pattern.match(node)
175
+ if not match:
176
+ continue
177
+
178
+ assert isinstance(match, Match)
179
+ ag_node = match.nodes[ag_node_idx]
180
+ assert ag_node.target == c10d.all_gather_into_tensor.default
181
+
182
+ if ag_node in visited_ag_nodes:
183
+ continue
184
+ visited_ag_nodes.add(ag_node)
185
+
186
+ ag_match = _AllGatherMatch(
187
+ match=match,
188
+ shard_node=match.kwargs["shard"],
189
+ ag_node=ag_node,
190
+ res_node=node,
191
+ gather_dim=match.kwargs.get("gather_dim", 0),
192
+ group_name=match.kwargs["group_name"],
193
+ )
194
+ all_gathers.append(ag_match)
195
+
196
+ return list(reversed(all_gathers))
197
+
198
+
199
+ @dataclass
200
+ class _ReduceScatterMatch:
201
+ match: Match
202
+ input_node: torch.fx.Node
203
+ rs_node: torch.fx.Node
204
+ res_node: torch.fx.Node
205
+ reduce_op: str
206
+ scatter_dim: int
207
+ group_name: str
208
+
209
+ def replace_with(self, new_node: torch.fx.Node) -> None:
210
+ self.res_node.replace_all_uses_with(new_node)
211
+
212
+ def erase(self) -> None:
213
+ for node in reversed(self.match.nodes):
214
+ if len(node.users) == 0:
215
+ node.graph.erase_node(node)
216
+
217
+
218
+ def find_reduce_scatter_patterns(graph: torch.fx.Graph):
219
+ c10d = torch.ops._c10d_functional
220
+
221
+ def reduce_scatter_template(inp: PatternExpr):
222
+ return CallFunction(
223
+ c10d.wait_tensor.default,
224
+ CallFunction(
225
+ c10d.reduce_scatter_tensor.default,
226
+ inp,
227
+ KeywordArg("reduce_op"),
228
+ Ignored(),
229
+ KeywordArg("group_name"),
230
+ ),
231
+ )
232
+
233
+ # Matches funcol.reduce_scatter_tensor with scatter_dim == 0
234
+ zero_dim_reduce_scatter_pattern = reduce_scatter_template(KeywordArg("input"))
235
+
236
+ # Matches funcol.reduce_scatter_tensor with scatter_dim > 0
237
+ non_zero_dim_reduce_scatter_pattern = reduce_scatter_template(
238
+ CallFunction(
239
+ aten.cat.default,
240
+ ListOf(
241
+ CallFunction(
242
+ operator.getitem,
243
+ CallFunction(
244
+ aten.split.Tensor,
245
+ KeywordArg("input"),
246
+ Ignored(),
247
+ KeywordArg("scatter_dim"),
248
+ _users=MULTIPLE,
249
+ ),
250
+ Ignored(),
251
+ )
252
+ ),
253
+ ),
254
+ )
255
+
256
+ reduce_scatters = []
257
+ for node in reversed(graph.nodes):
258
+ if node.target == c10d.wait_tensor.default:
259
+ if match := non_zero_dim_reduce_scatter_pattern.match(node):
260
+ assert isinstance(match, Match)
261
+ reduce_scatters.append(
262
+ _ReduceScatterMatch(
263
+ match=match,
264
+ input_node=match.kwargs["input"],
265
+ rs_node=match.nodes[-2],
266
+ res_node=node,
267
+ reduce_op=match.kwargs["reduce_op"],
268
+ scatter_dim=match.kwargs["scatter_dim"],
269
+ group_name=match.kwargs["group_name"],
270
+ )
271
+ )
272
+ elif match := zero_dim_reduce_scatter_pattern.match(node):
273
+ assert isinstance(match, Match)
274
+ reduce_scatters.append(
275
+ _ReduceScatterMatch(
276
+ match=match,
277
+ input_node=match.kwargs["input"],
278
+ rs_node=match.nodes[0],
279
+ res_node=node,
280
+ reduce_op=match.kwargs["reduce_op"],
281
+ scatter_dim=0,
282
+ group_name=match.kwargs["group_name"],
283
+ )
284
+ )
285
+ return list(reversed(reduce_scatters))
286
+
287
+
288
+ @dataclass
289
+ class _Matmul:
290
+ nodes: List[torch.fx.Node]
291
+ arg_ancestor_nodes: Set[torch.fx.Node] = field(init=False)
292
+ A_node: torch.fx.Node
293
+ B_node: torch.fx.Node
294
+
295
+ def __post_init__(self):
296
+ assert len(self.nodes) in (1, 3)
297
+ if len(self.nodes) == 1:
298
+ assert self.nodes[0].target in (aten.mm.default, aten._scaled_mm.default)
299
+ else:
300
+ assert self.nodes[0].target == aten.reshape.default
301
+ assert self.nodes[1].target in (aten.mm.default, aten._scaled_mm.default)
302
+ assert self.nodes[2].target == aten.reshape.default
303
+ self.arg_ancestor_nodes = _find_ancestors(self.B_node)
304
+
305
+ def replace_with(self, new_node: torch.fx.Node) -> None:
306
+ """
307
+ Replace the matmul with the new node.
308
+ """
309
+ graph = new_node.graph
310
+
311
+ # For 2D-matmuls, we simply replace the mm node with `new_node`.
312
+ if len(self.nodes) == 1:
313
+ mm_node = self.nodes[0]
314
+ assert mm_node.target in (aten.mm.default, aten._scaled_mm.default)
315
+ mm_node.replace_all_uses_with(new_node)
316
+ graph.erase_node(mm_node)
317
+ return
318
+
319
+ # An ND-matmul is reshape -> mm -> reshape sequence. We first replace
320
+ # the second reshape node with `new_node`. Then, we ensure that the
321
+ # original mm node in the sequence ends up with zero users by replacing
322
+ # it with a reverse reshape of `new_node`.
323
+ graph = new_node.graph
324
+ assert len(self.nodes) == 3
325
+ mm_node = self.nodes[1]
326
+ output_reshape_node = self.nodes[2]
327
+
328
+ assert mm_node.target in (aten.mm.default, aten._scaled_mm.default)
329
+ assert output_reshape_node.target == aten.reshape.default
330
+
331
+ output_reshape_node.replace_all_uses_with(new_node)
332
+ if len(mm_node.users) > 1:
333
+ with graph.inserting_after(new_node):
334
+ new_mm_node = graph.call_function(
335
+ aten.reshape.default,
336
+ args=(new_node, list(_get_tensor(mm_node).shape)),
337
+ )
338
+ mm_node.replace_all_uses_with(new_mm_node)
339
+
340
+ def erase(self) -> None:
341
+ for node in reversed(self.nodes):
342
+ if len(node.users) == 0:
343
+ node.graph.erase_node(node)
344
+
345
+ @classmethod
346
+ def from_match(cls, match: List[torch.fx.Node]) -> "_Matmul":
347
+ assert len(match) in (1, 3)
348
+ assert match[0].target in (
349
+ aten.mm.default,
350
+ aten.reshape.default,
351
+ )
352
+ mm_node = match[0] if len(match) == 1 else match[1]
353
+ return _Matmul(
354
+ nodes=match,
355
+ A_node=cast(torch.fx.Node, match[0].args[0]),
356
+ B_node=cast(torch.fx.Node, mm_node.args[1]),
357
+ )
358
+
359
+
360
+ @dataclass
361
+ class _ScaledMatmul(_Matmul):
362
+ A_scale_node: torch.fx.Node
363
+ B_scale_node: torch.fx.Node
364
+ bias_node: Optional[torch.fx.Node]
365
+ result_scale_node: Optional[torch.fx.Node]
366
+ out_dtype: Optional[torch.dtype]
367
+ use_fast_accum: bool
368
+
369
+ def __post_init__(self):
370
+ super().__post_init__()
371
+ self.arg_ancestor_nodes |= _find_ancestors(self.A_scale_node)
372
+ self.arg_ancestor_nodes |= _find_ancestors(self.B_scale_node)
373
+
374
+ @classmethod
375
+ def from_match(cls, match: List[torch.fx.Node]) -> "_ScaledMatmul":
376
+ assert len(match) in (1, 3)
377
+ assert match[0].target in (
378
+ aten._scaled_mm.default,
379
+ aten.reshape.default,
380
+ )
381
+ mm_node = match[0] if len(match) == 1 else match[1]
382
+
383
+ def get_arg(node: torch.fx.Node, idx: int, default: Any) -> Any:
384
+ if idx >= len(node.args):
385
+ return default
386
+ return node.args[idx]
387
+
388
+ return _ScaledMatmul(
389
+ nodes=match,
390
+ A_node=cast(torch.fx.Node, match[0].args[0]),
391
+ B_node=cast(torch.fx.Node, mm_node.args[1]),
392
+ A_scale_node=cast(torch.fx.Node, mm_node.args[2]),
393
+ B_scale_node=cast(torch.fx.Node, mm_node.args[3]),
394
+ bias_node=get_arg(mm_node, 4, None),
395
+ result_scale_node=get_arg(mm_node, 5, None),
396
+ out_dtype=get_arg(mm_node, 6, None),
397
+ use_fast_accum=get_arg(mm_node, 7, False),
398
+ )
399
+
400
+
401
+ def _find_reshape_mm_reshape(node: torch.fx.Node) -> List[_Matmul]:
402
+ if node.target != aten.reshape.default:
403
+ return []
404
+
405
+ matches = []
406
+ for mm_node in node.users:
407
+ if mm_node.target not in (aten.mm.default, aten._scaled_mm.default):
408
+ continue
409
+ for reshape_node in mm_node.users:
410
+ if reshape_node.target != aten.reshape.default:
411
+ continue
412
+
413
+ # Since the reshape -> mm -> reshape pattern would be subsumed into
414
+ # the fused op, we only match the patterns where the shape of the
415
+ # second reshape is matches the mm result produced by the fused op.
416
+ matmul_input_node = cast(torch.fx.Node, node.args[0])
417
+ B_node = cast(torch.fx.Node, mm_node.args[1])
418
+ matmul_out_shape = torch.Size(
419
+ [
420
+ *_get_tensor(matmul_input_node).shape[:-1],
421
+ _get_tensor(B_node).shape[-1],
422
+ ]
423
+ )
424
+ if _get_tensor(reshape_node).shape != matmul_out_shape:
425
+ continue
426
+ matches.append([node, mm_node, reshape_node])
427
+ # If for some rare reason mm_node is being reshaped by two
428
+ # different reshape nodes, we only include mm_node once in the
429
+ # parsing result.
430
+ break
431
+
432
+ matmuls = []
433
+ for match in matches:
434
+ mm_node = match[1]
435
+ if mm_node.target == aten.mm.default:
436
+ matmul = _Matmul.from_match(match)
437
+ matmuls.append(matmul)
438
+ elif mm_node.target == aten._scaled_mm.default:
439
+ matmul = _ScaledMatmul.from_match(match)
440
+ matmuls.append(matmul)
441
+ else:
442
+ raise AssertionError(
443
+ "Expect the node's target to be either aten.mm.default or "
444
+ f"aten._scaled_mm.default. Got {mm_node.target}."
445
+ )
446
+ return matmuls
447
+
448
+
449
+ def _find_consumer_matmuls(node: torch.fx.Node) -> List[_Matmul]:
450
+ """
451
+ Find the matmuls that use `node` as the lhs argument.
452
+ """
453
+ matmuls = []
454
+ for user in node.users:
455
+ # ND matmuls
456
+ if user.target == aten.reshape.default:
457
+ matmuls.extend(_find_reshape_mm_reshape(user))
458
+ # 2D matmuls
459
+ elif user.target == aten.mm.default:
460
+ matmul = _Matmul.from_match(match=[user])
461
+ matmuls.append(matmul)
462
+ elif user.target == aten._scaled_mm.default:
463
+ matmul = _ScaledMatmul.from_match([user])
464
+ matmuls.append(matmul)
465
+ return matmuls
466
+
467
+
468
+ def _insert_fused_all_gather_matmul(
469
+ graph: torch.fx.Graph,
470
+ matmuls: List[_Matmul],
471
+ shard_node: torch.fx.Node,
472
+ gather_dim: int,
473
+ group_name: str,
474
+ ) -> torch.fx.Node:
475
+ mm_types = set(map(type, matmuls))
476
+ assert len(mm_types) == 1
477
+ mm_type = next(iter(mm_types))
478
+ if mm_type == _Matmul:
479
+ B_nodes = [matmul.B_node for matmul in matmuls]
480
+ return graph.call_function(
481
+ torch.ops.symm_mem.fused_all_gather_matmul.default,
482
+ args=(shard_node, B_nodes, gather_dim, group_name),
483
+ )
484
+ elif mm_type == _ScaledMatmul:
485
+ scaled_matmuls = cast(List[_ScaledMatmul], matmuls)
486
+ return graph.call_function(
487
+ torch.ops.symm_mem.fused_all_gather_scaled_matmul.default,
488
+ args=(
489
+ shard_node,
490
+ [matmul.B_node for matmul in scaled_matmuls],
491
+ scaled_matmuls[0].A_scale_node,
492
+ [matmul.B_scale_node for matmul in scaled_matmuls],
493
+ gather_dim,
494
+ group_name,
495
+ [matmul.bias_node for matmul in scaled_matmuls],
496
+ [matmul.result_scale_node for matmul in scaled_matmuls],
497
+ [matmul.out_dtype for matmul in scaled_matmuls],
498
+ [matmul.use_fast_accum for matmul in scaled_matmuls],
499
+ ),
500
+ )
501
+ else:
502
+ raise AssertionError(f"Unexpected matmul match type: {mm_type}")
503
+
504
+
505
+ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
506
+ """
507
+ Fused the pattern
508
+
509
+ A = all_gather_tensor(A_shard, gather_dim, group_name)
510
+ C_0 = torch.matmul(A, B_0)
511
+ C_1 = torch.matmul(A, B_1)
512
+ C_2 = torch.matmul(A, B_2)
513
+ ...
514
+
515
+ into
516
+
517
+ A, Cs = torch.ops.symm_mem.fused_all_gather_matmul(
518
+ A_shard, [B_0, B_1, B_2, ...], gather_dim, group_name,
519
+ )
520
+ """
521
+ if (
522
+ not torch.distributed.is_available()
523
+ or not torch.distributed.is_nccl_available()
524
+ ):
525
+ return
526
+
527
+ c10d = torch.ops._c10d_functional
528
+ from torch.distributed._symmetric_memory import (
529
+ is_symm_mem_enabled_for_group,
530
+ restride_A_shard_for_fused_all_gather_matmul,
531
+ )
532
+
533
+ shard_node, ag_node, ag_res_node, gather_dim, group_name = (
534
+ all_gather.shard_node,
535
+ all_gather.ag_node,
536
+ all_gather.res_node,
537
+ all_gather.gather_dim,
538
+ all_gather.group_name,
539
+ )
540
+
541
+ if not is_symm_mem_enabled_for_group(group_name):
542
+ return
543
+
544
+ if gather_dim >= len(_get_tensor(shard_node).shape) - 1:
545
+ # Decomposing the matmul on the K dimension is not supported
546
+ return
547
+
548
+ # Find consumer matmuls
549
+ matmuls = _find_consumer_matmuls(ag_res_node)
550
+
551
+ # The matmuls are only fusible if non-A args don't depend on the all-gather
552
+ # result node
553
+ matmuls = [
554
+ matmul
555
+ for matmul in matmuls
556
+ if all_gather.res_node not in matmul.arg_ancestor_nodes
557
+ ]
558
+
559
+ if len(matmuls) == 0 or len(set(map(type, matmuls))) != 1:
560
+ return
561
+
562
+ # Fuse the all_gather_tensor with the eligible matmuls
563
+ graph = ag_node.graph
564
+ with graph.inserting_before(ag_node):
565
+ if "val" in shard_node.meta:
566
+ restrided = restride_A_shard_for_fused_all_gather_matmul(
567
+ _get_tensor(shard_node),
568
+ gather_dim,
569
+ )
570
+ shard_node = graph.call_function(
571
+ inductor_prims.force_stride_order,
572
+ args=(shard_node, restrided.stride()),
573
+ )
574
+
575
+ fused_node = _insert_fused_all_gather_matmul(
576
+ graph, matmuls, shard_node, gather_dim, group_name
577
+ )
578
+ new_ag_node = graph.call_function(
579
+ operator.getitem,
580
+ args=(fused_node, 0),
581
+ )
582
+ new_out_nodes = graph.call_function(
583
+ operator.getitem,
584
+ args=(fused_node, 1),
585
+ )
586
+ for idx, matmul in enumerate(matmuls):
587
+ new_out_node = graph.call_function(
588
+ operator.getitem,
589
+ args=(new_out_nodes, idx),
590
+ )
591
+ matmul.replace_with(new_out_node)
592
+ matmul.erase()
593
+ all_gather.replace_with(new_ag_node)
594
+ all_gather.erase()
595
+
596
+ # Raise ancestors of non-A args that are topologically ordered between
597
+ # ag_res_node and the matmul above fused_node.
598
+ order = {node: idx for idx, node in enumerate(graph.nodes)}
599
+ nodes_to_raise = sorted(
600
+ {x for matmul in matmuls for x in matmul.arg_ancestor_nodes},
601
+ key=lambda x: order[x],
602
+ )
603
+ for node in nodes_to_raise:
604
+ if order[node] > order[fused_node]:
605
+ fused_node.prepend(node)
606
+
607
+
608
+ def _find_producer_matmul(node: torch.fx.Node) -> Optional[_Matmul]:
609
+ if node.target == aten.mm.default:
610
+ return _Matmul.from_match(match=[node])
611
+ elif node.target == aten._scaled_mm.default:
612
+ return _ScaledMatmul.from_match(match=[node])
613
+ elif node.target == aten.reshape.default:
614
+ reshape_node_1 = node
615
+
616
+ mm_node = reshape_node_1.args[0]
617
+ assert isinstance(mm_node, torch.fx.Node)
618
+ if mm_node.target not in (aten.mm.default, aten._scaled_mm.default):
619
+ return None
620
+
621
+ reshape_node_0 = mm_node.args[0]
622
+ assert isinstance(reshape_node_0, torch.fx.Node)
623
+ if reshape_node_0.target != aten.reshape.default:
624
+ return None
625
+
626
+ if mm_node.target == aten.mm.default:
627
+ return _Matmul.from_match(match=[reshape_node_0, mm_node, reshape_node_1])
628
+ elif mm_node.target == aten._scaled_mm.default:
629
+ return _ScaledMatmul.from_match(
630
+ match=[reshape_node_0, mm_node, reshape_node_1]
631
+ )
632
+ return None
633
+
634
+
635
+ def _insert_fused_matmul_reduce_scatter(
636
+ graph: torch.fx.Graph,
637
+ matmul: _Matmul,
638
+ reduce_op: str,
639
+ scatter_dim: int,
640
+ group_name: str,
641
+ ) -> torch.fx.Node:
642
+ if type(matmul) == _Matmul:
643
+ return graph.call_function(
644
+ torch.ops.symm_mem.fused_matmul_reduce_scatter.default,
645
+ args=(matmul.A_node, matmul.B_node, reduce_op, scatter_dim, group_name),
646
+ )
647
+ elif type(matmul) == _ScaledMatmul:
648
+ return graph.call_function(
649
+ torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default,
650
+ args=(
651
+ matmul.A_node,
652
+ matmul.B_node,
653
+ matmul.A_scale_node,
654
+ matmul.B_scale_node,
655
+ reduce_op,
656
+ scatter_dim,
657
+ group_name,
658
+ matmul.bias_node,
659
+ matmul.result_scale_node,
660
+ matmul.out_dtype,
661
+ matmul.use_fast_accum,
662
+ ),
663
+ )
664
+ else:
665
+ raise AssertionError(f"Unexpected matmul match type: {type(matmul)}")
666
+
667
+
668
+ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
669
+ """
670
+ Fused the pattern
671
+
672
+ reduce_scatter_tensor(A @ B, scatter_dim, group_name)
673
+
674
+ into
675
+
676
+ torch.ops.symm_mem.fused_matmul_reduce_scatter(
677
+ A, B, scatter_dim, group_name,
678
+ )
679
+ """
680
+ if (
681
+ not torch.distributed.is_available()
682
+ or not torch.distributed.is_nccl_available()
683
+ ):
684
+ return
685
+
686
+ c10d = torch.ops._c10d_functional
687
+ from torch.distributed._symmetric_memory import (
688
+ is_symm_mem_enabled_for_group,
689
+ restride_A_for_fused_matmul_reduce_scatter,
690
+ )
691
+
692
+ input_node, rs_node, rs_res_node, reduce_op, scatter_dim, group_name = (
693
+ reduce_scatter.input_node,
694
+ reduce_scatter.rs_node,
695
+ reduce_scatter.res_node,
696
+ reduce_scatter.reduce_op,
697
+ reduce_scatter.scatter_dim,
698
+ reduce_scatter.group_name,
699
+ )
700
+
701
+ if not is_symm_mem_enabled_for_group(group_name):
702
+ return
703
+
704
+ # Currently fused_matmul_reduce_scatter doesn't return the matmul result,
705
+ # so we can't apply the fusion if the matmul result is used by multiple
706
+ # users. This is not a fundamental limitation of the fused op and can be
707
+ # addressed if needed.
708
+ if len(input_node.users) != 1:
709
+ return
710
+
711
+ matmul = _find_producer_matmul(input_node)
712
+ if matmul is None:
713
+ return
714
+
715
+ if rs_res_node in matmul.arg_ancestor_nodes:
716
+ return
717
+
718
+ graph = rs_res_node.graph
719
+ with graph.inserting_before(rs_res_node):
720
+ if "val" in matmul.A_node.meta:
721
+ restrided = restride_A_for_fused_matmul_reduce_scatter(
722
+ _get_tensor(matmul.A_node),
723
+ scatter_dim,
724
+ )
725
+ matmul.A_node = graph.call_function(
726
+ inductor_prims.force_stride_order,
727
+ args=(matmul.A_node, restrided.stride()),
728
+ )
729
+
730
+ fused_node = _insert_fused_matmul_reduce_scatter(
731
+ graph,
732
+ matmul,
733
+ reduce_op,
734
+ scatter_dim,
735
+ group_name,
736
+ )
737
+ reduce_scatter.replace_with(fused_node)
738
+ reduce_scatter.erase()
739
+ matmul.erase()
740
+
741
+ order = {node: idx for idx, node in enumerate(graph.nodes)}
742
+ nodes_to_raise = sorted(
743
+ matmul.arg_ancestor_nodes,
744
+ key=lambda x: order[x],
745
+ )
746
+ for node in nodes_to_raise:
747
+ if order[node] > order[fused_node]:
748
+ fused_node.prepend(node)
749
+
750
+
751
+ def _get_node_to_ancestors(
752
+ graph: torch.fx.Graph,
753
+ ) -> Dict[torch.fx.Node, Set[torch.fx.Node]]:
754
+ """
755
+ Compute the ancestors for all nodes in a graph.
756
+ """
757
+ node_to_ancestors = defaultdict(set)
758
+ for node in graph.nodes:
759
+ node_to_ancestors[node] = set(node.all_input_nodes)
760
+ for dep in node.all_input_nodes:
761
+ node_to_ancestors[node] |= node_to_ancestors[dep]
762
+
763
+ return node_to_ancestors
764
+
765
+
766
+ def _get_collective_to_overlappable_nodes(
767
+ graph: torch.fx.Graph,
768
+ ) -> Dict[torch.fx.Node, List[torch.fx.Node]]:
769
+ """
770
+ For each collective in the graph, find nodes that are neither ancestors nor
771
+ descendants of the collective.
772
+ """
773
+
774
+ def is_collective(node) -> bool:
775
+ # Only consider all-gather and reduce-scatter in the context of
776
+ # micro-pipeline TP.
777
+ return node.target in [
778
+ torch.ops._c10d_functional.all_gather_into_tensor.default,
779
+ torch.ops._c10d_functional.reduce_scatter_tensor.default,
780
+ ]
781
+
782
+ node_to_ancestors = _get_node_to_ancestors(graph)
783
+ collective_to_overlappable_nodes = defaultdict(list)
784
+ for node in graph.nodes:
785
+ if not is_collective(node):
786
+ continue
787
+ for x in graph.nodes:
788
+ if (
789
+ node not in node_to_ancestors[x]
790
+ and x not in node_to_ancestors[node]
791
+ and x.op == "call_function"
792
+ ):
793
+ collective_to_overlappable_nodes[node].append(x)
794
+
795
+ return collective_to_overlappable_nodes
796
+
797
+
798
+ def _get_unexposed_collectives(graph: torch.fx.Graph) -> List[torch.fx.Node]:
799
+ """
800
+ Find all unexposed collectives in the graph.
801
+
802
+ Because we don't have the runtime estimate, this function is a rough
803
+ estimation using the following strong/hand-wavy assumptions:
804
+
805
+ - Only a predefined set of "compute intensive" operation can hide a collective.
806
+ - Any "compute intensive" operation can hide exactly one collective.
807
+ """
808
+
809
+ def _is_compute_intensive(node: torch.fx.Node) -> bool:
810
+ return node.target in [torch.ops.aten.mm.default]
811
+
812
+ collective_to_overlapping_candidates = defaultdict(list)
813
+ available_nodes = set()
814
+ collective_to_overlappable_nodes = _get_collective_to_overlappable_nodes(graph)
815
+ for collective, overlappable_nodes in collective_to_overlappable_nodes.items():
816
+ candidates = [x for x in overlappable_nodes if _is_compute_intensive(x)]
817
+ collective_to_overlapping_candidates[collective] = candidates
818
+ available_nodes |= set(candidates)
819
+
820
+ unexposed_collectives = []
821
+ for (
822
+ collective,
823
+ overlapping_candidates,
824
+ ) in collective_to_overlapping_candidates.items():
825
+ # Each collective consumes exactly one overlapping candidate
826
+ for x in overlapping_candidates:
827
+ if x in available_nodes:
828
+ unexposed_collectives.append(collective)
829
+ available_nodes.remove(x)
830
+ break
831
+ return unexposed_collectives
832
+
833
+
834
+ def micro_pipeline_tp_pass(graph: torch.fx.Graph):
835
+ all_gathers = find_all_gather_patterns(graph)
836
+ reduce_scatters = find_reduce_scatter_patterns(graph)
837
+
838
+ # When a collective can be hidden through either simple overlapping or
839
+ # micro-pipeline TP, we prefer simple overlapping to avoid the overhead
840
+ # associated with decomposition. If reorder_for_compute_comm_overlap is
841
+ # enabled, we identify collectives that can be hidden through simple
842
+ # overlapping and exclude them from micro-pipeline TP candidates.
843
+ if config.reorder_for_compute_comm_overlap:
844
+ unexposed_collectives = _get_unexposed_collectives(graph)
845
+ all_gathers = [x for x in all_gathers if x.ag_node not in unexposed_collectives]
846
+ reduce_scatters = [
847
+ x for x in reduce_scatters if x.rs_node not in unexposed_collectives
848
+ ]
849
+
850
+ for all_gather in all_gathers:
851
+ fuse_all_gather_matmul(all_gather)
852
+
853
+ for reduce_scatter in reduce_scatters:
854
+ fuse_matmul_reduce_scatter(reduce_scatter)
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/misc_patterns.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+ from typing import Dict, Set, Tuple
4
+
5
+ import torch
6
+ from torch._dynamo.utils import counters
7
+ from torch._ops import OpOverload, OpOverloadPacket
8
+
9
+ from ..pattern_matcher import fwd_only, register_replacement
10
+
11
+
12
+ aten = torch.ops.aten
13
+
14
+
15
+ @functools.lru_cache(None)
16
+ def _misc_patterns_init():
17
+ from .joint_graph import patterns as joint_graph_patterns
18
+ from .post_grad import pass_patterns as post_grad_patterns_all
19
+
20
+ post_grad_patterns = post_grad_patterns_all[1] # medium priority
21
+
22
+ if torch.cuda.is_available():
23
+ # workaround https://github.com/pytorch/pytorch/issues/97894
24
+ device = "cuda"
25
+ else:
26
+ device = "cpu"
27
+
28
+ # These patterns do 2 things
29
+ # 1. Since we know that index is completely unique, we can codegen it using
30
+ # stores instead of atomic adds, which is quite a bit faster.
31
+ # 2. Also, since we are guaranteed that they are completely within bounds,
32
+ # we can use unsafe indexing and skip debug asserts
33
+ def randperm_index_add_pattern(x, y):
34
+ index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]]
35
+ return torch.index_add(x, dim=0, source=y, index=index), index
36
+
37
+ def randperm_index_add_replacement(x, y):
38
+ index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]]
39
+ return (
40
+ torch.ops.aten._unsafe_index_put(
41
+ x, (index,), aten._unsafe_index(x, (index,)) + y, accumulate=False
42
+ ),
43
+ index,
44
+ )
45
+
46
+ register_replacement(
47
+ randperm_index_add_pattern,
48
+ randperm_index_add_replacement,
49
+ [torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)],
50
+ fwd_only,
51
+ [post_grad_patterns, joint_graph_patterns],
52
+ )
53
+
54
+ def randperm_index_pattern(x, slice_shape):
55
+ index = torch.randperm(x.shape[0], device=x.device)[:slice_shape]
56
+ return torch.ops.aten.index(x, (index,)), index
57
+
58
+ def randperm_index_replacement(x, slice_shape):
59
+ index = torch.randperm(x.shape[0], device=x.device)[:slice_shape]
60
+ return torch.ops.aten._unsafe_index(x, (index,)), index
61
+
62
+ register_replacement(
63
+ randperm_index_pattern,
64
+ randperm_index_replacement,
65
+ [torch.empty(4, 8, device=device)],
66
+ fwd_only,
67
+ [post_grad_patterns, joint_graph_patterns],
68
+ scalar_workaround={"slice_shape": 42},
69
+ )
70
+
71
+
72
+ class NumpyCompatNormalization:
73
+ numpy_compat: Dict[str, Tuple[str, ...]] = {
74
+ "dim": ("axis",),
75
+ "keepdim": ("keepdims",),
76
+ "input": ("x", "a", "x1"),
77
+ "other": ("x2",),
78
+ }
79
+ inverse_mapping: Dict[str, str]
80
+ cache: Dict["torch.fx.graph.Target", Set[str]]
81
+
82
+ def __init__(self) -> None:
83
+ self.cache = {} # callable -> tuple of replaceable args e.g. ["axis"]
84
+ self.inverse_mapping = {}
85
+ for actual_kwarg, numpy_kwargs in self.numpy_compat.items():
86
+ for numpy_kwarg in numpy_kwargs:
87
+ assert numpy_kwarg not in self.inverse_mapping
88
+ self.inverse_mapping[numpy_kwarg] = actual_kwarg
89
+
90
+ def __call__(self, graph: torch.fx.Graph):
91
+ for node in graph.nodes:
92
+ if node.op != "call_function":
93
+ continue
94
+ if isinstance(node.target, (OpOverload, OpOverloadPacket)):
95
+ # only applies to torch ops; e.g. torch.stack(axis=1) works, torch.ops.aten.stack(axis=1) doesn't.
96
+ continue
97
+ kwargs = node.kwargs
98
+
99
+ if node.target in self.cache:
100
+ replaceable_kwargs = self.cache[node.target]
101
+ else:
102
+ signatures = torch.fx.operator_schemas.get_signature_for_torch_op(
103
+ node.target
104
+ )
105
+ signatures = () if signatures is None else signatures
106
+ replaceable_kwargs = set()
107
+ for sig in signatures:
108
+ for param_name in sig.parameters.keys():
109
+ if param_name in self.numpy_compat:
110
+ replaceable_kwargs.update(self.numpy_compat[param_name])
111
+
112
+ self.cache[node.target] = replaceable_kwargs
113
+
114
+ if not replaceable_kwargs:
115
+ continue
116
+
117
+ new_kwargs = {}
118
+ kwargs_changed = False
119
+ for k, v in kwargs.items():
120
+ if k in replaceable_kwargs:
121
+ kwargs_changed = True
122
+ new_kwargs[self.inverse_mapping[k]] = v
123
+ else:
124
+ new_kwargs[k] = v
125
+
126
+ if kwargs_changed:
127
+ node.kwargs = torch.fx.immutable_collections.immutable_dict(new_kwargs)
128
+ counters["inductor"]["numpy_compat_normalization"] += 1
129
+
130
+
131
+ numpy_compat_normalization = NumpyCompatNormalization()
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py ADDED
@@ -0,0 +1,1266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+ import operator
4
+ from functools import reduce
5
+ from typing import Any, Tuple
6
+
7
+ import torch
8
+ from torch.fx.experimental.symbolic_shapes import has_free_symbols
9
+
10
+ from .. import ir
11
+ from ..lowering import lowerings as L
12
+ from ..pattern_matcher import (
13
+ Arg,
14
+ CallFunction,
15
+ filter_nodes,
16
+ get_arg_value,
17
+ KeywordArg,
18
+ MULTIPLE,
19
+ )
20
+ from ..virtualized import ops, V
21
+ from .freezing_patterns import register_freezing_graph_pattern
22
+ from .post_grad import register_lowering_pattern
23
+ from .quantization import (
24
+ _register_quantization_lowerings,
25
+ _register_quantization_weight_pack_pass,
26
+ _register_woq_lowerings,
27
+ )
28
+
29
+
30
+ if torch._C._has_mkldnn:
31
+ aten = torch.ops.aten
32
+ mkldnn = torch.ops.mkldnn
33
+ prims = torch.ops.prims
34
+
35
+ _conv_args = [Arg() for _ in range(10)]
36
+ _linear_args = [Arg() for _ in range(6)]
37
+ _conv_transpose_args = [Arg() for _ in range(11)]
38
+
39
+ def _conv_call(users=1):
40
+ return CallFunction(
41
+ mkldnn._convolution_pointwise.default, *_conv_args, _users=users
42
+ )
43
+
44
+ def _linear_call(users=1):
45
+ return CallFunction(
46
+ mkldnn._linear_pointwise.default, *_linear_args, _users=users
47
+ )
48
+
49
+ def _conv_transpose_call(users=1):
50
+ return CallFunction(
51
+ mkldnn._convolution_transpose_pointwise.default,
52
+ *_conv_transpose_args,
53
+ _users=users,
54
+ )
55
+
56
+ def _to_float(input_call, users=1):
57
+ return CallFunction(
58
+ prims.convert_element_type.default,
59
+ input_call,
60
+ KeywordArg("to_float"),
61
+ _users=users,
62
+ )
63
+
64
+ def _to_bf16(input_call):
65
+ return CallFunction(
66
+ prims.convert_element_type.default,
67
+ input_call,
68
+ KeywordArg("to_bf16"),
69
+ _users=1,
70
+ )
71
+
72
+ def _to_fp16(input_call):
73
+ return CallFunction(
74
+ prims.convert_element_type.default,
75
+ input_call,
76
+ KeywordArg("to_fp16"),
77
+ _users=1,
78
+ )
79
+
80
+ def _unary_fusion_pattern(unary_fusion, call_fn, users, lowp_dtype):
81
+ # only insert to_dtype if lowp_dtype is True
82
+ computation_call = (
83
+ _to_float(call_fn(), users=users) if lowp_dtype else call_fn(users=users)
84
+ )
85
+ out = unary_fusion(computation_call)
86
+ if lowp_dtype == torch.bfloat16:
87
+ return _to_bf16(out)
88
+ elif lowp_dtype == torch.float16:
89
+ return _to_fp16(out)
90
+ else:
91
+ return out
92
+
93
+ def _gelu_fusion_1(computation_call):
94
+ return CallFunction(
95
+ aten.mul,
96
+ CallFunction(aten.mul, computation_call, 0.5),
97
+ CallFunction(
98
+ aten.add,
99
+ CallFunction(
100
+ aten.erf,
101
+ CallFunction(aten.mul, computation_call, 0.7071067811865476),
102
+ ),
103
+ 1,
104
+ ),
105
+ )
106
+
107
+ def _gelu_fusion_2(computation_call):
108
+ return CallFunction(
109
+ aten.mul,
110
+ CallFunction(aten.mul, computation_call, 0.5),
111
+ CallFunction(
112
+ aten.add,
113
+ CallFunction(
114
+ aten.tanh,
115
+ CallFunction(
116
+ aten.mul,
117
+ CallFunction(
118
+ aten.add,
119
+ computation_call,
120
+ CallFunction(
121
+ aten.mul,
122
+ CallFunction(
123
+ aten.mul,
124
+ CallFunction(
125
+ aten.mul, computation_call, computation_call
126
+ ),
127
+ computation_call,
128
+ ),
129
+ 0.044715,
130
+ ),
131
+ ),
132
+ 0.7978845608028654,
133
+ ),
134
+ ),
135
+ 1,
136
+ ),
137
+ )
138
+
139
+ def _hardswish_fusion(computation_call):
140
+ return CallFunction(
141
+ aten.div,
142
+ CallFunction(
143
+ aten.mul,
144
+ computation_call,
145
+ CallFunction(
146
+ aten.clamp_max,
147
+ CallFunction(
148
+ aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0
149
+ ),
150
+ 6,
151
+ ),
152
+ ),
153
+ 6,
154
+ )
155
+
156
+ def _silu_fusion(computation_call):
157
+ return CallFunction(
158
+ aten.mul, computation_call, CallFunction(aten.sigmoid, computation_call)
159
+ )
160
+
161
+ def _hardsigmoid_fusion(computation_call):
162
+ return CallFunction(
163
+ aten.div,
164
+ CallFunction(
165
+ aten.clamp_max,
166
+ CallFunction(
167
+ aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0
168
+ ),
169
+ 6,
170
+ ),
171
+ 6,
172
+ )
173
+
174
+ def _leaky_relu_fusion(computation_call):
175
+ return CallFunction(
176
+ aten.where,
177
+ CallFunction(aten.gt, computation_call, 0),
178
+ computation_call,
179
+ CallFunction(aten.mul, computation_call, KeywordArg("negative_slope")),
180
+ )
181
+
182
+ def _hardtanh_fusion(computation_call):
183
+ return CallFunction(
184
+ aten.clamp_max,
185
+ CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")),
186
+ KeywordArg("max_value"),
187
+ )
188
+
189
+ def _combined_fusion(computation_call, elementwise_op):
190
+ return CallFunction(elementwise_op, computation_call)
191
+
192
+ # binary_op(other, computation_op)
193
+ def _binary_fusion_v1(computation_call, binary_fn):
194
+ return CallFunction(binary_fn, KeywordArg("other"), computation_call)
195
+
196
+ # binary_op(computation_op, other)
197
+ def _binary_fusion_v2(computation_call, binary_fn):
198
+ return CallFunction(binary_fn, computation_call, KeywordArg("other"))
199
+
200
+ def _is_single_computation_op(computation_op, lowp_dtype=None):
201
+ def fn(match):
202
+ computation_nodes = filter_nodes(match.nodes, computation_op)
203
+
204
+ if lowp_dtype:
205
+ output_node_meta = match.output_node().meta.get("val")
206
+ if output_node_meta.dtype != lowp_dtype:
207
+ return False
208
+
209
+ if len(computation_nodes) < 1:
210
+ return False
211
+ if any(n.args[-3] != "none" for n in computation_nodes):
212
+ return False
213
+ return True
214
+
215
+ return fn
216
+
217
+ def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None):
218
+ def fn(match):
219
+ matched = _is_single_computation_op(computation_op, lowp_dtype)(match)
220
+ computation_node = filter_nodes(match.nodes, computation_op)[0]
221
+ if lowp_dtype:
222
+ conversion_dtype_nodes = filter_nodes(
223
+ match.nodes, prims.convert_element_type.default
224
+ )
225
+ if len(conversion_dtype_nodes) != 2:
226
+ return False
227
+ # fusion pattern is always in the form of computation_op + to_float32 + unary_op + to_bfloat16
228
+ if computation_node == conversion_dtype_nodes[0].args[0]:
229
+ to_float = conversion_dtype_nodes[0].args[1]
230
+ to_lp = conversion_dtype_nodes[1].args[1]
231
+ else:
232
+ to_float = conversion_dtype_nodes[1].args[1]
233
+ to_lp = conversion_dtype_nodes[0].args[1]
234
+ matched = matched and to_float == torch.float and to_lp == lowp_dtype
235
+ return matched
236
+
237
+ return fn
238
+
239
+ def _register_unary_fusion_lowering(
240
+ pattern, unary_attr, computation_op, lowp_dtype=None
241
+ ):
242
+ @register_lowering_pattern(
243
+ pattern,
244
+ extra_check=_is_valid_computation_unary_fusion(computation_op, lowp_dtype),
245
+ )
246
+ def fn(match, *args, **kwargs):
247
+ computation_args = list(args)[:-3] + [
248
+ unary_attr.op_name,
249
+ unary_attr.scalars_attr,
250
+ unary_attr.algorithm_attr,
251
+ ]
252
+ return L[computation_op](*computation_args)
253
+
254
+ return fn
255
+
256
+ def _register_leaky_relu_fusion_lowering(pattern, computation_op, lowp_dtype=None):
257
+ @register_lowering_pattern(
258
+ pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype)
259
+ )
260
+ def fn(match, *args, **kwargs):
261
+ negative_slope = kwargs.get("negative_slope")
262
+ if isinstance(negative_slope, ir.TensorBox):
263
+ matched = False
264
+ else: # inp is a Number
265
+ matched = True
266
+ if lowp_dtype:
267
+ dtype1 = kwargs.get("to_float")
268
+ dtype2 = (
269
+ kwargs.get("to_bf16")
270
+ if lowp_dtype == torch.bfloat16
271
+ else kwargs.get("to_fp16")
272
+ )
273
+ matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype
274
+ computation_args = list(args)
275
+ if matched:
276
+ computation_args = computation_args[:-3] + [
277
+ "leaky_relu",
278
+ [negative_slope],
279
+ "",
280
+ ]
281
+ return L[computation_op](*computation_args)
282
+ else:
283
+ # computation_args += ["none", [], ""]
284
+ out = L[computation_op](*computation_args)
285
+ if lowp_dtype:
286
+ out = L[prims.convert_element_type.default](out, dtype=torch.float)
287
+ out = L[aten.where](
288
+ L[aten.gt](out, 0),
289
+ out,
290
+ L[aten.mul](out, negative_slope),
291
+ )
292
+ if lowp_dtype:
293
+ out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined]
294
+ return out
295
+
296
+ return fn
297
+
298
+ def _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype=None):
299
+ @register_lowering_pattern(
300
+ pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype)
301
+ )
302
+ def fn(match, *args, **kwargs):
303
+ min_value = kwargs.get("min_value")
304
+ max_value = kwargs.get("max_value")
305
+ if isinstance(min_value, ir.TensorBox) or isinstance(
306
+ max_value, ir.TensorBox
307
+ ):
308
+ matched = False
309
+ else: # inp is a Number
310
+ assert max_value is not None
311
+ matched = min_value <= max_value
312
+ if lowp_dtype:
313
+ dtype1 = kwargs.get("to_float")
314
+ dtype2 = (
315
+ kwargs.get("to_bf16")
316
+ if lowp_dtype == torch.bfloat16
317
+ else kwargs.get("to_fp16")
318
+ )
319
+ matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype
320
+ computation_args = list(args)
321
+ if matched:
322
+ computation_args = computation_args[:-3] + [
323
+ "hardtanh",
324
+ [min_value, max_value],
325
+ "",
326
+ ]
327
+ return L[computation_op](*computation_args)
328
+ else:
329
+ out = L[computation_op](*computation_args)
330
+ if lowp_dtype:
331
+ out = L[prims.convert_element_type.default](out, dtype=torch.float)
332
+ out = L[aten.clamp_max](L[aten.clamp_min](out, min_value), max_value)
333
+ if lowp_dtype:
334
+ out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined]
335
+ return out
336
+
337
+ return fn
338
+
339
+ _binary_attr = {
340
+ aten.add: "add",
341
+ ops.add: "add",
342
+ aten.sub: "sub",
343
+ ops.sub: "sub",
344
+ }
345
+
346
+ def _is_valid_binary(match, fn):
347
+ binary_nodes = filter_nodes(match.nodes, fn)
348
+ if len(binary_nodes) < 1:
349
+ return False
350
+
351
+ def get_meta_value(argument: torch.fx.node.Argument):
352
+ # Only torch.fx.Node is expected to have meta.
353
+ if isinstance(argument, torch.fx.Node):
354
+ return argument.meta.get("val", None)
355
+ return None
356
+
357
+ if any(
358
+ not isinstance(get_meta_value(n.args[0]), torch.Tensor)
359
+ or not isinstance(get_meta_value(n.args[1]), torch.Tensor)
360
+ for n in binary_nodes
361
+ ):
362
+ return False
363
+ # check alpha is one.
364
+ if any(
365
+ get_arg_value(n, 2, kwarg_name="alpha") != 1.0
366
+ and get_arg_value(n, 2, kwarg_name="alpha") is not None
367
+ for n in binary_nodes
368
+ ):
369
+ return False
370
+ if any(
371
+ get_meta_value(n.args[0]).size() != get_meta_value(n.args[1]).size()
372
+ or get_meta_value(n.args[0]).device != get_meta_value(n.args[1]).device
373
+ or get_meta_value(n.args[0]).dtype != get_meta_value(n.args[1]).dtype
374
+ for n in binary_nodes
375
+ ):
376
+ return False
377
+ # check args[0] and args[1] is not same
378
+ if any(n.args[0] == n.args[1] for n in binary_nodes):
379
+ return False
380
+ return True
381
+
382
+ def _is_valid_computation_binary(computation_op, binary_op, other_index=None):
383
+ def fn(match):
384
+ if not _is_single_computation_op(computation_op)(match):
385
+ return False
386
+ if not _is_valid_binary(match, binary_op):
387
+ return False
388
+ return True
389
+
390
+ return fn
391
+
392
+ def _get_remaining_users(extra_input_node, compute_node):
393
+ # Think about this pattern:
394
+ # ReLU
395
+ # / \
396
+ # Conv1
397
+ # / \
398
+ # Conv2
399
+ # \ /
400
+ # Add
401
+ # Although, the extra input node (ReLU) has more than 1 users: Conv1 and Add.
402
+ # The Conv1 is the ancestor node of the current compute node (Conv2).
403
+ # This indicates that the buffer of ReLU has completed all its usage,
404
+ # So we can safely make changes to it now by doing Conv2->Add inplace fusion.
405
+ # Take above case as example:
406
+ # * extra_input_node: ReLU
407
+ # * compute_node: Conv2
408
+ # _get_remaining_users will return the users of extra_input_node which are not
409
+ # ancestor node of compute_node.
410
+ def _is_ancestor_node(_current_node, _ancestor_node):
411
+ # Check whether _ancestor_node is the ancestor node of _current_node
412
+ _node_list = [_current_node]
413
+ _visited_nodes = set()
414
+ while len(_node_list) != 0:
415
+ _current_node = _node_list.pop(0)
416
+ if _current_node not in _visited_nodes:
417
+ _visited_nodes.add(_current_node)
418
+ if _current_node == _ancestor_node:
419
+ return True
420
+ elif isinstance(
421
+ _current_node, torch.fx.Node
422
+ ) and _current_node.op not in ["placeholder", "output", "get_attr"]:
423
+ for input in _current_node.all_input_nodes:
424
+ _node_list.append(input) # noqa: PERF402
425
+ return False
426
+
427
+ return [
428
+ user
429
+ for user in list(extra_input_node.users)
430
+ if not _is_ancestor_node(compute_node, user)
431
+ ]
432
+
433
+ def _is_valid_computation_binary_inplace(computation_op, binary_op, other_index):
434
+ def fn(match):
435
+ if not _is_valid_computation_binary(computation_op, binary_op)(match):
436
+ return False
437
+ binary_nodes = filter_nodes(match.nodes, binary_op)
438
+
439
+ def _get_compute_node(_binary_node, _other_index):
440
+ assert (
441
+ len(_binary_node.all_input_nodes) == 2
442
+ ), "Binary node should have 2 input nodes."
443
+ _compute_index = 1 if (_other_index == 0) else 0
444
+ return _binary_node.args[_compute_index]
445
+
446
+ def _other_input_not_inplaceable(_binary_node, _other_index):
447
+ _compute_node = _get_compute_node(_binary_node, _other_index)
448
+ return (
449
+ len(
450
+ _get_remaining_users(
451
+ _binary_node.args[_other_index], _compute_node
452
+ )
453
+ )
454
+ > 1
455
+ or _binary_node.args[_other_index] == _compute_node.args[0]
456
+ )
457
+
458
+ if any(_other_input_not_inplaceable(n, other_index) for n in binary_nodes):
459
+ return False
460
+ if any(
461
+ n.args[other_index].op in ["placeholder", "output"]
462
+ for n in binary_nodes
463
+ ):
464
+ return False
465
+ return True
466
+
467
+ return fn
468
+
469
+ def _register_binary_unary_fusion_lowering(
470
+ pattern,
471
+ computation_op,
472
+ binary_op,
473
+ fusion_op,
474
+ unary_attr=None,
475
+ ):
476
+ @register_lowering_pattern(
477
+ pattern, extra_check=_is_valid_computation_binary(computation_op, binary_op)
478
+ )
479
+ def fn(match, *args, **kwargs):
480
+ other = kwargs.get("other")
481
+ assert isinstance(other, ir.TensorBox)
482
+ binary_attr = _binary_attr[binary_op]
483
+ args_list = list(args)
484
+ computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr]
485
+ if len(args_list) > 6:
486
+ if unary_attr is not None:
487
+ computation_args += [
488
+ 1.0,
489
+ unary_attr.op_name,
490
+ unary_attr.scalars_attr,
491
+ unary_attr.algorithm_attr,
492
+ ]
493
+ else:
494
+ computation_args += [1.0, None, [], None]
495
+ return L[fusion_op](*computation_args)
496
+
497
+ return fn
498
+
499
+ def _can_be_inplace(_other):
500
+ if isinstance(_other.data, ir.View):
501
+ return _can_be_inplace(_other.data)
502
+ else:
503
+ return not (
504
+ isinstance(_other.data, ir.ReinterpretView)
505
+ or len(_other.get_inputs_that_alias_output()) > 0
506
+ )
507
+
508
+ def _register_binary_unary_maybe_inplace_fusion_lowering(
509
+ pattern,
510
+ computation_op,
511
+ binary_op,
512
+ inplace_fusion_op,
513
+ outplace_fusion_op,
514
+ unary_attr=None,
515
+ other_index=None,
516
+ ):
517
+ @register_lowering_pattern(
518
+ pattern,
519
+ extra_check=_is_valid_computation_binary_inplace(
520
+ computation_op, binary_op, other_index
521
+ ),
522
+ )
523
+ def fn(match, *args, **kwargs):
524
+ other = kwargs.get("other")
525
+ assert isinstance(other, ir.TensorBox)
526
+ binary_attr = _binary_attr[binary_op]
527
+ args_list = list(args)
528
+ computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr]
529
+ if len(args_list) > 6:
530
+ if unary_attr is not None:
531
+ computation_args += [
532
+ 1.0,
533
+ unary_attr.op_name,
534
+ unary_attr.scalars_attr,
535
+ unary_attr.algorithm_attr,
536
+ ]
537
+ else:
538
+ computation_args += [1.0, None, [], None]
539
+ # Make sure the other is not an alias or mutation(fx side doesn't has such info).
540
+ other.realize()
541
+ if not _can_be_inplace(other):
542
+ return L[outplace_fusion_op](*computation_args)
543
+ return L[inplace_fusion_op](*computation_args)
544
+
545
+ return fn
546
+
547
+ computation_ops = [
548
+ mkldnn._convolution_pointwise.default,
549
+ mkldnn._linear_pointwise.default,
550
+ mkldnn._convolution_transpose_pointwise.default,
551
+ ]
552
+
553
+ class UnaryAttr:
554
+ def __init__(
555
+ self, op_name: str, scalars_attr=None, algorithm_attr=None
556
+ ) -> None:
557
+ self.op_name = op_name
558
+ self.scalars_attr = scalars_attr if scalars_attr else []
559
+ self.algorithm_attr = algorithm_attr if algorithm_attr else ""
560
+
561
+ def _register_unary_fusion():
562
+ computation_call_fns = [_conv_call, _linear_call, _conv_transpose_call]
563
+
564
+ def _unary_fusion_patterns(lowp_dtype):
565
+ replacement_unary_fusion_patterns = {
566
+ UnaryAttr("gelu", algorithm_attr="tanh"): [
567
+ _unary_fusion_pattern(_gelu_fusion_2, call_fn, 4, lowp_dtype)
568
+ for call_fn in computation_call_fns
569
+ ],
570
+ UnaryAttr("gelu", algorithm_attr="none"): [
571
+ _unary_fusion_pattern(_gelu_fusion_1, call_fn, 2, lowp_dtype)
572
+ for call_fn in computation_call_fns
573
+ ],
574
+ UnaryAttr("hardswish"): [
575
+ _unary_fusion_pattern(_hardswish_fusion, call_fn, 2, lowp_dtype)
576
+ for call_fn in computation_call_fns
577
+ ],
578
+ UnaryAttr("hardsigmoid"): [
579
+ _unary_fusion_pattern(_hardsigmoid_fusion, call_fn, 1, lowp_dtype)
580
+ for call_fn in computation_call_fns
581
+ ],
582
+ UnaryAttr("swish"): [
583
+ _unary_fusion_pattern(_silu_fusion, call_fn, 2, lowp_dtype)
584
+ for call_fn in computation_call_fns
585
+ ],
586
+ }
587
+ if not lowp_dtype:
588
+ call_user1 = [call_fn(users=1) for call_fn in computation_call_fns]
589
+ replacement_unary_fusion_patterns.update(
590
+ {
591
+ UnaryAttr("relu"): [
592
+ _combined_fusion(u, aten.relu) for u in call_user1
593
+ ],
594
+ UnaryAttr("sigmoid"): [
595
+ _combined_fusion(u, aten.sigmoid) for u in call_user1
596
+ ],
597
+ UnaryAttr("tanh"): [
598
+ _combined_fusion(u, aten.tanh) for u in call_user1
599
+ ],
600
+ }
601
+ )
602
+
603
+ return replacement_unary_fusion_patterns
604
+
605
+ for lowp_dtype in [torch.bfloat16, torch.float16, None]:
606
+ replace_patterns = _unary_fusion_patterns(lowp_dtype)
607
+ for unary_attr, patterns in replace_patterns.items():
608
+ _register_unary_fusion_lowering(
609
+ patterns[0], unary_attr, computation_ops[0], lowp_dtype
610
+ )
611
+ _register_unary_fusion_lowering(
612
+ patterns[1], unary_attr, computation_ops[1], lowp_dtype
613
+ )
614
+ _register_unary_fusion_lowering(
615
+ patterns[2], unary_attr, computation_ops[2], lowp_dtype
616
+ )
617
+ _leaky_relu_patterns = [
618
+ _unary_fusion_pattern(_leaky_relu_fusion, call_fn, 3, lowp_dtype)
619
+ for call_fn in computation_call_fns
620
+ ]
621
+ for pattern, computation_op in zip(_leaky_relu_patterns, computation_ops):
622
+ _register_leaky_relu_fusion_lowering(
623
+ pattern, computation_op, lowp_dtype
624
+ )
625
+ hardtanh_patterns = [
626
+ _unary_fusion_pattern(_hardtanh_fusion, call_fn, 1, lowp_dtype)
627
+ for call_fn in computation_call_fns
628
+ ]
629
+ for pattern, computation_op in zip(hardtanh_patterns, computation_ops):
630
+ _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype)
631
+
632
+ def _register_inplace_fusion():
633
+ binary_ops = [aten.add, ops.add]
634
+ inplace_fusion_op = mkldnn._convolution_pointwise_.binary
635
+ outplace_fusion_op = mkldnn._convolution_pointwise.binary
636
+ conv_call = _conv_call(users=1)
637
+ conv_op = computation_ops[0]
638
+ for binary_op in binary_ops:
639
+ binary_v1 = _binary_fusion_v1(conv_call, binary_op)
640
+ binary_unary_v1 = _combined_fusion(binary_v1, aten.relu)
641
+ _register_binary_unary_maybe_inplace_fusion_lowering(
642
+ binary_unary_v1,
643
+ conv_op,
644
+ binary_op,
645
+ inplace_fusion_op,
646
+ outplace_fusion_op,
647
+ other_index=0,
648
+ unary_attr=UnaryAttr("relu"),
649
+ )
650
+ _register_binary_unary_maybe_inplace_fusion_lowering(
651
+ binary_v1,
652
+ conv_op,
653
+ binary_op,
654
+ inplace_fusion_op,
655
+ outplace_fusion_op,
656
+ other_index=0,
657
+ )
658
+ binary_v2 = _binary_fusion_v2(conv_call, binary_op)
659
+ binary_unary_v2 = _combined_fusion(binary_v2, aten.relu)
660
+ _register_binary_unary_maybe_inplace_fusion_lowering(
661
+ binary_unary_v2,
662
+ conv_op,
663
+ binary_op,
664
+ inplace_fusion_op,
665
+ outplace_fusion_op,
666
+ other_index=1,
667
+ unary_attr=UnaryAttr("relu"),
668
+ )
669
+ _register_binary_unary_maybe_inplace_fusion_lowering(
670
+ binary_v2,
671
+ conv_op,
672
+ binary_op,
673
+ inplace_fusion_op,
674
+ outplace_fusion_op,
675
+ other_index=1,
676
+ )
677
+
678
+ def _register_binary_fusion():
679
+ binary_ops = [aten.add, ops.add, aten.sub, ops.sub]
680
+ fusion_ops = [
681
+ mkldnn._convolution_pointwise.binary,
682
+ mkldnn._linear_pointwise.binary,
683
+ ]
684
+ _computation_user_1 = [_conv_call(users=1), _linear_call(users=1)]
685
+ for computation_call, computation_op, fusion_op in zip(
686
+ _computation_user_1, computation_ops[:-1], fusion_ops
687
+ ):
688
+ for binary_op in binary_ops:
689
+ pattern = _binary_fusion_v2(computation_call, binary_op)
690
+ _register_binary_unary_fusion_lowering(
691
+ pattern, computation_op, binary_op, fusion_op
692
+ )
693
+
694
+ for binary_op in [aten.add, ops.add]:
695
+ pattern = _binary_fusion_v1(computation_call, binary_op)
696
+ _register_binary_unary_fusion_lowering(
697
+ pattern, computation_op, binary_op, fusion_op
698
+ )
699
+
700
+ def _register_binary_unary_fusion():
701
+ binary_ops = [aten.add, ops.add, aten.sub, ops.sub]
702
+ fusion_ops = [mkldnn._convolution_pointwise.binary]
703
+ _computation_user_1 = [_conv_call(users=1)]
704
+ for computation_call, computation_op, fusion_op in zip(
705
+ _computation_user_1, computation_ops[:-1], fusion_ops
706
+ ):
707
+ for binary_op in binary_ops:
708
+ pattern_v1 = _combined_fusion(
709
+ _binary_fusion_v2(computation_call, binary_op), aten.relu
710
+ )
711
+ _register_binary_unary_fusion_lowering(
712
+ pattern_v1,
713
+ computation_op,
714
+ binary_op,
715
+ fusion_op,
716
+ unary_attr=UnaryAttr("relu"),
717
+ )
718
+ for binary_op in [aten.add, ops.add]:
719
+ pattern_v2 = _combined_fusion(
720
+ _binary_fusion_v1(computation_call, binary_op), aten.relu
721
+ )
722
+ _register_binary_unary_fusion_lowering(
723
+ pattern_v2,
724
+ computation_op,
725
+ binary_op,
726
+ fusion_op,
727
+ unary_attr=UnaryAttr("relu"),
728
+ )
729
+
730
+ def _recover_linear():
731
+ # convert reshape+linear+reshape to a single linear for applying fusion path.
732
+ @register_freezing_graph_pattern(
733
+ CallFunction(
734
+ aten.reshape.default,
735
+ CallFunction(
736
+ mkldnn._linear_pointwise.default,
737
+ CallFunction(
738
+ aten.reshape.default,
739
+ Arg(),
740
+ KeywordArg("reshape_1"),
741
+ _users=MULTIPLE,
742
+ ),
743
+ Arg(),
744
+ Arg(),
745
+ Arg(),
746
+ Arg(),
747
+ Arg(),
748
+ ),
749
+ KeywordArg("reshape_2"),
750
+ ),
751
+ pass_number=1,
752
+ )
753
+ def reshape_linear_reshape_pattern(match, *args, **kwargs):
754
+ def get_val(val):
755
+ return val if isinstance(val, int) else val.meta.get("val")
756
+
757
+ reshape_1 = kwargs.get("reshape_1")
758
+ reshape_2 = kwargs.get("reshape_2")
759
+ assert isinstance(reshape_1, list)
760
+ assert isinstance(reshape_2, list)
761
+ assert len(reshape_1) == 2
762
+
763
+ graph = match.graph
764
+ reshape_2_node = match.output_node()
765
+ linear_input_node = reshape_2_node.args[0].args[0].args[0]
766
+ # check linear's input's shape[:-1] == reshape_2[:-1]
767
+ # and check product(reshape_2[:-1]) == reshape_1[0]
768
+ can_remove_reshape = linear_input_node.meta.get("val").shape[
769
+ :-1
770
+ ] == torch.Size([get_val(val) for val in reshape_2[:-1]])
771
+ can_remove_reshape = can_remove_reshape and (
772
+ reduce(
773
+ operator.mul,
774
+ [get_val(val) for val in reshape_2[:-1]],
775
+ )
776
+ == get_val(reshape_1[0])
777
+ )
778
+
779
+ if can_remove_reshape:
780
+ repl = graph.call_function(mkldnn._linear_pointwise.default, args)
781
+ repl.meta.update(reshape_2_node.meta)
782
+ reshape_2_node.replace_all_uses_with(repl)
783
+ old_linear_node = reshape_2_node.args[0]
784
+ reshape_1_node = old_linear_node.args[0]
785
+ graph.erase_node(reshape_2_node)
786
+ graph.erase_node(old_linear_node)
787
+ if len(reshape_1_node.users) == 0:
788
+ graph.erase_node(reshape_1_node)
789
+
790
+ def is_linear_add_bias(match):
791
+ add_node = match.output_node()
792
+ linear_node = add_node.args[0]
793
+ packed_weight_node = linear_node.args[1]
794
+ assert packed_weight_node.target == mkldnn._reorder_linear_weight
795
+ transpose_weight_node = packed_weight_node.args[0]
796
+ assert transpose_weight_node.target == aten.permute.default
797
+ weight_meta = transpose_weight_node.args[0].meta.get("val")
798
+ bias_node = add_node.args[1]
799
+ if isinstance(bias_node, int):
800
+ # we only folding bias if it is a constant
801
+ return False
802
+ bias_meta = add_node.args[1].meta.get("val")
803
+ if weight_meta is None or bias_meta is None:
804
+ return False
805
+ assert weight_meta.dtype in (
806
+ torch.bfloat16,
807
+ torch.float16,
808
+ )
809
+ if bias_meta.dtype != weight_meta.dtype:
810
+ return False
811
+ return (
812
+ linear_node.args[2] is None
813
+ and bias_meta.dim() == 1
814
+ and bias_meta.size(0) == weight_meta.size(1)
815
+ )
816
+
817
+ # convert linear+bias to a single linear for applying fusion path.
818
+ @register_freezing_graph_pattern(
819
+ CallFunction(
820
+ aten.add.Tensor,
821
+ CallFunction(mkldnn._linear_pointwise.default, *_linear_args),
822
+ Arg(),
823
+ ),
824
+ pass_number=1,
825
+ extra_check=is_linear_add_bias,
826
+ )
827
+ def linear_bias_pattern(match, *args):
828
+ graph = match.graph
829
+ add_node = match.output_node()
830
+ linear_node = add_node.args[0]
831
+ new_args = list(linear_node.args)
832
+ new_args[2] = add_node.args[1]
833
+ repl = graph.call_function(
834
+ mkldnn._linear_pointwise.default, tuple(new_args)
835
+ )
836
+ repl.meta.update(add_node.meta)
837
+ add_node.replace_all_uses_with(repl)
838
+ match.erase_nodes()
839
+
840
+ def _is_packable_mkldnn_rnn_layer(match):
841
+ lstm_node = match.output_node()
842
+ POS_WEIGHTS = [1, 2]
843
+ POS_INPUTS = [0, 5, 6]
844
+ POS_ARGS = POS_WEIGHTS + POS_INPUTS
845
+ # Weights should be Constant
846
+ if any(
847
+ lstm_node.args[POS_WEIGHT].op != "get_attr" for POS_WEIGHT in POS_WEIGHTS
848
+ ):
849
+ return False
850
+
851
+ # Meta info for weights and inputs should be available
852
+ if any(lstm_node.args[POS_ARG].meta.get("val") is None for POS_ARG in POS_ARGS):
853
+ return False
854
+
855
+ # Check device
856
+ if any(
857
+ lstm_node.args[POS_ARG].meta.get("val").device.type != "cpu"
858
+ for POS_ARG in POS_ARGS
859
+ ):
860
+ return False
861
+
862
+ # Check dtype
863
+ if any(
864
+ lstm_node.args[POS_ARG].meta.get("val").dtype == torch.bfloat16
865
+ and not mkldnn._is_mkldnn_bf16_supported()
866
+ for POS_ARG in POS_ARGS
867
+ ):
868
+ return False
869
+ if any(
870
+ lstm_node.args[POS_ARG].meta.get("val").dtype == torch.float16
871
+ and not mkldnn._is_mkldnn_fp16_supported()
872
+ for POS_ARG in POS_ARGS
873
+ ):
874
+ return False
875
+
876
+ return True
877
+
878
+ def _is_packable_convolution(match):
879
+ """
880
+ Check if the node is supported for MKLDNN convolution.
881
+ """
882
+ conv_node = match.output_node()
883
+ input_meta_value = conv_node.args[0].meta.get("val")
884
+ weight_meta_value = conv_node.args[1].meta.get("val")
885
+ if input_meta_value is None or weight_meta_value is None:
886
+ return False
887
+ input_size = input_meta_value.shape
888
+ if conv_node.args[1].op != "get_attr":
889
+ return False
890
+ for meta_value in [input_meta_value, weight_meta_value]:
891
+ if (
892
+ meta_value is None
893
+ or meta_value.device.type != "cpu"
894
+ or (meta_value.dim() != 4 and meta_value.dim() != 5)
895
+ ):
896
+ return False
897
+ if (
898
+ input_meta_value.dtype == torch.bfloat16
899
+ or weight_meta_value.dtype == torch.bfloat16
900
+ ):
901
+ if not mkldnn._is_mkldnn_bf16_supported():
902
+ return False
903
+ if (
904
+ input_meta_value.dtype == torch.float16
905
+ or weight_meta_value.dtype == torch.float16
906
+ ):
907
+ if not mkldnn._is_mkldnn_fp16_supported():
908
+ return False
909
+ is_transposed = conv_node.args[-3]
910
+ if is_transposed:
911
+ # TODO: Support dynamic shape case for MKLDNN conv transpose.
912
+ if has_free_symbols(input_size):
913
+ return False
914
+ groups = conv_node.args[-1]
915
+ in_channels = weight_meta_value.size(0)
916
+ # doesn't support group_depthwise_conv_transpose.
917
+ if groups > 1 and groups == in_channels:
918
+ return False
919
+ # Port from: aten/src/ATen/native/Convolution.cpp:is_output_padding_big
920
+ output_paddings = conv_node.args[-2]
921
+ strides = conv_node.args[3]
922
+ if any(
923
+ output_padding >= stride
924
+ for output_padding, stride in zip(output_paddings, strides)
925
+ ):
926
+ return False
927
+ return True
928
+
929
+ def _is_packable_linear(match):
930
+ """
931
+ Check if the node is supported for MKLDNN linear.
932
+ """
933
+ linear_node = match.output_node()
934
+ # mkldnn linear only supports beta=1or0 and alpha=1
935
+ if linear_node.target == aten.addmm.default:
936
+ alpha = linear_node.kwargs.get("alpha", 1.0)
937
+ beta = linear_node.kwargs.get("beta", 1.0)
938
+ if (beta != 0.0 and beta != 1.0) or alpha != 1.0:
939
+ return False
940
+ # weight_idx is 1 for aten.mm and is 2 for aten.addmm
941
+ weight_idx = 2 if linear_node.target == aten.addmm.default else 1
942
+ if linear_node.args[weight_idx].op != "get_attr":
943
+ return False
944
+ input_meta_value = linear_node.args[weight_idx - 1].meta.get("val")
945
+ weight_meta_value = linear_node.args[weight_idx].meta.get("val")
946
+ if input_meta_value is None or weight_meta_value is None:
947
+ return False
948
+ batch_size = input_meta_value.shape[0]
949
+ if (
950
+ input_meta_value.dtype == torch.float64
951
+ or weight_meta_value.dtype == torch.float64
952
+ ):
953
+ return False
954
+ is_lp_weight = weight_meta_value.dtype in (
955
+ torch.bfloat16,
956
+ torch.float16,
957
+ )
958
+ # on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol.
959
+ # on aarch64, use mkldnn op for fp32 as well if acl is enabled
960
+ if (
961
+ not is_lp_weight
962
+ and not mkldnn._is_mkldnn_acl_supported()
963
+ and ((not torch._C.has_mkl) or has_free_symbols(batch_size))
964
+ ):
965
+ return False
966
+ for meta_value in [input_meta_value, weight_meta_value]:
967
+ if (
968
+ meta_value is None
969
+ or meta_value.device.type != "cpu"
970
+ or meta_value.dim() != 2
971
+ ):
972
+ return False
973
+ if weight_idx == 2:
974
+ bias_meta_value = linear_node.args[0].meta.get("val")
975
+ if (
976
+ bias_meta_value is None
977
+ or meta_value.device.type != "cpu"
978
+ or bias_meta_value.dim() != 1
979
+ or bias_meta_value.size(0) != weight_meta_value.size(1)
980
+ ):
981
+ return False
982
+
983
+ if (
984
+ input_meta_value.dtype == torch.bfloat16
985
+ or weight_meta_value.dtype == torch.bfloat16
986
+ ):
987
+ if not mkldnn._is_mkldnn_bf16_supported():
988
+ return False
989
+ if (
990
+ input_meta_value.dtype == torch.float16
991
+ or weight_meta_value.dtype == torch.float16
992
+ ):
993
+ if not mkldnn._is_mkldnn_fp16_supported():
994
+ return False
995
+ return True
996
+
997
+ _aten_conv_args = (
998
+ Arg(),
999
+ Arg(),
1000
+ Arg(),
1001
+ Arg(),
1002
+ Arg(),
1003
+ Arg(),
1004
+ KeywordArg("is_transposed"),
1005
+ Arg(),
1006
+ Arg(),
1007
+ )
1008
+
1009
+ _aten_mkldnn_rnn_layer_args = (
1010
+ Arg(), # input
1011
+ Arg(), # weight0
1012
+ Arg(), # weight1
1013
+ Arg(), # weight2
1014
+ Arg(), # weight3
1015
+ Arg(), # hx_
1016
+ Arg(), # cx_
1017
+ KeywordArg("reverse"), # reverse
1018
+ Arg(), # batch_sizes
1019
+ Arg(), # mode
1020
+ Arg(), # hidden_size
1021
+ Arg(), # num_layers
1022
+ Arg(), # has_biases
1023
+ Arg(), # bidirectional
1024
+ Arg(), # batch_first
1025
+ Arg(), # train
1026
+ )
1027
+
1028
+ def _register_weight_pack_pass():
1029
+ @register_freezing_graph_pattern(
1030
+ CallFunction(aten.convolution.default, *_aten_conv_args),
1031
+ extra_check=_is_packable_convolution,
1032
+ )
1033
+ def convolution(match, *args, **kwargs):
1034
+ is_transposed = kwargs.get("is_transposed")
1035
+ assert isinstance(is_transposed, bool)
1036
+ graph = match.graph
1037
+ conv_node = match.output_node()
1038
+ input_size = conv_node.args[0].meta.get("val").shape
1039
+ with graph.inserting_before(conv_node):
1040
+ constant_args = [args[4], args[3], args[5], args[-1]]
1041
+ packed_weight_op = mkldnn._reorder_convolution_weight
1042
+ packed_conv_op = mkldnn._convolution_pointwise.default
1043
+ if is_transposed:
1044
+ constant_args.insert(1, args[-2]) # output_padding
1045
+ packed_weight_op = mkldnn._reorder_convolution_transpose_weight
1046
+ packed_conv_op = mkldnn._convolution_transpose_pointwise.default
1047
+ if not has_free_symbols(input_size):
1048
+ packed_weight_inputs = (
1049
+ (args[1],) + tuple(constant_args) + (input_size,)
1050
+ )
1051
+ packed_weight_node = graph.create_node(
1052
+ "call_function", packed_weight_op, args=packed_weight_inputs
1053
+ )
1054
+ else:
1055
+ assert not is_transposed
1056
+ # For dynamic shape case, we need to pack weight in runtime.
1057
+ packed_weight_node = args[1]
1058
+ packed_conv_inputs = (
1059
+ (args[0], packed_weight_node, args[2])
1060
+ + tuple(constant_args)
1061
+ + ("none", [], "")
1062
+ )
1063
+ packed_conv_node = graph.create_node(
1064
+ "call_function", packed_conv_op, tuple(packed_conv_inputs)
1065
+ )
1066
+ conv_node.replace_all_uses_with(packed_conv_node)
1067
+ packed_conv_node.meta.update(conv_node.meta)
1068
+ graph.erase_node(conv_node)
1069
+
1070
+ @register_freezing_graph_pattern(
1071
+ CallFunction(aten.mkldnn_rnn_layer.default, *_aten_mkldnn_rnn_layer_args),
1072
+ extra_check=_is_packable_mkldnn_rnn_layer,
1073
+ )
1074
+ def mkldnn_rnn_layer(match, *args, **kwargs):
1075
+ def get_item(graph, node, index):
1076
+ return graph.call_function(operator.getitem, (node, index))
1077
+
1078
+ graph = match.graph
1079
+ lstm_node = match.output_node()
1080
+ input = args[0]
1081
+ weight0, weight1 = args[1:3]
1082
+ reverse = kwargs.get("reverse")
1083
+ packed_lstm_op = aten.mkldnn_rnn_layer.default
1084
+ hidden_size = args[9]
1085
+ has_biases = args[11]
1086
+ batch_first = args[13]
1087
+ with graph.inserting_before(lstm_node):
1088
+ packed_weight_op = mkldnn._reorder_mkldnn_rnn_layer_weight.default
1089
+ packed_weight_inputs = (
1090
+ weight0,
1091
+ weight1,
1092
+ hidden_size,
1093
+ reverse,
1094
+ has_biases,
1095
+ batch_first,
1096
+ )
1097
+ packed_weight_node = graph.create_node(
1098
+ "call_function", packed_weight_op, packed_weight_inputs, {}, "name"
1099
+ )
1100
+ packed_weight_items = [
1101
+ get_item(graph, packed_weight_node, i) for i in range(2)
1102
+ ]
1103
+ pack_lstm_inputs = (
1104
+ args[0],
1105
+ *packed_weight_items,
1106
+ args[3],
1107
+ args[4],
1108
+ args[5],
1109
+ args[6],
1110
+ reverse,
1111
+ *args[7:],
1112
+ )
1113
+
1114
+ packed_lstm_node = graph.create_node(
1115
+ "call_function", packed_lstm_op, args=pack_lstm_inputs
1116
+ )
1117
+ lstm_node.replace_all_uses_with(packed_lstm_node)
1118
+ packed_lstm_node.meta.update(lstm_node.meta)
1119
+ graph.erase_node(lstm_node)
1120
+
1121
+ @register_freezing_graph_pattern(
1122
+ CallFunction(
1123
+ aten.addmm.default,
1124
+ Arg(),
1125
+ Arg(),
1126
+ Arg(),
1127
+ beta=KeywordArg("beta"),
1128
+ alpha=KeywordArg("alpha"),
1129
+ ),
1130
+ extra_check=_is_packable_linear,
1131
+ )
1132
+ @register_freezing_graph_pattern(
1133
+ CallFunction(aten.mm.default, Arg(), Arg()),
1134
+ extra_check=_is_packable_linear,
1135
+ )
1136
+ def linear(match, *args, **kwargs):
1137
+ graph = match.graph
1138
+ linear_node = match.output_node()
1139
+ input = args[0] if linear_node.target == aten.mm.default else args[1]
1140
+ bias = (
1141
+ None
1142
+ if linear_node.target == aten.mm.default
1143
+ or (
1144
+ linear_node.target == aten.addmm.default
1145
+ and linear_node.kwargs.get("beta", 1.0) == 0.0
1146
+ )
1147
+ else args[0]
1148
+ )
1149
+ weight = args[1] if linear_node.target == aten.mm.default else args[2]
1150
+ with graph.inserting_before(linear_node):
1151
+ transpose_weight_node = graph.create_node(
1152
+ "call_function", aten.permute.default, (weight, (1, 0))
1153
+ )
1154
+ weight_dtype = weight.meta.get("val").dtype
1155
+ is_lp_weight = weight_dtype in (
1156
+ torch.bfloat16,
1157
+ torch.float16,
1158
+ )
1159
+ batch_size = input.meta.get("val").shape[0]
1160
+ if has_free_symbols(batch_size):
1161
+ assert (
1162
+ is_lp_weight or mkldnn._is_mkldnn_acl_supported()
1163
+ ), f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
1164
+ # For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance.
1165
+ packed_weight_inputs = (
1166
+ transpose_weight_node,
1167
+ batch_size.node.shape_env.size_hint(batch_size.node.expr)
1168
+ if has_free_symbols(batch_size)
1169
+ else batch_size,
1170
+ )
1171
+ # MKL packed matrix can't be copied to a different address because the internal implementation
1172
+ # depends on the alignment of internally-stored metadata.
1173
+ # In aot mode, we need to firstly save the packed weight, when loading it,
1174
+ # it will be in a different address which doesn't work.
1175
+ # Disable MKL prepack linear in AOT mode
1176
+ packed_weight_op = (
1177
+ mkldnn._reorder_linear_weight
1178
+ if (
1179
+ is_lp_weight
1180
+ or mkldnn._is_mkldnn_acl_supported()
1181
+ or V.aot_compilation is True
1182
+ )
1183
+ else torch.ops.mkl._mkl_reorder_linear_weight
1184
+ )
1185
+ packed_weight_node = graph.create_node(
1186
+ "call_function", packed_weight_op, args=packed_weight_inputs
1187
+ )
1188
+
1189
+ packed_linear_inputs: Tuple[Any, ...] = (input, packed_weight_node)
1190
+ if (
1191
+ is_lp_weight
1192
+ or mkldnn._is_mkldnn_acl_supported()
1193
+ or V.aot_compilation is True
1194
+ ):
1195
+ packed_linear_inputs += (bias, "none", [], "")
1196
+ packed_linear_op = mkldnn._linear_pointwise.default
1197
+ else:
1198
+ packed_linear_inputs += (transpose_weight_node, bias, batch_size)
1199
+ packed_linear_op = torch.ops.mkl._mkl_linear
1200
+ packed_linear_node = graph.create_node(
1201
+ "call_function", packed_linear_op, packed_linear_inputs
1202
+ )
1203
+ linear_node.replace_all_uses_with(packed_linear_node)
1204
+ packed_linear_node.meta.update(linear_node.meta)
1205
+ graph.erase_node(linear_node)
1206
+
1207
+ def _eliminate_duplicate_packed_nodes(gm):
1208
+ """
1209
+ Combine packed weight nodes with the same inputs to reduce memory usage.
1210
+ for example:
1211
+ class Model(nn.Module):
1212
+ def __init__(self) -> None:
1213
+ super().__init__()
1214
+ self.linear = nn.Linear(32, 32, bias=True)
1215
+
1216
+ def forward(self, x):
1217
+ return self.linear(self.linear(x))
1218
+
1219
+ the above's packed weight nodes are duplicate if two linear calls have same input size.
1220
+ """
1221
+ if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()):
1222
+ return gm
1223
+
1224
+ packed_weight_ops = [
1225
+ torch._C._nn.mkldnn_reorder_conv2d_weight,
1226
+ torch._C._nn.mkldnn_reorder_conv3d_weight,
1227
+ mkldnn._reorder_convolution_transpose_weight,
1228
+ mkldnn._reorder_linear_weight,
1229
+ mkldnn._reorder_mkldnn_rnn_layer_weight,
1230
+ ]
1231
+ if torch._C.has_mkl:
1232
+ packed_weight_ops.append(torch.ops.mkl._mkl_reorder_linear_weight)
1233
+
1234
+ for node in gm.graph.nodes:
1235
+ if node.target in packed_weight_ops and len(node.args[0].users) > 1:
1236
+ for user_node in list(node.args[0].users.keys()):
1237
+ if (
1238
+ user_node.target == node.target
1239
+ and user_node != node
1240
+ and user_node.args == node.args
1241
+ ):
1242
+ user_node.replace_all_uses_with(node)
1243
+ gm.graph.erase_node(user_node)
1244
+
1245
+ @functools.lru_cache(None)
1246
+ def _mkldnn_fusion_init():
1247
+ # TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now.
1248
+ # Otherwise even the matmul or innerproduct can not be accelerated with acl
1249
+ if (
1250
+ torch.backends.mkldnn.enabled
1251
+ and torch.backends.mkldnn.is_available()
1252
+ and not torch.ops.mkldnn._is_mkldnn_acl_supported()
1253
+ ):
1254
+ _register_unary_fusion()
1255
+ _register_inplace_fusion()
1256
+ _register_binary_unary_fusion()
1257
+ _register_binary_fusion()
1258
+ _register_quantization_lowerings()
1259
+ _register_woq_lowerings()
1260
+
1261
+ @functools.lru_cache(None)
1262
+ def _mkldnn_weight_pack_init():
1263
+ if torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available():
1264
+ _register_weight_pack_pass()
1265
+ _recover_linear()
1266
+ _register_quantization_weight_pack_pass()
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/numeric_utils.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import gc
3
+ import logging
4
+ import os
5
+ import random
6
+ import traceback
7
+
8
+ import numpy
9
+
10
+ import torch
11
+ import torch.optim as optim
12
+
13
+ from .. import config
14
+
15
+
16
+ logger: logging.Logger = logging.getLogger(__name__)
17
+
18
+ MAIN_RANDOM_SEED = 1337
19
+
20
+ # Set the CUBLAS_WORKSPACE_CONFIG environment variable
21
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
22
+
23
+
24
+ # If the two forward functions involve any non-deterministic operations,
25
+ # such as certain types of parallelism or asynchronous execution,
26
+ # this can also lead to different outputs.
27
+ def set_deterministic() -> None:
28
+ """Make torch manual seed deterministic."""
29
+
30
+ torch.manual_seed(MAIN_RANDOM_SEED)
31
+ random.seed(MAIN_RANDOM_SEED)
32
+ numpy.random.seed(MAIN_RANDOM_SEED)
33
+ torch.use_deterministic_algorithms(True)
34
+
35
+
36
+ def clean_memory() -> None:
37
+ """Clean memory to avoid OOM."""
38
+ gc.collect()
39
+ torch.cuda.empty_cache()
40
+
41
+
42
+ # We compare the numerical results before and after pre/post grad fx passes
43
+ # transformation to make sure the numerical results are the same.
44
+ def compare_dict_tensors(dict_base, dict_control, precision):
45
+ if len(set(dict_base.keys())) != len(set(dict_control.keys())):
46
+ logger.warning("Mismatch keys found before and after pre/post grad fx passes.")
47
+ logger.debug("keys before pre/post grad fx passes %s", dict_base.keys())
48
+ logger.debug("keys after pre/post grad fx passes %s", dict_control.keys())
49
+ return False
50
+ is_allclose = True
51
+ for key in dict_base.keys():
52
+ if key not in dict_control:
53
+ logger.warning(
54
+ "Mismatch parameter name %s does not exist after pre/post grad fx passes",
55
+ key,
56
+ )
57
+ # Some parameters have `None`, and not every param has a valid .grad field, we skip them
58
+ if dict_base[key] is None or dict_control[key] is None:
59
+ continue
60
+ if not torch.allclose(
61
+ dict_base[key],
62
+ dict_control[key],
63
+ rtol=precision,
64
+ atol=precision,
65
+ equal_nan=True,
66
+ ):
67
+ logger.warning(
68
+ "Mismatch parameter values found before and after pre/post grad fx passes."
69
+ )
70
+ logger.debug("value before pre/post grad fx passes %s", dict_base[key])
71
+ logger.debug("value after pre/post grad fx passes %s", dict_control[key])
72
+ is_allclose = False
73
+ return is_allclose
74
+
75
+
76
+ def compare_tuple_tensors(tuple_base, tuple_control, precision):
77
+ if len(tuple_base) != len(tuple_control):
78
+ logger.warning(
79
+ "Mismatch fw output length. before transformation: %s, after transformation: %s",
80
+ len(tuple_base),
81
+ len(tuple_control),
82
+ )
83
+ return False
84
+ is_allclose = True
85
+ for i in range(len(tuple_base)):
86
+ # Some parameters have `None`, we skip them
87
+ if tuple_base[i] is None or tuple_control[i] is None:
88
+ continue
89
+ if not torch.allclose(
90
+ tuple_base[i],
91
+ tuple_control[i],
92
+ rtol=precision,
93
+ atol=precision,
94
+ equal_nan=True,
95
+ ):
96
+ logger.debug(
97
+ "forward output before pre/post grad fx passes %s", tuple_base[i]
98
+ )
99
+ logger.debug(
100
+ "forward output after pre/post grad fx passes %s", tuple_control[i]
101
+ )
102
+ is_allclose = False
103
+ return is_allclose
104
+
105
+
106
+ def compare_parameters(model_base, model_control, precision):
107
+ return compare_dict_tensors(
108
+ dict(model_base.named_parameters()),
109
+ dict(model_control.named_parameters()),
110
+ precision,
111
+ )
112
+
113
+
114
+ def compare_forward_output(pred_base, pred_control, precision):
115
+ return compare_tuple_tensors(
116
+ pred_base,
117
+ pred_control,
118
+ precision,
119
+ )
120
+
121
+
122
+ def compare_gradients(model_base, model_control, precision):
123
+ grad_base = {key: param.grad for key, param in model_base.named_parameters()}
124
+ grad_pt2 = {key: param.grad for key, param in model_control.named_parameters()}
125
+ return compare_dict_tensors(
126
+ grad_base,
127
+ grad_pt2,
128
+ precision,
129
+ )
130
+
131
+
132
+ def run_model(
133
+ model_base, model_control, model_input, num_iterations=10, precision=1e-4
134
+ ):
135
+ clean_memory()
136
+ for i in range(num_iterations):
137
+ logger.info("start %s iteration", i)
138
+ set_deterministic()
139
+ pred_base = model_base(*model_input)
140
+ set_deterministic()
141
+ pred_control = model_control(*model_input)
142
+
143
+ res = compare_parameters(model_base, model_control, precision)
144
+ logger.info("compare parameters. Numerical result : %s", res)
145
+
146
+ res = compare_forward_output(pred_base, pred_control, precision)
147
+ logger.info("compare loss/predict. Numerical result : %s", res)
148
+ # tensor may not have a grad_fn
149
+ try:
150
+ _ = pred_base[0].sum().backward(retain_graph=True)
151
+ _ = pred_control[0].sum().backward(retain_graph=True)
152
+ res = compare_gradients(model_base, model_control, precision)
153
+ logger.info("compare param grad. Numerical result : %s", res)
154
+ except Exception:
155
+ logger.exception("Exception when comparing gradients")
156
+ traceback.print_exc()
157
+
158
+ if config.fx_passes_numeric_check["requires_optimizer"]:
159
+ try:
160
+ optimizer_base = optim.SGD(
161
+ [param for name, param in model_base.named_parameters()], lr=0.01
162
+ )
163
+ optimizer_base.step()
164
+
165
+ optimizer_control = optim.SGD(
166
+ [param for name, param in model_control.named_parameters()], lr=0.01
167
+ )
168
+ optimizer_control.step()
169
+
170
+ res = compare_parameters(model_base, model_control, precision)
171
+ logger.info(
172
+ "compare parameters with optimizer added. Numerical result : %s",
173
+ res,
174
+ )
175
+ except Exception as e:
176
+ logger.exception(
177
+ "Exception when optimizer is added to check parameter names"
178
+ )
179
+ traceback.print_exc()
180
+ else:
181
+ logger.warning(
182
+ "no parameter with optimizer to compare with length %s before transformation"
183
+ " and the length %s after transformation",
184
+ len(dict(model_base.named_parameters())),
185
+ len(dict(model_control.named_parameters())),
186
+ )
187
+
188
+
189
+ def numeric_check_if_enabled(
190
+ gm_before_fx_passes,
191
+ gm_after_fx_passes,
192
+ example_inputs,
193
+ num_iterations,
194
+ precision,
195
+ ):
196
+ # need to topo-sort graphmodule before we run the model,
197
+ # otherwise it may fail as refer before def
198
+ # fail silently in order not to block the model run
199
+ try:
200
+ with torch.autograd.set_detect_anomaly(True):
201
+ run_model(
202
+ gm_before_fx_passes,
203
+ gm_after_fx_passes,
204
+ example_inputs,
205
+ num_iterations=num_iterations,
206
+ precision=precision,
207
+ )
208
+ except Exception as e:
209
+ logger.warning(
210
+ "Runtime numeric check failed in pre grad fx passes with error: %s", e
211
+ )
212
+ traceback.print_exc()
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pad_mm.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+ import itertools
4
+ import operator
5
+ import typing
6
+ from typing import Callable, List, Optional, Union
7
+
8
+ import torch
9
+ import torch._inductor.runtime.runtime_utils
10
+ from torch import Tensor
11
+ from torch._dynamo.utils import counters
12
+ from torch._inductor import utils
13
+ from torch._inductor.autoheuristic.autoheuristic import (
14
+ AHContext,
15
+ AutoHeuristic,
16
+ LocalFeedback,
17
+ )
18
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
19
+ context_add_strides,
20
+ context_add_using_tf32,
21
+ pad_mm_operations,
22
+ pad_mm_precondition,
23
+ )
24
+ from torch._subclasses.fake_tensor import FakeTensor
25
+ from torch.utils._mode_utils import no_dispatch
26
+
27
+ from ...utils._triton import has_triton
28
+ from ..pattern_matcher import (
29
+ fwd_only,
30
+ gen_register_replacement,
31
+ joint_fwd_bwd,
32
+ Match,
33
+ ReplaceFn,
34
+ SearchFn,
35
+ )
36
+
37
+
38
+ aten = torch.ops.aten
39
+
40
+
41
+ # This flag is only used for testing purpose.
42
+ # Changing it to True will ignore comparing do_bench times
43
+ # between original pattern and padded one.
44
+ _skip_do_bench_times = False
45
+
46
+
47
+ def fetch_fake_tensors(match, kwarg_names) -> List[Tensor]:
48
+ kwargs = match.kwargs
49
+ return [kwargs[name].meta["val"] for name in kwarg_names]
50
+
51
+
52
+ def unwrap_fake_args(*arg_names):
53
+ def decorator(func):
54
+ def wrapper(match):
55
+ fake_tensors = fetch_fake_tensors(match, arg_names)
56
+ return func(*fake_tensors)
57
+
58
+ return wrapper
59
+
60
+ return decorator
61
+
62
+
63
+ def get_alignment_size(x: Tensor) -> int:
64
+ return get_alignment_size_dtype(x.dtype)
65
+
66
+
67
+ def get_alignment_size_dtype(dtype: torch.dtype) -> int:
68
+ if dtype == torch.float16 or dtype == torch.half or dtype == torch.bfloat16:
69
+ return 8
70
+ elif dtype == torch.float32 or dtype == torch.float:
71
+ return 4
72
+ else:
73
+ return 0
74
+
75
+
76
+ def check_device(a: Tensor, b: Tensor) -> bool:
77
+ return a.is_cuda and b.is_cuda
78
+
79
+
80
+ def check_dtype(a: Tensor, b: Tensor) -> bool:
81
+ return a.is_floating_point() and b.is_floating_point()
82
+
83
+
84
+ def should_pad_common(
85
+ mat1: Tensor, mat2: Tensor, input: Optional[Tensor] = None
86
+ ) -> bool:
87
+ # It's fine we have symbolic shapes or strides as long as they
88
+ # have hints. Later, we will make sure we only pad non-symbolic dimensions.
89
+ def valid_shape_and_stride(t: Optional[Tensor]) -> bool:
90
+ if t is None:
91
+ return True
92
+
93
+ symbolic_cnt = 0
94
+ for x in t.size():
95
+ if isinstance(x, int):
96
+ continue
97
+ elif utils.is_symbolic(x):
98
+ if not x.node.has_hint():
99
+ return False
100
+ symbolic_cnt += 1
101
+ else:
102
+ return False
103
+ # filter out cases where all dimentions are symbolic
104
+ if symbolic_cnt == len(t.size()):
105
+ return False
106
+ return all(
107
+ isinstance(x, int) or (utils.is_symbolic(x) and x.node.has_hint())
108
+ for x in t.stride()
109
+ )
110
+
111
+ return (
112
+ torch._inductor.config.shape_padding
113
+ and check_device(mat1, mat2)
114
+ and check_dtype(mat1, mat2)
115
+ and all(valid_shape_and_stride(t) for t in (mat1, mat2, input))
116
+ )
117
+
118
+
119
+ def get_padded_length(x: Union[int, torch.SymInt], alignment_size) -> int:
120
+ # we don't pad x if it is symbolic
121
+ if isinstance(x, torch.SymInt) or alignment_size == 0 or x % alignment_size == 0:
122
+ return 0
123
+
124
+ # ignore dim that can be squeezed away
125
+ if x == 1:
126
+ return 0
127
+
128
+ return int((x // alignment_size + 1) * alignment_size) - x
129
+
130
+
131
+ def pad_dim(x: Tensor, padded_length: int, dim: int) -> Tensor:
132
+ if padded_length == 0:
133
+ return x
134
+ pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :])
135
+ return torch.cat([x, pad], dim=dim)
136
+
137
+
138
+ def addmm_pattern(
139
+ input: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float
140
+ ) -> Tensor:
141
+ return aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)
142
+
143
+
144
+ def should_pad_addmm(match: Match) -> bool:
145
+ mat1, mat2, input = fetch_fake_tensors(match, ("mat1", "mat2", "input"))
146
+ return should_pad_common(mat1, mat2, input) and should_pad_bench(
147
+ match, mat1, mat2, torch.ops.aten.addmm, input=input
148
+ )
149
+
150
+
151
+ def pad_addmm(
152
+ input: Optional[Tensor],
153
+ mat1: Tensor,
154
+ mat2: Tensor,
155
+ m_padded_length: int,
156
+ k_padded_length: int,
157
+ n_padded_length: int,
158
+ beta=1.0,
159
+ alpha=1.0,
160
+ mat1_pre_padded: bool = False,
161
+ mat2_pre_padded: bool = False,
162
+ ):
163
+ # for paddings, dim order is reversed for some reasons
164
+ # and for every dim, we need to specify left and right padding
165
+ if not mat1_pre_padded:
166
+ mat1 = pad_mat1(
167
+ mat1, m_padded_length=m_padded_length, k_padded_length=k_padded_length
168
+ )
169
+ if not mat2_pre_padded:
170
+ mat2 = pad_mat2(
171
+ mat2, k_padded_length=k_padded_length, n_padded_length=n_padded_length
172
+ )
173
+
174
+ # the add broadcasts, so we only pad if the dimension != 1
175
+ if input is not None:
176
+ if n_padded_length != 0:
177
+ if input.dim() == 2 and input.shape[1] != 1:
178
+ input = pad_dim(input, n_padded_length, 1)
179
+ elif input.dim() == 1 and input.shape[0] != 1:
180
+ input = pad_dim(input, n_padded_length, 0)
181
+ if m_padded_length != 0 and input.dim() == 2 and input.shape[0] != 1:
182
+ input = pad_dim(input, m_padded_length, 0)
183
+
184
+ res = aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)
185
+
186
+ if m_padded_length != 0:
187
+ res = res[:-m_padded_length, :]
188
+ if n_padded_length != 0:
189
+ res = res[:, :-n_padded_length]
190
+ return res
191
+
192
+
193
+ def addmm_replace(
194
+ input: Optional[Tensor], mat1: Tensor, mat2: Tensor, beta=1.0, alpha=1.0
195
+ ) -> Tensor:
196
+ k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
197
+ n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
198
+ m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
199
+ return pad_addmm(
200
+ input,
201
+ mat1,
202
+ mat2,
203
+ m_padded_length,
204
+ k_padded_length,
205
+ n_padded_length,
206
+ beta,
207
+ alpha,
208
+ )
209
+
210
+
211
+ def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool:
212
+ denominator = M * K + N * K + M * N
213
+ if denominator == 0:
214
+ return False
215
+ arithmetic_intensity = (M * N * K) / denominator
216
+
217
+ # we have experienced some large perf hits in this case, even in bandwidth bound regimes
218
+ if (
219
+ dtype is torch.bfloat16
220
+ and K > M
221
+ and K > N
222
+ and torch.cuda.get_device_capability() < (9, 0)
223
+ ): # doesnt repro on h100s:
224
+ return True
225
+
226
+ # Fails with AMD
227
+ try:
228
+ machine_balance = (
229
+ 1000 * utils.get_device_tflops(dtype)
230
+ ) / utils.get_gpu_dram_gbps()
231
+ except Exception:
232
+ return True
233
+
234
+ # dram_gbps might be underestimating bandwidth because of cache.
235
+ # if we estimate machine balance too low we might miss some speedups,
236
+ # if we extimate too high there will be unnecessary compilation time increase.
237
+ # TODO - finetune coefficient here. As a reference point, Triton mm model assumes
238
+ # 80% of reads are in cache and cache is 4x faster than dram_gbps
239
+ machine_balance = machine_balance * 0.5
240
+
241
+ return arithmetic_intensity > machine_balance
242
+
243
+
244
+ @functools.lru_cache(None)
245
+ def get_pad_cache():
246
+ return torch._inductor.codecache.LocalCache()
247
+
248
+
249
+ def get_cached_should_pad(key: str) -> bool:
250
+ return get_pad_cache().lookup(key)
251
+
252
+
253
+ def set_cached_should_pad(key: str, value: bool):
254
+ return get_pad_cache().set_value(key, value=value)
255
+
256
+
257
+ def get_cached_base_mm_benchmark_time(key: str) -> float:
258
+ return get_pad_cache().lookup(key)
259
+
260
+
261
+ def set_cached_base_mm_benchmark_time(key: str, value: float):
262
+ return get_pad_cache().set_value(key, value=value)
263
+
264
+
265
+ def should_pad_bench_key(
266
+ match,
267
+ mat1: Tensor,
268
+ mat2: Tensor,
269
+ op,
270
+ input: Optional[Tensor] = None,
271
+ is_base_time_key=False,
272
+ ) -> str:
273
+ def tensor_key(t):
274
+ return (t.shape, t.stride(), t.dtype)
275
+
276
+ tf32_key = (
277
+ None if mat1.dtype != torch.float32 else torch.backends.cuda.matmul.allow_tf32
278
+ )
279
+
280
+ def fmt_pad(name):
281
+ if is_base_time_key:
282
+ return None
283
+ return f"exclude_pad:{should_exclude_padding_time(match, name)}"
284
+
285
+ key = (
286
+ tensor_key(mat1),
287
+ tensor_key(mat2),
288
+ fmt_pad("mat1"),
289
+ fmt_pad("mat2"),
290
+ op,
291
+ input if input is None else tensor_key(input),
292
+ tf32_key,
293
+ )
294
+
295
+ key = str(key)
296
+ if is_base_time_key:
297
+ key = f"base mm time: {key}"
298
+ return key
299
+
300
+
301
+ def get_non_view_def(node):
302
+ if node.op == operator.getitem:
303
+ return get_non_view_def(node.args[0])
304
+
305
+ if (
306
+ node.op == "call_function"
307
+ and isinstance(node.target, torch._ops.OpOverload)
308
+ and utils.is_view(node.target)
309
+ ):
310
+ return get_non_view_def(node.all_input_nodes[0])
311
+
312
+ return node
313
+
314
+
315
+ def should_exclude_padding_time(match, arg_name):
316
+ node_def = get_non_view_def(match.kwargs[arg_name])
317
+
318
+ # constant padding converts tensors to contiguous so even if the input tensor
319
+ # can be planned layout transform is not free. TODO - way to pad and preserve layout ?
320
+ if not fetch_fake_tensors(match, (arg_name,))[0].is_contiguous():
321
+ return False
322
+
323
+ # TODO - see issue https://githpub.com/pytorch/pytorch/issues/128889
324
+ # We would only able to completely plan these out if we were only doing
325
+ # first dimension padding. non-first we would still need a copy
326
+ # because these outputs are fixed dense.
327
+ cannot_plan_output = [
328
+ aten.mm.default,
329
+ aten.convolution.default,
330
+ aten.convolution_backward.default,
331
+ aten.bmm.default,
332
+ aten.addmm.default,
333
+ aten._scaled_dot_product_flash_attention.default,
334
+ aten._scaled_dot_product_efficient_attention.default,
335
+ ]
336
+
337
+ if node_def.target in cannot_plan_output:
338
+ return False
339
+
340
+ if (
341
+ node_def.target == aten.cat.default
342
+ and len(node_def.all_input_nodes)
343
+ > torch._inductor.config.max_pointwise_cat_inputs
344
+ ):
345
+ return False
346
+
347
+ # optimistically assume we should be able to memory plan away
348
+ # all non inputs
349
+ return node_def.op != "placeholder"
350
+
351
+
352
+ def should_pad(key: str, ori_time, pad_time) -> bool:
353
+ multiplier = 1.1
354
+ # Shape padding introduces additional memory ops. Based on microbenchmarks, 1.1x represents a reasonable
355
+ # tradeoff between performance improvement from shape padding and overhead from additional memory ops
356
+ # TODO: Build a learned model which would be better than this heuristic
357
+ if "shape_padding_multiplier" in torch._inductor.config.post_grad_fusion_options:
358
+ multiplier = torch._inductor.config.post_grad_fusion_options[
359
+ "shape_padding_multiplier"
360
+ ].get("value", 1.1)
361
+ counters["inductor"]["shape_padding_multiplier"] += 1
362
+ should_pad = _skip_do_bench_times or ori_time > pad_time * multiplier
363
+ set_cached_should_pad(key, should_pad)
364
+ return should_pad
365
+
366
+
367
+ def should_pad_bench(
368
+ match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None
369
+ ) -> bool:
370
+ do_bench = functools.partial(
371
+ torch._inductor.runtime.benchmarking.benchmarker.benchmark_gpu,
372
+ warmup=5,
373
+ )
374
+ m_padded_length = 0
375
+ n_padded_length = 0
376
+ batchsize = 1
377
+ with no_dispatch():
378
+ if op is torch.ops.aten.mm or op is torch.ops.aten.addmm:
379
+ m = mat1.shape[0]
380
+ k = mat1.shape[1]
381
+ n = mat2.shape[1]
382
+ k_padded_length = get_padded_length(k, get_alignment_size(mat1))
383
+ n_padded_length = get_padded_length(n, get_alignment_size(mat2))
384
+ m_padded_length = get_padded_length(m, get_alignment_size(mat1))
385
+ elif op is torch.ops.aten.bmm:
386
+ batchsize = mat1.shape[0]
387
+ m = mat1.shape[1]
388
+ k = mat1.shape[2]
389
+ n = mat2.shape[2]
390
+ k_padded_length = get_padded_length(k, get_alignment_size(mat1))
391
+ m_padded_length = get_padded_length(m, get_alignment_size(mat1))
392
+ n_padded_length = get_padded_length(n, get_alignment_size(mat2))
393
+ else:
394
+ return False
395
+
396
+ if m_padded_length == k_padded_length == n_padded_length == 0:
397
+ return False
398
+
399
+ def realize_symbols(ds):
400
+ return [d if isinstance(d, int) else d.node.hint for d in ds]
401
+
402
+ if any(
403
+ dim == 0
404
+ for dim in itertools.chain(
405
+ realize_symbols(mat1.shape), realize_symbols(mat2.shape)
406
+ )
407
+ ):
408
+ return False
409
+
410
+ if torch._inductor.config.force_shape_pad:
411
+ return True
412
+
413
+ if not has_triton():
414
+ return False
415
+
416
+ if not is_mm_compute_bound(m, k, n, mat1.dtype):
417
+ return False
418
+
419
+ # We don't want to look up the cache for cases that are trivially false
420
+ # since it does file io
421
+ key = should_pad_bench_key(match, mat1, mat2, op, input)
422
+
423
+ cached_pad = get_cached_should_pad(key)
424
+ if cached_pad is not None:
425
+ return cached_pad
426
+
427
+ def realize_tensor(t):
428
+ if isinstance(t, FakeTensor):
429
+ size_hints = realize_symbols(t.size())
430
+ stride_hint = realize_symbols(t.stride())
431
+ real_size = (
432
+ sum((d - 1) * s for d, s in zip(size_hints, stride_hint)) + 1
433
+ )
434
+ real_t = torch.randn(real_size, dtype=t.dtype, device=t.device)
435
+ return torch.as_strided(real_t, size_hints, stride_hint)
436
+ else:
437
+ return torch.randn_like(t)
438
+
439
+ mat1 = realize_tensor(mat1)
440
+ mat2 = realize_tensor(mat2)
441
+
442
+ # since we key on whether or not the inputs can be memory planned, set cache for the
443
+ # original time which is unaffected by whether or not the input can be planned
444
+ ori_time_key = should_pad_bench_key(
445
+ match, mat1, mat2, op, input, is_base_time_key=True
446
+ )
447
+ ori_time = get_cached_base_mm_benchmark_time(ori_time_key)
448
+ if ori_time is None and op is torch.ops.aten.addmm and input is not None:
449
+ # realize bias for addmm
450
+ input = realize_tensor(input)
451
+
452
+ mat1_pad = mat1
453
+ mat2_pad = mat2
454
+
455
+ is_bmm = op is torch.ops.aten.bmm
456
+
457
+ mat1_pre_padded = should_exclude_padding_time(match, "mat1")
458
+ fns = []
459
+ if mat1_pre_padded and (m_padded_length or k_padded_length):
460
+ mat1_pad = pad_mat1(
461
+ mat1_pad,
462
+ m_padded_length=m_padded_length,
463
+ k_padded_length=k_padded_length,
464
+ is_bmm=is_bmm,
465
+ )
466
+
467
+ def write_pad():
468
+ if is_bmm:
469
+ mat1_pad[:, -m_padded_length:, -k_padded_length:].fill_(0)
470
+ else:
471
+ mat1_pad[-m_padded_length:, -k_padded_length:].fill_(0)
472
+
473
+ fns.append(write_pad)
474
+
475
+ mat2_pre_padded = should_exclude_padding_time(match, "mat2")
476
+ if mat2_pre_padded and (k_padded_length or n_padded_length):
477
+ mat2_pad = pad_mat2(
478
+ mat2_pad,
479
+ k_padded_length=k_padded_length,
480
+ n_padded_length=n_padded_length,
481
+ is_bmm=is_bmm,
482
+ )
483
+
484
+ def write_pad():
485
+ if is_bmm:
486
+ mat2_pad[:, -k_padded_length:, -n_padded_length:].fill_(0)
487
+ else:
488
+ mat2_pad[-k_padded_length:, -n_padded_length:].fill_(0)
489
+
490
+ fns.append(write_pad)
491
+
492
+ if op is torch.ops.aten.addmm:
493
+ input_pad = None
494
+ if input is not None and input.is_cuda:
495
+ input_pad = torch.randn_like(input)
496
+ fns.append(
497
+ lambda: pad_addmm(
498
+ input_pad,
499
+ mat1_pad,
500
+ mat2_pad,
501
+ m_padded_length,
502
+ k_padded_length,
503
+ n_padded_length,
504
+ mat1_pre_padded=mat1_pre_padded,
505
+ mat2_pre_padded=mat2_pre_padded,
506
+ )
507
+ )
508
+ elif op is torch.ops.aten.mm:
509
+ fns.append(
510
+ lambda: pad_mm(
511
+ mat1_pad,
512
+ mat2_pad,
513
+ m_padded_length,
514
+ k_padded_length,
515
+ n_padded_length,
516
+ mat1_pre_padded=mat1_pre_padded,
517
+ mat2_pre_padded=mat2_pre_padded,
518
+ )
519
+ )
520
+ else:
521
+ fns.append(
522
+ lambda: pad_bmm(
523
+ mat1_pad,
524
+ mat2_pad,
525
+ m_padded_length,
526
+ k_padded_length,
527
+ n_padded_length,
528
+ mat1_pre_padded=mat1_pre_padded,
529
+ mat2_pre_padded=mat2_pre_padded,
530
+ )
531
+ )
532
+
533
+ def orig_bench_fn():
534
+ if op is torch.ops.aten.bmm or op is torch.ops.aten.mm:
535
+ op(mat1, mat2)
536
+ else:
537
+ op(input, mat1, mat2)
538
+
539
+ def pad_bench_fn():
540
+ for fn in fns:
541
+ fn()
542
+
543
+ if (
544
+ torch._inductor.config.run_autoheuristic("pad_mm")
545
+ and op is torch.ops.aten.mm
546
+ ):
547
+ ah_should_pad = run_autoheuristic(
548
+ mat1,
549
+ mat2,
550
+ orig_bench_fn,
551
+ pad_bench_fn,
552
+ m_padded_length,
553
+ k_padded_length,
554
+ n_padded_length,
555
+ do_bench,
556
+ mat1_pre_padded,
557
+ mat2_pre_padded,
558
+ ori_time,
559
+ ori_time_key,
560
+ key,
561
+ )
562
+ if ah_should_pad is not None:
563
+ return ah_should_pad
564
+
565
+ if ori_time is None:
566
+ ori_time = do_bench(orig_bench_fn)
567
+ set_cached_base_mm_benchmark_time(ori_time_key, ori_time)
568
+
569
+ pad_time = do_bench(pad_bench_fn)
570
+ return should_pad(key, ori_time, pad_time)
571
+
572
+
573
+ def get_context(
574
+ mat1: Tensor,
575
+ mat2: Tensor,
576
+ mat1_pre_padded: bool,
577
+ mat2_pre_padded: bool,
578
+ m_padded_length: int,
579
+ k_padded_length: int,
580
+ n_padded_length: int,
581
+ ):
582
+ context = AHContext()
583
+
584
+ context.add_feature("m", mat1.shape[0])
585
+ context.add_feature("k", mat1.shape[1])
586
+ context.add_feature("n", mat2.shape[1])
587
+
588
+ context_add_strides(context, "mat1", mat1.stride())
589
+ context_add_strides(context, "mat2", mat2.stride())
590
+
591
+ context.add_feature("m_padded_length", m_padded_length)
592
+ context.add_feature("k_padded_length", k_padded_length)
593
+ context.add_feature("n_padded_length", n_padded_length)
594
+
595
+ context.add_feature("mat1_align_size", get_alignment_size(mat1))
596
+ context.add_feature("mat2_align_size", get_alignment_size(mat2))
597
+
598
+ context.add_feature("mat1_dtype", mat1.dtype, is_categorical=True)
599
+ context.add_feature("mat2_dtype", mat2.dtype, is_categorical=True)
600
+
601
+ context.add_feature("prepadded_mat1", mat1_pre_padded, is_categorical=True)
602
+ context.add_feature("prepadded_mat2", mat2_pre_padded, is_categorical=True)
603
+
604
+ context_add_using_tf32(context, mat1.dtype)
605
+ return context
606
+
607
+
608
+ def run_autoheuristic(
609
+ mat1: Tensor,
610
+ mat2: Tensor,
611
+ orig_bench_fn: Callable[[], None],
612
+ pad_bench_fn: Callable[[], None],
613
+ m_padded_length: int,
614
+ k_padded_length: int,
615
+ n_padded_length: int,
616
+ do_bench,
617
+ mat1_pre_padded: bool,
618
+ mat2_pre_padded: bool,
619
+ ori_time,
620
+ ori_time_key: str,
621
+ key: str,
622
+ ) -> Optional[bool]:
623
+ def feedback_fn(choice: str):
624
+ if choice == orig_choice:
625
+ return do_bench(orig_bench_fn)
626
+ elif choice == pad_choice:
627
+ return do_bench(pad_bench_fn)
628
+ return None
629
+
630
+ def fallback() -> str:
631
+ return "autotune"
632
+
633
+ orig_choice = "orig"
634
+ pad_choice = "pad"
635
+ choices = [orig_choice, pad_choice]
636
+ feedback = LocalFeedback(feedback_fn)
637
+ context = get_context(
638
+ mat1,
639
+ mat2,
640
+ mat1_pre_padded,
641
+ mat2_pre_padded,
642
+ m_padded_length,
643
+ k_padded_length,
644
+ n_padded_length,
645
+ )
646
+ name = "pad_mm"
647
+ autoheuristic = AutoHeuristic(
648
+ fallback=fallback,
649
+ choices=choices,
650
+ feedback=feedback,
651
+ context=context,
652
+ name=name,
653
+ augment_context=pad_mm_operations(),
654
+ precondition=pad_mm_precondition,
655
+ )
656
+ choice = autoheuristic.get_choice()
657
+ choice2should_pad = {orig_choice: False, pad_choice: True, "autotune": None}
658
+ ah_should_pad = choice2should_pad.get(choice, None)
659
+
660
+ if torch._inductor.config.collect_autoheuristic(name):
661
+ ah_ori_time = autoheuristic.get_collected_feedback(orig_choice)
662
+ ah_pad_time = autoheuristic.get_collected_feedback(pad_choice)
663
+
664
+ # if precondition is not satisifed, autoheuristic does not collect data
665
+ if ah_ori_time is not None and ah_pad_time is not None:
666
+ if ori_time is None:
667
+ set_cached_base_mm_benchmark_time(ori_time_key, ah_ori_time)
668
+ return should_pad(key, ah_ori_time, ah_pad_time)
669
+ if ah_should_pad is not None:
670
+ set_cached_should_pad(key, ah_should_pad)
671
+ return ah_should_pad
672
+
673
+
674
+ def mm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor:
675
+ return aten.mm(mat1, mat2)
676
+
677
+
678
+ def should_pad_mm(match: Match) -> bool:
679
+ mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2"))
680
+ return should_pad_common(mat1, mat2) and should_pad_bench(
681
+ match, mat1, mat2, torch.ops.aten.mm
682
+ )
683
+
684
+
685
+ def pad_mat1(mat1, *, m_padded_length, k_padded_length, is_bmm=False):
686
+ if m_padded_length == 0 and k_padded_length == 0:
687
+ return mat1
688
+ elif k_padded_length != 0 and m_padded_length != 0:
689
+ # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding
690
+ pad_arg = [0, k_padded_length, 0, m_padded_length]
691
+ if is_bmm:
692
+ pad_arg.extend((0, 0))
693
+ return aten.constant_pad_nd(mat1, pad_arg)
694
+ elif m_padded_length != 0:
695
+ return pad_dim(mat1, m_padded_length, 0 if not is_bmm else 1)
696
+ else:
697
+ assert k_padded_length != 0
698
+ return pad_dim(mat1, k_padded_length, 1 if not is_bmm else 2)
699
+
700
+
701
+ def pad_mat2(mat2, *, k_padded_length, n_padded_length, is_bmm=False):
702
+ if k_padded_length == 0 and n_padded_length == 0:
703
+ return mat2
704
+ elif k_padded_length != 0 and n_padded_length != 0:
705
+ # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding
706
+ pad_arg = [0, n_padded_length, 0, k_padded_length]
707
+ if is_bmm:
708
+ pad_arg.extend((0, 0))
709
+ return aten.constant_pad_nd(mat2, pad_arg)
710
+ elif k_padded_length != 0:
711
+ return pad_dim(mat2, k_padded_length, 0 if not is_bmm else 1)
712
+ else:
713
+ assert n_padded_length != 0
714
+ return pad_dim(mat2, n_padded_length, 1 if not is_bmm else 2)
715
+
716
+
717
+ def pad_mm(
718
+ mat1: Tensor,
719
+ mat2: Tensor,
720
+ m_padded_length: int,
721
+ k_padded_length: int,
722
+ n_padded_length: int,
723
+ mat1_pre_padded: bool = False,
724
+ mat2_pre_padded: bool = False,
725
+ ) -> Tensor:
726
+ if not mat1_pre_padded:
727
+ mat1 = pad_mat1(
728
+ mat1, m_padded_length=m_padded_length, k_padded_length=k_padded_length
729
+ )
730
+ if not mat2_pre_padded:
731
+ mat2 = pad_mat2(
732
+ mat2, k_padded_length=k_padded_length, n_padded_length=n_padded_length
733
+ )
734
+ res = aten.mm(mat1, mat2)
735
+ if m_padded_length != 0:
736
+ res = res[:-m_padded_length, :]
737
+ if n_padded_length != 0:
738
+ res = res[:, :-n_padded_length]
739
+ return res
740
+
741
+
742
+ def mm_replace(mat1: Tensor, mat2: Tensor) -> Tensor:
743
+ k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
744
+ m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
745
+ n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
746
+ return pad_mm(
747
+ mat1,
748
+ mat2,
749
+ m_padded_length,
750
+ k_padded_length,
751
+ n_padded_length,
752
+ )
753
+
754
+
755
+ def bmm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor:
756
+ return aten.bmm(mat1, mat2)
757
+
758
+
759
+ def should_pad_bmm(match: Match) -> bool:
760
+ mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2"))
761
+ return should_pad_common(mat1, mat2) and should_pad_bench(
762
+ match, mat1, mat2, torch.ops.aten.bmm
763
+ )
764
+
765
+
766
+ def pad_bmm(
767
+ mat1: Tensor,
768
+ mat2: Tensor,
769
+ m_padded_length: int,
770
+ k_padded_length: int,
771
+ n_padded_length: int,
772
+ mat1_pre_padded: bool = False,
773
+ mat2_pre_padded: bool = False,
774
+ ) -> Tensor:
775
+ if not mat1_pre_padded:
776
+ mat1 = pad_mat1(
777
+ mat1,
778
+ m_padded_length=m_padded_length,
779
+ k_padded_length=k_padded_length,
780
+ is_bmm=True,
781
+ )
782
+ if not mat2_pre_padded:
783
+ mat2 = pad_mat2(
784
+ mat2,
785
+ k_padded_length=k_padded_length,
786
+ n_padded_length=n_padded_length,
787
+ is_bmm=True,
788
+ )
789
+ res = aten.bmm(mat1, mat2)
790
+ if m_padded_length != 0:
791
+ res = res[:, :-m_padded_length, :]
792
+ if n_padded_length != 0:
793
+ res = res[:, :, :-n_padded_length]
794
+ return res
795
+
796
+
797
+ def bmm_replace(mat1: Tensor, mat2: Tensor) -> Tensor:
798
+ k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1))
799
+ n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2))
800
+ m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
801
+ return pad_bmm(
802
+ mat1,
803
+ mat2,
804
+ m_padded_length,
805
+ k_padded_length,
806
+ n_padded_length,
807
+ )
808
+
809
+
810
+ @functools.lru_cache(None)
811
+ def _pad_mm_init():
812
+ from .joint_graph import patterns
813
+
814
+ if torch.cuda.is_available():
815
+ # workaround https://github.com/pytorch/pytorch/issues/97894
816
+ device = "cuda"
817
+ else:
818
+ device = "cpu"
819
+
820
+ # sizes/values dont actually matter for initial trace
821
+ # once we get a possible match we re-trace with the actual values and verify the match still holds
822
+
823
+ dim2a = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True)
824
+ dim2b = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True)
825
+
826
+ dim3a = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True)
827
+ dim3b = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True)
828
+
829
+ dim1a = functools.partial(torch.empty, (4), device=device, requires_grad=True)
830
+
831
+ # workaround https://github.com/pytorch/pytorch/issues/97894
832
+ # 0.113377 is a "magic" value that lets us recover the lost input arg relationship
833
+ rep = {"beta": 0.213377, "alpha": 0.113377}
834
+
835
+ for pattern, replacement, args, workaround, extra_check in [
836
+ (
837
+ typing.cast(SearchFn, mm_pattern),
838
+ typing.cast(ReplaceFn, mm_replace),
839
+ [dim2a(), dim2b()],
840
+ {},
841
+ should_pad_mm,
842
+ ),
843
+ (
844
+ typing.cast(SearchFn, bmm_pattern),
845
+ typing.cast(ReplaceFn, bmm_replace),
846
+ [dim3a(), dim3b()],
847
+ {},
848
+ should_pad_bmm,
849
+ ),
850
+ (
851
+ typing.cast(SearchFn, addmm_pattern),
852
+ typing.cast(ReplaceFn, addmm_replace),
853
+ [dim1a(), dim2a(), dim2b()],
854
+ rep,
855
+ should_pad_addmm,
856
+ ),
857
+ ]:
858
+ assert isinstance(workaround, dict) # mypy is unable to infer the type properly
859
+ name = pattern.__name__
860
+
861
+ gen_register_replacement(
862
+ f"{name}_training",
863
+ pattern,
864
+ replacement,
865
+ args,
866
+ joint_fwd_bwd,
867
+ patterns,
868
+ extra_check=extra_check,
869
+ scalar_workaround=workaround,
870
+ )
871
+
872
+ gen_register_replacement(
873
+ f"{name}_inference",
874
+ pattern,
875
+ replacement,
876
+ args,
877
+ fwd_only,
878
+ patterns,
879
+ extra_check=extra_check,
880
+ scalar_workaround=workaround,
881
+ )
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/post_grad.py ADDED
@@ -0,0 +1,1318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-decorators
2
+ # mypy: allow-untyped-defs
3
+ import functools
4
+ import itertools
5
+ import logging
6
+ import operator
7
+ from collections import Counter, defaultdict
8
+ from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union
9
+
10
+ import torch
11
+ import torch._inductor as inductor
12
+ import torch.utils._pytree as pytree
13
+ from torch import fx
14
+ from torch._decomp import register_decomposition
15
+ from torch._dynamo.utils import counters, optimus_scuba_log
16
+ from torch._inductor import comms
17
+ from torch._inductor.virtualized import ops
18
+ from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype
19
+ from torch._utils_internal import upload_graph
20
+ from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
21
+ from torch.fx.passes.graph_transform_observer import GraphTransformObserver
22
+
23
+ from .. import config, ir, pattern_matcher
24
+ from ..codegen.common import BackendFeature, has_backend_feature
25
+ from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage
26
+ from ..lowering import lowerings as L
27
+ from ..pattern_matcher import (
28
+ _return_true,
29
+ Arg,
30
+ CallFunction,
31
+ CallFunctionVarArgs,
32
+ filter_nodes,
33
+ get_arg_value,
34
+ get_mutation_region_id,
35
+ Ignored,
36
+ init_once_fakemode,
37
+ KeywordArg,
38
+ ListOf,
39
+ Match,
40
+ MULTIPLE,
41
+ PatternMatcherPass,
42
+ register_graph_pattern,
43
+ stable_topological_sort,
44
+ )
45
+ from ..utils import decode_device, get_gpu_type, is_pointwise_use
46
+ from ..virtualized import V
47
+ from .b2b_gemm import B2B_GEMM_PASS
48
+ from .ddp_fusion import fuse_ddp_communication
49
+ from .group_batch_fusion import group_batch_fusion_passes, POST_GRAD_FUSIONS
50
+ from .micro_pipeline_tp import micro_pipeline_tp_pass
51
+ from .pre_grad import is_same_dict, save_inductor_dict
52
+ from .reinplace import reinplace_inplaceable_ops
53
+ from .split_cat import POST_GRAD_PATTERNS
54
+
55
+
56
+ if TYPE_CHECKING:
57
+ from sympy import Expr
58
+
59
+
60
+ log = logging.getLogger(__name__)
61
+ aten = torch.ops.aten
62
+ prims = torch.ops.prims
63
+
64
+ # First pass_patterns[0] are applied, then [1], then [2]
65
+ pass_patterns = [
66
+ PatternMatcherPass(),
67
+ PatternMatcherPass(),
68
+ PatternMatcherPass(),
69
+ ]
70
+
71
+
72
+ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
73
+ """
74
+ Passes that run on after grad. This is called once on the forwards
75
+ graph and once on the backwards graph.
76
+
77
+ The IR here has been normalized and functionalized.
78
+ """
79
+ if config.dce:
80
+ # has some issues with mutation in inference mode
81
+ gm.graph.eliminate_dead_code()
82
+
83
+ if is_inference and config.reorder_for_locality:
84
+ reorder_for_locality(gm.graph)
85
+
86
+ fake_tensor_updater = FakeTensorUpdater(gm.graph)
87
+
88
+ if config.post_grad_custom_pre_pass is not None:
89
+ with GraphTransformObserver(
90
+ gm, "post_grad_custom_pre_pass", config.trace.log_url_for_graph_xform
91
+ ):
92
+ config.post_grad_custom_pre_pass(gm.graph)
93
+
94
+ if config.pattern_matcher:
95
+ lazy_init()
96
+ optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph)
97
+ group_batch_fusion_passes(gm.graph, pre_grad=False)
98
+ remove_noop_ops(gm.graph)
99
+ for patterns in pass_patterns:
100
+ patterns.apply(gm.graph) # type: ignore[arg-type]
101
+ for pass_name in config.post_grad_fusion_options:
102
+ # skip all patterns for group batch fusions
103
+ if pass_name in POST_GRAD_FUSIONS:
104
+ continue
105
+ pattern_matcher_pass = POST_GRAD_PATTERNS[pass_name]
106
+ inductor_before_change = save_inductor_dict(
107
+ [pattern_matcher_pass.pass_name]
108
+ )
109
+ pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
110
+ if not is_same_dict(counters["inductor"], inductor_before_change):
111
+ optimus_scuba_log[
112
+ f"{pattern_matcher_pass.pass_name}_post_grad"
113
+ ] = upload_graph(gm.graph)
114
+ if config.b2b_gemm_pass:
115
+ B2B_GEMM_PASS.apply(gm.graph) # type: ignore[arg-type]
116
+
117
+ if config._micro_pipeline_tp:
118
+ micro_pipeline_tp_pass(gm.graph)
119
+
120
+ if config._fuse_ddp_communication:
121
+ fuse_ddp_communication(
122
+ gm.graph,
123
+ config._fuse_ddp_communication_passes,
124
+ config._fuse_ddp_bucket_size,
125
+ )
126
+
127
+ if config.post_grad_custom_post_pass is not None:
128
+ with GraphTransformObserver(
129
+ gm, "post_grad_custom_post_pass", config.trace.log_url_for_graph_xform
130
+ ):
131
+ config.post_grad_custom_post_pass(gm.graph)
132
+
133
+ stable_topological_sort(gm.graph)
134
+
135
+ move_constructors_to_gpu(gm.graph)
136
+
137
+ fake_tensor_updater.incremental_update()
138
+
139
+ # Keep these last, since they introduces mutation. Look at
140
+ # ./fx_passes/README.md for a discussion of mutation invariants.
141
+ reinplace_inplaceable_ops(gm.graph)
142
+ decompose_auto_functionalized(gm.graph)
143
+
144
+ comms.reinplace_fsdp_all_gather(gm.graph)
145
+
146
+ gm.recompile()
147
+ optimus_scuba_log["after_recompile_post_grad"] = upload_graph(gm.graph)
148
+ gm.graph.lint()
149
+
150
+
151
+ @init_once_fakemode
152
+ def lazy_init():
153
+ if torch._C._has_mkldnn:
154
+ from . import decompose_mem_bound_mm # noqa: F401
155
+ from .mkldnn_fusion import _mkldnn_fusion_init
156
+
157
+ _mkldnn_fusion_init()
158
+
159
+
160
+ def reorder_for_locality(graph: torch.fx.Graph):
161
+ def visit(other_node):
162
+ if (
163
+ other_node.op == "call_function"
164
+ and other_node.target != operator.getitem
165
+ and all((n in seen_nodes) for n in other_node.users)
166
+ and get_mutation_region_id(graph, node)
167
+ == get_mutation_region_id(graph, other_node)
168
+ ):
169
+ # move node's producers right before it
170
+ node.prepend(other_node)
171
+
172
+ seen_nodes = set()
173
+
174
+ # only reorder nodes before the first copy_ in the graph.
175
+ # copy_ will appear at the end of functionalized graphs when there is mutation on inputs,
176
+ # and this reordering doesnt work well with mutation
177
+ first_copy = next(
178
+ iter(graph.find_nodes(op="call_function", target=torch.ops.aten.copy_.default)),
179
+ None,
180
+ )
181
+ past_mutating_epilogue = True if first_copy is None else False
182
+
183
+ for node in reversed(graph.nodes):
184
+ seen_nodes.add(node)
185
+ if not past_mutating_epilogue:
186
+ past_mutating_epilogue = node is first_copy
187
+ continue
188
+
189
+ torch.fx.map_arg((node.args, node.kwargs), visit)
190
+
191
+
192
+ def register_lowering_pattern(pattern, extra_check=_return_true, pass_number=1):
193
+ """
194
+ Register an aten to inductor IR replacement pattern
195
+ """
196
+ return pattern_matcher.register_lowering_pattern(
197
+ pattern, extra_check, pass_dict=pass_patterns[pass_number]
198
+ )
199
+
200
+
201
+ ################################################################################
202
+ # Actual patterns below this point.
203
+ # Priority of patterns is:
204
+ # - later output nodes first
205
+ # - order patterns are defined in
206
+ ################################################################################
207
+
208
+
209
+ def is_valid_mm_plus_mm(match: Match):
210
+ *b1, m1, k1 = match.kwargs["mat1"].meta.get("tensor_meta").shape
211
+ *b2, k2, n1 = match.kwargs["mat2"].meta.get("tensor_meta").shape
212
+ if k1 != k2:
213
+ return False
214
+
215
+ *b1, m2, k3 = match.kwargs["mat3"].meta.get("tensor_meta").shape
216
+ *b2, k4, n2 = match.kwargs["mat4"].meta.get("tensor_meta").shape
217
+ if k3 != k4:
218
+ return False
219
+
220
+ if m1 != m2 or n1 != n2:
221
+ return False
222
+
223
+ return True
224
+
225
+
226
+ def scatter_upon_const_tensor_extra_check(m):
227
+ if not config.optimize_scatter_upon_const_tensor:
228
+ return False
229
+ full_shape = m.kwargs["shape"]
230
+ selector = m.kwargs["selector"]
231
+ dim = m.kwargs["dim"]
232
+ if dim < 0:
233
+ dim += len(full_shape)
234
+
235
+ selector_ft = selector.meta["val"]
236
+ assert selector_ft.dim() == len(full_shape)
237
+
238
+ for idx, select_sz, full_sz in zip(
239
+ itertools.count(), selector_ft.shape, full_shape
240
+ ):
241
+ if idx == dim:
242
+ continue
243
+
244
+ # TODO: the pattern can be updated to support the case that index tensor
245
+ # is shorter. But that will need a more complex condition expression
246
+ # especially for multi-dimensional tensors.
247
+ # Skip it for now.
248
+ if isinstance(full_sz, fx.Node):
249
+ full_sz = full_sz.meta["val"]
250
+ if select_sz < full_sz:
251
+ return False
252
+
253
+ # Actually we can support small size larger than 1. It would be a bit
254
+ # tedius. E.g., we load all the index values (not many) and compare
255
+ # them with the position in tensor to decide what value to return.
256
+ return selector_ft.size(dim) == 1
257
+
258
+
259
+ @register_lowering_pattern(
260
+ CallFunction(
261
+ aten.scatter.value,
262
+ CallFunction(
263
+ aten.full,
264
+ KeywordArg("shape"),
265
+ KeywordArg("background_val"),
266
+ dtype=KeywordArg("dtype"),
267
+ ),
268
+ KeywordArg("dim"),
269
+ KeywordArg("selector"),
270
+ KeywordArg("val"), # scalar value
271
+ ),
272
+ extra_check=scatter_upon_const_tensor_extra_check,
273
+ )
274
+ def scatter_upon_const_tensor(
275
+ match: Match, shape, background_val, dtype, dim, selector, val
276
+ ):
277
+ """
278
+ Match the pattern of full+scatter into a pointwise.
279
+
280
+ TODO: Right now the scatter value must be a scalar. But we could support it
281
+ when it is a tensor as well.
282
+ """
283
+ from torch._inductor import metrics
284
+
285
+ metrics.num_matches_for_scatter_upon_const_tensor += 1
286
+
287
+ selector_loader = selector.make_loader()
288
+
289
+ def inner_fn(idx):
290
+ selector_idx = list(idx)
291
+ selector_idx[dim] = 0
292
+
293
+ selector = selector_loader(selector_idx)
294
+ return ops.where(
295
+ selector == ops.index_expr(idx[dim], torch.int64),
296
+ ops.constant(val, dtype),
297
+ ops.constant(background_val, dtype),
298
+ )
299
+
300
+ return ir.Pointwise.create(
301
+ device=selector.get_device(),
302
+ dtype=dtype,
303
+ inner_fn=inner_fn,
304
+ ranges=shape,
305
+ )
306
+
307
+
308
+ @register_lowering_pattern(
309
+ CallFunction(
310
+ aten.add,
311
+ CallFunction(aten.mm, KeywordArg("mat1"), KeywordArg("mat2")),
312
+ CallFunction(aten.mm, KeywordArg("mat3"), KeywordArg("mat4")),
313
+ ),
314
+ extra_check=is_valid_mm_plus_mm,
315
+ )
316
+ def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4):
317
+ return inductor.kernel.mm_plus_mm.tuned_mm_plus_mm(mat1, mat2, mat3, mat4)
318
+
319
+
320
+ def cuda_and_enabled_mixed_mm(match):
321
+ return (
322
+ (config.use_mixed_mm or config.mixed_mm_choice != "default")
323
+ and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False)
324
+ and (
325
+ match.kwargs["mat2_dtype"].itemsize
326
+ > match.kwargs["mat2"].meta.get("val").dtype.itemsize
327
+ )
328
+ and has_backend_feature("cuda", BackendFeature.TRITON_TEMPLATES)
329
+ )
330
+
331
+
332
+ def cuda_and_enabled_mixed_mm_and_not_int8(match):
333
+ return (
334
+ cuda_and_enabled_mixed_mm(match)
335
+ and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False)
336
+ and getattr(match.kwargs["mat2"].meta.get("val"), "dtype", torch.int8)
337
+ != torch.int8
338
+ ) # bitshift numerics in triton and pytorch don't match for torch.int8
339
+
340
+
341
+ """
342
+ this is intended to be used to unpack a [K,N] int4 tensor from a [K/2, N] uint4x2 tensor
343
+ (where the int4 and uint4x2 are represented with int8 and uint8 respectively)
344
+ where every other row of the int4 is packed with the row above it as:
345
+ uint4x2[k,n] = (8+int4[2*k,n])+(8+int4[2*k+1,n])<<4
346
+
347
+ unpack formulas:
348
+ int4[2*k,n]=(uint4x2[k,n] & 0xF) - 8
349
+ int4[2*k+1,n]=(uint4x2[k,n] >> 4) - 8
350
+
351
+ thus matching on unpack formula:
352
+ torch.mm(mat1, torch.cat((mat2 & 0xF, mat2>>4),1).reshape(mat2_mm_shape).to(mat2_dtype).sub(8))
353
+
354
+ note: although the unpack formula in pytorch and the triton kernel is designed for a uint8 mat2, the behavior
355
+ of the kernel matches the pytorch formula for all dtypes except torch.int8
356
+ where the bitwise numerics in triton do not match those in pytorch.
357
+ """
358
+
359
+
360
+ @register_lowering_pattern(
361
+ CallFunction(
362
+ aten.mm.default,
363
+ KeywordArg("mat1"),
364
+ CallFunction(
365
+ aten.sub.Tensor,
366
+ CallFunction(
367
+ prims.convert_element_type.default,
368
+ CallFunction(
369
+ aten.reshape.default,
370
+ CallFunction(
371
+ aten.cat.default,
372
+ ListOf(
373
+ CallFunction(
374
+ aten.bitwise_and.Scalar,
375
+ KeywordArg("mat2"),
376
+ 0xF,
377
+ ),
378
+ # CallFunction(
379
+ # aten.__rshift__.Scalar,
380
+ # KeywordArg("mat2"),
381
+ # 4,
382
+ # ),
383
+ True,
384
+ ),
385
+ 1,
386
+ ),
387
+ KeywordArg("mat2_mm_shape"),
388
+ ),
389
+ KeywordArg("mat2_dtype"),
390
+ ),
391
+ 8,
392
+ ),
393
+ ),
394
+ extra_check=cuda_and_enabled_mixed_mm_and_not_int8,
395
+ )
396
+ def uint4x2_mixed_mm(match: Match, mat1, mat2, mat2_mm_shape, mat2_dtype):
397
+ return inductor.kernel.unpack_mixed_mm.tuned_uint4x2_mixed_mm(
398
+ mat1, mat2, mat2_mm_shape, mat2_dtype
399
+ )
400
+
401
+
402
+ """
403
+ torch.mm(mat1, mat2.to(mat2_dtype))
404
+ """
405
+
406
+
407
+ @register_lowering_pattern(
408
+ CallFunction(
409
+ aten.mm,
410
+ KeywordArg("mat1"),
411
+ CallFunction(
412
+ prims.convert_element_type.default,
413
+ KeywordArg("mat2"),
414
+ KeywordArg("mat2_dtype"),
415
+ ),
416
+ ),
417
+ extra_check=cuda_and_enabled_mixed_mm,
418
+ )
419
+ def mixed_mm(match: Match, mat1, mat2, mat2_dtype):
420
+ return inductor.kernel.mm.tuned_mixed_mm(mat1, mat2, mat2_dtype)
421
+
422
+
423
+ @register_graph_pattern(
424
+ CallFunction(
425
+ aten.cumsum.default,
426
+ CallFunction(
427
+ torch.ops.aten.full.default,
428
+ KeywordArg("shape"),
429
+ KeywordArg("fill_value"),
430
+ dtype=KeywordArg("dtype"),
431
+ layout=Ignored(),
432
+ device=KeywordArg("device"),
433
+ pin_memory=False,
434
+ _users=MULTIPLE,
435
+ ),
436
+ KeywordArg("dim"),
437
+ _users=MULTIPLE,
438
+ ),
439
+ pass_dict=pass_patterns[1],
440
+ )
441
+ def pointless_cumsum_replacement(match: Match, shape, fill_value, device, dtype, dim):
442
+ """Based on a pattern in OPTForCausalLM"""
443
+
444
+ if is_integer_dtype(dtype) or is_boolean_dtype(dtype):
445
+ # cumsum promotes all integral types to int64
446
+ dtype = torch.int64
447
+
448
+ def repl(*shape):
449
+ dim_size = shape[dim]
450
+ idx = torch.arange(1, dim_size + 1, device=device, dtype=dtype)
451
+
452
+ inter_shape = [1] * len(shape)
453
+ inter_shape[dim] = dim_size
454
+ return (idx * fill_value).view(inter_shape).expand(shape)
455
+
456
+ # only replace the output node, not all nodes
457
+ match.nodes = [match.output_node()]
458
+ match.replace_by_example(repl, list(shape))
459
+
460
+
461
+ def shape_of_mm(a, b):
462
+ m, _ = a.get_size()
463
+ _, n = b.get_size()
464
+ return [m, n]
465
+
466
+
467
+ @register_lowering_pattern(
468
+ CallFunction(aten.cat, ListOf(CallFunction(aten.mm, Arg(), Arg())), Arg()),
469
+ )
470
+ def cat_mm(match, inputs, dim):
471
+ return cat_tuned_op(match, inputs, dim, op=L[aten.mm], shape_of=shape_of_mm)
472
+
473
+
474
+ @register_lowering_pattern(
475
+ CallFunction(
476
+ aten.cat, ListOf(CallFunction(aten.addmm, Arg(), Arg(), Arg())), Arg()
477
+ ),
478
+ )
479
+ def cat_addmm(match, inputs, dim):
480
+ def shape_of(bias, a, b):
481
+ m, _ = a.get_size()
482
+ _, n = b.get_size()
483
+ return [m, n]
484
+
485
+ return cat_tuned_op(match, inputs, dim, op=L[aten.addmm], shape_of=shape_of)
486
+
487
+
488
+ def cat_tuned_op(match, inputs, dim, *, op, shape_of):
489
+ """
490
+ Memory planning to remove cat. We can't use the stock memory
491
+ planner since autotuning matmuls needs to know the output layout.
492
+ """
493
+ if len(inputs) == 1:
494
+ return op(*inputs[0])
495
+
496
+ # TODO(jansel): rewrite this as a bmm?
497
+ if dim < 0:
498
+ dim += len(shape_of(*inputs[0]))
499
+ assert dim in (0, 1)
500
+ notdim = 1 - dim
501
+
502
+ new_size: Optional[Union[List[Expr], List[int]]] = None
503
+ offsets_start = []
504
+ offsets_end = []
505
+
506
+ # compute output sizes
507
+ for i in range(len(inputs)):
508
+ shape = shape_of(*inputs[i])
509
+ if new_size is None:
510
+ new_size = shape
511
+ else:
512
+ new_size[notdim] = V.graph.sizevars.guard_equals( # type: ignore[call-overload]
513
+ shape[notdim], new_size[notdim]
514
+ )
515
+ new_size[dim] += shape[dim]
516
+ offsets_start.append(new_size[dim] - shape[dim])
517
+ offsets_end.append(new_size[dim])
518
+
519
+ assert new_size is not None
520
+ dtype = functools.reduce(
521
+ torch.promote_types,
522
+ [x.get_dtype() for x in itertools.chain.from_iterable(inputs)],
523
+ )
524
+ device = inputs[0][0].get_device()
525
+ kernel = ir.ConcatKernel(
526
+ name=None,
527
+ layout=ir.FixedLayout(device, dtype, new_size),
528
+ inputs=[],
529
+ )
530
+ kernel_tensor = ir.TensorBox.create(kernel)
531
+
532
+ for i in range(len(inputs)):
533
+ dst = ir.SliceView.create(kernel_tensor, dim, offsets_start[i], offsets_end[i])
534
+ src = op(*inputs[i], layout=dst.get_layout()).data.data
535
+ assert isinstance(src, (ir.ExternKernelOut, ir.TemplateBuffer))
536
+ src.layout = ir.NonOwningLayout(dst)
537
+ kernel.inputs.append(src)
538
+
539
+ kernel.name = V.graph.register_buffer(kernel)
540
+ kernel.inputs = ir.ConcatKernel.unwrap_storage(kernel.inputs)
541
+ V.graph.register_operation(kernel)
542
+ return kernel_tensor
543
+
544
+
545
+ _cat_1 = CallFunction(aten.cat, Arg(), 1, _users=2)
546
+
547
+
548
+ @register_lowering_pattern(
549
+ CallFunction(
550
+ aten.cat,
551
+ [
552
+ _cat_1,
553
+ CallFunction(
554
+ aten.slice,
555
+ _cat_1,
556
+ 1,
557
+ 0,
558
+ KeywordArg("size"),
559
+ ),
560
+ ],
561
+ 1,
562
+ )
563
+ )
564
+ def cat_slice_cat(match, cat_input, size, dim=1):
565
+ """
566
+ This is an example of a more complex pattern where cat_1 is used
567
+ multiple times inside the pattern. We fold 2 calls to cat into one.
568
+
569
+ Matches:
570
+ cat_1: f32[1024, 4077] = torch.ops.aten.cat.default([add_26, primals_217], 1)
571
+ slice_1: f32[1024, 4077] = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
572
+ slice_2: f32[1024, 19] = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19)
573
+ cat_2: f32[1024, 4096] = torch.ops.aten.cat.default([cat_1, slice_2], 1)
574
+
575
+
576
+ Rewrite to:
577
+ slice_2 = torch.ops.aten.slice.Tensor(add_26, 1, 0, 19)
578
+ cat_2 = torch.ops.aten.cat.default([add_26, primals_217, slice2], 1)
579
+ """
580
+ first, *rest = cat_input
581
+ # Optimization is optional, because we can just not fold the cat
582
+ # size should be within first.get_size()[dim] such that the optimization is valid.
583
+ # For negative `end`, we currently fallback to not optimizing.
584
+ if size >= 0 and V.graph.sizevars.statically_known_leq(size, first.get_size()[dim]):
585
+ # fold 2 cats into 1 cat
586
+ return L[aten.cat](
587
+ [
588
+ first,
589
+ *rest,
590
+ L[aten.slice](first, dim, 0, size),
591
+ ],
592
+ dim,
593
+ )
594
+ else:
595
+ # don't expect to hit this case, just fall back
596
+ tmp = L[aten.cat](cat_input, dim)
597
+ return L[aten.cat](
598
+ [
599
+ tmp,
600
+ L[aten.slice](tmp, dim, 0, size),
601
+ ],
602
+ dim,
603
+ )
604
+
605
+
606
+ def is_valid_splitwithsizes_cat(match):
607
+ split_nodes = filter_nodes(match.nodes, aten.split_with_sizes)
608
+ cat_nodes = filter_nodes(match.nodes, aten.cat)
609
+ get_item_nodes = filter_nodes(match.nodes, operator.getitem)
610
+ if len(split_nodes) != 1 or len(cat_nodes) != 1:
611
+ return False
612
+ split_node, cat_node = split_nodes[0], cat_nodes[0]
613
+ # The dim of split and cat should match for passthrough
614
+ if get_arg_value(split_node, 2, "dim") != get_arg_value(cat_node, 1, "dim"):
615
+ return False
616
+ get_item_args = {
617
+ get_arg_value(get_item_node, 1) for get_item_node in get_item_nodes
618
+ }
619
+ assert None not in get_item_args
620
+ split_sizes = get_arg_value(split_node, 1, "split_sizes")
621
+ # All parts of split should be included in the cat
622
+ if get_item_args != set(range(len(split_sizes))):
623
+ return False
624
+ # The order of get_item_args should same with cat_node used.
625
+ # For example, if the split_node like split_with_sizes(input, [2, 2, 3], 1),
626
+ # the cat node should be like cat([get_item(0), get_item(1), get_item(2)], 1).
627
+ cat_items_args_order = [
628
+ get_arg_value(item_node, 1) for item_node in get_arg_value(cat_node, 0)
629
+ ]
630
+ if cat_items_args_order != list(range(len(split_sizes))):
631
+ return False
632
+
633
+ return True
634
+
635
+
636
+ def same_meta(node1: torch.fx.Node, node2: torch.fx.Node):
637
+ """True if two nodes have the same metadata"""
638
+ val1 = node1.meta.get("val")
639
+ val2 = node2.meta.get("val")
640
+ return (
641
+ val1 is not None
642
+ and val2 is not None
643
+ and statically_known_true(sym_eq(val1.size(), val2.size()))
644
+ and val1.layout == val2.layout
645
+ and val1.dtype == val2.dtype
646
+ and val1.device == val2.device
647
+ and (
648
+ val1.layout != torch.strided
649
+ or statically_known_true(sym_eq(val1.stride(), val2.stride()))
650
+ )
651
+ )
652
+
653
+
654
+ noop_registry: Dict[Any, Any] = {}
655
+
656
+
657
+ def register_noop_decomp(targets, nop_arg=0):
658
+ def register_fun(cond):
659
+ register_decomposition(targets, registry=noop_registry, unsafe=True)(
660
+ (cond, nop_arg) # type: ignore[arg-type]
661
+ )
662
+ return cond
663
+
664
+ return register_fun
665
+
666
+
667
+ @register_noop_decomp(aten.slice)
668
+ def slice_noop(self, dim=0, start=None, end=None, step=1):
669
+ if start is None or end is None:
670
+ return False
671
+ if (
672
+ statically_known_true(sym_eq(start, 0))
673
+ and statically_known_true(end >= 2**63 - 1)
674
+ and statically_known_true(sym_eq(step, 1))
675
+ ):
676
+ return True
677
+ return False
678
+
679
+
680
+ @register_noop_decomp(aten.slice_scatter, 1)
681
+ def slice_scatter_noop(self, src, dim=0, start=None, end=None, step=1):
682
+ if start is None:
683
+ start = 0
684
+ if end is None:
685
+ end = 2**63 - 1
686
+ if start == 0 and end >= 2**63 - 1 and step == 1:
687
+ return True
688
+ return False
689
+
690
+
691
+ @register_noop_decomp(aten.repeat)
692
+ def repeat_noop(self, repeats):
693
+ return all(r == 1 for r in repeats)
694
+
695
+
696
+ @register_noop_decomp(aten.constant_pad_nd)
697
+ def constant_pad_nd(x, padding, fill_value=0):
698
+ return all(p == 0 for p in padding)
699
+
700
+
701
+ @register_noop_decomp(torch.ops.prims.convert_element_type)
702
+ def convert_element_type_noop(x, dtype: torch.dtype):
703
+ return x.dtype == dtype
704
+
705
+
706
+ @register_noop_decomp(torch.ops.prims.device_put)
707
+ def device_put_noop(x, device):
708
+ return x.device == decode_device(device)
709
+
710
+
711
+ @register_noop_decomp([aten.ceil, aten.floor, aten.round, aten.trunc])
712
+ def int_noop(x):
713
+ return is_integer_dtype(x.dtype)
714
+
715
+
716
+ @register_noop_decomp([aten.pow])
717
+ def pow_noop(a, b):
718
+ return isinstance(b, int) and b == 1
719
+
720
+
721
+ @register_noop_decomp([aten.cat], lambda args: args[0][0])
722
+ def cat_noop(inputs, dim=0):
723
+ return len(inputs) == 1
724
+
725
+
726
+ @register_noop_decomp(aten.view)
727
+ def view_noop(arg, size):
728
+ return arg.shape == size
729
+
730
+
731
+ # Note, we also always have a check for identical metadata, which is why these
732
+ # are safe
733
+ @register_noop_decomp([aten.copy], nop_arg=1)
734
+ @register_noop_decomp([aten.alias, aten.clone])
735
+ def true_noop(*args, **kwargs):
736
+ return True
737
+
738
+
739
+ def remove_noop_ops(graph: torch.fx.Graph):
740
+ """
741
+ Removes both operations that are essentially aten.clone and operations that are essentially aten.alias from the graph.
742
+ """
743
+ inputs = set()
744
+ input_storages = set()
745
+ output_storages = set()
746
+
747
+ for node in graph.find_nodes(op="placeholder"):
748
+ inputs.add(node)
749
+ input_storages.add(get_node_storage(node))
750
+
751
+ output_node = next(iter(reversed(graph.nodes)))
752
+ assert output_node.op == "output"
753
+ outputs = output_node.args[0]
754
+ if not isinstance(outputs, (list, tuple)):
755
+ # nested subgraphs can have singleton outputs
756
+ outputs = (outputs,)
757
+ for out in outputs:
758
+ if isinstance(out, torch.fx.Node):
759
+ output_storages.add(get_node_storage(out))
760
+
761
+ for node in graph.nodes:
762
+ if node.target in noop_registry:
763
+ cond, src_index = noop_registry[node.target]
764
+ if isinstance(src_index, int):
765
+ src = node.args[src_index]
766
+ else:
767
+ src = src_index(node.args)
768
+ if not isinstance(src, torch.fx.Node):
769
+ continue
770
+ # Don't introduce new aliasing between inputs and outputs.
771
+ # See fx_passes/README.md for a discussion of why this is
772
+ # necessary.
773
+ node_storage = get_node_storage(node)
774
+ src_storage = get_node_storage(src)
775
+ node_is_view = node_storage == src_storage
776
+ if (
777
+ not node_is_view
778
+ and node_storage in output_storages
779
+ and (src_storage in input_storages or src_storage in output_storages)
780
+ ):
781
+ continue
782
+
783
+ # Even if input and outputs are expected to alias,
784
+ # don't make "node is src" True
785
+ if (
786
+ node_is_view
787
+ and node in output_node.args
788
+ and (src in inputs or src in output_node.args)
789
+ ):
790
+ continue
791
+
792
+ is_valid, args, kwargs = get_fake_args_kwargs(node)
793
+ if not is_valid:
794
+ continue
795
+ if same_meta(node, src) and cond(*args, **kwargs):
796
+ node.replace_all_uses_with(src)
797
+ graph.erase_node(node)
798
+
799
+
800
+ def decompose_auto_functionalized(graph):
801
+ """Decomposes auto_functionalized and triton_kernel_wrapper_functional
802
+ nodes into clones and the underlying mutation node.
803
+
804
+ We assume that the reinplacing pass runs before this; the reinplacing pass
805
+ tells us (via rewriting the arguments or .meta to those nodes) which
806
+ Tensors we should clone and which Tensors are safe to reinplace.
807
+ """
808
+ graph_pass = PatternMatcherPass()
809
+
810
+ @register_graph_pattern(
811
+ CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized),
812
+ pass_dict=graph_pass,
813
+ )
814
+ def _(match: Match, *args, **kwargs):
815
+ from torch._higher_order_ops.auto_functionalize import auto_functionalized_dense
816
+
817
+ only_clone_these_tensors = tuple(
818
+ match.nodes[0].meta.get("only_clone_these_tensors", [])
819
+ )
820
+
821
+ flat_args, spec = pytree.tree_flatten((args, kwargs))
822
+
823
+ # NB: we combine (args, kwargs) into flat args for replacing.
824
+ # This is replace_by_example uses make_fx which does not support
825
+ # tracing a function with kwargs.
826
+ def decomp(*flat_args):
827
+ args, kwargs = pytree.tree_unflatten(flat_args, spec)
828
+ return auto_functionalized_dense(*args, only_clone_these_tensors, **kwargs)
829
+
830
+ match.replace_by_example(decomp, flat_args, run_functional_passes=False)
831
+
832
+ @register_graph_pattern(
833
+ CallFunctionVarArgs(torch.ops.higher_order.triton_kernel_wrapper_functional),
834
+ pass_dict=graph_pass,
835
+ )
836
+ def _(match: Match, *args, **kwargs):
837
+ from torch._higher_order_ops.triton_kernel_wrap import (
838
+ triton_kernel_wrapper_functional_dense,
839
+ )
840
+
841
+ flat_args, spec = pytree.tree_flatten((args, kwargs))
842
+
843
+ # NB: we combine (args, kwargs) into flat args for replacing.
844
+ # This is replace_by_example uses make_fx which does not support
845
+ # tracing a function with kwargs.
846
+ def decomp(*flat_args):
847
+ args, kwargs = pytree.tree_unflatten(flat_args, spec)
848
+ return (triton_kernel_wrapper_functional_dense(*args, **kwargs),)
849
+
850
+ match.replace_by_example(decomp, flat_args, run_functional_passes=False)
851
+
852
+ @register_graph_pattern(
853
+ CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized_v2),
854
+ pass_dict=graph_pass,
855
+ )
856
+ def _(match: Match, *args, **kwargs):
857
+ from torch._higher_order_ops.auto_functionalize import (
858
+ auto_functionalized_v2_dense,
859
+ )
860
+
861
+ only_clone_these_bases = tuple(
862
+ match.nodes[0].meta.get("only_clone_these_tensors", [])
863
+ )
864
+
865
+ flat_args, spec = pytree.tree_flatten((args, kwargs))
866
+
867
+ # NB: we combine (args, kwargs) into flat args for replacing.
868
+ # This is replace_by_example uses make_fx which does not support
869
+ # tracing a function with kwargs.
870
+ def decomp(*flat_args):
871
+ args, kwargs = pytree.tree_unflatten(flat_args, spec)
872
+ return auto_functionalized_v2_dense(*args, only_clone_these_bases, **kwargs)
873
+
874
+ match.replace_by_example(decomp, flat_args, run_functional_passes=False)
875
+
876
+ graph_pass.apply(graph)
877
+
878
+ for node in graph.find_nodes(
879
+ op="call_function", target=torch.ops.higher_order.auto_functionalized
880
+ ):
881
+ raise AssertionError("auto_functionalized was not removed")
882
+
883
+ for node in graph.find_nodes(
884
+ op="call_function", target=torch.ops.higher_order.auto_functionalized_v2
885
+ ):
886
+ raise AssertionError("auto_functionalized_v2 was not removed")
887
+
888
+ for node in graph.find_nodes(
889
+ op="call_function",
890
+ target=torch.ops.higher_order.triton_kernel_wrapper_functional,
891
+ ):
892
+ raise AssertionError("triton_kernel_wrapper_functional was not removed")
893
+
894
+
895
+ @register_lowering_pattern(
896
+ CallFunction(
897
+ aten.cat,
898
+ ListOf(
899
+ CallFunction(
900
+ operator.getitem,
901
+ CallFunction(
902
+ aten.split_with_sizes,
903
+ KeywordArg("input_"),
904
+ Ignored(),
905
+ Ignored(),
906
+ _users=MULTIPLE,
907
+ ),
908
+ Ignored(),
909
+ ),
910
+ ),
911
+ Ignored(),
912
+ ),
913
+ pass_number=2,
914
+ extra_check=is_valid_splitwithsizes_cat,
915
+ )
916
+ def splitwithsizes_cat_replace(match, input_):
917
+ return input_
918
+
919
+
920
+ def is_valid_cat_splitwithsizes(match):
921
+ cat_nodes = filter_nodes(match.nodes, aten.cat)
922
+ split_nodes = filter_nodes(match.nodes, aten.split_with_sizes)
923
+ if len(split_nodes) != 1 or len(cat_nodes) != 1:
924
+ return False
925
+ split_node, cat_node = split_nodes[0], cat_nodes[0]
926
+
927
+ # the cat node has other users: can't eliminate
928
+ if len(cat_node.users) > 1:
929
+ return False
930
+
931
+ # the dim of the cat and split should match
932
+ dim = get_arg_value(split_node, 2, "dim")
933
+ if dim != get_arg_value(cat_node, 1, "dim"):
934
+ return False
935
+
936
+ cat_inputs = list(get_arg_value(cat_node, 0))
937
+ split_sizes = get_arg_value(split_node, 1, "split_sizes")
938
+ # the number of input tensors in cat and the
939
+ # length of the split sizes should match
940
+ if len(cat_inputs) != len(split_sizes):
941
+ return False
942
+
943
+ for cat_input, split_size in zip(cat_inputs, split_sizes):
944
+ # each cat input tensor's size along dim
945
+ # should match the corresponding split size
946
+ if "val" not in cat_input.meta:
947
+ return False
948
+ cat_input_size = cat_input.meta["val"].size(dim)
949
+ if cat_input_size != split_size:
950
+ return False
951
+
952
+ return True
953
+
954
+
955
+ @register_lowering_pattern(
956
+ CallFunction(
957
+ aten.split_with_sizes,
958
+ CallFunction(
959
+ aten.cat,
960
+ KeywordArg("input_"),
961
+ Ignored(),
962
+ _users=MULTIPLE,
963
+ ),
964
+ Ignored(),
965
+ Ignored(),
966
+ ),
967
+ pass_number=2,
968
+ extra_check=is_valid_cat_splitwithsizes,
969
+ )
970
+ def cat_splitwithsizes_replace(match, input_):
971
+ return input_
972
+
973
+
974
+ def view_to_reshape(gm):
975
+ """
976
+ Replace view ops in the GraphModule to reshape ops.
977
+ """
978
+ for nd in gm.graph.find_nodes(
979
+ op="call_function", target=torch.ops.aten.view.default
980
+ ):
981
+ nd.target = torch.ops.aten.reshape.default
982
+
983
+
984
+ def should_prefer_unfused_addmm(match):
985
+ inp = match.kwargs["inp"]
986
+ if not inp.meta["val"].is_cuda:
987
+ return False
988
+
989
+ output = match.output_node()
990
+ return all(is_pointwise_use(use) for use in output.users)
991
+
992
+
993
+ @register_graph_pattern(
994
+ CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()),
995
+ pass_dict=pass_patterns[2],
996
+ extra_check=should_prefer_unfused_addmm,
997
+ )
998
+ def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp):
999
+ def repl(inp, x1, x2):
1000
+ return x1 @ x2 + inp
1001
+
1002
+ match.replace_by_example(repl, [inp, mat1, mat2])
1003
+
1004
+
1005
+ def is_valid_addmm_fusion(match):
1006
+ mat1, mat2 = match.args
1007
+ inp = match.kwargs["inp"]
1008
+
1009
+ if not (
1010
+ isinstance(inp, torch.fx.Node) and isinstance(inp.meta["val"], torch.Tensor)
1011
+ ):
1012
+ return False # Input is a number
1013
+
1014
+ in_shape = inp.meta["val"].shape
1015
+ mm_shape = mat1.meta["val"].shape[0], mat2.meta["val"].shape[1]
1016
+ matched = is_expandable_to(in_shape, mm_shape)
1017
+ if not matched:
1018
+ return False # Shape mismatch
1019
+
1020
+ return not should_prefer_unfused_addmm(match)
1021
+
1022
+
1023
+ @register_graph_pattern(
1024
+ CallFunction(
1025
+ aten.add,
1026
+ CallFunction(aten.mm, Arg(), Arg()),
1027
+ KeywordArg("inp"),
1028
+ ),
1029
+ pass_dict=pass_patterns[2],
1030
+ extra_check=is_valid_addmm_fusion,
1031
+ )
1032
+ @register_graph_pattern(
1033
+ CallFunction(
1034
+ aten.add,
1035
+ KeywordArg("inp"),
1036
+ CallFunction(aten.mm, Arg(), Arg()),
1037
+ ),
1038
+ pass_dict=pass_patterns[2],
1039
+ extra_check=is_valid_addmm_fusion,
1040
+ )
1041
+ def addmm(match, mat1, mat2, *, inp):
1042
+ def repl(inp, mat1, mat2):
1043
+ return aten.addmm(inp, mat1, mat2)
1044
+
1045
+ match.replace_by_example(repl, [inp, mat1, mat2])
1046
+
1047
+
1048
+ def check_shape_cuda_and_fused_int_mm_mul_enabled(match):
1049
+ return (
1050
+ config.force_fuse_int_mm_with_mul
1051
+ and len(getattr(match.args[2].meta.get("val"), "shape", [])) == 2
1052
+ and getattr(match.args[2].meta.get("val"), "is_cuda", False)
1053
+ )
1054
+
1055
+
1056
+ @register_lowering_pattern(
1057
+ CallFunction(
1058
+ prims.convert_element_type.default,
1059
+ CallFunction(
1060
+ aten.mul,
1061
+ CallFunction(
1062
+ aten._int_mm,
1063
+ Arg(),
1064
+ Arg(),
1065
+ ),
1066
+ Arg(),
1067
+ ),
1068
+ Arg(),
1069
+ ),
1070
+ check_shape_cuda_and_fused_int_mm_mul_enabled,
1071
+ )
1072
+ @register_lowering_pattern(
1073
+ CallFunction(
1074
+ aten.mul,
1075
+ CallFunction(
1076
+ aten._int_mm,
1077
+ Arg(),
1078
+ Arg(),
1079
+ ),
1080
+ Arg(),
1081
+ ),
1082
+ check_shape_cuda_and_fused_int_mm_mul_enabled,
1083
+ )
1084
+ def fused_int_mm_mul(match: Match, mat1, mat2, mat3, out_dtype=None):
1085
+ return inductor.kernel.mm.tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype)
1086
+
1087
+
1088
+ def is_index_put_and_requires_h2d_sync_for_gpu_value(node):
1089
+ from torch.fx.operator_schemas import normalize_function
1090
+
1091
+ if node.target not in [
1092
+ torch.ops.aten.index_put.default,
1093
+ torch.ops.aten.index_put_.default,
1094
+ ]:
1095
+ return False
1096
+ # Inductor falls back to aten.index_put_.
1097
+ # index_put_ will will call nonzero() and perform a H2D sync if
1098
+ # any of its indices are bool/byte tensors
1099
+ # However, it will short-circuit this H2D sync and run mask_fill_
1100
+ # if the value we are putting is a cpu scalar.
1101
+ # Therefore, when inductor sees an index_put_ with byte tensor indices,
1102
+ # it should *not* convert the cpu scalar value into a gpu tensor.
1103
+ args_, kwargs_ = normalize_function(node.target, node.args, node.kwargs) # type: ignore[misc]
1104
+ any_byte_bool_indices = False
1105
+ indices = args_[1]
1106
+ for i in indices:
1107
+ if i is not None and i.meta["val"].dtype in [torch.bool, torch.int8]:
1108
+ any_byte_bool_indices = True
1109
+
1110
+ val = args_[2].meta["val"]
1111
+ val_is_cpu_scalar = val.device.type == "cpu" and val.numel() == 1
1112
+ # If both these conditions hold, then converting the val
1113
+ # to a gpu tensor will incur a H2D sync when inductor calls aten.index_put_
1114
+ return any_byte_bool_indices and val_is_cpu_scalar
1115
+
1116
+
1117
+ class ConstructorMoverPass:
1118
+ def __init__(self, target: str, allow_outputs: bool = False) -> None:
1119
+ """
1120
+ Move constructors from cpu to the target_device.
1121
+
1122
+ Sweeps through the module, looking for constructor nodes that can be moved
1123
+ to the target_device.
1124
+
1125
+ A constructor node can be moved to the target_device iff all of its users
1126
+ can also be moved (tested by cannot_be_moved). Otherwise, all dependent
1127
+ constructor nodes won't be moved.
1128
+
1129
+ - target: target device type
1130
+ - allow_outputs: allow outputs to be moved
1131
+ """
1132
+
1133
+ self.target = target
1134
+ self.allow_outputs = allow_outputs
1135
+
1136
+ assert isinstance(target, str), (
1137
+ "target should be a string representing the device type. "
1138
+ f"Got: {type(target).__name__}"
1139
+ )
1140
+
1141
+ def allow_cpu_device(self, node: fx.Node) -> bool:
1142
+ """
1143
+ Returns whether a node that returns a tensor on the target device may have
1144
+ cpu tensors as input.
1145
+ """
1146
+ return node.target in (
1147
+ torch.ops.aten.index.Tensor,
1148
+ torch.ops.aten.index_put.default,
1149
+ torch.ops.aten.index_put_.default,
1150
+ torch.ops.aten.copy.default,
1151
+ torch.ops.aten.copy_.default,
1152
+ torch.ops.aten.slice_scatter.default,
1153
+ )
1154
+
1155
+ def cannot_be_moved(self, node: fx.Node) -> bool:
1156
+ """
1157
+ Returns whether a node can be moved to the target device.
1158
+
1159
+ If this function returns False, it means that this node and all of its users
1160
+ won't be moved into the target device.
1161
+ """
1162
+ if node.target == "output":
1163
+ return not self.allow_outputs
1164
+
1165
+ if not (
1166
+ isinstance(node.target, torch._ops.OpOverload)
1167
+ and node.target.namespace in ("prims", "aten")
1168
+ ):
1169
+ return True
1170
+ if is_index_put_and_requires_h2d_sync_for_gpu_value(node):
1171
+ return True
1172
+
1173
+ return False
1174
+
1175
+ def get_node_device(self, node: fx.Node) -> Optional[torch.device]:
1176
+ """
1177
+ Get the device of a node.
1178
+ """
1179
+ ten = node.meta.get("val")
1180
+ return None if not isinstance(ten, torch.Tensor) else ten.device
1181
+
1182
+ def get_cpu_indeg_count(self, graph: fx.Graph) -> Dict[fx.Node, int]:
1183
+ """
1184
+ Get the number of cpu inputs to a node
1185
+ """
1186
+ cpu_indeg: Dict[fx.Node, int] = Counter()
1187
+
1188
+ for node in graph.nodes:
1189
+ cpu_count = 0
1190
+
1191
+ def add_cpu_inp(node):
1192
+ nonlocal cpu_count
1193
+ device = self.get_node_device(node)
1194
+ cpu_count += device is not None and device.type == "cpu"
1195
+
1196
+ pytree.tree_map_only(fx.Node, add_cpu_inp, (node.args, node.kwargs))
1197
+
1198
+ if cpu_count:
1199
+ cpu_indeg[node] = cpu_count
1200
+
1201
+ return cpu_indeg
1202
+
1203
+ def __call__(self, graph: fx.Graph) -> None:
1204
+ target_devices = set()
1205
+ constructors = []
1206
+
1207
+ for node in graph.nodes:
1208
+ device = self.get_node_device(node)
1209
+ if device and device.type == self.target:
1210
+ target_devices.add(device)
1211
+
1212
+ if not (
1213
+ isinstance(node.target, torch._ops.OpOverload)
1214
+ and node.target.namespace in ("prims", "aten")
1215
+ ):
1216
+ continue
1217
+
1218
+ if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target):
1219
+ continue
1220
+
1221
+ if not node.kwargs.get("device") == torch.device("cpu"):
1222
+ continue
1223
+
1224
+ constructors.append(node)
1225
+
1226
+ # not handling multiple target devices initially
1227
+ if not constructors or len(target_devices) != 1:
1228
+ return
1229
+
1230
+ movable_constructors = self.find_movable_constructors(graph, constructors)
1231
+
1232
+ for node in movable_constructors:
1233
+ kwargs = node.kwargs.copy()
1234
+ kwargs["device"] = next(iter(target_devices))
1235
+ node.kwargs = kwargs
1236
+
1237
+ def find_movable_constructors(
1238
+ self, graph: fx.Graph, constructors: List[fx.Node]
1239
+ ) -> Set[fx.Node]:
1240
+ """
1241
+ Starting from the cpu constructors, iterate through the graph and test that all of their
1242
+ downstream uses can safely be moved to cpu.
1243
+ """
1244
+ cpu_indeg: Dict[fx.Node, int] = self.get_cpu_indeg_count(graph)
1245
+
1246
+ # which constructors cannot be moved to gpu
1247
+ cannot_move_to_gpu: Set[fx.Node] = set()
1248
+
1249
+ # For any node in the graph, which constructors does it have a dependency on
1250
+ constructor_dependencies: Dict[fx.Node, Set[fx.Node]] = defaultdict(set)
1251
+
1252
+ # if a cpu node has a dependency on two different cpu constructors,
1253
+ # then if either constructor cannot be moved to gpu, the other cannot as well.
1254
+ # In this case any node with a dependency on one will have a dependency on the other
1255
+ equal_constructor_sets: Dict[fx.Node, Set[fx.Node]] = {
1256
+ c: {c} for c in constructors
1257
+ }
1258
+
1259
+ def make_dependencies_equivalent(
1260
+ set1: Set[fx.Node], set2: Set[fx.Node]
1261
+ ) -> Set[fx.Node]:
1262
+ # could use union find but not worth complexity here
1263
+ set1.update(set2)
1264
+ for obj in set1:
1265
+ equal_constructor_sets[obj] = set1
1266
+ return set1
1267
+
1268
+ queue: List[fx.Node] = list(constructors)
1269
+
1270
+ for c in queue:
1271
+ constructor_dependencies[c].add(c)
1272
+
1273
+ while queue:
1274
+ node = queue.pop()
1275
+ dependencies = constructor_dependencies[node]
1276
+
1277
+ for user in node.users:
1278
+ if self.cannot_be_moved(user):
1279
+ cannot_move_to_gpu.update(dependencies)
1280
+ break
1281
+
1282
+ # this node was used on a op which takes in multiple devices and output a gpu
1283
+ # tensor. we can convert its cpu input to gpu without making further changes
1284
+ node_device = self.get_node_device(user)
1285
+ if (
1286
+ self.allow_cpu_device(user)
1287
+ and node_device
1288
+ and node_device.type == self.target
1289
+ ):
1290
+ del cpu_indeg[user]
1291
+ else:
1292
+ # otherwise, we should continue look at its downstream uses
1293
+ cpu_indeg[user] -= 1
1294
+ if cpu_indeg[user] == 0:
1295
+ del cpu_indeg[user]
1296
+ queue.append(user)
1297
+
1298
+ unioned_set = make_dependencies_equivalent(
1299
+ dependencies, constructor_dependencies[user]
1300
+ )
1301
+ constructor_dependencies[user] = unioned_set
1302
+
1303
+ for node in cpu_indeg:
1304
+ if constructor_dependencies[node]:
1305
+ cannot_move_to_gpu.update(constructor_dependencies[node])
1306
+
1307
+ all_cannot_move_to_gpu = cannot_move_to_gpu.copy()
1308
+ for constructor in cannot_move_to_gpu:
1309
+ all_cannot_move_to_gpu.update(equal_constructor_sets[constructor])
1310
+
1311
+ return set(constructors) - all_cannot_move_to_gpu
1312
+
1313
+
1314
+ def move_constructors_to_gpu(graph: fx.Graph) -> None:
1315
+ """
1316
+ Moves intermediary tensors which are constructed on the cpu to gpu when safe
1317
+ """
1318
+ ConstructorMoverPass(get_gpu_type())(graph)
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pre_grad.py ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import copy
3
+ import itertools
4
+ import logging
5
+ from typing import Dict, Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch._dynamo.utils import counters, detect_fake_mode, optimus_scuba_log
10
+ from torch._utils_internal import upload_graph
11
+ from torch.fx.experimental.optimization import (
12
+ matches_module_pattern,
13
+ replace_node_module,
14
+ )
15
+ from torch.fx.passes.graph_transform_observer import GraphTransformObserver
16
+ from torch.fx.passes.shape_prop import ShapeProp
17
+ from torch.nn import functional as F
18
+ from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights
19
+
20
+ from .. import config
21
+ from ..fx_utils import matches_module_function_pattern
22
+ from ..pattern_matcher import (
23
+ init_once_fakemode,
24
+ PatternMatcherPass,
25
+ stable_topological_sort,
26
+ )
27
+ from ..utils import is_cpu_device, pass_execution_and_save
28
+ from .group_batch_fusion import group_batch_fusion_passes, PRE_GRAD_FUSIONS
29
+ from .misc_patterns import numpy_compat_normalization
30
+ from .split_cat import PRE_GRAD_PATTERNS
31
+
32
+
33
+ log = logging.getLogger(__name__)
34
+
35
+ efficient_conv_bn_eval_pass = PatternMatcherPass(
36
+ pass_name="efficient_conv_bn_eval_pass"
37
+ )
38
+
39
+ fuse_split_linear_add_pass = PatternMatcherPass(
40
+ pass_name="fuse_split_linear_add_pass",
41
+ )
42
+ fuse_chunk_squeeze_cat_pass = PatternMatcherPass(
43
+ pass_name="fuse_chunk_squeeze_cat_pass",
44
+ )
45
+ remove_reshape_pass = PatternMatcherPass(
46
+ pass_name="remove_reshape_pass",
47
+ )
48
+
49
+ # based on predispatch aten IR
50
+ normalization_pass_aten = PatternMatcherPass()
51
+ merge_splits_pass_aten = PatternMatcherPass()
52
+ split_cat_pass_aten = PatternMatcherPass()
53
+ unbind_stack_pass_aten = PatternMatcherPass()
54
+ merge_getitem_cat_pass_aten = PatternMatcherPass()
55
+ merge_stack_tahn_unbind_pass_aten = PatternMatcherPass()
56
+ mutate_cat_pass_aten = PatternMatcherPass()
57
+ remove_split_with_size_one_pass_aten = PatternMatcherPass()
58
+
59
+
60
+ def save_inductor_dict(pass_to_compare=None):
61
+ if not pass_to_compare:
62
+ pass_to_compare = list(config.pre_grad_fusion_options.keys()) + list(
63
+ config.post_grad_fusion_options.keys()
64
+ )
65
+ return {p: dict(counters["inductor"]).get(p, 0) for p in pass_to_compare}
66
+
67
+
68
+ def is_same_dict(inductor_dict, optimus_dict):
69
+ for pass_name, count in optimus_dict.items():
70
+ if count != dict(inductor_dict).get(pass_name, 0):
71
+ return False
72
+ return True
73
+
74
+
75
+ def normalize_node_kwargs_pass(graph):
76
+ return None
77
+
78
+
79
+ def fuse_parallel_linear_pass(graph):
80
+ return None
81
+
82
+
83
+ def remove_split_ops(graph, shape_prop):
84
+ return None
85
+
86
+
87
+ def fuse_chunk_reshape_unsqueeze_concat_pass(graph):
88
+ return None
89
+
90
+
91
+ def fuse_chunk_reshape_concat_pass(graph):
92
+ return None
93
+
94
+
95
+ def remove_noop_pass(graph):
96
+ return None
97
+
98
+
99
+ def stack_to_unsqueeze_pass(graph):
100
+ return None
101
+
102
+
103
+ @init_once_fakemode
104
+ def lazy_init():
105
+ from . import efficient_conv_bn_eval, split_cat # noqa: F401 # noqa: F401
106
+
107
+ if config.is_fbcode():
108
+ from . import fb # type: ignore[attr-defined] # noqa: F401
109
+
110
+
111
+ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs=None):
112
+ """
113
+ Apply passes on the input FX graph using Torch IR.
114
+
115
+ WARNING:
116
+ The IR before grad is not functional or normalized, so it is harder
117
+ to write passes on this IR. Passes must be safe with respect to
118
+ aliasing and mutation and need to handle all possible arg schemas.
119
+
120
+ Consider adding a new pass to post_grad.py or joint_graph.py which
121
+ are after functionalization and normalization.
122
+ """
123
+ if config.pattern_matcher:
124
+ lazy_init()
125
+ if hasattr(
126
+ config, "fx_passes_numeric_check"
127
+ ) and config.fx_passes_numeric_check.get("pre_grad", False):
128
+ gm_before_fx_passes = gm.__copy__()
129
+ # explicitly run with predispatch atenIR based passes
130
+ if config.is_predispatch:
131
+
132
+ def shape_prop(mod) -> None:
133
+ ShapeProp(
134
+ gm=mod,
135
+ # pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode`
136
+ fake_mode=detect_fake_mode(example_inputs),
137
+ ).propagate(*example_inputs)
138
+
139
+ # normalization pass
140
+ pass_execution_and_save(
141
+ normalization_pass_aten.apply,
142
+ gm,
143
+ example_inputs,
144
+ "[Pre grad(predispatch IR)]Apply normalization pass",
145
+ )
146
+ # normalize kwargs, must be called as the first pass
147
+ pass_execution_and_save(
148
+ normalize_node_kwargs_pass,
149
+ gm,
150
+ example_inputs,
151
+ "[Pre grad(predispatch IR)]Apply normalize_node_kwargs_pass",
152
+ )
153
+ pass_execution_and_save(
154
+ remove_noop_pass,
155
+ gm,
156
+ example_inputs,
157
+ "[Pre grad(predispatch IR)]Apply remove_noop pass",
158
+ )
159
+ pass_execution_and_save(
160
+ fuse_chunk_reshape_concat_pass,
161
+ gm,
162
+ example_inputs,
163
+ "[Pre grad(predispatch IR)] Apply fuse_chunk_reshape_concat_pass",
164
+ )
165
+ pass_execution_and_save(
166
+ group_batch_fusion_passes,
167
+ gm,
168
+ example_inputs,
169
+ "[Pre grad(predispatch IR)] Apply group_batch_fusion",
170
+ )
171
+ pass_execution_and_save(
172
+ normalize_node_kwargs_pass,
173
+ gm,
174
+ example_inputs,
175
+ "[Pre grad(predispatch IR)]Apply normalize_node_kwargs_pass",
176
+ )
177
+ pass_execution_and_save(
178
+ fuse_chunk_squeeze_cat_pass.apply,
179
+ gm,
180
+ example_inputs,
181
+ "[Pre grad(predispatch IR)] Apply fuse_chunk_squeeze_cat_pass",
182
+ )
183
+ pass_execution_and_save(
184
+ fuse_split_linear_add_pass.apply,
185
+ gm,
186
+ example_inputs,
187
+ "[Pre grad(predispatch IR)] Apply fuse_split_linear_add_pass",
188
+ )
189
+ pass_execution_and_save(
190
+ remove_reshape_pass.apply,
191
+ gm,
192
+ example_inputs,
193
+ "[Pre grad(predispatch IR)] Apply remove_reshape_pass",
194
+ )
195
+ pass_execution_and_save(
196
+ fuse_parallel_linear_pass,
197
+ gm,
198
+ example_inputs,
199
+ "[Pre grad(predispatch IR)] Apply fuse_parallel_linear_pass",
200
+ )
201
+ pass_execution_and_save(
202
+ lambda graph: remove_split_ops(graph.owning_module, shape_prop),
203
+ gm,
204
+ example_inputs,
205
+ "[Pre grad(predispatch IR)] Apply remove_split_ops",
206
+ )
207
+ # run before fuse_chunk_reshape_unsqueeze_concat_pass
208
+ pass_execution_and_save(
209
+ stack_to_unsqueeze_pass,
210
+ gm,
211
+ example_inputs,
212
+ "[Pre grad(predispatch IR)] Apply stack_to_unsqueeze_pass",
213
+ )
214
+ pass_execution_and_save(
215
+ fuse_chunk_reshape_unsqueeze_concat_pass,
216
+ gm,
217
+ example_inputs,
218
+ "[Pre grad(predispatch IR)] Apply fuse_chunk_reshape_unsqueeze_concat_pass",
219
+ )
220
+ # Remove noops at the end, which may be generated other passes.
221
+ pass_execution_and_save(
222
+ remove_noop_pass,
223
+ gm,
224
+ example_inputs,
225
+ "[Pre grad(predispatch IR)]Apply remove_noop pass",
226
+ )
227
+ shape_prop(gm)
228
+
229
+ else:
230
+ # We only log the graph with changes to avoid the excessive compilation time
231
+ # https://fb.workplace.com/groups/257735836456307/permalink/633533465543207/
232
+ if example_inputs is not None:
233
+ gm = fuse_fx(gm, example_inputs)
234
+ numpy_compat_normalization(gm.graph)
235
+ optimus_scuba_log["before_recompile_pre_grad"] = upload_graph(gm.graph)
236
+ group_batch_fusion_passes(gm.graph, pre_grad=True)
237
+ for pass_name in config.pre_grad_fusion_options:
238
+ # skip all patterns for group batch fusions
239
+ if pass_name in PRE_GRAD_FUSIONS:
240
+ continue
241
+ pattern_matcher_pass = PRE_GRAD_PATTERNS[pass_name]
242
+ inductor_before_change = save_inductor_dict(
243
+ [pattern_matcher_pass.pass_name]
244
+ )
245
+ # we support run same pattern multiple times, the default is to run only once
246
+ counter = config.pre_grad_fusion_options[pass_name].get("counter", 1)
247
+ for _ in range(counter):
248
+ pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
249
+ if not is_same_dict(counters["inductor"], inductor_before_change):
250
+ optimus_scuba_log[
251
+ f"{pattern_matcher_pass.pass_name}_pre_grad"
252
+ ] = upload_graph(gm.graph)
253
+ # TODO: move efficient_conv_bn_eval_pass to the fusions dict too.
254
+ efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type]
255
+
256
+ if config.pre_grad_custom_pass is not None:
257
+ with GraphTransformObserver(
258
+ gm, "pre_grad_custom_pass", config.trace.log_url_for_graph_xform
259
+ ):
260
+ config.pre_grad_custom_pass(gm.graph)
261
+ stable_topological_sort(gm.graph)
262
+
263
+ from .quantization import quant_lift_up
264
+
265
+ quant_lift_up(gm)
266
+
267
+ gm.graph.lint()
268
+ gm.recompile()
269
+ optimus_scuba_log["after_recompile_pre_grad"] = upload_graph(gm.graph)
270
+
271
+ if (
272
+ config.pattern_matcher
273
+ and hasattr(config, "fx_passes_numeric_check")
274
+ and config.fx_passes_numeric_check.get("pre_grad", False)
275
+ and example_inputs is not None
276
+ ):
277
+ from .numeric_utils import numeric_check_if_enabled
278
+
279
+ gm_after_fx_passes = gm.__copy__()
280
+ numeric_check_if_enabled(
281
+ gm_before_fx_passes, # type: ignore[possibly-undefined]
282
+ gm_after_fx_passes,
283
+ example_inputs,
284
+ config.fx_passes_numeric_check.get("num_iterations", 1),
285
+ config.fx_passes_numeric_check.get("precision", 1e-4),
286
+ )
287
+
288
+ return gm
289
+
290
+
291
+ def fuse_fx(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule:
292
+ is_cpu = is_cpu_device(example_inputs)
293
+ # pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode`
294
+ fake_mode = detect_fake_mode(example_inputs)
295
+
296
+ gm = sink_cat_after_pointwise(gm)
297
+ if config.permute_fusion and not is_cpu:
298
+ # For linear permute fusion, we need to check input info to identify
299
+ # and perform proper permutation/transpose
300
+ ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs)
301
+ with GraphTransformObserver(
302
+ gm, "linear_permute_fusion", config.trace.log_url_for_graph_xform
303
+ ):
304
+ gm = linear_permute_fusion(gm)
305
+ with GraphTransformObserver(
306
+ gm, "permute_linear_fusion", config.trace.log_url_for_graph_xform
307
+ ):
308
+ gm = permute_linear_fusion(gm)
309
+ with GraphTransformObserver(
310
+ gm, "permute_matmul_fusion", config.trace.log_url_for_graph_xform
311
+ ):
312
+ gm = permute_matmul_fusion(gm)
313
+
314
+ # make sure the autograd is disabled.
315
+ if torch.is_grad_enabled() or not is_cpu:
316
+ return gm
317
+ if config.freezing:
318
+ with GraphTransformObserver(
319
+ gm, "remove_identity", config.trace.log_url_for_graph_xform
320
+ ):
321
+ gm = remove_identity(gm)
322
+ with GraphTransformObserver(
323
+ gm, "fuse_conv_bn", config.trace.log_url_for_graph_xform
324
+ ):
325
+ gm = fuse_conv_bn(gm)
326
+ return gm
327
+
328
+
329
+ def fetch_attr(target: str, mod):
330
+ target_atoms = target.split(".")
331
+ attr_itr = mod
332
+ for i, atom in enumerate(target_atoms):
333
+ if not hasattr(attr_itr, atom):
334
+ raise RuntimeError(
335
+ f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
336
+ )
337
+ attr_itr = getattr(attr_itr, atom)
338
+ return attr_itr
339
+
340
+
341
+ def remove_identity(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
342
+ """
343
+ Removes all identity layers from the module.
344
+ """
345
+
346
+ class IdentityRemover(torch.fx.Transformer):
347
+ def call_module(self, target, args, kwargs):
348
+ if isinstance(self.submodules[target], nn.Identity):
349
+ assert len(args) == 1
350
+ return args[0]
351
+ else:
352
+ return super().call_module(target, args, kwargs)
353
+
354
+ return IdentityRemover(gm).transform()
355
+
356
+
357
+ def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False) -> torch.fx.GraphModule:
358
+ """
359
+ Fuses Convolution/BN layers for inference purposes.
360
+ """
361
+ modules_patterns = [
362
+ (torch.nn.Conv1d, torch.nn.BatchNorm1d),
363
+ (torch.nn.Conv2d, torch.nn.BatchNorm2d),
364
+ (torch.nn.Conv3d, torch.nn.BatchNorm3d),
365
+ ]
366
+ module_function_patterns = [
367
+ (torch.nn.Conv1d, F.batch_norm),
368
+ (torch.nn.Conv2d, F.batch_norm),
369
+ (torch.nn.Conv3d, F.batch_norm),
370
+ ]
371
+ modules = dict(gm.named_modules())
372
+
373
+ class ConvBNFusion:
374
+ def __init__(
375
+ self,
376
+ bn_node,
377
+ conv_module,
378
+ bn_module=None, # For BN Module
379
+ bn_running_mean=None, # For Functional BN
380
+ bn_running_var=None,
381
+ bn_eps=None,
382
+ bn_weight=None,
383
+ bn_bias=None,
384
+ ) -> None:
385
+ self.bn_nodes = [
386
+ bn_node,
387
+ ]
388
+ self.conv_module = conv_module
389
+ self.bn_module = bn_module
390
+ self.bn_running_mean = bn_running_mean
391
+ self.bn_running_var = bn_running_var
392
+ self.bn_eps = bn_eps
393
+ self.bn_weight = bn_weight
394
+ self.bn_bias = bn_bias
395
+ self.fusion_enabled = True
396
+
397
+ def add_bn_node(self, bn_node):
398
+ self.bn_nodes.append(bn_node)
399
+
400
+ def disable_fusion(self):
401
+ self.fusion_enabled = False
402
+
403
+ def is_fusion_enabled(self):
404
+ return self.fusion_enabled
405
+
406
+ conv_bn_to_fuse: Dict[int, ConvBNFusion] = {}
407
+ for pattern in modules_patterns:
408
+ conv_bn_to_fuse.clear()
409
+ for node in gm.graph.nodes:
410
+ if matches_module_pattern(pattern, node, modules):
411
+ if len(node.args[0].users) > 1: # Output of conv is used by other nodes
412
+ continue
413
+ conv = modules[node.args[0].target]
414
+ bn = modules[node.target]
415
+ eval_mode = all(not n.training for n in [conv, bn])
416
+ if not eval_mode:
417
+ continue
418
+ if not bn.track_running_stats:
419
+ continue
420
+
421
+ # Do hash based on the module name of conv
422
+ hash_id = hash(node.args[0].target)
423
+ if hash_id not in conv_bn_to_fuse:
424
+ conv_bn_to_fuse[hash_id] = ConvBNFusion(node, conv, bn)
425
+ else:
426
+ if bn == conv_bn_to_fuse[hash_id].bn_module:
427
+ # Do fusion if same bn module
428
+ conv_bn_to_fuse[hash_id].add_bn_node(node)
429
+ else:
430
+ # Disable the conv bn folding if conv shared by different bn
431
+ conv_bn_to_fuse[hash_id].disable_fusion()
432
+
433
+ for conv_bn_fusion in conv_bn_to_fuse.values():
434
+ if conv_bn_fusion.is_fusion_enabled():
435
+ bn_nodes = conv_bn_fusion.bn_nodes
436
+ conv = conv_bn_fusion.conv_module
437
+ bn = conv_bn_fusion.bn_module
438
+
439
+ fused_conv = fuse_conv_bn_eval(conv, bn)
440
+ for bn_node in bn_nodes:
441
+ replace_node_module(bn_node.args[0], modules, fused_conv)
442
+ bn_node.replace_all_uses_with(bn_node.args[0])
443
+ gm.graph.erase_node(bn_node)
444
+
445
+ gm.graph.lint()
446
+ for pattern in module_function_patterns:
447
+ conv_bn_to_fuse.clear()
448
+ for node in gm.graph.nodes:
449
+ if matches_module_function_pattern(pattern, node, modules):
450
+ # TODO: support kwargs.
451
+ if len(node.args) != 8:
452
+ continue
453
+ conv = modules[node.args[0].target]
454
+ bn_training = node.args[5]
455
+ bn_eps = node.args[7]
456
+ if conv.training or bn_training:
457
+ continue
458
+ if type(bn_eps) is not float:
459
+ continue
460
+
461
+ def _used_by_same_conv_module(users):
462
+ conv_module_name = users[0].args[0].target
463
+ return all(
464
+ conv_module_name == user.args[0].target for user in users
465
+ )
466
+
467
+ bn_args_is_constant = all(
468
+ n.op == "get_attr"
469
+ and (len(n.users) == 1 or _used_by_same_conv_module(list(n.users)))
470
+ for n in node.args[1:5]
471
+ )
472
+ if not bn_args_is_constant:
473
+ continue
474
+ bn_running_mean = fetch_attr(node.args[1].target, gm)
475
+ bn_running_var = fetch_attr(node.args[2].target, gm)
476
+ bn_weight = fetch_attr(node.args[3].target, gm)
477
+ bn_bias = fetch_attr(node.args[4].target, gm)
478
+ if bn_running_mean is None or bn_running_var is None:
479
+ continue
480
+
481
+ # Do hash based on the module name of conv
482
+ hash_id = hash(node.args[0].target)
483
+ if hash_id not in conv_bn_to_fuse:
484
+ conv_bn_to_fuse[hash_id] = ConvBNFusion(
485
+ node,
486
+ conv,
487
+ bn_running_mean=bn_running_mean,
488
+ bn_running_var=bn_running_var,
489
+ bn_eps=bn_eps,
490
+ bn_weight=bn_weight,
491
+ bn_bias=bn_bias,
492
+ )
493
+ else:
494
+ if (
495
+ hash(bn_running_mean)
496
+ == hash(conv_bn_to_fuse[hash_id].bn_running_mean)
497
+ and hash(bn_running_var)
498
+ == hash(conv_bn_to_fuse[hash_id].bn_running_var)
499
+ and torch.allclose(
500
+ torch.tensor(bn_eps),
501
+ torch.tensor(conv_bn_to_fuse[hash_id].bn_eps),
502
+ )
503
+ and hash(bn_weight) == hash(conv_bn_to_fuse[hash_id].bn_weight)
504
+ and hash(bn_bias) == hash(conv_bn_to_fuse[hash_id].bn_bias)
505
+ ):
506
+ # Do fusion if same functional bn
507
+ conv_bn_to_fuse[hash_id].add_bn_node(node)
508
+ else:
509
+ # Disable the conv bn folding if conv shared by different bn
510
+ conv_bn_to_fuse[hash_id].disable_fusion()
511
+
512
+ for conv_bn_fusion in conv_bn_to_fuse.values():
513
+ if conv_bn_fusion.is_fusion_enabled():
514
+ bn_nodes = conv_bn_fusion.bn_nodes
515
+ conv = conv_bn_fusion.conv_module
516
+ bn_running_mean = conv_bn_fusion.bn_running_mean
517
+ bn_running_var = conv_bn_fusion.bn_running_var
518
+ bn_eps = conv_bn_fusion.bn_eps
519
+ bn_weight = conv_bn_fusion.bn_weight
520
+ bn_bias = conv_bn_fusion.bn_bias
521
+
522
+ fused_conv = copy.deepcopy(conv)
523
+ fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights(
524
+ fused_conv.weight,
525
+ fused_conv.bias,
526
+ bn_running_mean,
527
+ bn_running_var,
528
+ bn_eps,
529
+ bn_weight,
530
+ bn_bias,
531
+ )
532
+ for bn_node in bn_nodes:
533
+ replace_node_module(bn_node.args[0], modules, fused_conv)
534
+ bn_node.replace_all_uses_with(bn_node.args[0])
535
+ gm.graph.erase_node(bn_node)
536
+ gm.graph.lint()
537
+ gm.recompile()
538
+
539
+ return gm
540
+
541
+
542
+ class NormalizedLinearNode:
543
+ def __init__(self, node: torch.fx.Node) -> None:
544
+ assert node.op == "call_function"
545
+ assert node.target in [torch.nn.functional.linear]
546
+ self.node: torch.fx.Node = node
547
+
548
+ def get_input(self) -> torch.fx.Node:
549
+ if len(self.node.args) > 0:
550
+ return self.node.args[0] # type: ignore[return-value]
551
+ else:
552
+ return self.node.kwargs["input"] # type: ignore[return-value]
553
+
554
+ def get_weight(self) -> torch.fx.Node:
555
+ if len(self.node.args) > 1:
556
+ return self.node.args[1] # type: ignore[return-value]
557
+ else:
558
+ return self.node.kwargs["weight"] # type: ignore[return-value]
559
+
560
+ def get_bias(self) -> torch.fx.Node:
561
+ if len(self.node.args) > 2:
562
+ return self.node.args[2] # type: ignore[return-value]
563
+ else:
564
+ return self.node.kwargs["bias"] if "bias" in self.node.kwargs else None # type: ignore[return-value]
565
+
566
+
567
+ class NormalizedMatmulNode:
568
+ def __init__(self, node: torch.fx.Node) -> None:
569
+ assert node.op == "call_function"
570
+ assert node.target in [torch.bmm, torch.matmul]
571
+ self.node: torch.fx.Node = node
572
+
573
+ def get_input(self) -> torch.fx.Node:
574
+ if len(self.node.args) > 0:
575
+ return self.node.args[0] # type: ignore[return-value]
576
+ else:
577
+ return self.node.kwargs["input"] # type: ignore[return-value]
578
+
579
+ def get_other(self) -> torch.fx.Node:
580
+ if len(self.node.args) > 1:
581
+ return self.node.args[1] # type: ignore[return-value]
582
+ else:
583
+ return self.node.kwargs["other"] # type: ignore[return-value]
584
+
585
+
586
+ def check_permute(node: torch.fx.Node) -> bool:
587
+ ranks = len(node.meta["tensor_meta"].shape)
588
+ if len(node.args) > 3:
589
+ permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] # type: ignore[operator]
590
+ elif (
591
+ "permutation" in node.kwargs
592
+ and node.kwargs["permutation"] is not None
593
+ and len(node.kwargs["permutation"]) > 2 # type: ignore[arg-type]
594
+ ):
595
+ permutation = [i % ranks for i in node.kwargs["permutation"]] # type: ignore[union-attr]
596
+ else:
597
+ return False
598
+ allowed_permutation = list(range(ranks))
599
+ allowed_permutation[-1] = ranks - 2
600
+ allowed_permutation[-2] = ranks - 1
601
+ return permutation == allowed_permutation
602
+
603
+
604
+ def sink_cat_after_pointwise(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
605
+ def one_user(node):
606
+ users = list(node.users)
607
+ return users[0] if len(users) == 1 else None
608
+
609
+ def is_view(node):
610
+ view = {"view"}
611
+ return node.op == "call_method" and node.target in view
612
+
613
+ def is_pointwise_unary(node):
614
+ pointwise = {torch.relu, torch.tanh, "relu", "tanh"}
615
+ return node.op in {"call_function", "call_method"} and node.target in pointwise
616
+
617
+ g = module.graph
618
+ for node in g.nodes:
619
+ if node.op != "call_function" or node.target != torch.cat:
620
+ continue
621
+
622
+ cat_or_view = node
623
+ while True:
624
+ user = one_user(cat_or_view)
625
+ if not user or not is_view(user):
626
+ break
627
+ cat_or_view = user
628
+
629
+ if user and is_pointwise_unary(user):
630
+ with g.inserting_before(node):
631
+
632
+ def cat_args(tensors, dim=0):
633
+ return tensors, dim
634
+
635
+ tensors, dim = cat_args(*node.args, **node.kwargs)
636
+ new_kwargs = {
637
+ name: val for name, val in user.kwargs.items() if name != "input"
638
+ }
639
+ new_tensors = [
640
+ g.create_node(user.op, user.target, args=(arg,), kwargs=new_kwargs)
641
+ for arg in tensors
642
+ ]
643
+ new_cat = g.create_node(
644
+ "call_function", torch.cat, args=(new_tensors, dim)
645
+ )
646
+ user.replace_all_uses_with(cat_or_view)
647
+ node.replace_all_uses_with(new_cat)
648
+ g.erase_node(user)
649
+ g.erase_node(node)
650
+ g.lint()
651
+ module.recompile()
652
+ return module
653
+
654
+
655
+ def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
656
+ for node in module.graph.find_nodes(op="call_method", target="permute"):
657
+ if check_permute(node):
658
+ if len(node.args) > 0:
659
+ input_node = node.args[0]
660
+ else:
661
+ input_node = node.kwargs["input"]
662
+ if (
663
+ input_node.op == "call_function"
664
+ and input_node.target == torch.nn.functional.linear
665
+ ):
666
+ normalized = NormalizedLinearNode(input_node)
667
+ input = normalized.get_input()
668
+ weight = normalized.get_weight()
669
+ bias = normalized.get_bias()
670
+ with module.graph.inserting_before(node):
671
+ fused_node = module.graph.call_function(
672
+ linear_transpose, args=(input, weight, bias)
673
+ )
674
+ node.replace_all_uses_with(fused_node)
675
+ module.graph.erase_node(node)
676
+ if len(input_node.users) == 0:
677
+ module.graph.erase_node(input_node)
678
+
679
+ module.graph.lint()
680
+ module.recompile()
681
+ return module
682
+
683
+
684
+ # Y1 = X * W^T + bias
685
+ # Y2 = Y1.permute(0, 2, 1)
686
+ # ---->
687
+ # Y2 = (W * X^T + bias.unsqueeze(-1))^T
688
+ def linear_transpose(
689
+ input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
690
+ ) -> torch.Tensor:
691
+ if bias is None:
692
+ return torch.matmul(weight, input.transpose(-1, -2))
693
+ return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1)
694
+
695
+
696
+ def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
697
+ for node in module.graph.find_nodes(
698
+ op="call_function", target=torch.nn.functional.linear
699
+ ):
700
+ if len(node.args) > 0:
701
+ input_node = node.args[0]
702
+ else:
703
+ input_node = node.kwargs["input"]
704
+ if (
705
+ input_node.op == "call_method"
706
+ and input_node.target == "permute"
707
+ and check_permute(input_node)
708
+ ):
709
+ normalized = NormalizedLinearNode(node)
710
+ if len(input_node.args) > 0:
711
+ input = input_node.args[0]
712
+ else:
713
+ input = input_node.kwargs["input"]
714
+ weight = normalized.get_weight()
715
+ bias = normalized.get_bias()
716
+ with module.graph.inserting_before(node):
717
+ fused_node = module.graph.call_function(
718
+ transpose_linear, args=(input, weight, bias)
719
+ )
720
+ node.replace_all_uses_with(fused_node)
721
+ module.graph.erase_node(node)
722
+ if len(input_node.users) == 0:
723
+ module.graph.erase_node(input_node)
724
+
725
+ module.graph.lint()
726
+ module.recompile()
727
+ return module
728
+
729
+
730
+ def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
731
+ for node in itertools.chain(
732
+ module.graph.find_nodes(op="call_function", target=torch.bmm),
733
+ module.graph.find_nodes(op="call_function", target=torch.matmul),
734
+ ):
735
+ normalized = NormalizedMatmulNode(node)
736
+ input_A_node = normalized.get_input()
737
+ input_B_node = normalized.get_other()
738
+ input_A = input_A_node
739
+ input_B = input_B_node
740
+ Atrans = Btrans = False
741
+ if (
742
+ input_A_node.op == "call_method"
743
+ and input_A_node.target == "permute"
744
+ and check_permute(input_A_node)
745
+ ):
746
+ Atrans = True
747
+ if len(input_A_node.args) > 0:
748
+ input_A = input_A_node.args[0] # type: ignore[assignment]
749
+ else:
750
+ input_A = input_A_node.kwargs["input"] # type: ignore[assignment]
751
+
752
+ if (
753
+ input_B_node.op == "call_method"
754
+ and input_B_node.target == "permute"
755
+ and check_permute(input_B_node)
756
+ ):
757
+ Btrans = True
758
+ if len(input_B_node.args) > 0:
759
+ input_B = input_B_node.args[0] # type: ignore[assignment]
760
+ else:
761
+ input_B = input_B_node.kwargs["input"] # type: ignore[assignment]
762
+
763
+ if Atrans or Btrans:
764
+ with module.graph.inserting_before(node):
765
+ fused_node = module.graph.call_function(
766
+ transpose_matmul,
767
+ args=(input_A, input_B, Atrans, Btrans),
768
+ )
769
+ node.replace_all_uses_with(fused_node)
770
+ module.graph.erase_node(node)
771
+ if Atrans and len(input_A_node.users) == 0:
772
+ module.graph.erase_node(input_A_node)
773
+ if Btrans and len(input_B_node.users) == 0:
774
+ module.graph.erase_node(input_B_node)
775
+
776
+ module.graph.lint()
777
+ module.recompile()
778
+ return module
779
+
780
+
781
+ # X1 = X.permute(0, 2, 1)
782
+ # Y1 = X1 * W1^T + bias1
783
+ # ---->
784
+ # Y2 = X1.transpose(-1, -2) * W1^T + bias1
785
+ def transpose_linear(
786
+ input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
787
+ ) -> torch.Tensor:
788
+ if bias is None:
789
+ return torch.matmul(input.transpose(-1, -2), weight.t())
790
+ return torch.matmul(input.transpose(-1, -2), weight.t()) + bias
791
+
792
+
793
+ def transpose_matmul(
794
+ A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool
795
+ ) -> torch.Tensor:
796
+ if Atrans:
797
+ A = A.transpose(-1, -2)
798
+ if Btrans:
799
+ B = B.transpose(-1, -2)
800
+ return torch.matmul(A, B)
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/quantization.py ADDED
@@ -0,0 +1,2589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-decorators
2
+ # mypy: allow-untyped-defs
3
+ import copy
4
+ import functools
5
+ import itertools
6
+ import math
7
+ import operator
8
+ from typing import Any, Tuple
9
+
10
+ import torch
11
+ from torch._dynamo.utils import counters
12
+ from torch.fx.experimental.symbolic_shapes import has_free_symbols
13
+ from torch.fx.node import map_arg
14
+
15
+ from ..lowering import lowerings as L, require_channels_last
16
+ from ..pattern_matcher import Arg, CallFunction, filter_nodes, KeywordArg, ListOf, Match
17
+ from ..utils import pad_listlike
18
+ from .freezing_patterns import register_freezing_graph_pattern
19
+ from .post_grad import register_lowering_pattern
20
+
21
+
22
+ aten = torch.ops.aten
23
+ prims = torch.ops.prims
24
+ quantized_decomposed = torch.ops.quantized_decomposed
25
+ quantized = torch.ops.quantized
26
+
27
+ # Only for per tensor quant since permute may changes the channel idx
28
+ _PER_TENSOR_QUANTIZE_OPS = [
29
+ quantized_decomposed.quantize_per_tensor.default,
30
+ quantized_decomposed.quantize_per_tensor.tensor,
31
+ ]
32
+
33
+ _VIEW_OPS = [
34
+ aten.transpose.int,
35
+ aten.permute.default,
36
+ aten.view.default,
37
+ ]
38
+
39
+ """
40
+ The quantization.py file primarily incorporates passes related to quantization fusion
41
+ in inductor, includes:
42
+ 1. Dequant Promotion;
43
+ 2. Conv/GEMM weight prepack with oneDNN Library;
44
+ 3. Conv/GEMM quantization fusion with output quant node (if have);
45
+ 4. Other pointwise operators' quantization fusion like: qmaxpool2d, qcat and more;
46
+
47
+ It also involves int8-mixed-fp32 and int8-mixed-bf16 quantization. The main difference
48
+ of patterns for int8-mixed-bf16, comparing with int8-mixed-fp32, is
49
+ 1. There is to(dtype=torch.bfloat16) node at the inputs of activation and weight for Conv/GEMM.
50
+ 2. There is to(dtype=torch.float32) node at the outputs of Conv/GEMM before inputs to next quant node.
51
+ Refer to: https://github.com/pytorch/pytorch/issues/111640 for detail design of int8-mixed-bf16
52
+ quantization.
53
+ """
54
+
55
+
56
+ def _get_pattern_output_dtype(match: Match):
57
+ """
58
+ Get the pattern's output dtype from node's meta
59
+ Assume only 1 output node in this matched pattern.
60
+ """
61
+ pattern_output_nodes = match.output_nodes()
62
+ assert len(pattern_output_nodes) == 1
63
+ output_node = pattern_output_nodes[0]
64
+ assert isinstance(output_node, torch.fx.Node)
65
+ output_dtype = output_node.meta["val"].dtype
66
+ assert output_dtype in [torch.uint8, torch.float32, torch.bfloat16]
67
+ return output_dtype
68
+
69
+
70
+ def _may_generate_pattern_with_dtype_convert(
71
+ pattern, dtype=Arg(), with_dtype_convert=True, users=1
72
+ ):
73
+ if with_dtype_convert:
74
+ return CallFunction(
75
+ prims.convert_element_type.default,
76
+ pattern,
77
+ dtype,
78
+ _users=users,
79
+ )
80
+ else:
81
+ return pattern
82
+
83
+
84
+ def _may_generate_pattern_with_reshape(pattern, reshape_size=Arg(), with_reshape=True):
85
+ if with_reshape:
86
+ return CallFunction(
87
+ torch.ops.aten.reshape.default,
88
+ pattern,
89
+ reshape_size,
90
+ )
91
+ else:
92
+ return pattern
93
+
94
+
95
+ def _generate_linear_t_pattern(
96
+ _dequant_per_channel_pattern,
97
+ dtype,
98
+ ):
99
+ assert dtype in [torch.float32, torch.bfloat16]
100
+ t_pattern = CallFunction(
101
+ aten.permute.default,
102
+ _may_generate_pattern_with_dtype_convert(
103
+ _dequant_per_channel_pattern,
104
+ KeywordArg("autocast_wgt_dtype"),
105
+ dtype == torch.bfloat16,
106
+ ),
107
+ KeywordArg("permute_axes"),
108
+ )
109
+ return t_pattern
110
+
111
+
112
+ def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16):
113
+ # only insert to_dtype if is_bf16 is True
114
+ computation_call = _may_generate_pattern_with_dtype_convert(
115
+ call_fn, dtype=KeywordArg("to_float"), with_dtype_convert=is_bf16, users=users
116
+ )
117
+ return unary_fusion(computation_call)
118
+
119
+
120
+ def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False):
121
+ dequantize_per_tensor_activation_pattern = CallFunction(
122
+ quantized_decomposed.dequantize_per_tensor.tensor
123
+ if is_tensor_overload
124
+ else quantized_decomposed.dequantize_per_tensor.default,
125
+ KeywordArg("x"),
126
+ KeywordArg("x_scale"),
127
+ KeywordArg("x_zp"),
128
+ KeywordArg("x_quant_min"),
129
+ KeywordArg("x_quant_max"),
130
+ KeywordArg("x_dq_dtype"),
131
+ )
132
+ return dequantize_per_tensor_activation_pattern
133
+
134
+
135
+ dequantize_per_channel_weight_pattern = CallFunction(
136
+ quantized_decomposed.dequantize_per_channel.default,
137
+ KeywordArg("q_weight"),
138
+ KeywordArg("w_scale"),
139
+ KeywordArg("w_zp"),
140
+ KeywordArg("w_axis"),
141
+ KeywordArg("w_quant_min"),
142
+ KeywordArg("w_quant_max"),
143
+ KeywordArg("w_dtype"),
144
+ )
145
+
146
+ dequantize_per_channel_to_bf16_weight_pattern = (
147
+ _may_generate_pattern_with_dtype_convert(
148
+ dequantize_per_channel_weight_pattern,
149
+ KeywordArg("autocast_wgt_dtype"),
150
+ )
151
+ )
152
+
153
+ dequantize_per_channel_clone_weight_pattern = CallFunction(
154
+ aten.clone.default,
155
+ dequantize_per_channel_weight_pattern,
156
+ memory_format=KeywordArg("memory_format"),
157
+ )
158
+
159
+ dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction(
160
+ aten.clone.default,
161
+ dequantize_per_channel_to_bf16_weight_pattern,
162
+ memory_format=KeywordArg("memory_format"),
163
+ )
164
+
165
+
166
+ def get_dequantize_qconv_pt2e_pattern(users=1):
167
+ return CallFunction(
168
+ torch.ops.onednn.qconv2d_pointwise.default,
169
+ KeywordArg("x"),
170
+ KeywordArg("x_scale"), # x_scale
171
+ KeywordArg("x_zp"), # x_zp
172
+ KeywordArg("packed_weight"), # packed_weight
173
+ KeywordArg("w_scale"), # w_scale
174
+ KeywordArg("w_zp"), # w_zp
175
+ KeywordArg("b"), # bias
176
+ KeywordArg("stride"),
177
+ KeywordArg("padding"),
178
+ KeywordArg("dilation"),
179
+ KeywordArg("groups"),
180
+ KeywordArg("output_scale"), # output_scale = 1.0
181
+ KeywordArg("output_zero_point"), # output_zero_point = 0
182
+ KeywordArg("output_dtype"), # output_dtype = None
183
+ KeywordArg("attr"), # attr = "none"
184
+ Arg(), # scalars
185
+ Arg(), # algorithm
186
+ _users=users,
187
+ )
188
+
189
+
190
+ def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors, users=1):
191
+ qlinear_op = (
192
+ torch.ops.onednn.qlinear_pointwise.tensor
193
+ if x_scale_zp_are_tensors
194
+ else torch.ops.onednn.qlinear_pointwise.default
195
+ )
196
+ return CallFunction(
197
+ qlinear_op,
198
+ KeywordArg("x"),
199
+ KeywordArg("x_scale"),
200
+ KeywordArg("x_zp"),
201
+ KeywordArg("packed_weight"),
202
+ KeywordArg("w_scale"),
203
+ KeywordArg("w_zp"),
204
+ KeywordArg("b"),
205
+ KeywordArg("output_scale"),
206
+ KeywordArg("output_zero_point"),
207
+ KeywordArg("output_dtype"),
208
+ KeywordArg("postop_name"),
209
+ KeywordArg("postop_args"),
210
+ KeywordArg("postop_algorithm"),
211
+ _users=users,
212
+ )
213
+
214
+
215
+ dequantize_accum_pattern = CallFunction(
216
+ quantized_decomposed.dequantize_per_tensor.default,
217
+ KeywordArg("accum"),
218
+ KeywordArg("accum_scale"),
219
+ KeywordArg("accum_zp"),
220
+ Arg(),
221
+ Arg(),
222
+ KeywordArg("accum_dq_dtype"),
223
+ )
224
+
225
+
226
+ def generate_pattern_with_binary(
227
+ binary_post_op,
228
+ computation_call,
229
+ extra_input_pattern,
230
+ dtype_convert=False,
231
+ swap_inputs=False,
232
+ ):
233
+ binary_pattern = (
234
+ CallFunction(
235
+ binary_post_op,
236
+ extra_input_pattern,
237
+ computation_call,
238
+ )
239
+ if swap_inputs
240
+ else CallFunction(
241
+ binary_post_op,
242
+ computation_call,
243
+ extra_input_pattern,
244
+ )
245
+ )
246
+ return _may_generate_pattern_with_dtype_convert(
247
+ binary_pattern,
248
+ KeywordArg("convert_dtype_after_inplace_add"),
249
+ dtype_convert,
250
+ )
251
+
252
+
253
+ def generate_pattern_with_unary(computation_call, unary_post_op):
254
+ if unary_post_op is not None:
255
+ return CallFunction(
256
+ unary_post_op,
257
+ computation_call,
258
+ )
259
+ return computation_call
260
+
261
+
262
+ def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False):
263
+ quantized_op_output_pattern_pt2e = CallFunction(
264
+ quantized_decomposed.quantize_per_tensor.default,
265
+ _may_generate_pattern_with_dtype_convert(
266
+ computation_call,
267
+ Arg(),
268
+ with_dtype_convert,
269
+ ),
270
+ KeywordArg("o_inv_scale"),
271
+ KeywordArg("o_zp"),
272
+ KeywordArg("o_qmin"),
273
+ KeywordArg("o_qmax"),
274
+ KeywordArg("o_dtype"),
275
+ )
276
+ return quantized_op_output_pattern_pt2e
277
+
278
+
279
+ def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_value):
280
+ if kwarg_name in check_node.kwargs:
281
+ actual_value = check_node.kwargs[kwarg_name]
282
+ return actual_value == expected_value
283
+ else:
284
+ assert len(check_node.args) >= (args_index + 1)
285
+ actual_value = check_node.args[args_index]
286
+ return actual_value == expected_value
287
+
288
+
289
+ def _is_valid_quantized_conv2d_optimization_pattern():
290
+ def fn(match):
291
+ output_dtype = _get_pattern_output_dtype(match)
292
+ if output_dtype in [torch.float32, torch.bfloat16]:
293
+ # Only keep matched pattern with same output_dtype
294
+ qconv_node_after_weight_prepack = filter_nodes(
295
+ match.nodes, torch.ops.onednn.qconv2d_pointwise
296
+ )[0]
297
+ return _check_node_kwarg_arg_value(
298
+ qconv_node_after_weight_prepack, "output_dtype", 13, output_dtype
299
+ )
300
+ return True
301
+
302
+ return fn
303
+
304
+
305
+ def _register_quantized_conv_lowering(
306
+ pattern,
307
+ pass_number,
308
+ computation_op,
309
+ unary_attr,
310
+ ):
311
+ @register_lowering_pattern(
312
+ pattern,
313
+ extra_check=_is_valid_quantized_conv2d_optimization_pattern(),
314
+ pass_number=pass_number,
315
+ )
316
+ def qconv(match: Match, *args, **kwargs):
317
+ # Activation QParams
318
+ x, x_scale, x_zp = (
319
+ kwargs["x"],
320
+ kwargs["x_scale"],
321
+ kwargs["x_zp"],
322
+ )
323
+ # Weight QParams
324
+ packed_weight, w_scale, w_zp = (
325
+ kwargs["packed_weight"],
326
+ kwargs["w_scale"],
327
+ kwargs["w_zp"],
328
+ )
329
+ # Conv Params
330
+ b, stride, padding, dilation, groups = (
331
+ kwargs["b"],
332
+ kwargs["stride"],
333
+ kwargs["padding"],
334
+ kwargs["dilation"],
335
+ kwargs["groups"],
336
+ )
337
+ output_dtype = _get_pattern_output_dtype(match)
338
+ assert output_dtype in [torch.uint8, torch.float32, torch.bfloat16]
339
+ # Output QParams
340
+ o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
341
+ o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
342
+ assert (
343
+ kwargs["attr"] == "none"
344
+ ) # Expected no post op fused in weight prepack phase
345
+ if unary_attr.op_name == "hardtanh":
346
+ min_value = kwargs.get("min_value")
347
+ max_value = kwargs.get("max_value")
348
+ unary_attr.scalars_attr = [min_value, max_value]
349
+
350
+ computation_args = (
351
+ x,
352
+ x_scale,
353
+ x_zp,
354
+ packed_weight,
355
+ w_scale,
356
+ w_zp,
357
+ b,
358
+ stride,
359
+ padding,
360
+ dilation,
361
+ groups,
362
+ o_inv_scale,
363
+ o_zero_point,
364
+ output_dtype,
365
+ unary_attr.op_name,
366
+ unary_attr.scalars_attr,
367
+ unary_attr.algorithm_attr,
368
+ )
369
+ counters["inductor"]["qconv2d_unary_matcher_count"] += 1
370
+ counters["inductor"]["qconv2d_unary_matcher_nodes"] += len(match.nodes)
371
+ return L[computation_op](*computation_args)
372
+
373
+ return qconv
374
+
375
+
376
+ def _is_valid_quantized_linear_optimization_pattern():
377
+ def fn(match):
378
+ output_dtype = _get_pattern_output_dtype(match)
379
+ if output_dtype in [torch.float32, torch.bfloat16]:
380
+ # Only keep matched pattern with same output_dtype
381
+ qlinear_node_after_weight_prepack = filter_nodes(
382
+ match.nodes, torch.ops.onednn.qlinear_pointwise
383
+ )[0]
384
+ return _check_node_kwarg_arg_value(
385
+ qlinear_node_after_weight_prepack, "output_dtype", 9, output_dtype
386
+ )
387
+ return True
388
+
389
+ return fn
390
+
391
+
392
+ def _register_quantized_linear_lowering(
393
+ pattern,
394
+ pass_number,
395
+ computation_op,
396
+ unary_attr,
397
+ ):
398
+ @register_lowering_pattern(
399
+ pattern,
400
+ extra_check=_is_valid_quantized_linear_optimization_pattern(),
401
+ pass_number=pass_number,
402
+ )
403
+ def qlinear(match: Match, *args, **kwargs):
404
+ output_dtype = _get_pattern_output_dtype(match)
405
+ # Activation QParams
406
+ x, x_scale, x_zp = (
407
+ kwargs["x"],
408
+ kwargs["x_scale"],
409
+ kwargs["x_zp"],
410
+ )
411
+ # Weight QParams
412
+ packed_weight, w_scale, w_zp = (
413
+ kwargs["packed_weight"],
414
+ kwargs["w_scale"],
415
+ kwargs["w_zp"],
416
+ )
417
+
418
+ # bias
419
+ b = kwargs["b"] if "b" in kwargs else None
420
+
421
+ # Output QParams
422
+ o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
423
+ o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
424
+ assert (
425
+ kwargs["postop_name"] == "none"
426
+ ) # Expected no post op fused in weight prepack phase
427
+
428
+ computation_args = (
429
+ x,
430
+ x_scale,
431
+ x_zp,
432
+ packed_weight,
433
+ w_scale,
434
+ w_zp,
435
+ b,
436
+ o_inv_scale,
437
+ o_zero_point,
438
+ output_dtype,
439
+ unary_attr.op_name,
440
+ unary_attr.scalars_attr,
441
+ unary_attr.algorithm_attr,
442
+ )
443
+ counters["inductor"]["qlinear_unary_matcher_count"] += 1
444
+ counters["inductor"]["qlinear_unary_matcher_nodes"] += len(match.nodes)
445
+ return L[computation_op](*computation_args)
446
+
447
+ return qlinear
448
+
449
+
450
+ def _register_quantized_linear_binary_lowering(
451
+ pattern,
452
+ pass_number,
453
+ computation_op,
454
+ binary_unary_attr,
455
+ ):
456
+ @register_lowering_pattern(
457
+ pattern,
458
+ extra_check=_is_valid_qlinear_binary_optimization_pattern(),
459
+ pass_number=pass_number,
460
+ )
461
+ def qlinear_binary(match: Match, *args, **kwargs):
462
+ output_dtype = _get_pattern_output_dtype(match)
463
+ assert output_dtype is not None
464
+ # Activation QParams
465
+ x, x_scale, x_zp = (
466
+ kwargs["x"],
467
+ kwargs["x_scale"],
468
+ kwargs["x_zp"],
469
+ )
470
+ x2 = (
471
+ kwargs["accum"]
472
+ if binary_unary_attr.binary_op_name == "sum"
473
+ else kwargs["other"]
474
+ )
475
+ x2_scale = 1.0
476
+ x2_zp = 0
477
+ # Weight QParams
478
+ packed_weight, w_scale, w_zp = (
479
+ kwargs["packed_weight"],
480
+ kwargs["w_scale"],
481
+ kwargs["w_zp"],
482
+ )
483
+ # bias
484
+ b = kwargs["b"] if "b" in kwargs else None
485
+ # Output QParams
486
+ o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
487
+ o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
488
+
489
+ x2.realize()
490
+ from .mkldnn_fusion import _can_be_inplace
491
+
492
+ binary_op_name = binary_unary_attr.binary_op_name
493
+
494
+ if binary_op_name == "sum" and not _can_be_inplace(x2):
495
+ # When we enable the GEMM Template, the output of QLinear
496
+ # will be reshaped from 2D back to 3D if the input is 3D.
497
+ # This causes _can_be_inplace(x2) to return False if x2 happens
498
+ # to be the output of QLinear in this scenario.
499
+ # Change the post op from sum to binary add for this case.
500
+ # Refer to test case:
501
+ # test_mkldnn_pattern_matcher.py::test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2
502
+ binary_op_name = "add"
503
+
504
+ computation_args = (
505
+ x,
506
+ x_scale,
507
+ x_zp,
508
+ packed_weight,
509
+ w_scale,
510
+ w_zp,
511
+ x2,
512
+ b,
513
+ o_inv_scale,
514
+ o_zero_point,
515
+ output_dtype,
516
+ x2_scale,
517
+ x2_zp,
518
+ binary_op_name,
519
+ binary_unary_attr.alpha,
520
+ binary_unary_attr.unary_op_name,
521
+ binary_unary_attr.scalars_attr,
522
+ binary_unary_attr.algorithm_attr,
523
+ )
524
+ counters["inductor"]["qlinear_binary_matcher_count"] += 1
525
+ counters["inductor"]["qlinear_binary_matcher_nodes"] += len(match.nodes)
526
+ return L[computation_op](*computation_args)
527
+
528
+ return qlinear_binary
529
+
530
+
531
+ def _is_valid_qconv_binary_optimization_pattern():
532
+ return _is_valid_quantized_op_binary_optimization_pattern(
533
+ torch.ops.onednn.qconv2d_pointwise
534
+ )
535
+
536
+
537
+ def _is_valid_qlinear_binary_optimization_pattern():
538
+ return _is_valid_quantized_op_binary_optimization_pattern(
539
+ torch.ops.onednn.qlinear_pointwise,
540
+ # we don't insert q-dq for extra input due to accuracy issues
541
+ extra_input_from_dequant=False,
542
+ )
543
+
544
+
545
+ def _is_valid_quantized_op_binary_optimization_pattern(
546
+ qop, extra_input_from_dequant=True
547
+ ):
548
+ # Check if it's a valid Binary Pattern for qconv2d and qlinear:
549
+ # * qop_pointwise should only has one users
550
+ # * If extra_input_from_dequant is True, extra input of binary node should come from dequant pattern
551
+ # * the two inputs of binary node should have attribute "meta" and should be tensors
552
+ # * the two inputs of binary node should have the same shape
553
+ # * All users of the extra input in this pattern should be
554
+ # ancestor nodes of the compute node, except for the binary node
555
+ # connected to the compute node.
556
+ def fn(match):
557
+ output_dtype = _get_pattern_output_dtype(match)
558
+ compute_node = filter_nodes(match.nodes, qop)[0]
559
+ # qop_pointwise should only have one user
560
+ if len(compute_node.users) != 1:
561
+ return False
562
+ binary_node_inputs = next(iter(compute_node.users)).args
563
+ assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs"
564
+ if output_dtype in [torch.float32, torch.bfloat16]:
565
+ extra_input_of_binary_node = None
566
+ for arg in binary_node_inputs:
567
+ if arg != compute_node:
568
+ extra_input_of_binary_node = arg
569
+ break
570
+ assert extra_input_of_binary_node is not None
571
+ # Extra input of binary node comes from dequant pattern
572
+ if extra_input_from_dequant and (
573
+ (not isinstance(extra_input_of_binary_node, torch.fx.Node))
574
+ or (
575
+ extra_input_of_binary_node.target
576
+ != quantized_decomposed.dequantize_per_tensor.default
577
+ )
578
+ ):
579
+ return False
580
+
581
+ # the two inputs of binary node should have attribute "meta" and should be tensors
582
+ if not (
583
+ hasattr(binary_node_inputs[0], "meta")
584
+ and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
585
+ ) or not (
586
+ hasattr(binary_node_inputs[1], "meta")
587
+ and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
588
+ ):
589
+ return False
590
+ # the two inputs of binary node should have the same shape
591
+ if (
592
+ binary_node_inputs[0].meta["val"].size() # type: ignore[union-attr]
593
+ != binary_node_inputs[1].meta["val"].size() # type: ignore[union-attr]
594
+ ):
595
+ return False
596
+
597
+ # All users of the extra input in this pattern should be
598
+ # ancestor nodes of the compute node, except for the binary node
599
+ # connected to the compute node.
600
+
601
+ from .mkldnn_fusion import _get_remaining_users
602
+
603
+ extra_input_of_pattern = (
604
+ match.kwargs["other"]
605
+ if "other" in match.kwargs
606
+ else (
607
+ match.kwargs["accum"]
608
+ if output_dtype == torch.uint8 or (not extra_input_from_dequant)
609
+ else match.kwargs["accum_after_dequant"]
610
+ )
611
+ )
612
+ if (
613
+ len(_get_remaining_users(extra_input_of_pattern, compute_node)) > 1
614
+ or extra_input_of_pattern == compute_node.args[0]
615
+ ):
616
+ return False
617
+ return True
618
+
619
+ return fn
620
+
621
+
622
+ def _register_quantized_conv_binary_lowering(
623
+ pattern,
624
+ pass_number,
625
+ computation_op,
626
+ binary_unary_attr,
627
+ ):
628
+ @register_lowering_pattern(
629
+ pattern,
630
+ extra_check=_is_valid_qconv_binary_optimization_pattern(),
631
+ pass_number=pass_number,
632
+ )
633
+ def qconv_binary(match: Match, *args, **kwargs):
634
+ output_dtype = _get_pattern_output_dtype(match)
635
+ assert output_dtype is not None
636
+ x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"]
637
+ accum = (
638
+ kwargs["accum"]
639
+ if output_dtype == torch.uint8
640
+ else kwargs["accum_after_dequant"]
641
+ )
642
+ accum_scale = kwargs["accum_scale"] if output_dtype == torch.uint8 else 1.0
643
+ accum_zp = kwargs["accum_zp"] if output_dtype == torch.uint8 else 0
644
+ packed_weight, w_scale, w_zp = (
645
+ kwargs["packed_weight"],
646
+ kwargs["w_scale"],
647
+ kwargs["w_zp"],
648
+ )
649
+ b, stride, padding, dilation, groups = (
650
+ kwargs["b"],
651
+ kwargs["stride"],
652
+ kwargs["padding"],
653
+ kwargs["dilation"],
654
+ kwargs["groups"],
655
+ )
656
+ # Output QParams
657
+ o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
658
+ o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
659
+
660
+ accum.realize()
661
+ from .mkldnn_fusion import _can_be_inplace
662
+
663
+ assert _can_be_inplace(
664
+ accum
665
+ ), "QConv Binary Inplace Fusion requires accum is not an alias or mutation."
666
+
667
+ computation_args = (
668
+ x,
669
+ x_scale,
670
+ x_zp,
671
+ accum,
672
+ accum_scale,
673
+ accum_zp,
674
+ packed_weight,
675
+ w_scale,
676
+ w_zp,
677
+ b,
678
+ stride,
679
+ padding,
680
+ dilation,
681
+ groups,
682
+ o_inv_scale,
683
+ o_zero_point,
684
+ output_dtype,
685
+ binary_unary_attr.binary_op_name,
686
+ binary_unary_attr.alpha,
687
+ binary_unary_attr.unary_op_name,
688
+ binary_unary_attr.scalars_attr,
689
+ binary_unary_attr.algorithm_attr,
690
+ )
691
+ counters["inductor"]["qconv2d_binary_matcher_count"] += 1
692
+ counters["inductor"]["qconv2d_binary_matcher_nodes"] += len(match.nodes)
693
+ return L[computation_op](*computation_args)
694
+
695
+ return qconv_binary
696
+
697
+
698
+ def _register_quantization_unary_fusion():
699
+ from .mkldnn_fusion import (
700
+ _gelu_fusion_1 as _gelu_fusion_erf,
701
+ _gelu_fusion_2 as _gelu_fusion_tanh,
702
+ _hardswish_fusion,
703
+ _hardtanh_fusion,
704
+ _silu_fusion,
705
+ )
706
+
707
+ class UnaryAttr:
708
+ def __init__(
709
+ self, op_name: str, scalars_attr=None, algorithm_attr=None
710
+ ) -> None:
711
+ self.op_name = op_name
712
+ self.scalars_attr = scalars_attr if scalars_attr else []
713
+ self.algorithm_attr = algorithm_attr if algorithm_attr else ""
714
+
715
+ for original_pattern_output_dtype in [torch.float32, torch.bfloat16]:
716
+ # QConv2d
717
+ # Priority 1 to match: QConv2d Unary pattern with int8 output
718
+ # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly.
719
+ # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant
720
+ is_bf16 = original_pattern_output_dtype == torch.bfloat16
721
+ conv_unary_replace_patterns = {
722
+ UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
723
+ get_dequantize_qconv_pt2e_pattern(1),
724
+ ),
725
+ UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
726
+ generate_pattern_with_unary(
727
+ get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
728
+ ),
729
+ ),
730
+ UnaryAttr("hardtanh", [], ""): generate_pattern_with_output_quant(
731
+ _unary_fusion_pattern(
732
+ _hardtanh_fusion,
733
+ get_dequantize_qconv_pt2e_pattern(1),
734
+ 1,
735
+ is_bf16,
736
+ ),
737
+ with_dtype_convert=is_bf16,
738
+ ),
739
+ UnaryAttr("hardswish", [], ""): generate_pattern_with_output_quant(
740
+ _unary_fusion_pattern(
741
+ _hardswish_fusion,
742
+ get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
743
+ 2,
744
+ is_bf16,
745
+ ),
746
+ with_dtype_convert=is_bf16,
747
+ ),
748
+ UnaryAttr("swish", [], ""): generate_pattern_with_output_quant(
749
+ _unary_fusion_pattern(
750
+ _silu_fusion,
751
+ get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
752
+ 2,
753
+ is_bf16,
754
+ ),
755
+ with_dtype_convert=is_bf16,
756
+ ),
757
+ }
758
+
759
+ for unary_attr, patterns in conv_unary_replace_patterns.items():
760
+ # Register qconv2d pattern for ExternKernel Lowering
761
+ _register_quantized_conv_lowering(
762
+ patterns,
763
+ 1, # pass_number
764
+ torch.ops.onednn.qconv2d_pointwise, # computation_op
765
+ unary_attr, # unary_attr
766
+ )
767
+
768
+ # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output
769
+ conv_unary_replace_float_out_patterns = {
770
+ UnaryAttr("relu", [], ""): generate_pattern_with_unary(
771
+ get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
772
+ ),
773
+ UnaryAttr("hardtanh", [], ""): _may_generate_pattern_with_dtype_convert(
774
+ _unary_fusion_pattern(
775
+ _hardtanh_fusion,
776
+ get_dequantize_qconv_pt2e_pattern(1),
777
+ 1,
778
+ is_bf16,
779
+ ),
780
+ Arg(),
781
+ is_bf16,
782
+ ),
783
+ UnaryAttr("hardswish", [], ""): _may_generate_pattern_with_dtype_convert(
784
+ _unary_fusion_pattern(
785
+ _hardswish_fusion,
786
+ get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
787
+ 2,
788
+ is_bf16,
789
+ ),
790
+ Arg(),
791
+ is_bf16,
792
+ ),
793
+ UnaryAttr("swish", [], ""): _may_generate_pattern_with_dtype_convert(
794
+ _unary_fusion_pattern(
795
+ _silu_fusion,
796
+ get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
797
+ 2,
798
+ is_bf16,
799
+ ),
800
+ Arg(),
801
+ is_bf16,
802
+ ),
803
+ }
804
+
805
+ for unary_attr, patterns in conv_unary_replace_float_out_patterns.items():
806
+ # Register qconv2d pattern for ExternKernel Lowering
807
+ _register_quantized_conv_lowering(
808
+ patterns,
809
+ 2, # pass_number
810
+ torch.ops.onednn.qconv2d_pointwise, # computation_op
811
+ unary_attr, # unary_attr
812
+ )
813
+
814
+ # QLinear
815
+ for x_scale_zp_are_tensors in (False, True):
816
+ qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors)
817
+ # Priority 1 to match: QLinear Unary pattern with int8 output
818
+ linear_unary_replace_patterns = {
819
+ UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
820
+ qlinear_pattern,
821
+ ),
822
+ UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
823
+ generate_pattern_with_unary(qlinear_pattern, aten.relu.default),
824
+ ),
825
+ UnaryAttr("gelu", [], "none"): generate_pattern_with_output_quant(
826
+ _unary_fusion_pattern(
827
+ _gelu_fusion_erf,
828
+ get_qlinear_pt2e_pattern(
829
+ x_scale_zp_are_tensors, 1 if is_bf16 else 2
830
+ ),
831
+ 2,
832
+ is_bf16,
833
+ ),
834
+ with_dtype_convert=is_bf16,
835
+ ),
836
+ UnaryAttr("gelu", [], "tanh"): generate_pattern_with_output_quant(
837
+ _unary_fusion_pattern(
838
+ _gelu_fusion_tanh,
839
+ get_qlinear_pt2e_pattern(
840
+ x_scale_zp_are_tensors, 1 if is_bf16 else 4
841
+ ),
842
+ 4,
843
+ is_bf16,
844
+ ),
845
+ with_dtype_convert=is_bf16,
846
+ ),
847
+ }
848
+
849
+ for unary_attr, patterns in linear_unary_replace_patterns.items():
850
+ _register_quantized_linear_lowering(
851
+ patterns,
852
+ 1, # pass_number
853
+ torch.ops.onednn.qlinear_pointwise, # computation_op
854
+ unary_attr, # unary_attr
855
+ )
856
+
857
+ # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
858
+ linear_unary_replace_float_out_patterns = {
859
+ UnaryAttr("relu", [], ""): generate_pattern_with_unary(
860
+ qlinear_pattern, aten.relu.default
861
+ ),
862
+ UnaryAttr("gelu", [], "none"): _may_generate_pattern_with_dtype_convert(
863
+ _unary_fusion_pattern(
864
+ _gelu_fusion_erf,
865
+ get_qlinear_pt2e_pattern(
866
+ x_scale_zp_are_tensors, 1 if is_bf16 else 2
867
+ ),
868
+ 2,
869
+ is_bf16,
870
+ ),
871
+ Arg(),
872
+ is_bf16,
873
+ ),
874
+ UnaryAttr("gelu", [], "tanh"): _may_generate_pattern_with_dtype_convert(
875
+ _unary_fusion_pattern(
876
+ _gelu_fusion_tanh,
877
+ get_qlinear_pt2e_pattern(
878
+ x_scale_zp_are_tensors, 1 if is_bf16 else 4
879
+ ),
880
+ 4,
881
+ is_bf16,
882
+ ),
883
+ Arg(),
884
+ is_bf16,
885
+ ),
886
+ }
887
+
888
+ for unary_attr, patterns in linear_unary_replace_float_out_patterns.items():
889
+ _register_quantized_linear_lowering(
890
+ patterns,
891
+ 2, # pass_number
892
+ torch.ops.onednn.qlinear_pointwise, # computation_op
893
+ unary_attr, # unary_attr
894
+ )
895
+
896
+
897
+ def _register_quantization_binary_fusion():
898
+ class BinaryUnaryAttr:
899
+ def __init__(
900
+ self,
901
+ binary_op_name: str,
902
+ alpha=None,
903
+ unary_op_name: str = "none",
904
+ scalars_attr=None,
905
+ algorithm_attr=None,
906
+ ) -> None:
907
+ self.binary_op_name = binary_op_name
908
+ self.alpha = alpha if alpha else 1.0
909
+ self.unary_op_name = unary_op_name
910
+ self.scalars_attr = scalars_attr if scalars_attr else []
911
+ self.algorithm_attr = algorithm_attr if algorithm_attr else ""
912
+
913
+ for int8_mixed_bf16_with_inplace_add in [False, True]:
914
+ # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output
915
+ binary_replace_patterns = {
916
+ BinaryUnaryAttr(
917
+ "sum", 1.0, "none", [], ""
918
+ ): generate_pattern_with_output_quant(
919
+ generate_pattern_with_binary(
920
+ aten.add.Tensor,
921
+ get_dequantize_qconv_pt2e_pattern(1),
922
+ dequantize_accum_pattern,
923
+ int8_mixed_bf16_with_inplace_add,
924
+ ),
925
+ ),
926
+ BinaryUnaryAttr(
927
+ "sum", 1.0, "relu", [], ""
928
+ ): generate_pattern_with_output_quant(
929
+ generate_pattern_with_unary(
930
+ generate_pattern_with_binary(
931
+ aten.add.Tensor,
932
+ get_dequantize_qconv_pt2e_pattern(1),
933
+ dequantize_accum_pattern,
934
+ int8_mixed_bf16_with_inplace_add,
935
+ ),
936
+ aten.relu.default,
937
+ ),
938
+ ),
939
+ }
940
+
941
+ for binary_unary_attr, patterns in binary_replace_patterns.items():
942
+ _register_quantized_conv_binary_lowering(
943
+ patterns,
944
+ 0, # pass_number
945
+ torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
946
+ binary_unary_attr, # binary_unary_attr
947
+ )
948
+
949
+ # Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output
950
+ binary_replace_float_out_patterns = {
951
+ BinaryUnaryAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary(
952
+ generate_pattern_with_binary(
953
+ aten.add.Tensor,
954
+ get_dequantize_qconv_pt2e_pattern(1),
955
+ KeywordArg("accum_after_dequant"),
956
+ int8_mixed_bf16_with_inplace_add,
957
+ ),
958
+ aten.relu.default,
959
+ ),
960
+ }
961
+
962
+ for (
963
+ binary_unary_attr,
964
+ patterns,
965
+ ) in binary_replace_float_out_patterns.items():
966
+ if int8_mixed_bf16_with_inplace_add:
967
+ _register_quantized_conv_binary_lowering(
968
+ patterns,
969
+ 0, # pass_number
970
+ torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
971
+ binary_unary_attr, # binary_unary_attr
972
+ )
973
+ else:
974
+ _register_quantized_conv_binary_lowering(
975
+ patterns,
976
+ 1, # pass_number
977
+ torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
978
+ binary_unary_attr, # binary_unary_attr
979
+ )
980
+
981
+ # Priority 3: QConv2d Binary pattern with fp32/bfloat16 output
982
+ binary_replace_float_out_patterns = {
983
+ BinaryUnaryAttr("sum", 1.0, "none", [], ""): generate_pattern_with_binary(
984
+ aten.add.Tensor,
985
+ get_dequantize_qconv_pt2e_pattern(1),
986
+ KeywordArg("accum_after_dequant"),
987
+ int8_mixed_bf16_with_inplace_add,
988
+ ),
989
+ }
990
+
991
+ for (
992
+ binary_unary_attr,
993
+ patterns,
994
+ ) in binary_replace_float_out_patterns.items():
995
+ _register_quantized_conv_binary_lowering(
996
+ patterns,
997
+ 1 if int8_mixed_bf16_with_inplace_add else 2, # pass_number
998
+ torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
999
+ binary_unary_attr, # binary_unary_attr
1000
+ )
1001
+
1002
+ # QLinear
1003
+ r"""
1004
+ Supported linear-binary(-unary) patterns
1005
+
1006
+ linear(X) extra input
1007
+ \ /
1008
+ Add
1009
+ |
1010
+ Optional(relu)
1011
+ |
1012
+ Y
1013
+
1014
+ 1. int8-mixed-fp32
1015
+ +---+---------------+-----------+------------------------------+---------+
1016
+ | # | Add type | Quant out | Pattern | Post op |
1017
+ +---+---------------+-----------+------------------------------+---------+
1018
+ | 1 | In-/out-place | Yes | linear + fp32 -> (relu) -> q | add |
1019
+ +---+---------------+-----------+------------------------------+---------+
1020
+ | 2 | In-/out-place | No | linear + fp32 -> (relu) | sum |
1021
+ +---+---------------+-----------+------------------------------+---------+
1022
+
1023
+ 2. int8-mixed-bf16
1024
+ +---+----------+---------------+-----------+-----------------------------------------+---------+
1025
+ | # | X2 dtype | Add type | Quant out | Pattern | Post op |
1026
+ +---+----------+---------------+-----------+-----------------------------------------+---------+
1027
+ | 1 | BF16 | In-/out-place | Yes | linear + bf16 -> (relu) -> q | add |
1028
+ +---+----------+---------------+-----------+-----------------------------------------+---------+
1029
+ | 2 | BF16 | In-/out-place | No | linear + bf16 -> (relu) | sum |
1030
+ +---+----------+---------------+-----------+-----------------------------------------+---------+
1031
+ | 3 | FP32 | Out-place | Yes | linear + fp32 -> (relu) -> q | add |
1032
+ | | | In-place right| | | |
1033
+ +---+----------+---------------+-----------+-----------------------------------------+---------+
1034
+ | 4 | FP32 | Out-place | No | linear + fp32 -> (relu) | sum |
1035
+ | | | In-place right| | | |
1036
+ +---+----------+---------------+-----------+-----------------------------------------+---------+
1037
+ | 5 | FP32 | In-place left | Yes | linear + fp32 -> to_bf16 -> (relu) -> q | add |
1038
+ +---+----------+---------------+-----------+-----------------------------------------+---------+
1039
+ | 6 | FP32 | In-place left | No | linear + fp32 -> to_bf16 -> (relu) | add |
1040
+ +---+----------+---------------+-----------+-----------------------------------------+---------+
1041
+
1042
+ Note
1043
+ (1) The positions of linear and the extra input can be swapped.
1044
+ (2) we don't insert q-dq before the extra input of linear-add by recipe. But if q-dq is found at the
1045
+ extra input, we don't match that pattern because we cannot match all these patterns in 3 passes.
1046
+ """
1047
+ for x_scale_zp_are_tensors in (False, True):
1048
+ qlinear_binary_op = (
1049
+ torch.ops.onednn.qlinear_pointwise.binary_tensor
1050
+ if x_scale_zp_are_tensors
1051
+ else torch.ops.onednn.qlinear_pointwise.binary
1052
+ )
1053
+ unary_postop_list = ["none", "relu"]
1054
+ unary_postop_dict = {
1055
+ "none": None,
1056
+ "relu": aten.relu.default,
1057
+ }
1058
+ convert_dtype_after_binary_list = [False, True]
1059
+
1060
+ # Priority 1 to match: QLinear Binary or Binary-Unary pattern with int8 output
1061
+ # Covers case (1) of int8-mixed-fp32 and case (1)(3)(5) of int8-mixed-bf16,
1062
+ # totally 3 patterns (2 are identical)
1063
+ swap_binary_inputs_list = [False, True]
1064
+ int8_mixed_bf16_list = [False, True]
1065
+ combinations = itertools.product(
1066
+ unary_postop_list,
1067
+ int8_mixed_bf16_list,
1068
+ swap_binary_inputs_list,
1069
+ convert_dtype_after_binary_list,
1070
+ )
1071
+ qlinear_binary_replace_patterns = {}
1072
+ for unary_op, int8_mixed_bf16, swap_inputs, cvt_dtype_binary in combinations:
1073
+ if not int8_mixed_bf16 and cvt_dtype_binary:
1074
+ # No convert node after binary node if dtypes are all fp32
1075
+ continue
1076
+ qlinear_binary_replace_patterns.update(
1077
+ {
1078
+ BinaryUnaryAttr(
1079
+ "add", 1.0, unary_op, [], ""
1080
+ ): generate_pattern_with_output_quant(
1081
+ generate_pattern_with_unary(
1082
+ generate_pattern_with_binary(
1083
+ aten.add.Tensor,
1084
+ get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
1085
+ KeywordArg("other"),
1086
+ # If fp32 extra input is inplace added to bf16 linear output,
1087
+ # a to_bf16 node is inserted after binary
1088
+ dtype_convert=cvt_dtype_binary,
1089
+ swap_inputs=swap_inputs,
1090
+ ),
1091
+ unary_postop_dict[unary_op],
1092
+ ),
1093
+ )
1094
+ }
1095
+ )
1096
+ for binary_unary_attr, patterns in qlinear_binary_replace_patterns.items():
1097
+ _register_quantized_linear_binary_lowering(
1098
+ patterns,
1099
+ 0, # pass_number
1100
+ qlinear_binary_op, # computation_op
1101
+ binary_unary_attr, # binary_unary_attr
1102
+ )
1103
+
1104
+ # Priority 2.1 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output
1105
+ # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16,
1106
+ # totally 2 patterns (2 are identical)
1107
+ binary_replace_float_out_patterns = {}
1108
+ for swap_binary_inputs in swap_binary_inputs_list:
1109
+ binary_replace_float_out_patterns.update(
1110
+ {
1111
+ BinaryUnaryAttr(
1112
+ "sum", 1.0, "relu", [], ""
1113
+ ): generate_pattern_with_unary(
1114
+ generate_pattern_with_binary(
1115
+ aten.add.Tensor,
1116
+ get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
1117
+ KeywordArg("accum"),
1118
+ dtype_convert=False,
1119
+ swap_inputs=swap_binary_inputs,
1120
+ ),
1121
+ aten.relu.default,
1122
+ ),
1123
+ }
1124
+ )
1125
+ for (
1126
+ binary_unary_attr,
1127
+ patterns,
1128
+ ) in binary_replace_float_out_patterns.items():
1129
+ _register_quantized_linear_binary_lowering(
1130
+ patterns,
1131
+ 1, # pass_number
1132
+ qlinear_binary_op, # computation_op
1133
+ binary_unary_attr,
1134
+ )
1135
+ # Priority 2.2 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output
1136
+ # Covers case (6) of int8-mixed-bf16
1137
+ binary_replace_float_out_patterns = {}
1138
+ for swap_binary_inputs in swap_binary_inputs_list:
1139
+ binary_replace_float_out_patterns.update(
1140
+ {
1141
+ BinaryUnaryAttr(
1142
+ "add", 1.0, "relu", [], ""
1143
+ ): generate_pattern_with_unary(
1144
+ generate_pattern_with_binary(
1145
+ aten.add.Tensor,
1146
+ get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
1147
+ KeywordArg("other"),
1148
+ dtype_convert=True,
1149
+ swap_inputs=swap_binary_inputs,
1150
+ ),
1151
+ aten.relu.default,
1152
+ ),
1153
+ }
1154
+ )
1155
+ for (
1156
+ binary_unary_attr,
1157
+ patterns,
1158
+ ) in binary_replace_float_out_patterns.items():
1159
+ _register_quantized_linear_binary_lowering(
1160
+ patterns,
1161
+ 1, # pass_number
1162
+ qlinear_binary_op, # computation_op
1163
+ binary_unary_attr,
1164
+ )
1165
+
1166
+ # Priority 3.1: QLinear Binary pattern with fp32/bfloat16 output
1167
+ # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16,
1168
+ # totally 2 patterns (2 are identical)
1169
+ binary_replace_float_out_patterns = {}
1170
+ for swap_binary_inputs in swap_binary_inputs_list:
1171
+ binary_replace_float_out_patterns.update(
1172
+ {
1173
+ BinaryUnaryAttr(
1174
+ "sum", 1.0, "none", [], ""
1175
+ ): generate_pattern_with_binary(
1176
+ aten.add.Tensor,
1177
+ get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
1178
+ KeywordArg("accum"),
1179
+ dtype_convert=False,
1180
+ swap_inputs=swap_binary_inputs,
1181
+ ),
1182
+ }
1183
+ )
1184
+ for (
1185
+ binary_unary_attr,
1186
+ patterns,
1187
+ ) in binary_replace_float_out_patterns.items():
1188
+ _register_quantized_linear_binary_lowering(
1189
+ patterns,
1190
+ 2, # pass_number
1191
+ qlinear_binary_op, # computation_op
1192
+ binary_unary_attr,
1193
+ )
1194
+ # Priority 3.2: QLinear Binary pattern with fp32/bfloat16 output
1195
+ # Covers (6) of int8-mixed-bf16
1196
+ binary_replace_float_out_patterns = {}
1197
+ for swap_binary_inputs in swap_binary_inputs_list:
1198
+ binary_replace_float_out_patterns.update(
1199
+ {
1200
+ BinaryUnaryAttr(
1201
+ "add", 1.0, "none", [], ""
1202
+ ): generate_pattern_with_binary(
1203
+ aten.add.Tensor,
1204
+ get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
1205
+ KeywordArg("other"),
1206
+ dtype_convert=True,
1207
+ swap_inputs=swap_binary_inputs,
1208
+ ),
1209
+ }
1210
+ )
1211
+ for (
1212
+ binary_unary_attr,
1213
+ patterns,
1214
+ ) in binary_replace_float_out_patterns.items():
1215
+ _register_quantized_linear_binary_lowering(
1216
+ patterns,
1217
+ 2, # pass_number
1218
+ qlinear_binary_op, # computation_op
1219
+ binary_unary_attr,
1220
+ )
1221
+
1222
+
1223
+ def _is_valid_quantized_maxpool2d_optimization_pattern():
1224
+ def fn(match):
1225
+ # Only match the pattern which max_pool2d_with_indices returns value
1226
+ # instead of indices.
1227
+ get_item_node = filter_nodes(match.nodes, operator.getitem)[0]
1228
+ return get_item_node.args[1] == 0
1229
+
1230
+ return fn
1231
+
1232
+
1233
+ def _register_quantized_maxpool2d_lowering(
1234
+ pattern,
1235
+ computation_op,
1236
+ ):
1237
+ @register_lowering_pattern(
1238
+ pattern,
1239
+ extra_check=_is_valid_quantized_maxpool2d_optimization_pattern(),
1240
+ )
1241
+ def qmaxpool2d(match: Match, *args, **kwargs):
1242
+ x = kwargs["x"]
1243
+ kernel_size = kwargs["kernel_size"]
1244
+ stride = kwargs["stride"] if ("stride" in kwargs) else None
1245
+ padding = kwargs["padding"] if ("padding" in kwargs) else 0
1246
+ dilation = kwargs["dilation"] if ("dilation" in kwargs) else 1
1247
+ ceil_mode = kwargs["ceil_mode"] if ("ceil_mode" in kwargs) else False
1248
+
1249
+ if padding == 0:
1250
+ padding = [0, 0]
1251
+ if dilation == 1:
1252
+ dilation = [1, 1]
1253
+ if not stride:
1254
+ stride = kernel_size
1255
+ kernel_size = pad_listlike(kernel_size, 2)
1256
+ stride = pad_listlike(stride, 2)
1257
+ padding = pad_listlike(padding, 2)
1258
+ dilation = pad_listlike(dilation, 2)
1259
+
1260
+ assert len(kernel_size) == 2
1261
+ assert len(stride) == 2
1262
+ assert len(padding) == 2
1263
+ assert len(dilation) == 2
1264
+
1265
+ computation_args = (
1266
+ x,
1267
+ kernel_size,
1268
+ stride,
1269
+ padding,
1270
+ dilation,
1271
+ ceil_mode,
1272
+ )
1273
+ computation_args, _ = require_channels_last(computation_op, *computation_args)
1274
+ counters["inductor"]["qmaxpool2d_matcher_count"] += 1
1275
+ counters["inductor"]["qmaxpool2d_matcher_nodes"] += len(match.nodes)
1276
+ return L[computation_op](*computation_args)
1277
+
1278
+ return qmaxpool2d
1279
+
1280
+
1281
+ def _register_quantization_maxpool2d():
1282
+ # Currently, the default parameters are not in FX Graph generated by Dynamo export.
1283
+ # So, if user defines nn.MaxPool2d with different assignment of default parameter,
1284
+ # it will generate graph with different number of input nodes and hence
1285
+ # different pattern to be matched.
1286
+ # Refer to the issue: https://github.com/pytorch/pytorch/issues/105901
1287
+ max_pool2d_args_list = [
1288
+ [
1289
+ KeywordArg("stride"),
1290
+ ],
1291
+ [
1292
+ KeywordArg("stride"),
1293
+ KeywordArg("padding"),
1294
+ ],
1295
+ [
1296
+ KeywordArg("stride"),
1297
+ KeywordArg("padding"),
1298
+ KeywordArg("dilation"),
1299
+ ],
1300
+ [
1301
+ KeywordArg("stride"),
1302
+ KeywordArg("padding"),
1303
+ KeywordArg("dilation"),
1304
+ KeywordArg("ceil_mode"),
1305
+ ],
1306
+ ]
1307
+ for max_pool2d_args in max_pool2d_args_list:
1308
+ dequantize_maxpool2d_pattern = CallFunction(
1309
+ aten.max_pool2d_with_indices.default,
1310
+ get_dequantize_per_tensor_activation_pattern(),
1311
+ KeywordArg("kernel_size"),
1312
+ *max_pool2d_args,
1313
+ )
1314
+ dequantize_lowmem_maxpool2d_pattern = CallFunction(
1315
+ prims._low_memory_max_pool2d_with_offsets.default,
1316
+ get_dequantize_per_tensor_activation_pattern(),
1317
+ KeywordArg("kernel_size"),
1318
+ *max_pool2d_args,
1319
+ KeywordArg("offset_dtype"),
1320
+ )
1321
+ dequantize_maxpool2d_get_item_pattern = CallFunction(
1322
+ operator.getitem,
1323
+ dequantize_maxpool2d_pattern,
1324
+ Arg(),
1325
+ )
1326
+ dequantize_lowmem_maxpool2d_get_item_pattern = CallFunction(
1327
+ operator.getitem,
1328
+ dequantize_lowmem_maxpool2d_pattern,
1329
+ Arg(),
1330
+ )
1331
+ _register_quantized_maxpool2d_lowering(
1332
+ generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern),
1333
+ quantized.max_pool2d.default,
1334
+ )
1335
+ _register_quantized_maxpool2d_lowering(
1336
+ generate_pattern_with_output_quant(
1337
+ dequantize_lowmem_maxpool2d_get_item_pattern
1338
+ ),
1339
+ quantized.max_pool2d.default,
1340
+ )
1341
+
1342
+
1343
+ def _is_input_output_same_scale_zp(check_node):
1344
+ def fn(match):
1345
+ # Ensure all the inputs and output has same scale and zero point
1346
+ # Step 1: Check inputs/output zero point
1347
+ # Get dequant nodes at input
1348
+ dequant_nodes = filter_nodes(
1349
+ match.nodes, quantized_decomposed.dequantize_per_tensor.default
1350
+ )
1351
+ zero_points = [node.args[2] for node in dequant_nodes]
1352
+ # Get quant nodes at output
1353
+ quant_nodes = filter_nodes(
1354
+ match.nodes, quantized_decomposed.quantize_per_tensor.default
1355
+ )
1356
+ assert len(quant_nodes) == 1, "expect only 1 add node at output quant pattern"
1357
+ zero_points.append(quant_nodes[0].args[2])
1358
+ if not all(zero_point == zero_points[0] for zero_point in zero_points):
1359
+ return False
1360
+
1361
+ # Step 2: Check inputs/output scale
1362
+ scales = [node.args[1] for node in dequant_nodes]
1363
+ scales.append(quant_nodes[0].args[1])
1364
+ if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales): # type: ignore[arg-type]
1365
+ return False
1366
+
1367
+ return True
1368
+
1369
+ return fn
1370
+
1371
+
1372
+ def _register_quantized_cat_lowering(
1373
+ pattern,
1374
+ computation_op,
1375
+ ):
1376
+ @register_lowering_pattern(
1377
+ pattern,
1378
+ extra_check=_is_input_output_same_scale_zp(aten.cat.default),
1379
+ )
1380
+ def qcat(match: Match, inputs, dim, **kwargs):
1381
+ # inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...]
1382
+ uint8_inputs = [input[0] for input in inputs]
1383
+ counters["inductor"]["qcat_matcher_count"] += 1
1384
+ counters["inductor"]["qcat_matcher_nodes"] += len(match.nodes)
1385
+ return L[computation_op](uint8_inputs, dim)
1386
+
1387
+ return qcat
1388
+
1389
+
1390
+ _raw_dequantize_per_tensor_activation_pattern = CallFunction(
1391
+ quantized_decomposed.dequantize_per_tensor.default,
1392
+ Arg(),
1393
+ Arg(),
1394
+ Arg(),
1395
+ Arg(),
1396
+ Arg(),
1397
+ Arg(),
1398
+ )
1399
+
1400
+
1401
+ def _register_quantization_cat():
1402
+ dequantize_cat_pattern = CallFunction(
1403
+ aten.cat.default,
1404
+ ListOf(_raw_dequantize_per_tensor_activation_pattern),
1405
+ KeywordArg("dim"),
1406
+ )
1407
+ _register_quantized_cat_lowering(
1408
+ generate_pattern_with_output_quant(dequantize_cat_pattern),
1409
+ aten.cat,
1410
+ )
1411
+
1412
+
1413
+ def _register_quantized_reshape_lowering(
1414
+ pattern,
1415
+ computation_op,
1416
+ ):
1417
+ @register_lowering_pattern(
1418
+ pattern,
1419
+ extra_check=_is_input_output_same_scale_zp(aten.reshape.default),
1420
+ )
1421
+ def qreshape(match: Match, *args, **kwargs):
1422
+ qx = kwargs["x"]
1423
+ shape = kwargs["shape"]
1424
+ counters["inductor"]["qreshape_matcher_count"] += 1
1425
+ counters["inductor"]["qreshape_matcher_nodes"] += len(match.nodes)
1426
+ return L[computation_op](qx, shape)
1427
+
1428
+ return qreshape
1429
+
1430
+
1431
+ def _register_quantization_reshape():
1432
+ dequantize_reshape_pattern = CallFunction(
1433
+ torch.ops.aten.reshape.default,
1434
+ get_dequantize_per_tensor_activation_pattern(),
1435
+ KeywordArg("shape"),
1436
+ )
1437
+ _register_quantized_reshape_lowering(
1438
+ generate_pattern_with_output_quant(dequantize_reshape_pattern),
1439
+ aten.reshape,
1440
+ )
1441
+
1442
+
1443
+ def _is_valid_woq_optimization_pattern():
1444
+ def fn(match):
1445
+ assert all(k in match.kwargs for k in ("x", "weight", "scales"))
1446
+ x = match.kwargs["x"].meta["val"]
1447
+ weight = match.kwargs["weight"].meta["val"]
1448
+ scales = match.kwargs["scales"].meta["val"]
1449
+ return (
1450
+ # For now, we only support woq mm kernels
1451
+ # with x.type=bfloat16 and w.type=int8
1452
+ x.dtype == torch.bfloat16
1453
+ and weight.dtype == torch.int8
1454
+ and scales.dtype == torch.bfloat16
1455
+ # _weight_int8pack_mm kernel only supports cpu now
1456
+ # TODO: add cuda kernel support instead of calling mul+sum
1457
+ and x.device.type == "cpu"
1458
+ and x.device == weight.device
1459
+ and x.device == scales.device
1460
+ )
1461
+
1462
+ return fn
1463
+
1464
+
1465
+ def _register_woq_lowering(pattern, computation_woq, computation_reshape):
1466
+ @register_lowering_pattern(
1467
+ pattern,
1468
+ extra_check=_is_valid_woq_optimization_pattern(),
1469
+ )
1470
+ def woq(match: Match, *args, **kwargs):
1471
+ x = kwargs["x"]
1472
+ weight = kwargs["weight"]
1473
+ scales = kwargs["scales"]
1474
+ counters["inductor"]["woq_matcher_count"] += 1
1475
+ counters["inductor"]["woq_matcher_nodes"] += len(match.nodes)
1476
+ out_features = weight.get_size()[0]
1477
+ origin_x_size = x.get_size()
1478
+ x_shape = [-1, origin_x_size[-1]]
1479
+ out_shape = origin_x_size[:-1] + [
1480
+ out_features,
1481
+ ]
1482
+ func1 = L[computation_reshape](x, x_shape)
1483
+ func2 = L[computation_woq](func1, weight, scales)
1484
+ return L[computation_reshape](func2, out_shape)
1485
+
1486
+ return woq
1487
+
1488
+
1489
+ def _register_woq_mm_int8_pattern1():
1490
+ # F.linear(x, weight.to(dtype=x.dtype)) * scales
1491
+ # case of dispatching to mm, with x reshape
1492
+ _woq_pattern = CallFunction(
1493
+ aten.mul.Tensor,
1494
+ CallFunction(
1495
+ aten.reshape.default,
1496
+ CallFunction(
1497
+ aten.mm.default,
1498
+ CallFunction(aten.reshape.default, KeywordArg("x"), Arg()),
1499
+ CallFunction(
1500
+ aten.permute.default,
1501
+ CallFunction(
1502
+ prims.convert_element_type.default, KeywordArg("weight"), Arg()
1503
+ ),
1504
+ Arg(),
1505
+ ),
1506
+ ),
1507
+ Arg(),
1508
+ ),
1509
+ KeywordArg("scales"),
1510
+ )
1511
+ _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
1512
+
1513
+
1514
+ def _register_woq_mm_int8_pattern2():
1515
+ # F.linear(x, weight.to(dtype=x.dtype)) * scales
1516
+ # case of dispatching to mm, w/o x reshape
1517
+ _woq_pattern = CallFunction(
1518
+ aten.mul.Tensor,
1519
+ CallFunction(
1520
+ aten.reshape.default,
1521
+ CallFunction(
1522
+ aten.mm.default,
1523
+ KeywordArg("x"),
1524
+ CallFunction(
1525
+ aten.permute.default,
1526
+ CallFunction(
1527
+ prims.convert_element_type.default, KeywordArg("weight"), Arg()
1528
+ ),
1529
+ Arg(),
1530
+ ),
1531
+ ),
1532
+ Arg(),
1533
+ ),
1534
+ KeywordArg("scales"),
1535
+ )
1536
+ _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
1537
+
1538
+
1539
+ def _register_woq_mm_int8_pattern3():
1540
+ # F.linear(x, weight.to(dtype=x.dtype)) * scales
1541
+ # case of dispatching to bmm
1542
+ _woq_pattern = CallFunction(
1543
+ aten.mul.Tensor,
1544
+ CallFunction(
1545
+ aten.bmm.default,
1546
+ CallFunction(aten.expand.default, KeywordArg("x"), Arg()),
1547
+ CallFunction(
1548
+ aten.expand.default,
1549
+ CallFunction(
1550
+ aten.permute.default,
1551
+ CallFunction(
1552
+ prims.convert_element_type.default, KeywordArg("weight"), Arg()
1553
+ ),
1554
+ Arg(),
1555
+ ),
1556
+ Arg(),
1557
+ ),
1558
+ ),
1559
+ KeywordArg("scales"),
1560
+ )
1561
+ _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
1562
+
1563
+
1564
+ def _register_quantization_lowerings():
1565
+ _register_quantization_unary_fusion()
1566
+ _register_quantization_binary_fusion()
1567
+ _register_quantization_maxpool2d()
1568
+ _register_quantization_cat()
1569
+ _register_quantization_reshape()
1570
+
1571
+
1572
+ def _register_woq_lowerings():
1573
+ _register_woq_mm_int8_pattern1()
1574
+ _register_woq_mm_int8_pattern2()
1575
+ _register_woq_mm_int8_pattern3()
1576
+
1577
+
1578
+ def _is_valid_dequant_promotion_pattern(dtype=torch.float32):
1579
+ def _inner(match):
1580
+ assert dtype in [torch.float32, torch.bfloat16]
1581
+ dequant_pattern_end_node = match.output_node()
1582
+ if dequant_pattern_end_node.target not in [
1583
+ quantized_decomposed.dequantize_per_tensor.default,
1584
+ quantized_decomposed.dequantize_per_tensor.tensor,
1585
+ prims.convert_element_type.default,
1586
+ aten.reshape.default,
1587
+ ]:
1588
+ return False
1589
+
1590
+ if dequant_pattern_end_node.target is aten.reshape.default:
1591
+ dequant_node = (
1592
+ dequant_pattern_end_node.args[
1593
+ 0
1594
+ ] # pattern: linear <- reshape <- dequant
1595
+ if dtype == torch.float32
1596
+ else dequant_pattern_end_node.args[0].args[
1597
+ 0
1598
+ ] # pattern: linear <- reshape <- to_bf16 <- dequant
1599
+ )
1600
+ else:
1601
+ dequant_node = (
1602
+ dequant_pattern_end_node # pattern: linear <- dequant
1603
+ if dtype == torch.float32
1604
+ else dequant_pattern_end_node.args[
1605
+ 0
1606
+ ] # pattern: linear <- to_bf16 <- dequant
1607
+ )
1608
+
1609
+ if (
1610
+ dequant_node.target
1611
+ in [
1612
+ quantized_decomposed.dequantize_per_tensor.default,
1613
+ quantized_decomposed.dequantize_per_tensor.tensor,
1614
+ ]
1615
+ and len(list(dequant_pattern_end_node.users)) > 1
1616
+ ):
1617
+ # If dequant pattern has more than 1 users, then do dequant promoted
1618
+ return True
1619
+ return False
1620
+
1621
+ return _inner
1622
+
1623
+
1624
+ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
1625
+ @register_freezing_graph_pattern(
1626
+ pattern,
1627
+ extra_check=_is_valid_dequant_promotion_pattern(dtype),
1628
+ pass_number=pass_number,
1629
+ )
1630
+ def dequant_promotion(match: Match, *args, **kwargs):
1631
+ # Dequant_promotion will transform
1632
+ # graph 1:
1633
+ # quant
1634
+ # + - - - | - - - +
1635
+ # | dequant |
1636
+ # | / \ |
1637
+ # | node1 node2 |
1638
+ # + - | - - - | - +
1639
+ # quant quant
1640
+ # into:
1641
+ # graph 2:
1642
+ # quant
1643
+ # + - - / - \ - - +
1644
+ # |dequant dequant|
1645
+ # | | | |
1646
+ # | node1 node2 |
1647
+ # + - | - - - | - +
1648
+ # quant quant
1649
+ # In graph 1, the dequant node is shared by node1 and node2,
1650
+ # as a result, neither node1 nor node2 could form an int8
1651
+ # fusion pattern.
1652
+ # After this transformation, the graph 2 could hit the int8
1653
+ # fusion pattern: dequant-node-quant, respectively for
1654
+ # node1 and node2.
1655
+ assert dtype in [torch.float32, torch.bfloat16]
1656
+
1657
+ def clone_to_new_node(graph, source_node, user_node):
1658
+ # Clone the source_node to a new node
1659
+ # Replace user_node's input from source_node to new_node
1660
+ assert (
1661
+ source_node.op == "call_function"
1662
+ ), "clone_to_new_node only support node.op call_function"
1663
+ with graph.inserting_before(user_node):
1664
+ new_node = graph.call_function(
1665
+ source_node.target,
1666
+ args=source_node.args,
1667
+ kwargs=source_node.kwargs,
1668
+ )
1669
+ new_node.meta = copy.copy(source_node.meta)
1670
+ user_node.replace_input_with(source_node, new_node)
1671
+ return new_node
1672
+
1673
+ # Find the start node and end node of a dequant pattern
1674
+ # * End node should be the match.output_node()
1675
+ # * Start node should be the node of dequantize_per_tensor
1676
+ dequant_pattern_end_node = match.output_node()
1677
+ assert dequant_pattern_end_node.target in [
1678
+ quantized_decomposed.dequantize_per_tensor.default,
1679
+ quantized_decomposed.dequantize_per_tensor.tensor,
1680
+ prims.convert_element_type.default,
1681
+ aten.reshape.default,
1682
+ ]
1683
+
1684
+ # For a dequant pattern, we should expect see the node list as:
1685
+ # * OPT(aten.reshape.default)
1686
+ # * OPT(prims.convert_element_type.default) (to_bf16)
1687
+ # * dequantize_per_tensor
1688
+ def _find_first_node_in_dequant_pattern(_node):
1689
+ if _node.target in [
1690
+ quantized_decomposed.dequantize_per_tensor.default,
1691
+ quantized_decomposed.dequantize_per_tensor.tensor,
1692
+ ]:
1693
+ # For a dequant pattern, we expect the start node is a dequantize_per_tensor node
1694
+ return _node
1695
+ else:
1696
+ assert (
1697
+ len(_node.args) >= 1
1698
+ ), "In in dequant pattern, each node should have more than 1 arg."
1699
+ return _find_first_node_in_dequant_pattern(_node.args[0])
1700
+
1701
+ dequant_pattern_start_node = _find_first_node_in_dequant_pattern(
1702
+ dequant_pattern_end_node
1703
+ )
1704
+
1705
+ assert dequant_pattern_start_node.target in [
1706
+ quantized_decomposed.dequantize_per_tensor.default,
1707
+ quantized_decomposed.dequantize_per_tensor.tensor,
1708
+ ]
1709
+
1710
+ # Clone the dequant pattern for each user node
1711
+ graph = match.graph
1712
+ user_node_list = list(dequant_pattern_end_node.users)
1713
+ for user_node in user_node_list[1:]:
1714
+ _source_node = dequant_pattern_end_node
1715
+ _user_node = user_node
1716
+ while _source_node != dequant_pattern_start_node.args[0]:
1717
+ _user_node = clone_to_new_node(graph, _source_node, _user_node)
1718
+ _source_node = _source_node.args[0] # type: ignore[assignment]
1719
+
1720
+ counters["inductor"]["dequant_promotion_matcher_count"] += 1
1721
+ counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes)
1722
+
1723
+
1724
+ def _is_valid_dequant_conv2d_pattern(dtype):
1725
+ def _inner(match):
1726
+ # Here we do some further check to ensure:
1727
+ # 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now.
1728
+ # 2. The dequant pattern has only 1 user of conv2d node.
1729
+ # If these conditions don't meet, we will not
1730
+ # insert weight prepack node into the matched pattern.
1731
+ conv_node = match.output_node()
1732
+ assert conv_node.target is aten.convolution.default
1733
+ input_meta_value = conv_node.args[0].meta.get("val")
1734
+ weight_meta_value = conv_node.args[1].meta.get("val")
1735
+ for meta_value in [input_meta_value, weight_meta_value]:
1736
+ if (
1737
+ meta_value is None
1738
+ or meta_value.device.type != "cpu"
1739
+ or meta_value.dim() != 4
1740
+ ):
1741
+ # Only support conv2d now
1742
+ return False
1743
+
1744
+ assert dtype in [torch.float32, torch.bfloat16]
1745
+
1746
+ if dtype == torch.float32:
1747
+ dequant_node = conv_node.args[0]
1748
+ else:
1749
+ convert_to_bf16 = conv_node.args[0]
1750
+ dequant_node = convert_to_bf16.args[0]
1751
+
1752
+ if len(list(dequant_node.users)) != 1:
1753
+ # Ensure the dequant pattern only has 1 user
1754
+ # since we will delete the dequant pattern here
1755
+ return False
1756
+ return True
1757
+
1758
+ return _inner
1759
+
1760
+
1761
+ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32):
1762
+ @register_freezing_graph_pattern(
1763
+ pattern,
1764
+ extra_check=_is_valid_dequant_conv2d_pattern(dtype),
1765
+ pass_number=pass_number,
1766
+ )
1767
+ def qconv_weight_prepack(match: Match, *args, **kwargs):
1768
+ """
1769
+ Match the pattern:
1770
+ int8 activation
1771
+ |
1772
+ dequant_per_tensor
1773
+ |
1774
+ Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight
1775
+
1776
+ Insert weight prepack node and change the pattern to:
1777
+ int8 activation
1778
+ |
1779
+ onednn.qconv2d_pointwise <- onednn.qconv_prepack <- int8_weight
1780
+ """
1781
+ assert dtype in [torch.float32, torch.bfloat16]
1782
+ conv_node = match.output_node()
1783
+ assert conv_node.target is aten.convolution.default
1784
+ if dtype == torch.float32:
1785
+ dequant_node = conv_node.args[0]
1786
+ else:
1787
+ convert_to_bf16 = conv_node.args[0]
1788
+ dequant_node = convert_to_bf16.args[0] # type: ignore[union-attr]
1789
+ has_clone_to_channel_last_node_in_pattern = (
1790
+ conv_node.args[1].target is aten.clone.default # type: ignore[union-attr]
1791
+ )
1792
+ clone_node = (
1793
+ conv_node.args[1] if has_clone_to_channel_last_node_in_pattern else None
1794
+ )
1795
+
1796
+ if dtype == torch.float32:
1797
+ dequant_per_channel = (
1798
+ clone_node.args[0] # type: ignore[union-attr]
1799
+ if has_clone_to_channel_last_node_in_pattern
1800
+ else conv_node.args[1]
1801
+ )
1802
+ else:
1803
+ weight_to_bf16_node = (
1804
+ clone_node.args[0] # type: ignore[union-attr]
1805
+ if has_clone_to_channel_last_node_in_pattern
1806
+ else conv_node.args[1]
1807
+ )
1808
+ dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr]
1809
+
1810
+ assert (
1811
+ dequant_per_channel.target # type: ignore[union-attr]
1812
+ is quantized_decomposed.dequantize_per_channel.default
1813
+ )
1814
+
1815
+ # Activation QParams
1816
+ qx, x_zp, x_scale = (
1817
+ kwargs["x"],
1818
+ kwargs["x_zp"],
1819
+ kwargs["x_scale"],
1820
+ )
1821
+
1822
+ # Weight QParams
1823
+ qw, w_scale, w_zp = (
1824
+ kwargs["q_weight"],
1825
+ kwargs["w_scale"],
1826
+ kwargs["w_zp"],
1827
+ )
1828
+
1829
+ # Conv Params
1830
+ bias, stride, padding, dilation, groups = (
1831
+ kwargs["b"],
1832
+ kwargs["stride"],
1833
+ kwargs["padding"],
1834
+ kwargs["dilation"],
1835
+ kwargs["groups"],
1836
+ )
1837
+
1838
+ x_shape = qx.meta.get("tensor_meta").shape
1839
+ if has_free_symbols(x_shape):
1840
+ # For dynamic shape case, we can't get activation shape ahead of runtime.
1841
+ x_shape = None
1842
+ graph = match.graph
1843
+ with graph.inserting_before(conv_node):
1844
+ # Insert weight prepack node and the QConv node
1845
+ packed_weight_inputs = (
1846
+ qw,
1847
+ w_scale,
1848
+ x_scale,
1849
+ x_zp,
1850
+ stride,
1851
+ padding,
1852
+ dilation,
1853
+ groups,
1854
+ x_shape,
1855
+ )
1856
+ packed_weight_op = torch.ops.onednn.qconv_prepack
1857
+ prepack_weight_node = graph.call_function(
1858
+ packed_weight_op, args=packed_weight_inputs
1859
+ )
1860
+
1861
+ new_args: Tuple[Any, ...] = (
1862
+ qx,
1863
+ x_scale,
1864
+ x_zp,
1865
+ prepack_weight_node,
1866
+ w_scale,
1867
+ w_zp,
1868
+ bias,
1869
+ stride,
1870
+ padding,
1871
+ dilation,
1872
+ groups,
1873
+ 1.0, # output_scale
1874
+ 0, # output_zero_point
1875
+ dtype, # output_dtype
1876
+ "none", # attr
1877
+ [], # scalars
1878
+ "", # algorithm
1879
+ )
1880
+ new_conv_node = graph.call_function(
1881
+ torch.ops.onednn.qconv2d_pointwise.default, args=new_args
1882
+ )
1883
+ conv_node.replace_all_uses_with(new_conv_node)
1884
+ new_conv_node.meta.update(conv_node.meta)
1885
+
1886
+ # Erase the original conv node
1887
+ graph.erase_node(conv_node)
1888
+ # Erase the dequant pattern
1889
+ if dtype == torch.bfloat16:
1890
+ graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined, arg-type]
1891
+ graph.erase_node(dequant_node) # type: ignore[arg-type]
1892
+ # Erase the dequant per channel pattern
1893
+ if clone_node is not None:
1894
+ graph.erase_node(clone_node) # type: ignore[arg-type]
1895
+ if dtype == torch.bfloat16:
1896
+ graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type]
1897
+ graph.erase_node(dequant_per_channel) # type: ignore[arg-type]
1898
+ counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1
1899
+ counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len(
1900
+ match.nodes
1901
+ )
1902
+
1903
+
1904
+ def _generate_dequant_convolution_node_pattern(
1905
+ _dequant_per_channel_pattern, dtype=torch.float32
1906
+ ):
1907
+ assert dtype in [torch.float32, torch.bfloat16]
1908
+ dequant_convolution_node_pattern = CallFunction(
1909
+ aten.convolution.default,
1910
+ _may_generate_pattern_with_dtype_convert(
1911
+ get_dequantize_per_tensor_activation_pattern(),
1912
+ KeywordArg("autocast_act_dtype"),
1913
+ dtype == torch.bfloat16,
1914
+ ),
1915
+ _dequant_per_channel_pattern,
1916
+ KeywordArg("b"),
1917
+ KeywordArg("stride"),
1918
+ KeywordArg("padding"),
1919
+ KeywordArg("dilation"),
1920
+ KeywordArg("is_transposed"),
1921
+ KeywordArg("out_padding"),
1922
+ KeywordArg("groups"),
1923
+ )
1924
+ return dequant_convolution_node_pattern
1925
+
1926
+
1927
+ def _generate_qconv_weight_prepack_patterns(dtype=torch.float32):
1928
+ assert dtype in [torch.float32, torch.bfloat16]
1929
+ return (
1930
+ _generate_dequant_convolution_node_pattern(
1931
+ dequantize_per_channel_weight_pattern
1932
+ if dtype == torch.float32
1933
+ else dequantize_per_channel_to_bf16_weight_pattern,
1934
+ dtype,
1935
+ ),
1936
+ # There is another pattern due to the pass of convert_conv_weights_to_channels_last
1937
+ # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362.
1938
+ # Depend on some heuristics, it may or may not insert to(channel_last) node
1939
+ # between convolution and dequant_per_channel node
1940
+ _generate_dequant_convolution_node_pattern(
1941
+ dequantize_per_channel_clone_weight_pattern
1942
+ if dtype == torch.float32
1943
+ else dequantize_per_channel_to_bf16_clone_weight_pattern,
1944
+ dtype,
1945
+ ),
1946
+ )
1947
+
1948
+
1949
+ def _get_linear_node(match, input_dim_exceeds_two, input_contiguous):
1950
+ output_reshape_node = None
1951
+ if input_dim_exceeds_two:
1952
+ if input_contiguous:
1953
+ output_reshape_node = match.output_node()
1954
+ assert output_reshape_node.target is aten.reshape.default
1955
+ linear_node = output_reshape_node.args[0]
1956
+ else:
1957
+ linear_nodes = filter_nodes(match.nodes, aten.bmm.default)
1958
+ assert len(linear_nodes) == 1
1959
+ linear_node = linear_nodes[0]
1960
+ else:
1961
+ linear_node = match.output_node()
1962
+
1963
+ assert linear_node.target in (
1964
+ aten.addmm.default,
1965
+ aten.mm.default,
1966
+ aten.bmm.default,
1967
+ )
1968
+ return linear_node, output_reshape_node
1969
+
1970
+
1971
+ def _get_linear_dq_node(
1972
+ linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
1973
+ ):
1974
+ act_reshape_node = None
1975
+ activation_to_bf16_node = None
1976
+ act_expand_node = None
1977
+ if input_dim_exceeds_two:
1978
+ if input_contiguous:
1979
+ act_reshape_node = linear_node.args[input_index]
1980
+ assert act_reshape_node.target is aten.reshape.default
1981
+ if dtype == torch.float32:
1982
+ # pattern: linear -> reshape -> dequant
1983
+ dequant_node = act_reshape_node.args[0]
1984
+ else:
1985
+ # pattern: linear -> reshape -> to_bf16 -> dequant
1986
+ activation_to_bf16_node = act_reshape_node.args[0]
1987
+ dequant_node = activation_to_bf16_node.args[0]
1988
+ else:
1989
+ # bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous
1990
+ act_expand_node = linear_node.args[input_index]
1991
+ assert act_expand_node.target is aten.expand.default
1992
+ if dtype == torch.float32:
1993
+ dequant_node = act_expand_node.args[0]
1994
+ else:
1995
+ activation_to_bf16_node = act_expand_node.args[0]
1996
+ dequant_node = activation_to_bf16_node.args[0]
1997
+ else:
1998
+ if dtype == torch.float32:
1999
+ # pattern: linear -> dequant
2000
+ dequant_node = linear_node.args[input_index]
2001
+ else:
2002
+ # pattern: linear -> to_bf16 -> dequant
2003
+ activation_to_bf16_node = linear_node.args[input_index]
2004
+ dequant_node = activation_to_bf16_node.args[0]
2005
+ return dequant_node, act_reshape_node, activation_to_bf16_node, act_expand_node
2006
+
2007
+
2008
+ def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contiguous):
2009
+ def _inner(match):
2010
+ # Check dequant pattern has only 1 user.
2011
+ (
2012
+ linear_node,
2013
+ _,
2014
+ ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous)
2015
+
2016
+ input_index = 1 if linear_node.target is aten.addmm.default else 0
2017
+ assert dtype in [torch.float32, torch.bfloat16]
2018
+ (
2019
+ dequant_node,
2020
+ _,
2021
+ _,
2022
+ _,
2023
+ ) = _get_linear_dq_node(
2024
+ linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
2025
+ )
2026
+
2027
+ assert dequant_node.target in [
2028
+ quantized_decomposed.dequantize_per_tensor.default,
2029
+ quantized_decomposed.dequantize_per_tensor.tensor,
2030
+ ]
2031
+
2032
+ if len(list(dequant_node.users)) != 1:
2033
+ # Ensure the dequant pattern only has 1 user
2034
+ # since we will delete the dequant pattern here
2035
+ return False
2036
+
2037
+ # Extra check for bmm pattern
2038
+ if input_dim_exceeds_two and not input_contiguous:
2039
+ # Check for act
2040
+ # Act expand size should be exactly same as act size
2041
+ act_expand_size = match.kwargs["act_expand_size"]
2042
+ act_node = match.kwargs["x"]
2043
+ if not (
2044
+ hasattr(act_node, "meta")
2045
+ and isinstance(act_node.meta.get("val", None), torch.Tensor)
2046
+ and (act_node.meta["val"].size() == torch.Size(act_expand_size))
2047
+ ):
2048
+ return False
2049
+
2050
+ # Check for wgt
2051
+ # wgt permute dims should be [1, 0]
2052
+ wgt_permute_dims = match.kwargs["permute_axes"]
2053
+ if wgt_permute_dims != [1, 0]:
2054
+ return False
2055
+
2056
+ # Check below wgt size items:
2057
+ # wgt before expand should with dim 2
2058
+ # Expand size should with dim 3
2059
+ # Expand size[0] should same as act size[0]
2060
+ # Expand size[1] should same as wgt size[1]
2061
+ # Expand size[2] should same as wgt size[0]
2062
+ qweight_node = match.kwargs["q_weight"]
2063
+ wgt_expand_size = match.kwargs["wgt_expand_size"]
2064
+ if not (
2065
+ hasattr(qweight_node, "meta")
2066
+ and isinstance(qweight_node.meta.get("val", None), torch.Tensor)
2067
+ and len(qweight_node.meta["val"].size()) == 2
2068
+ and len(wgt_expand_size) == 3
2069
+ and wgt_expand_size[0] == act_node.meta["val"].size()[0]
2070
+ and wgt_expand_size[1] == qweight_node.meta["val"].size()[1]
2071
+ and wgt_expand_size[2] == qweight_node.meta["val"].size()[0]
2072
+ ):
2073
+ return False
2074
+
2075
+ return True
2076
+
2077
+ return _inner
2078
+
2079
+
2080
+ def _register_qlinear_weight_prepack_pass(
2081
+ pattern,
2082
+ pass_number,
2083
+ dtype=torch.float32,
2084
+ input_dim_exceeds_two=False,
2085
+ input_contiguous=True,
2086
+ ):
2087
+ @register_freezing_graph_pattern(
2088
+ pattern,
2089
+ extra_check=_is_valid_dequant_linear_pattern(
2090
+ dtype, input_dim_exceeds_two, input_contiguous
2091
+ ),
2092
+ pass_number=pass_number,
2093
+ )
2094
+ def qlinear_weight_prepack(match: Match, *args, **kwargs):
2095
+ """
2096
+ Match the pattern:
2097
+ int8 activation
2098
+ |
2099
+ dequant_per_tensor
2100
+ |
2101
+ mm/addmm <- t <- dequant_per_channel <- int8_weight
2102
+
2103
+ Insert weight prepack node and change the pattern to:
2104
+ int8 activation
2105
+ |
2106
+ onednn.qlinear_pointwise <- onednn.qlinear_prepack <- int8_weight
2107
+ """
2108
+ assert dtype in [torch.float32, torch.bfloat16]
2109
+ (
2110
+ linear_node,
2111
+ output_reshape_node,
2112
+ ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous)
2113
+ input_index = 1 if linear_node.target is aten.addmm.default else 0
2114
+ weight_index = input_index + 1
2115
+
2116
+ (
2117
+ dequant_node,
2118
+ act_reshape_node,
2119
+ activation_to_bf16_node,
2120
+ act_expand_node,
2121
+ ) = _get_linear_dq_node(
2122
+ linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
2123
+ )
2124
+
2125
+ if input_dim_exceeds_two and not input_contiguous:
2126
+ wgt_expand_node = linear_node.args[weight_index]
2127
+ assert wgt_expand_node.target is aten.expand.default
2128
+ t_node = wgt_expand_node.args[0]
2129
+ else:
2130
+ t_node = linear_node.args[weight_index]
2131
+
2132
+ if dtype == torch.float32:
2133
+ dequant_per_channel = t_node.args[0]
2134
+ else:
2135
+ weight_to_bf16_node = t_node.args[0]
2136
+ dequant_per_channel = weight_to_bf16_node.args[0]
2137
+ assert (
2138
+ dequant_per_channel.target
2139
+ is quantized_decomposed.dequantize_per_channel.default
2140
+ )
2141
+
2142
+ # Activation QParams
2143
+ qx, x_zp, x_scale = (
2144
+ kwargs["x"],
2145
+ kwargs["x_zp"],
2146
+ kwargs["x_scale"],
2147
+ )
2148
+
2149
+ # Weight QParams
2150
+ qw, w_scale, w_zp = (
2151
+ kwargs["q_weight"],
2152
+ kwargs["w_scale"],
2153
+ kwargs["w_zp"],
2154
+ )
2155
+
2156
+ # Params
2157
+ bias = kwargs["b"] if "b" in kwargs else None
2158
+
2159
+ x_shape = qx.meta.get("tensor_meta").shape
2160
+ if has_free_symbols(x_shape):
2161
+ # For dynamic shape case, we can't get activation shape ahead of runtime.
2162
+ x_shape = None
2163
+ graph = match.graph
2164
+ with graph.inserting_before(linear_node):
2165
+ # Insert weight prepack node and the qlinear node
2166
+ packed_weight_inputs = (
2167
+ qw,
2168
+ x_shape,
2169
+ )
2170
+ packed_weight_op = torch.ops.onednn.qlinear_prepack
2171
+ prepack_weight_node = graph.call_function(
2172
+ packed_weight_op, args=packed_weight_inputs
2173
+ )
2174
+
2175
+ new_args: Tuple[Any, ...] = (
2176
+ qx,
2177
+ x_scale,
2178
+ x_zp,
2179
+ prepack_weight_node,
2180
+ w_scale,
2181
+ w_zp,
2182
+ bias,
2183
+ 1.0, # output_scale
2184
+ 0, # output_zero_point
2185
+ dtype, # output_dtype
2186
+ "none", # post op name
2187
+ [], # post op args
2188
+ "", # post op algorithm
2189
+ )
2190
+ Node = torch.fx.node.Node
2191
+ if isinstance(x_scale, Node) and isinstance(x_zp, Node):
2192
+ new_linear_node = graph.call_function(
2193
+ torch.ops.onednn.qlinear_pointwise.tensor, args=new_args
2194
+ )
2195
+ else:
2196
+ new_linear_node = graph.call_function(
2197
+ torch.ops.onednn.qlinear_pointwise.default, args=new_args
2198
+ )
2199
+ if input_dim_exceeds_two:
2200
+ if input_contiguous:
2201
+ output_reshape_node.replace_all_uses_with(new_linear_node)
2202
+ new_linear_node.meta.update(output_reshape_node.meta)
2203
+ else:
2204
+ if bias:
2205
+ output_add_node_for_bias = match.output_node()
2206
+ assert output_add_node_for_bias.target is aten.add.Tensor
2207
+ output_add_node_for_bias.replace_all_uses_with(new_linear_node)
2208
+ new_linear_node.meta.update(output_add_node_for_bias.meta)
2209
+ else:
2210
+ linear_node.replace_all_uses_with(new_linear_node)
2211
+ new_linear_node.meta.update(linear_node.meta)
2212
+ else:
2213
+ linear_node.replace_all_uses_with(new_linear_node)
2214
+ new_linear_node.meta.update(linear_node.meta)
2215
+
2216
+ # Erase the original linear node
2217
+ if input_dim_exceeds_two:
2218
+ if input_contiguous:
2219
+ graph.erase_node(output_reshape_node)
2220
+ elif not input_contiguous and bias:
2221
+ graph.erase_node(output_add_node_for_bias) # type: ignore[possibly-undefined]
2222
+ graph.erase_node(linear_node)
2223
+ if input_dim_exceeds_two:
2224
+ if input_contiguous:
2225
+ graph.erase_node(act_reshape_node)
2226
+ else:
2227
+ graph.erase_node(act_expand_node)
2228
+ graph.erase_node(wgt_expand_node) # type: ignore[possibly-undefined]
2229
+ if dtype == torch.bfloat16:
2230
+ graph.erase_node(activation_to_bf16_node)
2231
+ # Erase the dequant pattern
2232
+ graph.erase_node(dequant_node)
2233
+ # Erase the dequant per channel pattern
2234
+ graph.erase_node(t_node)
2235
+ if dtype == torch.bfloat16:
2236
+ graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
2237
+ graph.erase_node(dequant_per_channel)
2238
+
2239
+ counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1
2240
+ counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len(
2241
+ match.nodes
2242
+ )
2243
+
2244
+
2245
+ def _generate_dequant_linear_node_pattern(
2246
+ _dequant_per_channel_pattern,
2247
+ dtype=torch.float32,
2248
+ input_dim_exceeds_two=False,
2249
+ is_tensor_overload=False,
2250
+ ):
2251
+ assert dtype in [torch.float32, torch.bfloat16]
2252
+ t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
2253
+ dequant_linear_bias_pattern = _may_generate_pattern_with_reshape(
2254
+ CallFunction(
2255
+ aten.addmm.default,
2256
+ KeywordArg("b"),
2257
+ _may_generate_pattern_with_reshape(
2258
+ _may_generate_pattern_with_dtype_convert(
2259
+ get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
2260
+ KeywordArg("autocast_act_dtype"),
2261
+ dtype == torch.bfloat16,
2262
+ ),
2263
+ KeywordArg("act_reshape_size"),
2264
+ input_dim_exceeds_two,
2265
+ ),
2266
+ t_pattern,
2267
+ ),
2268
+ KeywordArg("output_reshape_size"),
2269
+ input_dim_exceeds_two,
2270
+ )
2271
+ dequant_linear_no_bias_pattern = _may_generate_pattern_with_reshape(
2272
+ CallFunction(
2273
+ aten.mm.default,
2274
+ _may_generate_pattern_with_reshape(
2275
+ _may_generate_pattern_with_dtype_convert(
2276
+ get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
2277
+ KeywordArg("autocast_act_dtype"),
2278
+ dtype == torch.bfloat16,
2279
+ ),
2280
+ KeywordArg("act_reshape_size"),
2281
+ input_dim_exceeds_two,
2282
+ ),
2283
+ t_pattern,
2284
+ ),
2285
+ KeywordArg("output_reshape_size"),
2286
+ input_dim_exceeds_two,
2287
+ )
2288
+ return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern
2289
+
2290
+
2291
+ def _generate_dequant_bmm_node_pattern(
2292
+ _dequant_per_channel_pattern,
2293
+ dtype=torch.float32,
2294
+ with_bias=False,
2295
+ is_tensor_overload=False,
2296
+ ):
2297
+ # When activation of linear dim exceed 2 and not contiguous
2298
+ t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
2299
+
2300
+ assert dtype in [torch.float32, torch.bfloat16]
2301
+ dequant_bmm_pattern = CallFunction(
2302
+ aten.bmm.default,
2303
+ CallFunction(
2304
+ aten.expand.default,
2305
+ _may_generate_pattern_with_dtype_convert(
2306
+ get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
2307
+ KeywordArg("autocast_act_dtype"),
2308
+ dtype == torch.bfloat16,
2309
+ ),
2310
+ KeywordArg("act_expand_size"),
2311
+ ),
2312
+ CallFunction(
2313
+ aten.expand.default,
2314
+ t_pattern,
2315
+ KeywordArg("wgt_expand_size"),
2316
+ ),
2317
+ )
2318
+
2319
+ def _generate_pattern_with_output_add(_dequant_bmm_pattern, _with_bias):
2320
+ if _with_bias:
2321
+ return CallFunction(
2322
+ aten.add.Tensor,
2323
+ _dequant_bmm_pattern,
2324
+ KeywordArg("b"),
2325
+ )
2326
+ else:
2327
+ return _dequant_bmm_pattern
2328
+
2329
+ return _generate_pattern_with_output_add(dequant_bmm_pattern, with_bias)
2330
+
2331
+
2332
+ def _generate_qlinear_weight_prepack_patterns(
2333
+ dtype=torch.float32,
2334
+ input_dim_exceeds_two=False,
2335
+ input_contiguous=True,
2336
+ with_bias=False,
2337
+ is_tensor_overload=False,
2338
+ ):
2339
+ if input_dim_exceeds_two and not input_contiguous:
2340
+ return _generate_dequant_bmm_node_pattern(
2341
+ dequantize_per_channel_weight_pattern,
2342
+ dtype,
2343
+ with_bias,
2344
+ is_tensor_overload,
2345
+ )
2346
+ else:
2347
+ return _generate_dequant_linear_node_pattern(
2348
+ dequantize_per_channel_weight_pattern,
2349
+ dtype,
2350
+ input_dim_exceeds_two,
2351
+ is_tensor_overload,
2352
+ )
2353
+
2354
+
2355
+ def _register_dequant_promotion():
2356
+ dequant_pattern_cases = itertools.product(
2357
+ [torch.float32, torch.bfloat16], [True, False], [True, False]
2358
+ )
2359
+ for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases:
2360
+ # 4 dequantization patterns will be matched based on the dtype and input dimension size.
2361
+ # Case 1: int8-mixed-fp32, input dim size is 2
2362
+ # Case 2: int8-mixed-fp32, input dim size exceeds 2
2363
+ # Case 3: int8-mixed-bf16, input dim size is 2
2364
+ # Case 4: int8-mixed-bf16, input dim size exceeds 2
2365
+ # quant
2366
+ # + - - - - | - - - - +
2367
+ # | dequant |
2368
+ # | | |
2369
+ # | OPT(to_bf16) |
2370
+ # | | |
2371
+ # | OPT(reshape) |
2372
+ # | / \ |
2373
+ # | node1 node2 |
2374
+ # + - - | - - - | - - +
2375
+ # OPT(reshape) OPT(reshape)
2376
+ # + - - | - - - | - - +
2377
+ # OPT(to_fp32) OPT(to_fp32)
2378
+ # + - - | - - - | - - +
2379
+ # quant quant
2380
+ _register_dequant_promotion_pass(
2381
+ _may_generate_pattern_with_reshape(
2382
+ _may_generate_pattern_with_dtype_convert(
2383
+ get_dequantize_per_tensor_activation_pattern(
2384
+ is_tensor_overload=is_tensor_overload
2385
+ ),
2386
+ KeywordArg("autocast_act_dtype"),
2387
+ dtype == torch.bfloat16,
2388
+ ),
2389
+ KeywordArg("act_reshape_size"),
2390
+ with_reshape=input_dim_exceeds_two,
2391
+ ),
2392
+ pass_number=0,
2393
+ dtype=dtype,
2394
+ ) # pass_number=0 to run before weight prepack
2395
+
2396
+
2397
+ def _register_qconv_weight_prepack():
2398
+ for dtype in [torch.float32, torch.bfloat16]:
2399
+ weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype)
2400
+ for weight_prepack_pattern in weight_prepack_patterns:
2401
+ # Register to pass_number 1, so we can do dequant promotion in pass_number 0.
2402
+ _register_qconv_weight_prepack_pass(
2403
+ weight_prepack_pattern, pass_number=1, dtype=dtype
2404
+ )
2405
+
2406
+
2407
+ def _register_qlinear_weight_prepack():
2408
+ # 6 Linear related patterns will be matched based on the dtype, input dimension size and input contiguous.
2409
+ # Then convert the pattern into a QLinear node with int8_fp32/bf16.
2410
+ # Case 1: int8-mixed-fp32, input dim size is 2
2411
+ # Case 2: int8-mixed-fp32, input dim size exceeds 2 and contiguous
2412
+ # Case 3: int8-mixed-bf16, input dim size is 2
2413
+ # Case 4: int8-mixed-bf16, input dim size exceeds 2 and contiguous
2414
+
2415
+ # + - - - - | - - - - - - | - - - - - +
2416
+ # | dq_per_tensor dq_per_channel |
2417
+ # | | | |
2418
+ # | OPT(to_bf16) OPT(to_bf16) |
2419
+ # | | | |
2420
+ # | OPT(reshape) permute |
2421
+ # | \ / |
2422
+ # | addmm/mm |
2423
+ # | | |
2424
+ # | OPT(reshape) |
2425
+
2426
+ # Case 5: int8-mixed-fp32, input dim size exceeds 2 and not contiguous
2427
+ # Case 6: int8-mixed-bf16, input dim size exceeds 2 and not contiguous
2428
+
2429
+ # + - - - - | - - - - - - | - - - - - +
2430
+ # | dq_per_tensor dq_per_channel |
2431
+ # | | | |
2432
+ # | OPT(to_bf16) OPT(to_bf16) |
2433
+ # | | | |
2434
+ # | expand permute |
2435
+ # | \ | |
2436
+ # | expand |
2437
+ # | / |
2438
+ # | bmm |
2439
+ # | | |
2440
+ # | OPT(add) |
2441
+
2442
+ linear_weight_prepack_cases = itertools.product(
2443
+ [torch.float32, torch.bfloat16], [True, False], [True, False]
2444
+ )
2445
+
2446
+ # Step 1: register patterns from mm and addmm
2447
+ for dtype, input_dim_exceeds_two, is_tensor_overload in linear_weight_prepack_cases:
2448
+ weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns(
2449
+ dtype,
2450
+ input_dim_exceeds_two,
2451
+ is_tensor_overload=is_tensor_overload,
2452
+ )
2453
+ for weight_prepack_pattern in weight_prepack_patterns:
2454
+ # Register to pass_number 1, so we can do dequant promotion in pass_number 0.
2455
+ _register_qlinear_weight_prepack_pass(
2456
+ weight_prepack_pattern,
2457
+ pass_number=1,
2458
+ dtype=dtype,
2459
+ input_dim_exceeds_two=input_dim_exceeds_two,
2460
+ )
2461
+
2462
+ # Step 2: register patterns from bmm
2463
+ # Linear might be decomposed into bmm when input dim exceeds 2 and not contiguous
2464
+ # refer to:
2465
+ # https://github.com/pytorch/pytorch/blob/
2466
+ # 80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968
2467
+ # in this case, we can convert it back to qlinear
2468
+ for dtype, with_bias, is_tensor_overload in itertools.product(
2469
+ [torch.float32, torch.bfloat16], [True, False], [True, False]
2470
+ ):
2471
+ bmm_pattern = _generate_qlinear_weight_prepack_patterns(
2472
+ dtype=dtype,
2473
+ input_dim_exceeds_two=True,
2474
+ input_contiguous=False,
2475
+ with_bias=with_bias,
2476
+ is_tensor_overload=is_tensor_overload,
2477
+ )
2478
+ _register_qlinear_weight_prepack_pass(
2479
+ bmm_pattern,
2480
+ pass_number=1
2481
+ if with_bias
2482
+ else 2, # if with_bias, there is an output add, so we should try to match it firstly
2483
+ dtype=dtype,
2484
+ input_dim_exceeds_two=True,
2485
+ input_contiguous=False,
2486
+ )
2487
+
2488
+
2489
+ @functools.lru_cache(None)
2490
+ def _register_quantization_weight_pack_pass():
2491
+ # Step 1: Dequant promotion for int8-mixed-fp32/bf16
2492
+ _register_dequant_promotion()
2493
+
2494
+ # Step 2: QConv weight prepack
2495
+ _register_qconv_weight_prepack()
2496
+
2497
+ # Step 3: QLinear weight prepack
2498
+ _register_qlinear_weight_prepack()
2499
+
2500
+
2501
+ def quant_lift_up(graph_module: torch.fx.GraphModule):
2502
+ """
2503
+ Lift up the quant node before view like nodes. It can benefit performance
2504
+ of Attention like block. For example, we have the pattern as:
2505
+
2506
+ DQ
2507
+ DQ LINEAR
2508
+ LINEAR VIEW
2509
+ VIEW PERMUTE
2510
+ PERMUTE TRANSPOSE
2511
+ Q Q
2512
+ DQ DQ
2513
+ Matmul
2514
+ DIV
2515
+ ADD
2516
+ SOFTMAX
2517
+
2518
+ We want to lift up the the quant nodes from matmul before view like nodes
2519
+ as the output of Linear node.
2520
+
2521
+ DQ
2522
+ DQ LINEAR
2523
+ LINEAR Q
2524
+ Q VIEW
2525
+ VIEW PERMUTE
2526
+ PERMUTE TRANSPOSE
2527
+ DQ DQ
2528
+ Matmul
2529
+ DIV
2530
+ ADD
2531
+ SOFTMAX
2532
+
2533
+ It produces a DQ->LINEAR->Q pattern which can be fused by backend.
2534
+ """
2535
+
2536
+ def is_view_op(node):
2537
+ return node.op == "call_function" and node.target in _VIEW_OPS
2538
+
2539
+ for node in graph_module.graph.nodes:
2540
+ # <TODO> Leslie: Here we verify that the quant node has exactly
2541
+ # one input FX node, with constant scalar value for scale and zero point.
2542
+ # For the case input of quant node has more than one input FX nodes,
2543
+ # extend the implementation to lift up all the connected nodes
2544
+ # before the view nodes to keep the topological order.
2545
+ if (
2546
+ node.op == "call_function"
2547
+ and node.target in _PER_TENSOR_QUANTIZE_OPS
2548
+ and len(node.all_input_nodes) == 1
2549
+ and is_view_op(node.all_input_nodes[0])
2550
+ ):
2551
+ quant_node = node
2552
+ input_node_of_quant = quant_node.args[0]
2553
+
2554
+ # Check the nodes along lift up path has only 1 user node
2555
+ # Propagate view like node to find where to insert the new quant node
2556
+ could_lift_up = True
2557
+ current_node = quant_node
2558
+ input_node = current_node.args[0]
2559
+ while is_view_op(input_node):
2560
+ if len(input_node.users) != 1:
2561
+ could_lift_up = False
2562
+ break
2563
+ current_node = input_node
2564
+ input_node = current_node.args[0]
2565
+
2566
+ # Further check the input node of the first view node has only 1 user node
2567
+ if could_lift_up and len(input_node.users) == 1:
2568
+ # Replace dequant's input from quant to quant's input
2569
+ quant_node.replace_all_uses_with(input_node_of_quant)
2570
+ # Insert the new quant node
2571
+ with graph_module.graph.inserting_before(current_node):
2572
+ new_quant_node = graph_module.graph.node_copy(quant_node)
2573
+ input_node.replace_all_uses_with(new_quant_node)
2574
+
2575
+ # Update inputs of new_quant_node
2576
+ def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
2577
+ if n == input_node_of_quant:
2578
+ return input_node
2579
+ else:
2580
+ return n
2581
+
2582
+ new_args = map_arg(new_quant_node.args, maybe_replace_node)
2583
+ new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node)
2584
+ new_quant_node.args = new_args # type: ignore[assignment]
2585
+ new_quant_node.kwargs = new_kwargs # type: ignore[assignment]
2586
+ graph_module.graph.erase_node(quant_node)
2587
+
2588
+ graph_module.graph.lint()
2589
+ graph_module.recompile()
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/reinplace.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import itertools
3
+ import logging
4
+ import operator
5
+ from collections import defaultdict
6
+ from dataclasses import dataclass
7
+ from typing import Any, Callable, Dict, List, Tuple
8
+
9
+ import torch
10
+ from torch._higher_order_ops.triton_kernel_wrap import (
11
+ kernel_side_table,
12
+ triton_kernel_wrapper_functional,
13
+ )
14
+ from torch._inductor import config, inductor_prims
15
+ from torch._inductor.fx_utils import get_node_storage, is_node_realized
16
+ from torch._inductor.lowering import (
17
+ inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings,
18
+ )
19
+ from torch._inductor.virtualized import V
20
+ from torch.fx.immutable_collections import immutable_dict
21
+ from torch.fx.passes.reinplace import _is_view_op
22
+ from torch.utils import _pytree as pytree
23
+
24
+
25
+ log = logging.getLogger(__name__)
26
+ aten = torch.ops.aten
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class InplaceableOp:
31
+ inplace_op: Callable[..., Any]
32
+ mutated_arg: int
33
+ extra_check: Callable[[torch.fx.Node], bool] = lambda node: True
34
+
35
+
36
+ _SCATTER_OP_TO_VIEW = {
37
+ torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
38
+ torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
39
+ torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor,
40
+ torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default,
41
+ }
42
+ _VIEW_OP_TO_SCATTER = {v: k for k, v in _SCATTER_OP_TO_VIEW.items()}
43
+
44
+
45
+ def graph_call_function(graph: torch.fx.Graph, fn, *args, **kwargs):
46
+ fake_args, fake_kwargs = pytree.tree_map(
47
+ lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node,
48
+ (args, kwargs),
49
+ )
50
+ with V.fake_mode:
51
+ fake_result = fn(*fake_args, **fake_kwargs)
52
+
53
+ node = graph.call_function(fn, args, kwargs)
54
+ node.meta["val"] = fake_result
55
+ return node
56
+
57
+
58
+ @dataclass
59
+ class ViewOp:
60
+ target: torch._ops.OpOverload
61
+ args: Tuple[Any, ...]
62
+ kwargs: Dict[str, Any]
63
+
64
+
65
+ def _inplace_generalized_scatter(
66
+ inp: torch.Tensor, src: torch.Tensor, view_ops: List[ViewOp]
67
+ ) -> torch.Tensor:
68
+ tmp = inp
69
+ for view in view_ops:
70
+ fake_args, fake_kwargs = pytree.tree_map(
71
+ lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node,
72
+ (view.args, view.kwargs),
73
+ )
74
+ tmp = view.target(tmp, *fake_args, **fake_kwargs)
75
+ try:
76
+ tmp.copy_(src)
77
+ except RuntimeError as e:
78
+ raise RuntimeError(
79
+ f"shape error in scatter op, can not broadcast {src.shape} to {tmp.shape}"
80
+ ) from e
81
+ return inp
82
+
83
+
84
+ def _generalized_scatter(
85
+ inp: torch.Tensor, src: torch.Tensor, view_ops: List[ViewOp]
86
+ ) -> torch.Tensor:
87
+ out = inp.clone()
88
+ return _inplace_generalized_scatter(out, src, view_ops)
89
+
90
+
91
+ def _decompose_scatter_functional_helper(
92
+ graph: torch.fx.Graph,
93
+ inp: torch.Tensor,
94
+ src: torch.Tensor,
95
+ view_ops: List[ViewOp],
96
+ ) -> torch.fx.Node:
97
+ view_op, view_ops_tail = view_ops[0], view_ops[1:]
98
+
99
+ if view_ops_tail:
100
+ view = graph_call_function(
101
+ graph, view_op.target, inp, *view_op.args, **view_op.kwargs
102
+ )
103
+ src = _decompose_scatter_functional_helper(graph, view, src, view_ops[1:]) # type: ignore[assignment]
104
+
105
+ return graph_call_function(
106
+ graph,
107
+ _VIEW_OP_TO_SCATTER[view_op.target],
108
+ inp,
109
+ src,
110
+ *view_op.args,
111
+ **view_op.kwargs,
112
+ )
113
+
114
+
115
+ def _decompose_scatter_functional(
116
+ graph: torch.fx.Graph, node: torch.fx.Node
117
+ ) -> torch.fx.Node:
118
+ """Decompose _generalized_scatter to a sequence of view_scatter operations
119
+
120
+ e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)])
121
+
122
+ will become
123
+
124
+ view = aten.slice(inp, 0, 0, 10)
125
+ view_updated = aten.slice_scatter(view, src, 1, 10, -10)
126
+ inp_updated = aten.slice_scatter(inp, view_updated, 0, 0, 10)
127
+ """
128
+ assert node.target is _generalized_scatter
129
+ inp, src, view_ops = node.args
130
+ return _decompose_scatter_functional_helper(graph, *node.args) # type: ignore[arg-type]
131
+
132
+
133
+ def _decompose_scatter_mutating(
134
+ graph: torch.fx.Graph, node: torch.fx.Node
135
+ ) -> torch.fx.Node:
136
+ """Decompose _generalized_scatter using mutations
137
+
138
+ e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)])
139
+
140
+ will become
141
+
142
+ inp_updated = aten.clone(inp)
143
+ slice1 = aten.slice(inp_updated, 0, 0, 10)
144
+ slice2 = aten.slice(slice1, 1, 10, -10)
145
+ slice2.copy_(src)
146
+
147
+ """
148
+ assert node.target in (_generalized_scatter, _inplace_generalized_scatter)
149
+ inp, src, view_ops = node.args
150
+ assert not node.kwargs
151
+
152
+ if node.target is _generalized_scatter:
153
+ inp = graph_call_function(graph, aten.clone, inp)
154
+
155
+ tmp = inp
156
+ for view in view_ops: # type: ignore[union-attr]
157
+ tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs) # type: ignore[union-attr]
158
+
159
+ graph_call_function(graph, aten.copy_.default, tmp, src)
160
+ return inp # type: ignore[return-value]
161
+
162
+
163
+ # View ops whose view_scatter op is lowered into mutations anyway,
164
+ # so is never a pessimisation to decompose.
165
+ _ALWAYS_MUTATING_SCATTER_OPS = {
166
+ aten.as_strided.default,
167
+ aten.diagonal.default,
168
+ }
169
+
170
+
171
+ def scatter_always_uses_mutation(node: torch.fx.Node) -> bool:
172
+ _, _, view_ops = node.args
173
+ return any(view.target in _ALWAYS_MUTATING_SCATTER_OPS for view in view_ops) # type: ignore[union-attr]
174
+
175
+
176
+ def should_reinplace_scatter(node: torch.fx.Node) -> bool:
177
+ """Choose between mutating and functional scatter decompositions
178
+
179
+ Reinplacing view scatter ops can be pessimising as it blocks fusion with the
180
+ input or output tensor computations. However, it is still profitable if the
181
+ input and output would have been realized anyway.
182
+
183
+ """
184
+ inp, src, view_ops = node.args
185
+
186
+ # Mutating scatter ops unconditionally realize input and output
187
+ if scatter_always_uses_mutation(node):
188
+ return True
189
+
190
+ if is_node_realized(inp) and is_node_realized(node): # type: ignore[arg-type]
191
+ return True
192
+
193
+ # If the output is copied back into the input, this forces both to be
194
+ # realized as the output is a user of the input
195
+ if inp.op in ("placeholder", "get_attr") and any( # type: ignore[union-attr]
196
+ user.target is aten.copy_.default and user.args[0] is inp for user in node.users
197
+ ):
198
+ return True
199
+
200
+ # Otherwise, assume fusions will make functional variants profitable
201
+ return False
202
+
203
+
204
+ def decompose_generalized_scatter(graph: torch.fx.Graph) -> None:
205
+ """Replace _generalized_scatter with normal aten ops"""
206
+ for node in itertools.chain(
207
+ graph.find_nodes(op="call_function", target=_generalized_scatter),
208
+ graph.find_nodes(op="call_function", target=_inplace_generalized_scatter),
209
+ ):
210
+ use_mutation = (
211
+ node.target is _inplace_generalized_scatter
212
+ or scatter_always_uses_mutation(node)
213
+ )
214
+
215
+ with graph.inserting_before(node):
216
+ if use_mutation:
217
+ new_node = _decompose_scatter_mutating(graph, node)
218
+ else:
219
+ new_node = _decompose_scatter_functional(graph, node)
220
+
221
+ node.replace_all_uses_with(new_node)
222
+ graph.erase_node(node)
223
+
224
+
225
+ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
226
+ """
227
+ This canonicalizes view scatter ops into a generalized form, defined as:
228
+ def scatter(inp, src, views):
229
+ tmp = inp.clone()
230
+ for view in views:
231
+ tmp = view(tmp)
232
+ tmp.copy_(src)
233
+
234
+ We also fuse consecutive view scatter ops of the form
235
+ a = scatter(view2(self), src, [view1])
236
+ b = scatter(self, a, [view2])
237
+ which can be rewritten as
238
+ b = scatter(self, src, [view2, view1])
239
+ a = view2(b)
240
+
241
+ This is both more efficient as we only do a single scatter, and also
242
+ easier to reinplace since there is only one use of `self`
243
+ """
244
+
245
+ node_to_view_base: Dict[torch.fx.Node, torch.fx.Node] = {}
246
+ node_to_view_op: Dict[torch.fx.Node, List[ViewOp]] = defaultdict(list)
247
+
248
+ def handle_views(node: torch.fx.Node):
249
+ inp = node.args[0]
250
+ node_to_view_base[node] = node_to_view_base.get(inp, inp) # type: ignore[arg-type]
251
+ node_to_view_op[node] = [
252
+ *node_to_view_op[inp], # type: ignore[index]
253
+ ViewOp(
254
+ node.target, # type: ignore[arg-type]
255
+ args=node.args[1:],
256
+ kwargs=node.kwargs,
257
+ ),
258
+ ]
259
+
260
+ def handle_view_scatter(node: torch.fx.Node):
261
+ assert len(node.args) >= 2
262
+ inp, src = node.args[:2]
263
+
264
+ scatter_view_op = ViewOp(
265
+ _SCATTER_OP_TO_VIEW[node.target],
266
+ args=node.args[2:],
267
+ kwargs=node.kwargs,
268
+ )
269
+
270
+ def can_fuse():
271
+ if src.target is not _generalized_scatter: # type: ignore[union-attr]
272
+ return False
273
+ src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr]
274
+
275
+ inp_base = node_to_view_base.get(inp, inp) # type: ignore[arg-type]
276
+ src_base = node_to_view_base.get(src_inp, src_inp) # type: ignore[arg-type]
277
+ return inp_base is src_base and node_to_view_op[src_inp] == [ # type: ignore[index]
278
+ *node_to_view_op[inp], # type: ignore[index]
279
+ scatter_view_op,
280
+ ]
281
+
282
+ if not can_fuse():
283
+ with graph.inserting_before(node):
284
+ new_node = graph_call_function(
285
+ graph,
286
+ _generalized_scatter,
287
+ inp,
288
+ src,
289
+ [scatter_view_op],
290
+ )
291
+ node.replace_all_uses_with(new_node)
292
+ graph.erase_node(node)
293
+ return
294
+
295
+ src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr]
296
+ with graph.inserting_before(src): # type: ignore[arg-type]
297
+ new_node = graph_call_function(
298
+ graph,
299
+ _generalized_scatter,
300
+ inp,
301
+ src_src,
302
+ [scatter_view_op, *src_scatter_view_op], # type: ignore[misc]
303
+ )
304
+ node.replace_all_uses_with(new_node)
305
+ graph.erase_node(node)
306
+
307
+ if src.users: # type: ignore[union-attr]
308
+ new_src = graph_call_function(
309
+ graph,
310
+ _SCATTER_OP_TO_VIEW[node.target],
311
+ new_node,
312
+ *node.args[2:],
313
+ **node.kwargs,
314
+ )
315
+
316
+ handle_views(new_src)
317
+ src.replace_all_uses_with(new_src) # type: ignore[union-attr]
318
+
319
+ graph.erase_node(src) # type: ignore[arg-type]
320
+
321
+ for node in graph.nodes:
322
+ if _is_view_op(node.target):
323
+ handle_views(node)
324
+ elif node.target in _SCATTER_OP_TO_VIEW:
325
+ handle_view_scatter(node)
326
+
327
+
328
+ inplaceable_ops = {
329
+ aten.index_put.default: InplaceableOp(aten.index_put_.default, 0),
330
+ aten._unsafe_index_put.default: InplaceableOp(inductor_prims._unsafe_index_put_, 0),
331
+ _generalized_scatter: InplaceableOp(
332
+ _inplace_generalized_scatter,
333
+ 0,
334
+ extra_check=should_reinplace_scatter,
335
+ ),
336
+ }
337
+
338
+ try:
339
+ c10d_functional = torch.ops._c10d_functional
340
+ inplaceable_collective_ops = {
341
+ c10d_functional.all_reduce.default: InplaceableOp(
342
+ c10d_functional.all_reduce_.default, 0
343
+ ),
344
+ c10d_functional.all_reduce_coalesced.default: InplaceableOp(
345
+ c10d_functional.all_reduce_coalesced_.default, 0
346
+ ),
347
+ }
348
+ inplaceable_ops.update(inplaceable_collective_ops)
349
+ except AttributeError:
350
+ # _c10d_functional ops are only available when torch
351
+ # is built with USE_DISTRIBUTED=1.
352
+ pass
353
+
354
+ inplaceable_foreach_ops: Dict[torch._ops.OpOverload, InplaceableOp] = {}
355
+ for outplace_op, inplace_op in inplaceable_foreach_ops_lowerings.items():
356
+ inplaceable_foreach_ops[outplace_op] = InplaceableOp(inplace_op, 0)
357
+
358
+
359
+ inplaceable_triton_ops = {triton_kernel_wrapper_functional}
360
+
361
+
362
+ # Operators that don't depend on the tensor data
363
+ META_ONLY_OPS = {
364
+ aten.sym_size.int,
365
+ aten.sym_stride.int,
366
+ aten.sym_numel.default,
367
+ aten.sym_storage_offset.default,
368
+ }
369
+
370
+
371
+ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
372
+ """
373
+ Reinplaces in-placeable operations.
374
+ If there are no uses of a view of the mutated arg after the current node,
375
+ it is possible to inplace the op.
376
+ This above algorithm could be justified by observing side effects. While
377
+ we traverse the graph in forwards direction, only latter nodes could view
378
+ side effects of the current node. If the current node is not used later as
379
+ well as no view of this node is used later in the graph, then it is safe to
380
+ inplace as there would be no way to observe the side effects.
381
+ This condition is slightly different for graph inputs where they can only
382
+ be inplaced if the above condition is true and there's a copy_ in the
383
+ epilogue that signals that the caller wants to observe the mutation.
384
+
385
+ Unlike JIT Inductor, AOTInductor currently unlifts weights and buffers from
386
+ input args, so instead of checking mutation on placeholder, AOTInductor
387
+ checks mutation on get_attr. This is subject to change in future.
388
+ """
389
+
390
+ copy_args_to_copy_nodes = {}
391
+ # maps argument to the first copy_ node that mutates it.
392
+ copy_nodes = {}
393
+ mutated_inputs = set()
394
+ storage_to_nodes = defaultdict(list)
395
+ node_order: Dict[Any, int] = {}
396
+ for i, node in enumerate(reversed(graph.nodes)):
397
+ node_order[node] = len(graph.nodes) - i - 1
398
+ storage_to_nodes[get_node_storage(node)].append(node)
399
+ if node.target == aten.copy_.default and node.args[0].op in (
400
+ "placeholder",
401
+ "get_attr",
402
+ ):
403
+ dst = node.args[0]
404
+ src = node.args[1]
405
+ # If the target is a getitem and it indexes a possible clone,
406
+ # then skip over it
407
+ if src.target == operator.getitem and (
408
+ (
409
+ src.args[0].target == triton_kernel_wrapper_functional
410
+ and src.args[0].kwargs["kwargs"][src.args[1]] == node.args[0]
411
+ )
412
+ or (src.args[0].target in inplaceable_foreach_ops)
413
+ or (src.args[0].target == torch.ops.higher_order.auto_functionalized)
414
+ ):
415
+ src = src.args[0]
416
+
417
+ copy_args_to_copy_nodes[(dst, src)] = node
418
+ copy_nodes[dst] = node
419
+
420
+ mutated_inputs.add(node.args[0])
421
+
422
+ def any_use_of_views_after_node(node, shared_view_nodes, *, copy_node, mutated_arg):
423
+ node_loc = node_order[node]
424
+ copy_node_loc = node_order[copy_node] if copy_node is not None else None
425
+
426
+ def is_meta_only_user(node):
427
+ if _is_view_op(node.target):
428
+ return all(is_meta_only_user(u) for u in node.users)
429
+ return node.target in META_ONLY_OPS
430
+
431
+ for view in shared_view_nodes:
432
+ for user in view.users:
433
+ user_loc = node_order[user]
434
+ # Skip all users before node
435
+ if user_loc <= node_loc:
436
+ continue
437
+ # Ignore uses after the copy_ epilogue node, where the input
438
+ # has already been mutated anyway
439
+ if copy_node_loc is not None and copy_node_loc <= user_loc:
440
+ continue
441
+ # Reinplacing does not change shape metadata
442
+ if is_meta_only_user(user):
443
+ continue
444
+ # If our graph looks like:
445
+ # foo(mutated_arg)
446
+ # mutated_arg.copy_(other)
447
+ # then it's safe for us to reinplace foo because mutated_arg
448
+ # will get overwritten anyways.
449
+ if (
450
+ user.target is torch.ops.aten.copy_.default
451
+ and mutated_arg is user.args[0]
452
+ ):
453
+ continue
454
+ return True
455
+ return False
456
+
457
+ def can_inplace(node, mutated_arg):
458
+ if isinstance(mutated_arg, (list, tuple)):
459
+ unique_storages = {get_node_storage(arg) for arg in mutated_arg}
460
+ if len(unique_storages) != len(mutated_arg):
461
+ # at least two Tensors in mutated_arg alias each other, so we can't reinplace it.
462
+ # We can probably do better (that is, reinplace one of them and clone the other)
463
+ # but that requires more work and mutable List[Tensor] are not that common.
464
+ return False
465
+ return all(can_inplace(node, arg) for arg in mutated_arg)
466
+
467
+ if get_node_storage(mutated_arg) is None:
468
+ return False
469
+ shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)]
470
+
471
+ if mutated_arg.op in ("placeholder", "get_attr"):
472
+ # Get the first copy_ node that mutates the mutated_arg.
473
+ copy_node = copy_nodes.get(mutated_arg, None)
474
+ if copy_node is None:
475
+ # There is no copy_ back to the candidate mutated_arg (which is a graph input).
476
+ # Therefore the semantics of the program are that it does not mutate
477
+ # mutated_arg, so we cannot re-inplace it.
478
+ return False
479
+ if any_use_of_views_after_node(
480
+ node, shared_view_nodes, copy_node=copy_node, mutated_arg=mutated_arg
481
+ ):
482
+ return False
483
+
484
+ return True
485
+ elif any(view.op in ("placeholder", "get_attr") for view in shared_view_nodes):
486
+ # This should never happen in auto_functionalize_v2 non-inference mode,
487
+ # since all mutated_arg are bases.
488
+
489
+ # If mutated arg is view of any of the inputs of the graph,
490
+ # do not allow for inplacing.
491
+ # This would require more sophisticated algorithm to handle
492
+ return False
493
+ else:
494
+ return not any_use_of_views_after_node(
495
+ node, shared_view_nodes, copy_node=None, mutated_arg=mutated_arg
496
+ )
497
+
498
+ def log_inplace_results(
499
+ node_name,
500
+ old_tensors_to_clone,
501
+ tensors_to_clone,
502
+ possibly_missed_reinplacing_opportunities,
503
+ ):
504
+ log.info(
505
+ "For node %s, attempted to reinplace %s. We were unable to reinplace %s; "
506
+ "%s (if non-empty) are possible missed reinplacing opportunities that may be bad for "
507
+ "memory usage and performance.",
508
+ node_name,
509
+ old_tensors_to_clone,
510
+ tensors_to_clone,
511
+ possibly_missed_reinplacing_opportunities,
512
+ )
513
+ torch._dynamo.utils.counters["inductor"][
514
+ "possibly_missed_reinplacing_opportunities"
515
+ ] += len(possibly_missed_reinplacing_opportunities)
516
+
517
+ replace_dict: Dict[torch.fx.Node, torch.fx.Node] = {}
518
+
519
+ def reinplace_and_refine_tensors_to_clone(
520
+ old_tensors_to_clone, kwargs, node_name, auto_functionalize_v2=False
521
+ ):
522
+ tensors_to_clone: List[str] = []
523
+ storage_of_reinplaced_args = set()
524
+ possibly_missed_reinplacing_opportunities = []
525
+
526
+ def tensor_with_same_storage_already_reinplaced(arg):
527
+ if isinstance(arg, (list, tuple)):
528
+ return any(
529
+ get_node_storage(a) in storage_of_reinplaced_args for a in arg
530
+ )
531
+ return get_node_storage(mutated_arg) in storage_of_reinplaced_args
532
+
533
+ for arg in old_tensors_to_clone:
534
+ assert arg in kwargs
535
+
536
+ mutated_arg = kwargs[arg]
537
+
538
+ # Let's say we have:
539
+ # - op(x, y) that mutates both x and y
540
+ # - new_x, new_y = functional_op(x, y) is the functional variant
541
+ # If we are presented with functional_op(x, x), we must not reinplace
542
+ # this into op(x, x), because then it would be writing to the same Tensor.
543
+ # Instead, it's OK to reinplace one of them and to clone the other:
544
+ # >>> y = x.clone()
545
+ # >>> op(x, y)
546
+ # This also applies if we have views: functional_op(x, x[0])
547
+ # should not reinplace into op(x, x[0]).
548
+ should_attempt_reinplace = not tensor_with_same_storage_already_reinplaced(
549
+ mutated_arg
550
+ )
551
+ if should_attempt_reinplace and can_inplace(node, mutated_arg):
552
+ # In general, we probably do not need those optimizations.
553
+ copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
554
+ if copy_node is not None:
555
+ replace_dict[copy_node] = copy_node.args[0]
556
+ if not auto_functionalize_v2:
557
+ for user in node.users:
558
+ # For auto_functionalize_v2, arg is the index of the base, where base at index i corresponds to
559
+ # output atindex size(out)+i.
560
+ # This used to compare string with integers before for auto_functionalize_v2. Not sure
561
+ # if it was needed for inplaceable_triton_ops?
562
+ if user.target == operator.getitem and user.args[1] == arg:
563
+ replace_dict[user] = mutated_arg
564
+
565
+ if isinstance(mutated_arg, (list, tuple)):
566
+ for a in mutated_arg:
567
+ storage_of_reinplaced_args.add(get_node_storage(a))
568
+ else:
569
+ storage_of_reinplaced_args.add(get_node_storage(mutated_arg))
570
+ else:
571
+ if should_attempt_reinplace:
572
+ possibly_missed_reinplacing_opportunities.append(arg)
573
+ tensors_to_clone.append(arg)
574
+
575
+ log_inplace_results(
576
+ node_name,
577
+ old_tensors_to_clone,
578
+ tensors_to_clone,
579
+ possibly_missed_reinplacing_opportunities,
580
+ )
581
+ return tensors_to_clone
582
+
583
+ for node in graph.nodes:
584
+ if (inplaceable_op := inplaceable_ops.get(node.target, None)) is not None:
585
+ mutated_arg = node.args[inplaceable_op.mutated_arg]
586
+ if can_inplace(node, mutated_arg) and inplaceable_op.extra_check(node):
587
+ # TODO(yifu): this doesn't properly remove copy epilogues for
588
+ # ops that mutate multiple inputs. Need to revise the copy
589
+ # node tracking logic to support the case.
590
+ copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
591
+ if copy_node is not None:
592
+ replace_dict[copy_node] = copy_node.args[0]
593
+ node.target = inplaceable_op.inplace_op
594
+ elif node.target == torch.ops.higher_order.auto_functionalized_v2:
595
+ _mutable_op = node.args[0]
596
+ kwargs = node.kwargs
597
+
598
+ all_bases = kwargs["_all_bases"]
599
+ bases_to_clone = range(len(all_bases))
600
+ base_tensors_dct = dict(enumerate(all_bases))
601
+ new_bases_to_clone: List[int] = reinplace_and_refine_tensors_to_clone(
602
+ bases_to_clone,
603
+ base_tensors_dct,
604
+ node.target,
605
+ auto_functionalize_v2=True,
606
+ )
607
+ # Stash the metadata. There is a pass later on where we decompose
608
+ # auto_functionalized into clones + a mutable op; this metadata
609
+ # tells the decomp to only clone the following inputs
610
+ node.meta["only_clone_these_tensors"] = new_bases_to_clone
611
+ elif node.target == torch.ops.higher_order.auto_functionalized:
612
+ _mutable_op = node.args[0]
613
+ from torch._higher_order_ops.auto_functionalize import get_mutable_args
614
+
615
+ tensors_to_clone, _ = get_mutable_args(_mutable_op)
616
+ # Don't try to reinplace Optional[Tensor] args that are None.
617
+ tensors_to_clone = [
618
+ t for t in tensors_to_clone if node.kwargs[t] is not None
619
+ ]
620
+ tensors_to_clone = reinplace_and_refine_tensors_to_clone(
621
+ tensors_to_clone,
622
+ node.kwargs,
623
+ _mutable_op._name,
624
+ auto_functionalize_v2=False,
625
+ )
626
+
627
+ # Stash the metadata. There is a pass later on where we decompose
628
+ # auto_functionalized into clones + a mutable op; this metadata
629
+ # tells the decomp to only clone the following inputs
630
+ node.meta["only_clone_these_tensors"] = tensors_to_clone
631
+ elif node.target in inplaceable_triton_ops:
632
+ kernel_idx = node.kwargs["kernel_idx"]
633
+ kernel = kernel_side_table.get_kernel(kernel_idx)
634
+ from triton.runtime.autotuner import Autotuner
635
+ from triton.runtime.jit import JITFunction
636
+
637
+ if isinstance(kernel, JITFunction):
638
+ kernel_name = kernel.fn.__name__
639
+ elif isinstance(kernel, Autotuner):
640
+ if config.is_fbcode():
641
+ # Autotuner has different implementations for AMD and NV
642
+ if torch.version.hip is None:
643
+ kernel_name = kernel.base_fn.__name__
644
+ else:
645
+ kernel_name = kernel.fn.__name__
646
+ else:
647
+ kernel_name = kernel.base_fn.__name__
648
+ else:
649
+ raise AssertionError("Unknown triton kernel type")
650
+
651
+ # inplaceable_triton_ops take an additional argument called
652
+ # tensors_to_clone which contain a list of tensors to clone
653
+ # This pass iterates over them and sees which ones are safe
654
+ # to eliminate (i.e. no longer need the clones)
655
+ tensors_to_clone = reinplace_and_refine_tensors_to_clone(
656
+ node.kwargs["tensors_to_clone"], node.kwargs["kwargs"], kernel_name
657
+ )
658
+
659
+ kwargs = dict(node.kwargs)
660
+ kwargs["tensors_to_clone"] = tensors_to_clone
661
+ node.kwargs = immutable_dict(kwargs)
662
+ elif (
663
+ inplaceable_op := inplaceable_foreach_ops.get(node.target, None)
664
+ ) is not None:
665
+ mutated_args = node.args[inplaceable_op.mutated_arg]
666
+
667
+ if not all((arg, node) in copy_args_to_copy_nodes for arg in mutated_args):
668
+ continue
669
+
670
+ if can_inplace(node, mutated_args):
671
+ for arg in mutated_args:
672
+ copy_node = copy_args_to_copy_nodes[(arg, node)]
673
+ replace_dict[copy_node] = copy_node.args[0]
674
+
675
+ node.target = inplaceable_op.inplace_op
676
+ for node, replacement in replace_dict.items():
677
+ while replacement in replace_dict:
678
+ replacement = replace_dict[replacement]
679
+ replace_dict[node] = replacement
680
+
681
+ node.replace_all_uses_with(replacement)
682
+ graph.erase_node(node)
683
+
684
+
685
+ def reinplace_inplaceable_ops(graph: torch.fx.Graph) -> None:
686
+ canonicalize_view_scatter_ops(graph)
687
+ reinplace_inplaceable_ops_core(graph)
688
+ decompose_generalized_scatter(graph)
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/replace_random.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import collections
3
+ import logging
4
+
5
+ import torch
6
+ from torch.fx.passes.graph_transform_observer import GraphTransformObserver
7
+ from torch.fx.passes.shape_prop import _extract_tensor_metadata
8
+
9
+ from .. import config, inductor_prims
10
+ from ..pattern_matcher import (
11
+ CallFunctionVarArgs,
12
+ Match,
13
+ PatternMatcherPass,
14
+ register_graph_pattern,
15
+ )
16
+ from ..virtualized import V
17
+
18
+
19
+ log = logging.getLogger(__name__)
20
+ patterns = PatternMatcherPass()
21
+ aten = torch.ops.aten
22
+
23
+
24
+ def replace_random_passes(gm: torch.fx.GraphModule):
25
+ """Modify the given FX graph to use backend-native random ops"""
26
+ if config.fallback_random:
27
+ return 0
28
+
29
+ count = patterns.apply(gm)
30
+ with GraphTransformObserver(
31
+ gm, "fuse_seed_creation_pass", config.trace.log_url_for_graph_xform
32
+ ):
33
+ count += fuse_seed_creation_pass(gm.graph)
34
+
35
+ return count
36
+
37
+
38
+ def fuse_seed_creation_pass(graph: torch.fx.Graph):
39
+ """
40
+ Horizontally fuse all the seed generation on each device
41
+
42
+ a = inductor_seed(dev)
43
+ b = inductor_seed(dev)
44
+
45
+ Becomes:
46
+ seeds = inductor_seeds(2, dev)
47
+ a = inductor_lookup_seed(seeds, 0)
48
+ b = inductor_lookup_seed(seeds, 1)
49
+
50
+ We do this because seed creation is entirely launch overhead bound.
51
+ """
52
+ device_seeds = collections.defaultdict(list)
53
+ for node in graph.nodes:
54
+ if CallFunctionVarArgs(inductor_prims.seed).match(node):
55
+ device_seeds[node.args[0]].append(node)
56
+
57
+ if not device_seeds:
58
+ return 0
59
+
60
+ for device, seeds in device_seeds.items():
61
+ with graph.inserting_before(seeds[0]):
62
+ combined = graph.call_function(inductor_prims.seeds, (len(seeds), device))
63
+ with V.fake_mode:
64
+ combined.meta["val"] = torch.empty(
65
+ [len(seeds)], device=device, dtype=torch.int64
66
+ )
67
+ combined.meta["tensor_meta"] = _extract_tensor_metadata(
68
+ combined.meta["val"]
69
+ )
70
+
71
+ for idx, seed in enumerate(seeds):
72
+ with graph.inserting_before(seed):
73
+ new_seed = graph.call_function(
74
+ inductor_prims.lookup_seed, (combined, idx)
75
+ )
76
+ seed.replace_all_uses_with(new_seed)
77
+ new_seed.meta.update(seed.meta)
78
+ graph.erase_node(seed)
79
+
80
+ return len(device_seeds)
81
+
82
+
83
+ def default_kwargs(device):
84
+ return {}
85
+
86
+
87
+ def get_device(device):
88
+ if device is not None:
89
+ return device
90
+ return torch.empty([]).device # default device
91
+
92
+
93
+ @register_graph_pattern(CallFunctionVarArgs(aten.rand.default), pass_dict=patterns)
94
+ @register_graph_pattern(CallFunctionVarArgs(aten.rand.generator), pass_dict=patterns)
95
+ @register_graph_pattern(CallFunctionVarArgs(aten.randn.default), pass_dict=patterns)
96
+ @register_graph_pattern(CallFunctionVarArgs(aten.randn.generator), pass_dict=patterns)
97
+ def replace_random(
98
+ match: Match,
99
+ size,
100
+ *,
101
+ generator=None,
102
+ dtype=None,
103
+ device=None,
104
+ layout=None,
105
+ pin_memory=None,
106
+ ):
107
+ if generator is not None:
108
+ return
109
+
110
+ def replacement(size):
111
+ result = inductor_prims.random(
112
+ size, inductor_prims.seed(device), mode, **default_kwargs(device)
113
+ )
114
+ if dtype is not None:
115
+ result = result.to(dtype)
116
+ return result
117
+
118
+ mode = {
119
+ aten.rand: "rand",
120
+ aten.randn: "randn",
121
+ }[
122
+ match.output_node().target.overloadpacket # type: ignore[union-attr]
123
+ ] # type: ignore[union-attr]
124
+ device = get_device(device)
125
+ match.replace_by_example(replacement, [size])
126
+
127
+
128
+ @register_graph_pattern(CallFunctionVarArgs(aten.randint.low), pass_dict=patterns)
129
+ def replace_randint(
130
+ match: Match,
131
+ low,
132
+ high,
133
+ size,
134
+ *,
135
+ dtype=torch.int64,
136
+ device=None,
137
+ layout=None,
138
+ pin_memory=None,
139
+ ):
140
+ def replacement(low, high, size):
141
+ result = inductor_prims.randint(low, high, size, inductor_prims.seed(device))
142
+ return result.to(dtype)
143
+
144
+ device = get_device(device)
145
+ match.replace_by_example(replacement, [low, high, size])
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (218 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-311.pyc ADDED
Binary file (13.6 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-311.pyc ADDED
Binary file (16.8 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-311.pyc ADDED
Binary file (16.6 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-311.pyc ADDED
Binary file (18.2 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-311.pyc ADDED
Binary file (9.56 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-311.pyc ADDED
Binary file (17 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-311.pyc ADDED
Binary file (19.3 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-311.pyc ADDED
Binary file (49.5 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-311.pyc ADDED
Binary file (20.8 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_18.cpython-311.pyc ADDED
Binary file (37.7 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_19.cpython-311.pyc ADDED
Binary file (17.5 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-311.pyc ADDED
Binary file (13.6 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-311.pyc ADDED
Binary file (15.2 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-311.pyc ADDED
Binary file (15.2 kB). View file