koichi12 commited on
Commit
c169562
·
verified ·
1 Parent(s): 445c885

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-311.pyc +0 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-311.pyc +0 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_foreach.cpython-311.pyc +0 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/implementation.cpp +87 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp +354 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-311.pyc +0 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py +78 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/fuse_attention.py +786 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py +1059 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/joint_graph.py +341 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pre_grad.py +611 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/debug.cpython-311.pyc +0 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-311.pyc +0 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-311.pyc +0 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-311.pyc +0 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/match.py +121 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-311.pyc +0 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py +91 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/unification_tools.py +395 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/utils.py +105 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/variable.py +85 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSGeneratorImpl.h +52 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSHooks.h +57 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSStream.h +133 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_amp_foreach_non_finite_check_and_unscale_ops.h +50 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_copy_from_native.h +21 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_empty_per_channel_affine_quantized.h +113 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_fft_c2r_cuda_dispatch.h +28 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_flash_attention_forward.h +47 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_addcmul_native.h +35 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_erfc_cuda_dispatch.h +24 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_round.h +44 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_sin_native.h +25 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_tanh_cpu_dispatch.h +24 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_functional_assert_scalar_native.h +21 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_linalg_svd_ops.h +39 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nested_from_padded.h +39 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nnpack_spatial_convolution_compositeexplicitautograd_dispatch.h +28 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_prelu_kernel_backward_cuda_dispatch.h +23 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_transform_bias_rescale_qkv.h +39 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/acos_meta_dispatch.h +26 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/alias_ops.h +28 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/any_meta_dispatch.h +31 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/batch_norm_update_stats_cuda_dispatch.h +23 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/ccol_indices_compositeexplicitautograd_dispatch.h +23 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/column_stack_native.h +22 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/concatenate.h +53 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cudnn_convolution_add_relu_ops.h +39 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cudnn_grid_sampler_backward_cuda_dispatch.h +23 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/diag_embed_compositeexplicitautograd_dispatch.h +24 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-311.pyc ADDED
Binary file (91.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-311.pyc ADDED
Binary file (5.26 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_foreach.cpython-311.pyc ADDED
Binary file (15.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/implementation.cpp ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // NOTE: Like interface.cpp, this file will be copied into AOTInductor
2
+ // generated output. This file is intended to keep implementation
3
+ // details separate from the implementation of the AOTI public
4
+ // interface. Note also that #includes should go into interface.cpp
5
+ // for simplicity of maintenance.
6
+
7
+ namespace torch {
8
+ namespace aot_inductor {
9
+ template <typename T>
10
+ void convert_output_to_handle(
11
+ const ArrayRefTensor<T>& output,
12
+ AtenTensorHandle& handle) {
13
+ handle = output.expensiveCopyToTensor();
14
+ }
15
+
16
+ template <typename... Ts, std::size_t... Is>
17
+ void convert_outputs_to_handles_helper(
18
+ const std::tuple<ArrayRefTensor<Ts>...>& outputs,
19
+ AtenTensorHandle* output_handles,
20
+ std::index_sequence<Is...>) {
21
+ (convert_output_to_handle(std::get<Is>(outputs), output_handles[Is]), ...);
22
+ }
23
+ template <typename... Ts>
24
+ void convert_outputs_to_handles(
25
+ const std::tuple<ArrayRefTensor<Ts>...>& outputs,
26
+ AtenTensorHandle* output_handles) {
27
+ convert_outputs_to_handles_helper(
28
+ outputs, output_handles, std::make_index_sequence<sizeof...(Ts)>());
29
+ }
30
+
31
+ template <typename T>
32
+ void convert_handle_to_arrayref_tensor(
33
+ AtenTensorHandle handle,
34
+ ArrayRefTensor<T>& input) {
35
+ void* data_ptr;
36
+ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle, &data_ptr));
37
+ int64_t dim;
38
+ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(handle, &dim));
39
+ int64_t numel;
40
+ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(handle, &numel));
41
+ int64_t* sizes;
42
+ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle, &sizes));
43
+ int64_t* strides;
44
+ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle, &strides));
45
+ int32_t dtype;
46
+ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(handle, &dtype));
47
+ int32_t device_type;
48
+ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(handle, &device_type));
49
+ int32_t device_index;
50
+ AOTI_TORCH_ERROR_CODE_CHECK(
51
+ aoti_torch_get_device_index(handle, &device_index));
52
+
53
+ input = ArrayRefTensor<T>(
54
+ MiniArrayRef<T>(reinterpret_cast<T*>(data_ptr), numel),
55
+ MiniArrayRef<const int64_t>(sizes, dim),
56
+ MiniArrayRef<const int64_t>(strides, dim),
57
+ device_type,
58
+ device_index);
59
+ }
60
+
61
+ template <typename... Ts, std::size_t... Is>
62
+ void convert_handles_to_inputs_helper(
63
+ AtenTensorHandle* input_handles,
64
+ std::tuple<ArrayRefTensor<Ts>...>& inputs,
65
+ std::index_sequence<Is...>) {
66
+ (convert_handle_to_arrayref_tensor(input_handles[Is], std::get<Is>(inputs)),
67
+ ...);
68
+ }
69
+
70
+ template <typename... Ts>
71
+ void convert_handles_to_inputs(
72
+ AtenTensorHandle* input_handles,
73
+ std::tuple<ArrayRefTensor<Ts>...>& inputs) {
74
+ convert_handles_to_inputs_helper(
75
+ input_handles, inputs, std::make_index_sequence<sizeof...(Ts)>());
76
+ }
77
+
78
+ template <typename T>
79
+ void assert_numel(const ArrayRefTensor<T>& tensor, int64_t numel) {
80
+ if (tensor.numel() != numel) {
81
+ std::stringstream err;
82
+ err << "incorrect numel for input tensor. expected " << numel << ", got " << tensor.numel();
83
+ throw std::runtime_error(err.str());
84
+ }
85
+ }
86
+ } // namespace aot_inductor
87
+ } // namespace torch
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h>
2
+ #include <torch/csrc/inductor/aoti_runtime/interface.h>
3
+ #include <torch/csrc/inductor/aoti_runtime/model_container.h>
4
+ #include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h>
5
+ #include <torch/csrc/inductor/aoti_runtime/thread_local.h>
6
+
7
+ #include <iostream>
8
+ #include <sstream>
9
+ #include <stdexcept>
10
+ #include <vector>
11
+
12
+ #define CONVERT_EXCEPTION_TO_ERROR_CODE(...) \
13
+ try { \
14
+ __VA_ARGS__ \
15
+ } catch (const std::exception& e) { \
16
+ std::cerr << "Error: " << e.what() << std::endl; \
17
+ return AOTI_RUNTIME_FAILURE; \
18
+ } catch (...) { \
19
+ std::cerr << "Unknown exception occurred." << std::endl; \
20
+ return AOTI_RUNTIME_FAILURE; \
21
+ } \
22
+ return AOTI_RUNTIME_SUCCESS;
23
+
24
+ #define AOTI_VECTOR_SIZE_CHECK(actual_size, expected_size, name) \
25
+ do { \
26
+ AOTI_RUNTIME_CHECK( \
27
+ actual_size == expected_size, \
28
+ "expected " + std::string(name) + " vector size to be " + \
29
+ std::to_string(expected_size) + ", but got " + \
30
+ std::to_string(actual_size)); \
31
+ } while (0)
32
+
33
+ // AOTInductor uses at::addmm_out, which doesn't supports
34
+ // arguments that requires gradient. For this reason, we
35
+ // enforce no_grad context for run APIs.
36
+ //
37
+ // A RAII, thread local (!) guard that enables or disables grad mode upon
38
+ // construction, and sets it back to the original value upon destruction.
39
+ struct AOTINoGradGuard {
40
+ AOTINoGradGuard() : prev_mode(aoti_torch_grad_mode_is_enabled()) {
41
+ aoti_torch_grad_mode_set_enabled(false);
42
+ }
43
+ ~AOTINoGradGuard() {
44
+ aoti_torch_grad_mode_set_enabled(prev_mode);
45
+ }
46
+ bool prev_mode;
47
+ };
48
+
49
+ extern "C" {
50
+
51
+ AOTIRuntimeError AOTInductorModelContainerCreate(
52
+ AOTInductorModelContainerHandle* container_handle,
53
+ size_t num_models,
54
+ bool is_cpu,
55
+ const char* cubin_dir) {
56
+ return AOTInductorModelContainerCreateWithDevice(
57
+ container_handle,
58
+ num_models,
59
+ is_cpu ? "cpu" : "cuda",
60
+ cubin_dir);
61
+ }
62
+
63
+ AOTIRuntimeError AOTInductorModelContainerCreateWithDevice(
64
+ AOTInductorModelContainerHandle* container_handle,
65
+ size_t num_models,
66
+ const char* device_str,
67
+ const char* cubin_dir) {
68
+ if (num_models == 0) {
69
+ std::cerr << "Error: num_models must be positive, but got 0" << std::endl;
70
+ return AOTI_RUNTIME_FAILURE;
71
+ }
72
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
73
+ std::optional<std::string> cubin_dir_opt;
74
+ if (cubin_dir != nullptr) {
75
+ cubin_dir_opt.emplace(cubin_dir);
76
+ }
77
+ auto* container = new torch::aot_inductor::AOTInductorModelContainer(
78
+ num_models, std::string(device_str), cubin_dir_opt);
79
+ *container_handle =
80
+ reinterpret_cast<AOTInductorModelContainerHandle>(container);
81
+ })
82
+ }
83
+
84
+ AOTIRuntimeError AOTInductorModelContainerDelete(
85
+ AOTInductorModelContainerHandle container_handle) {
86
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
87
+ auto* container =
88
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
89
+ container_handle);
90
+ delete container;
91
+ });
92
+ }
93
+
94
+ AOTIRuntimeError AOTInductorModelContainerRun(
95
+ AOTInductorModelContainerHandle container_handle,
96
+ AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles
97
+ // are stolen; the array itself is borrowed
98
+ size_t num_inputs,
99
+ AtenTensorHandle*
100
+ output_handles, // array for writing output AtenTensorHandle; handles
101
+ // will be stolen by the caller; the array itself is
102
+ // borrowed
103
+ size_t num_outputs,
104
+ AOTInductorStreamHandle stream_handle,
105
+ AOTIProxyExecutorHandle proxy_executor_handle) {
106
+ auto* container =
107
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
108
+ container_handle);
109
+ AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs");
110
+ AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs");
111
+
112
+ auto stream =
113
+ reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
114
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
115
+ AOTINoGradGuard guard;
116
+ container->run(
117
+ input_handles, output_handles, stream, proxy_executor_handle);
118
+ })
119
+ }
120
+
121
+ AOTIRuntimeError AOTInductorModelContainerGetNumConstants(
122
+ AOTInductorModelContainerHandle container_handle,
123
+ size_t* num_constants) {
124
+ auto* container =
125
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
126
+ container_handle);
127
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
128
+ { *num_constants = container->num_constants(); })
129
+ }
130
+
131
+ AOTIRuntimeError AOTInductorModelContainerGetConstantName(
132
+ AOTInductorModelContainerHandle container_handle,
133
+ size_t idx,
134
+ const char** name) {
135
+ auto* container =
136
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
137
+ container_handle);
138
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
139
+ { *name = container->constant_name(idx); })
140
+ }
141
+
142
+ AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN(
143
+ AOTInductorModelContainerHandle container_handle,
144
+ size_t idx,
145
+ const char** original_fqn) {
146
+ auto* container =
147
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
148
+ container_handle);
149
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
150
+ { *original_fqn = container->constant_original_fqn(idx); })
151
+ }
152
+
153
+ AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded(
154
+ AOTInductorModelContainerHandle container_handle,
155
+ size_t idx,
156
+ bool* from_folded) {
157
+ auto* container =
158
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(container_handle);
159
+ CONVERT_EXCEPTION_TO_ERROR_CODE({ *from_folded = container->constant_from_folded(idx); })
160
+ }
161
+
162
+ AOTIRuntimeError AOTInductorModelContainerGetConstantDtype(
163
+ AOTInductorModelContainerHandle container_handle,
164
+ size_t idx,
165
+ int32_t* dtype) {
166
+ auto* container =
167
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
168
+ container_handle);
169
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
170
+ { *dtype = container->constant_dtype(idx); })
171
+ }
172
+
173
+ AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer(
174
+ AOTInductorModelContainerHandle container_handle,
175
+ AOTInductorConstantMapHandle constant_map_handle,
176
+ bool use_inactive,
177
+ bool validate_full_update) {
178
+ auto* container =
179
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
180
+ container_handle);
181
+ auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
182
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
183
+ container->update_constant_buffer(
184
+ *input_map, use_inactive, validate_full_update);
185
+ })
186
+ }
187
+
188
+ AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer(
189
+ AOTInductorModelContainerHandle container_handle,
190
+ AOTInductorConstantMapHandle constant_map_handle) {
191
+ return AOTInductorModelContainerUpdateConstantBuffer(container_handle,
192
+ constant_map_handle,
193
+ /*use_inactive*/ true,
194
+ /*validate_full_update*/ true);
195
+ }
196
+
197
+ AOTIRuntimeError AOTInductorModelContainerRunConstantFolding(
198
+ AOTInductorModelContainerHandle container_handle,
199
+ bool use_inactive,
200
+ AOTInductorStreamHandle stream_handle,
201
+ AOTIProxyExecutorHandle proxy_executor_handle) {
202
+ auto* container =
203
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
204
+ container_handle);
205
+ auto stream =
206
+ reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
207
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
208
+ AOTINoGradGuard guard;
209
+ container->run_const_fold(use_inactive, stream, proxy_executor_handle);
210
+ })
211
+ }
212
+
213
+ AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer(
214
+ AOTInductorModelContainerHandle container_handle) {
215
+ auto* container =
216
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
217
+ container_handle);
218
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
219
+ container->swap_constant_buffer();
220
+ })
221
+ }
222
+
223
+ AOTIRuntimeError AOTInductorModelContainerGetNumInputs(
224
+ AOTInductorModelContainerHandle container_handle,
225
+ size_t* ret_num_inputs) {
226
+ auto* container =
227
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
228
+ container_handle);
229
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
230
+ { *ret_num_inputs = container->num_inputs(); })
231
+ }
232
+
233
+ AOTIRuntimeError AOTInductorModelContainerGetInputName(
234
+ AOTInductorModelContainerHandle container_handle,
235
+ size_t input_idx,
236
+ const char** ret_input_names) {
237
+ auto* container =
238
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
239
+ container_handle);
240
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
241
+ { *ret_input_names = container->input_name(input_idx); })
242
+ }
243
+
244
+ AOTIRuntimeError AOTInductorModelContainerGetNumOutputs(
245
+ AOTInductorModelContainerHandle container_handle,
246
+ size_t* ret_num_outputs) {
247
+ auto* container =
248
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
249
+ container_handle);
250
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
251
+ { *ret_num_outputs = container->num_outputs(); })
252
+ }
253
+
254
+ AOTIRuntimeError AOTInductorModelContainerGetOutputName(
255
+ AOTInductorModelContainerHandle container_handle,
256
+ size_t output_idx,
257
+ const char** ret_output_names) {
258
+ auto* container =
259
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
260
+ container_handle);
261
+ CONVERT_EXCEPTION_TO_ERROR_CODE(
262
+ { *ret_output_names = container->output_name(output_idx); })
263
+ }
264
+
265
+ AOTIRuntimeError AOTInductorModelContainerGetCallSpec(
266
+ AOTInductorModelContainerHandle container_handle,
267
+ const char** in_spec,
268
+ const char** out_spec) {
269
+ auto* container =
270
+ reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
271
+ container_handle);
272
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
273
+ *in_spec = container->get_in_spec();
274
+ *out_spec = container->get_out_spec();
275
+ })
276
+ }
277
+
278
+ AOTIRuntimeError AOTInductorModelCreate(
279
+ AOTInductorModelHandle* model_handle,
280
+ AOTInductorConstantMapHandle constant_map_handle){
281
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
282
+ auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
283
+ auto constant_array = std::make_shared<std::vector<torch::aot_inductor::ConstantHandle>>();
284
+ auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
285
+
286
+ auto model = new torch::aot_inductor::AOTInductorModel(
287
+ constant_map,
288
+ constant_array,
289
+ "cpu", // device_str is hardcoded, as AOTInductorModelCreate is only use for CPU models
290
+ ""
291
+ );
292
+
293
+ if (input_map) {
294
+ for (auto const& kv : *input_map) {
295
+ constant_map->emplace(kv.first, kv.second);
296
+ }
297
+ } else {
298
+ model->load_constants();
299
+ }
300
+
301
+ *model_handle = reinterpret_cast<AOTInductorModelHandle>(model);
302
+ })}
303
+
304
+ AOTIRuntimeError AOTInductorModelRun(
305
+ AOTInductorModelHandle model_handle,
306
+ AtenTensorHandle* input_handles,
307
+ AtenTensorHandle* output_handles) {
308
+ auto model =
309
+ reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
310
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
311
+ AOTINoGradGuard guard;
312
+ model->run_impl(
313
+ input_handles,
314
+ output_handles,
315
+ (torch::aot_inductor::DeviceStreamType) nullptr,
316
+ nullptr);
317
+ })
318
+ }
319
+
320
+ AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle){
321
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
322
+ auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(
323
+ model_handle);
324
+ delete model;
325
+ })}
326
+
327
+ AOTIRuntimeError AOTInductorModelGetNumOutputs(
328
+ AOTInductorModelHandle model_handle,
329
+ size_t* ret_num_outputs) {
330
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
331
+ auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
332
+ *ret_num_outputs = model->num_outputs();
333
+ })
334
+ }
335
+
336
+ AOTIRuntimeError AOTInductorModelUpdateConstantsMap(
337
+ AOTInductorModelHandle model_handle,
338
+ AOTInductorConstantMapHandle constant_map_handle) {
339
+ auto model =
340
+ reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
341
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
342
+ auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
343
+ auto input_map =
344
+ reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(
345
+ constant_map_handle);
346
+
347
+ for (auto const& kv : *input_map) {
348
+ constant_map->emplace(kv.first, kv.second);
349
+ }
350
+ model->update_constants_map(std::move(constant_map));
351
+ })
352
+ }
353
+
354
+ } // extern "C"
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-311.pyc ADDED
Binary file (13.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Union
3
+
4
+ import torch
5
+ from torch.fx.experimental.proxy_tensor import py_sym_types, SymBool, SymFloat, SymInt
6
+
7
+
8
+ @dataclass
9
+ class _SymExprHash:
10
+ """
11
+ Hash for a py_sym_types that will use the underlying sympy expression
12
+ """
13
+
14
+ sym_obj: Union[SymInt, SymFloat, SymBool]
15
+
16
+ def __hash__(self) -> int:
17
+ return hash((type(self.sym_obj), self.sym_obj.node.expr))
18
+
19
+ def __eq__(self, value) -> bool:
20
+ if not isinstance(value, _SymExprHash):
21
+ return False
22
+ return self.sym_obj.node.expr == value.sym_obj.node.expr
23
+
24
+
25
+ class _SymHashingDict:
26
+ """
27
+ Wrapper around a dictionary that will convert sym types to hash with _SymExprHash and reuse
28
+ existing sym proxies.
29
+
30
+ SymPy hash is not always reliable so optimistically hash sympy expression, and if those fail,
31
+ fallback to symnodes.
32
+ """
33
+
34
+ def __init__(self):
35
+ self.sym_hash_dict = {}
36
+
37
+ def __setitem__(self, key, value):
38
+ self.sym_hash_dict.__setitem__(self._wrap_to_sym_expr_hash(key), value)
39
+
40
+ def __getitem__(self, key):
41
+ return self.sym_hash_dict[self._wrap_to_sym_expr_hash(key)]
42
+
43
+ def __contains__(self, key):
44
+ return self._wrap_to_sym_expr_hash(key) in self.sym_hash_dict
45
+
46
+ def get(self, key, default=None):
47
+ return self.sym_hash_dict.get(self._wrap_to_sym_expr_hash(key), default)
48
+
49
+ def _wrap_to_sym_expr_hash(self, key):
50
+ return _SymExprHash(key) if isinstance(key, py_sym_types) else key
51
+
52
+
53
+ def dedupe_symints(graph: torch.fx.Graph):
54
+ """
55
+ Dedupes sym ints in the graph to nodes are resolvable to symint graph inputs.
56
+
57
+ We only dedupe from graph inputs to avoid adding a potential dependency in the forward
58
+ from the backward.
59
+
60
+ """
61
+
62
+ sym_dict = _SymHashingDict()
63
+ resolvable_from_input_symints = set()
64
+
65
+ for node in graph.nodes:
66
+ val = node.meta.get("val", None)
67
+ if val is None or not isinstance(val, py_sym_types):
68
+ continue
69
+
70
+ if node.op == "placeholder":
71
+ resolvable_from_input_symints.add(node)
72
+ sym_dict[val] = node
73
+ elif existing_node := sym_dict.get(val):
74
+ node.replace_all_uses_with(existing_node)
75
+ graph.erase_node(node)
76
+ elif all(n in resolvable_from_input_symints for n in node.all_input_nodes):
77
+ sym_dict[val] = node
78
+ resolvable_from_input_symints.add(node)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/fuse_attention.py ADDED
@@ -0,0 +1,786 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import inspect
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+ from ..._dynamo.utils import counters
8
+ from ..pattern_matcher import (
9
+ filter_nodes,
10
+ fwd_only,
11
+ joint_fwd_bwd,
12
+ register_replacement,
13
+ )
14
+
15
+ log = logging.getLogger(__name__)
16
+ aten = torch.ops.aten
17
+
18
+
19
+ def _sfdp_pattern_1(query, key, value, inv_scale):
20
+ return (
21
+ torch.matmul(query, key.transpose(-2, -1))
22
+ .div(inv_scale)
23
+ .softmax(dim=-1)
24
+ .matmul(value)
25
+ )
26
+
27
+
28
+ def _sfdp_replacement_1(query, key, value, inv_scale):
29
+ counters["inductor"]["fuse_attention"] += 1
30
+ return aten.scaled_dot_product_attention(
31
+ query.contiguous(),
32
+ key.contiguous(),
33
+ value.contiguous(),
34
+ attn_mask=None,
35
+ dropout_p=0.0,
36
+ is_causal=False,
37
+ scale=1.0 / inv_scale,
38
+ )
39
+
40
+
41
+ def _sfdp_pattern_2(query, key, value, scale_factor):
42
+ return (
43
+ torch.matmul(query, key.transpose(-2, -1))
44
+ .mul(scale_factor)
45
+ .softmax(dim=-1)
46
+ .matmul(value)
47
+ )
48
+
49
+
50
+ def _sfdp_replacement_2(query, key, value, scale_factor):
51
+ counters["inductor"]["fuse_attention"] += 1
52
+ return aten.scaled_dot_product_attention(
53
+ query.contiguous(),
54
+ key.contiguous(),
55
+ value.contiguous(),
56
+ attn_mask=None,
57
+ dropout_p=0.0,
58
+ is_causal=False,
59
+ scale=scale_factor,
60
+ )
61
+
62
+
63
+ def _sfdp_pattern_3(query, key, value, inv_scale_factor, dropout_p):
64
+ return torch.nn.functional.dropout(
65
+ torch.matmul(query, key.transpose(-2, -1))
66
+ .div(inv_scale_factor)
67
+ .softmax(dim=-1),
68
+ p=dropout_p,
69
+ ).matmul(value)
70
+
71
+
72
+ def _sfdp_replacement_3(query, key, value, inv_scale_factor, dropout_p):
73
+ counters["inductor"]["fuse_attention"] += 1
74
+ return aten.scaled_dot_product_attention(
75
+ query.contiguous(),
76
+ key.contiguous(),
77
+ value.contiguous(),
78
+ attn_mask=None,
79
+ dropout_p=dropout_p,
80
+ is_causal=False,
81
+ scale=1.0 / inv_scale_factor,
82
+ )
83
+
84
+
85
+ def _sfdp_pattern_4(query, key, value, scale_factor, dropout_p):
86
+ return torch.nn.functional.dropout(
87
+ torch.matmul(query, key.transpose(-2, -1)).mul(scale_factor).softmax(dim=-1),
88
+ p=dropout_p,
89
+ ).matmul(value)
90
+
91
+
92
+ def _sfdp_replacement_4(query, key, value, scale_factor, dropout_p):
93
+ counters["inductor"]["fuse_attention"] += 1
94
+ return aten.scaled_dot_product_attention(
95
+ query.contiguous(),
96
+ key.contiguous(),
97
+ value.contiguous(),
98
+ attn_mask=None,
99
+ dropout_p=dropout_p,
100
+ is_causal=False,
101
+ scale=scale_factor,
102
+ )
103
+
104
+
105
+ def _sfdp_pattern_5(query, key, value, attn_mask):
106
+ attn_weight = torch.softmax(
107
+ (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1
108
+ )
109
+ # attn_weight = torch.dropout(attn_weight, dropout_p)
110
+ return attn_weight @ value
111
+
112
+
113
+ def _sfdp_replacement_5(query, key, value, attn_mask):
114
+ counters["inductor"]["fuse_attention"] += 1
115
+ return aten.scaled_dot_product_attention(
116
+ query.contiguous(),
117
+ key.contiguous(),
118
+ value.contiguous(),
119
+ attn_mask=attn_mask.to(dtype=query.dtype),
120
+ dropout_p=0.0,
121
+ is_causal=False,
122
+ )
123
+
124
+
125
+ def _sfdp_pattern_6(query, key, value, attn_mask, dropout_p):
126
+ attn_weight = torch.softmax(
127
+ (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1
128
+ )
129
+ attn_weight = torch.dropout(attn_weight, dropout_p, True)
130
+ return attn_weight @ value
131
+
132
+
133
+ def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p):
134
+ counters["inductor"]["fuse_attention"] += 1
135
+ return aten.scaled_dot_product_attention(
136
+ query.contiguous(),
137
+ key.contiguous(),
138
+ value.contiguous(),
139
+ attn_mask=attn_mask.to(dtype=query.dtype),
140
+ dropout_p=dropout_p,
141
+ is_causal=False,
142
+ )
143
+
144
+
145
+ def _sfdp_pattern_7(query, key, value, dropout_p):
146
+ # in real workloads inputs to matmul are permuted
147
+ # causing matmul to expand to a series of expand and clone calls
148
+ # we want the same to happen during pattern tracing
149
+ q = query.permute(0, 2, 1, 3)
150
+ k = key.permute(0, 2, 1, 3)
151
+ v = value.permute(0, 2, 1, 3)
152
+ div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
153
+ div = div.to(torch.float32)
154
+ attn_weight = torch.softmax(div, dim=-1)
155
+ attn_weight = torch.dropout(attn_weight, dropout_p, True)
156
+ attn_weight = attn_weight.to(torch.float16)
157
+ return attn_weight @ v
158
+
159
+
160
+ def _sfdp_replacement_7(query, key, value, dropout_p):
161
+ # sdpa prefers inputs in permuted format
162
+ # it makes a copy to put them in this format
163
+ # if they aren't already
164
+ # to make replacement efficient ensure that inputs to sdpa
165
+ # are in required order
166
+ counters["inductor"]["fuse_attention"] += 1
167
+ q = query.permute(0, 2, 1, 3)
168
+ k = key.permute(0, 2, 1, 3)
169
+ v = value.permute(0, 2, 1, 3)
170
+ return aten.scaled_dot_product_attention(
171
+ q,
172
+ k,
173
+ v,
174
+ attn_mask=None, # attn_mask,
175
+ dropout_p=dropout_p,
176
+ is_causal=False,
177
+ )
178
+
179
+
180
+ def _sfdp_pattern_8(query, key, value):
181
+ # no dropout version of pattern 7
182
+ q = query.permute(0, 2, 1, 3)
183
+ k = key.permute(0, 2, 1, 3)
184
+ v = value.permute(0, 2, 1, 3)
185
+ div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
186
+ div = div.to(torch.float32)
187
+ attn_weight = torch.softmax(div, dim=-1)
188
+ attn_weight = attn_weight.to(torch.float16)
189
+ return attn_weight @ v
190
+
191
+
192
+ def _sfdp_replacement_8(query, key, value):
193
+ counters["inductor"]["fuse_attention"] += 1
194
+ q = query.permute(0, 2, 1, 3)
195
+ k = key.permute(0, 2, 1, 3)
196
+ v = value.permute(0, 2, 1, 3)
197
+ return aten.scaled_dot_product_attention(
198
+ q,
199
+ k,
200
+ v,
201
+ attn_mask=None, # attn_mask,
202
+ dropout_p=0.0,
203
+ is_causal=False,
204
+ )
205
+
206
+
207
+ def _sfdp_pattern_9(query, key, value, dropout_p):
208
+ q = query.permute(0, 2, 1, 3)
209
+ k = key.permute(0, 2, 1, 3)
210
+ v = value.permute(0, 2, 1, 3)
211
+ q = q / math.sqrt(q.size(-1))
212
+ div = q @ k.transpose(-2, -1)
213
+ div = div.to(torch.float32)
214
+ attn_weight = torch.softmax(div, dim=-1)
215
+ attn_weight = torch.dropout(attn_weight, dropout_p, True)
216
+ attn_weight = attn_weight.to(torch.float16)
217
+ return attn_weight @ v
218
+
219
+
220
+ def _sfdp_replacement_9(query, key, value, dropout_p):
221
+ counters["inductor"]["fuse_attention"] += 1
222
+ q = query.permute(0, 2, 1, 3)
223
+ k = key.permute(0, 2, 1, 3)
224
+ v = value.permute(0, 2, 1, 3)
225
+ return aten.scaled_dot_product_attention(
226
+ q,
227
+ k,
228
+ v,
229
+ attn_mask=None, # attn_mask,
230
+ dropout_p=dropout_p,
231
+ is_causal=False,
232
+ )
233
+
234
+
235
+ def _sfdp_pattern_10(query, key, value):
236
+ # no dropout version of 9
237
+ q = query.permute(0, 2, 1, 3)
238
+ k = key.permute(0, 2, 1, 3)
239
+ v = value.permute(0, 2, 1, 3)
240
+ q = q / math.sqrt(q.size(-1))
241
+ div = q @ k.transpose(-2, -1)
242
+ div = div.to(torch.float32)
243
+ attn_weight = torch.softmax(div, dim=-1)
244
+ attn_weight = attn_weight.to(torch.float16)
245
+ return attn_weight @ v
246
+
247
+
248
+ def _sfdp_replacement_10(query, key, value):
249
+ counters["inductor"]["fuse_attention"] += 1
250
+ q = query.permute(0, 2, 1, 3)
251
+ k = key.permute(0, 2, 1, 3)
252
+ v = value.permute(0, 2, 1, 3)
253
+ return aten.scaled_dot_product_attention(
254
+ q,
255
+ k,
256
+ v,
257
+ attn_mask=None, # attn_mask,
258
+ dropout_p=0.0,
259
+ is_causal=False,
260
+ )
261
+
262
+
263
+ def _sfdp_pattern_11(query, key, value, inv_scale):
264
+ # Mainly for huggingface models
265
+ q = query.permute(0, 2, 1, 3)
266
+ k = key.permute(0, 2, 1, 3)
267
+ v = value.permute(0, 2, 1, 3)
268
+ return torch.matmul(q, k.transpose(-2, -1)).div(inv_scale).softmax(dim=-1).matmul(v)
269
+
270
+
271
+ def _sfdp_replacement_11(query, key, value, inv_scale):
272
+ counters["inductor"]["fuse_attention"] += 1
273
+ return aten.scaled_dot_product_attention(
274
+ query.transpose(1, 2),
275
+ key.transpose(1, 2),
276
+ value.transpose(1, 2),
277
+ attn_mask=None,
278
+ dropout_p=0.0,
279
+ is_causal=False,
280
+ scale=1.0 / inv_scale,
281
+ )
282
+
283
+
284
+ def _sfdp_pattern_12(query, key, value, inv_scale_factor, dropout_p):
285
+ q = query.permute(0, 2, 1, 3)
286
+ k = key.permute(0, 2, 1, 3)
287
+ v = value.permute(0, 2, 1, 3)
288
+ return torch.nn.functional.dropout(
289
+ torch.matmul(q, k.transpose(-2, -1)).div(inv_scale_factor).softmax(dim=-1),
290
+ p=dropout_p,
291
+ ).matmul(v)
292
+
293
+
294
+ def _sfdp_replacement_12(query, key, value, inv_scale_factor, dropout_p):
295
+ counters["inductor"]["fuse_attention"] += 1
296
+ return aten.scaled_dot_product_attention(
297
+ query.transpose(1, 2),
298
+ key.transpose(1, 2),
299
+ value.transpose(1, 2),
300
+ attn_mask=None,
301
+ dropout_p=dropout_p,
302
+ is_causal=False,
303
+ scale=1.0 / inv_scale_factor,
304
+ )
305
+
306
+
307
+ def _sfdp_pattern_13(query, key, value, dropout_p):
308
+ attn_weight = torch.bmm(query, key.transpose(1, 2)).softmax(dim=-1)
309
+ attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p)
310
+ return torch.bmm(attn_weight, value)
311
+
312
+
313
+ def _sfdp_replacement_13(query, key, value, dropout_p):
314
+ counters["inductor"]["fuse_attention"] += 1
315
+ return aten.scaled_dot_product_attention(
316
+ query.unsqueeze(0),
317
+ key.unsqueeze(0),
318
+ value.unsqueeze(0),
319
+ dropout_p=dropout_p,
320
+ scale=1.0,
321
+ ).squeeze(0)
322
+
323
+
324
+ def _sfdp_pattern_14(query, key, value, attn_mask, inv_scale):
325
+ # for BertLarge
326
+ # Permutations are needed to create clones in graph.
327
+ q = query.permute([0, 2, 1, 3])
328
+ k = key.permute([0, 2, 1, 3])
329
+ v = value.permute([0, 2, 1, 3])
330
+ return (
331
+ (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask)
332
+ .softmax(dim=-1)
333
+ .matmul(v)
334
+ )
335
+
336
+
337
+ def _sfdp_replacement_14(query, key, value, attn_mask, inv_scale):
338
+ counters["inductor"]["fuse_attention"] += 1
339
+ return aten.scaled_dot_product_attention(
340
+ query.transpose(1, 2),
341
+ key.transpose(1, 2),
342
+ value.transpose(1, 2),
343
+ attn_mask=attn_mask.to(dtype=query.dtype),
344
+ dropout_p=0.0,
345
+ is_causal=False,
346
+ scale=1.0 / inv_scale,
347
+ )
348
+
349
+
350
+ def _sfdp_pattern_15(query, key, value, attn_mask, inv_scale):
351
+ # for DistilBert
352
+ # Permutations are needed to create clones in graph.
353
+ q = query.permute([0, 2, 1, 3])
354
+ k = key.permute([0, 2, 1, 3])
355
+ v = value.permute([0, 2, 1, 3])
356
+ bs = q.size(0)
357
+ k_len = k.size(-2)
358
+ scores = q @ k.transpose(-2, -1)
359
+ scores = scores.div(inv_scale)
360
+ fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
361
+ attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores)
362
+ return torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1) @ v
363
+
364
+
365
+ def _sfdp_replacement_15(query, key, value, attn_mask, inv_scale):
366
+ counters["inductor"]["fuse_attention"] += 1
367
+ bs = query.size(0)
368
+ n_head = query.size(2)
369
+ q_len = query.size(1)
370
+ k_len = key.size(1)
371
+ # do attn_mask->logical_not() in aten.scaled_dot_product_attention
372
+ attn_mask = (
373
+ (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len))
374
+ )
375
+ return aten.scaled_dot_product_attention(
376
+ query.transpose(1, 2),
377
+ key.transpose(1, 2),
378
+ value.transpose(1, 2),
379
+ attn_mask=attn_mask.to(dtype=torch.bool),
380
+ dropout_p=0.0,
381
+ is_causal=False,
382
+ scale=1.0 / inv_scale,
383
+ )
384
+
385
+
386
+ def _sfdp_pattern_16(query, key, value, attn_mask, inv_scale, dropout_p):
387
+ # for BertLarge with dropout
388
+ q = query.permute([0, 2, 1, 3])
389
+ k = key.permute([0, 2, 1, 3])
390
+ v = value.permute([0, 2, 1, 3])
391
+ return (
392
+ torch.nn.functional.dropout(
393
+ (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask).softmax(
394
+ dim=-1
395
+ ),
396
+ dropout_p,
397
+ )
398
+ .to(dtype=query.dtype)
399
+ .matmul(v)
400
+ )
401
+
402
+
403
+ def _sfdp_replacement_16(query, key, value, attn_mask, inv_scale, dropout_p):
404
+ counters["inductor"]["fuse_attention"] += 1
405
+ return aten.scaled_dot_product_attention(
406
+ query.transpose(1, 2),
407
+ key.transpose(1, 2),
408
+ value.transpose(1, 2),
409
+ attn_mask=attn_mask.to(dtype=query.dtype),
410
+ dropout_p=dropout_p,
411
+ is_causal=False,
412
+ scale=1.0 / inv_scale,
413
+ )
414
+
415
+
416
+ def _sfdp_pattern_17(query, key, value, attn_mask, inv_scale, dropout_p):
417
+ # for DistilBert with dropout
418
+ q = query.permute([0, 2, 1, 3])
419
+ k = key.permute([0, 2, 1, 3])
420
+ v = value.permute([0, 2, 1, 3])
421
+ bs = q.size(0)
422
+ k_len = k.size(-2)
423
+ scores = q @ k.transpose(-2, -1)
424
+ scores = scores.div(inv_scale)
425
+ fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
426
+ attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores)
427
+ return (
428
+ torch.nn.functional.dropout(
429
+ torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1), dropout_p
430
+ )
431
+ @ v
432
+ )
433
+
434
+
435
+ def _sfdp_replacement_17(query, key, value, attn_mask, inv_scale, dropout_p):
436
+ counters["inductor"]["fuse_attention"] += 1
437
+ bs = query.size(0)
438
+ n_head = query.size(2)
439
+ q_len = query.size(1)
440
+ k_len = key.size(1)
441
+ # do attn_mask->logical_not() in aten.scaled_dot_product_attention
442
+ attn_mask = (
443
+ (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len))
444
+ )
445
+ return aten.scaled_dot_product_attention(
446
+ query.transpose(1, 2),
447
+ key.transpose(1, 2),
448
+ value.transpose(1, 2),
449
+ attn_mask=attn_mask.to(dtype=torch.bool),
450
+ dropout_p=dropout_p,
451
+ is_causal=False,
452
+ scale=1.0 / inv_scale,
453
+ )
454
+
455
+
456
+ def _sfdp_params_check(match):
457
+ assert all(k in match.kwargs for k in ("query", "key", "value"))
458
+ query = match.kwargs["query"].meta["val"]
459
+ key = match.kwargs["key"].meta["val"]
460
+ value = match.kwargs["value"].meta["val"]
461
+ if not (query.dtype == key.dtype == value.dtype) or not (
462
+ query.device == key.device == value.device
463
+ ):
464
+ return False
465
+ add_mask_node = filter_nodes(match.nodes, aten.add.Tensor)
466
+ # Has attn_mask add.
467
+ if len(add_mask_node) > 0:
468
+ attn_mask_node = add_mask_node[0].args[1]
469
+ # attn_mask_node may be a float/int number.
470
+ if not hasattr(attn_mask_node, "meta"):
471
+ return False
472
+ attn_mask = attn_mask_node.meta["val"] # type: ignore[union-attr]
473
+ # Make sure attn_mask.dtype == query.dtype or attn_mask.dtype == torch.bool
474
+ # attn_mask.dtype == torch.float for models like albert.
475
+ if (
476
+ not isinstance(attn_mask, torch.Tensor)
477
+ or not (
478
+ attn_mask.dtype == query.dtype
479
+ or attn_mask.dtype == torch.bool
480
+ or attn_mask.dtype == torch.float
481
+ )
482
+ or query.device != attn_mask.device
483
+ ):
484
+ return False
485
+ return True
486
+
487
+
488
+ def _sfdp_extra_check(scale_factor_op, disable_cuda=False):
489
+ def fn(match):
490
+ scale_factor_node = filter_nodes(match.nodes, scale_factor_op)[0]
491
+ # Note: args[1] of the scale_factor_node is always the scale_factor for the current patterns.
492
+ scale_factor = scale_factor_node.args[1]
493
+ # make sure the scale_factor a float/int. SymInt?
494
+ if not isinstance(scale_factor, (float, int)):
495
+ return False
496
+ if (
497
+ disable_cuda
498
+ and "query" in match.kwargs
499
+ and "cuda" in str(match.kwargs["query"].meta["val"].device)
500
+ ):
501
+ return False
502
+ return _sfdp_params_check(match)
503
+
504
+ return fn
505
+
506
+
507
+ def partialize_and_update_signature(func, **kwargs):
508
+ """
509
+ Equivalent to functools.partial but also updates the signature on returned function
510
+ """
511
+ original_sig = inspect.signature(func)
512
+ parameters = original_sig.parameters
513
+
514
+ new_parameters = {
515
+ key: value for key, value in parameters.items() if key not in kwargs
516
+ }
517
+ new_sig = inspect.Signature(parameters=list(new_parameters.values()))
518
+
519
+ partial_func = functools.partial(func, **kwargs)
520
+
521
+ def wrapper(*args, **kwargs):
522
+ return partial_func(*args, **kwargs)
523
+
524
+ wrapper.__signature__ = new_sig # type: ignore[attr-defined]
525
+ wrapper.__name__ = func.__name__
526
+
527
+ return wrapper
528
+
529
+
530
+ def _get_sfdp_patterns():
531
+ from .joint_graph import patterns
532
+
533
+ if torch.cuda.is_available():
534
+ # workaround https://github.com/pytorch/pytorch/issues/97894
535
+ device = "cuda"
536
+ else:
537
+ device = "cpu"
538
+
539
+ # sizes/values don't actually matter for initial trace
540
+ # once we get a possible match we re-trace with the actual values and verify the match still holds
541
+ g_inp = functools.partial(
542
+ torch.empty, (2, 4, 8, 16), device=device, requires_grad=True
543
+ )
544
+ # attn_mask
545
+ b_inp = functools.partial(torch.empty, (1, 1, 8, 8), device=device)
546
+ m_inp = functools.partial(torch.empty, (2, 1, 1, 4), device=device)
547
+ # inv_scale
548
+ c_inp = functools.partial(torch.tensor, 2.0, device=device)
549
+ # workaround https://github.com/pytorch/pytorch/issues/97894
550
+ # 0.113377 is a "magic" value that lets us recover the lost input arg relationship
551
+ d = {"dropout_p": 0.113377}
552
+
553
+ # we could also generate all these patterns in 3d.. TODO
554
+ g_3d_inp = functools.partial(
555
+ torch.empty, (1024, 128, 128), device=device, requires_grad=True
556
+ )
557
+
558
+ # reshape in matmul decomposition generates a clone when batch_size>1 due to the memory layout change.
559
+ # however when batch_size=1, reshape does not change the memory layout, so clone would not be generated.
560
+ # here we need to trace with input of batch_size=1 to generate a pattern graph without clone.
561
+ g_bs1_inp = functools.partial(
562
+ torch.empty, (1, 4, 8, 16), device=device, requires_grad=True
563
+ )
564
+ m_bs1_inp = functools.partial(torch.empty, (1, 1, 1, 4), device=device)
565
+
566
+ # softmax will generate a dtype conversion on inputs if they are in half,
567
+ # but will not in float, so we generate a pattern for both
568
+ for dtype in [torch.float, torch.half]:
569
+ g = functools.partial(g_inp, dtype=dtype)
570
+ b = functools.partial(b_inp, dtype=dtype)
571
+ m = functools.partial(m_inp, dtype=dtype)
572
+ m_float = functools.partial(m_inp, dtype=torch.float)
573
+ c = functools.partial(c_inp, dtype=dtype)
574
+ g_3d = functools.partial(g_3d_inp, dtype=dtype)
575
+ g_bs1 = functools.partial(g_bs1_inp, dtype=dtype)
576
+ m_bs1 = functools.partial(m_bs1_inp, dtype=dtype)
577
+ m_bs1_float = functools.partial(m_bs1_inp, dtype=torch.float)
578
+
579
+ candidates = [
580
+ (
581
+ _sfdp_pattern_1,
582
+ _sfdp_replacement_1,
583
+ [g(), g(), g(), c()],
584
+ {},
585
+ _sfdp_extra_check(aten.div.Tensor),
586
+ ),
587
+ (
588
+ _sfdp_pattern_2,
589
+ _sfdp_replacement_2,
590
+ [g(), g(), g(), c()],
591
+ {},
592
+ _sfdp_extra_check(aten.mul.Tensor),
593
+ ),
594
+ (
595
+ _sfdp_pattern_3,
596
+ _sfdp_replacement_3,
597
+ [g(), g(), g(), c()],
598
+ d,
599
+ _sfdp_extra_check(aten.div.Tensor),
600
+ ),
601
+ (
602
+ _sfdp_pattern_4,
603
+ _sfdp_replacement_4,
604
+ [g(), g(), g(), c()],
605
+ d,
606
+ _sfdp_extra_check(aten.mul.Tensor),
607
+ ),
608
+ (
609
+ _sfdp_pattern_5,
610
+ _sfdp_replacement_5,
611
+ [g(), g(), g(), b()],
612
+ {},
613
+ _sfdp_params_check,
614
+ ),
615
+ (
616
+ _sfdp_pattern_6,
617
+ _sfdp_replacement_6,
618
+ [g(), g(), g(), b()],
619
+ d,
620
+ _sfdp_params_check,
621
+ ),
622
+ (
623
+ _sfdp_pattern_7,
624
+ _sfdp_replacement_7,
625
+ [g(), g(), g()],
626
+ d,
627
+ _sfdp_params_check,
628
+ ),
629
+ (
630
+ _sfdp_pattern_8,
631
+ _sfdp_replacement_8,
632
+ [g(), g(), g()],
633
+ {},
634
+ _sfdp_params_check,
635
+ ),
636
+ (
637
+ _sfdp_pattern_9,
638
+ _sfdp_replacement_9,
639
+ [g(), g(), g()],
640
+ d,
641
+ _sfdp_params_check,
642
+ ),
643
+ (
644
+ _sfdp_pattern_10,
645
+ _sfdp_replacement_10,
646
+ [g(), g(), g()],
647
+ {},
648
+ _sfdp_params_check,
649
+ ),
650
+ (
651
+ _sfdp_pattern_11,
652
+ _sfdp_replacement_11,
653
+ [g(), g(), g(), c()],
654
+ {},
655
+ _sfdp_extra_check(aten.div.Tensor),
656
+ ),
657
+ (
658
+ _sfdp_pattern_12,
659
+ _sfdp_replacement_12,
660
+ [g(), g(), g(), c()],
661
+ d,
662
+ _sfdp_extra_check(aten.div.Tensor),
663
+ ),
664
+ (
665
+ _sfdp_pattern_13,
666
+ _sfdp_replacement_13,
667
+ [g_3d(), g_3d(), g_3d()],
668
+ d,
669
+ _sfdp_params_check,
670
+ ),
671
+ (
672
+ _sfdp_pattern_14,
673
+ _sfdp_replacement_14,
674
+ [g(), g(), g(), m(), c()],
675
+ {},
676
+ _sfdp_extra_check(aten.div.Tensor),
677
+ ),
678
+ (
679
+ _sfdp_pattern_15,
680
+ _sfdp_replacement_15,
681
+ [g(), g(), g(), m(), c()],
682
+ {},
683
+ _sfdp_extra_check(aten.div.Tensor),
684
+ ),
685
+ # TODO: Enable CUDA after solving Bert accuracy issue of calling efficient attention
686
+ (
687
+ _sfdp_pattern_16,
688
+ _sfdp_replacement_16,
689
+ [g(), g(), g(), m(), c()],
690
+ d,
691
+ _sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
692
+ ),
693
+ (
694
+ _sfdp_pattern_16,
695
+ _sfdp_replacement_16,
696
+ [g_bs1(), g_bs1(), g_bs1(), m_bs1(), c()],
697
+ d,
698
+ _sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
699
+ ),
700
+ (
701
+ _sfdp_pattern_17,
702
+ _sfdp_replacement_17,
703
+ [g(), g(), g(), m(), c()],
704
+ d,
705
+ _sfdp_extra_check(aten.div.Tensor),
706
+ ),
707
+ ]
708
+ mask_fp32_patterns = ["pattern_16"]
709
+ if dtype == torch.half:
710
+ # Add inputs of bf16 q/k/v and fp32 mask, for models like albert.
711
+ candidates.append(
712
+ (
713
+ _sfdp_pattern_16,
714
+ _sfdp_replacement_16,
715
+ [g(), g(), g(), m_float(), c()],
716
+ d,
717
+ _sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
718
+ )
719
+ )
720
+ candidates.append(
721
+ (
722
+ _sfdp_pattern_16,
723
+ _sfdp_replacement_16,
724
+ [g_bs1(), g_bs1(), g_bs1(), m_bs1_float(), c()],
725
+ d,
726
+ _sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
727
+ )
728
+ )
729
+
730
+ for pattern, replacement, args, workaround, extra_check in candidates:
731
+ # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern
732
+ # gets serialized to a python file and does not require tracing at runtime.
733
+ assert isinstance(workaround, dict)
734
+ name = pattern.__name__
735
+
736
+ if dtype != torch.float:
737
+ name += "_half"
738
+ if (
739
+ any(p in name for p in mask_fp32_patterns)
740
+ and args[3].dtype == torch.float32
741
+ ):
742
+ name += "_mask_fp32"
743
+ if args[0].size(0) == 1:
744
+ name += "_bs1"
745
+
746
+ training_name = name + "_training"
747
+ yield training_name, {
748
+ "search_fn": pattern,
749
+ "replace_fn": replacement,
750
+ "example_inputs": args,
751
+ "trace_fn": joint_fwd_bwd,
752
+ "pass_dicts": patterns,
753
+ "extra_check": extra_check,
754
+ "scalar_workaround": workaround,
755
+ }
756
+
757
+ if workaround:
758
+ assert len(workaround) == 1 and "dropout_p" in workaround
759
+ # functools.partial insufficient because we look at signature downstream
760
+ pattern = partialize_and_update_signature(pattern, dropout_p=0.0)
761
+ replacement = partialize_and_update_signature(
762
+ replacement, dropout_p=0.0
763
+ )
764
+ workaround = {}
765
+
766
+ inference_name = name + "_inference"
767
+ yield inference_name, {
768
+ "search_fn": pattern,
769
+ "replace_fn": replacement,
770
+ "example_inputs": args,
771
+ "trace_fn": fwd_only,
772
+ "pass_dicts": patterns,
773
+ "extra_check": extra_check,
774
+ "scalar_workaround": workaround,
775
+ }
776
+
777
+
778
+ @functools.lru_cache(None)
779
+ def _sfdp_init():
780
+ from .serialized_patterns.central_index import get_serialized_pattern
781
+
782
+ for key, register_replacement_kwargs in _get_sfdp_patterns():
783
+ search_fn_pattern = get_serialized_pattern(key)
784
+ register_replacement(
785
+ **register_replacement_kwargs, search_fn_pattern=search_fn_pattern
786
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py ADDED
@@ -0,0 +1,1059 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+ import operator
4
+ from collections import OrderedDict
5
+ from typing import (
6
+ Any,
7
+ DefaultDict,
8
+ Deque,
9
+ Dict,
10
+ Iterable,
11
+ Iterator,
12
+ List,
13
+ Optional,
14
+ Set,
15
+ Tuple,
16
+ )
17
+
18
+ import torch
19
+ from torch._dynamo.utils import counters
20
+
21
+ from .. import config
22
+ from ..pattern_matcher import (
23
+ CallFunctionVarArgs,
24
+ get_arg_value,
25
+ stable_topological_sort,
26
+ )
27
+
28
+ try:
29
+ # importing this will register fbgemm lowerings for inductor
30
+ import deeplearning.fbgemm.fbgemm_gpu.fb.inductor_lowerings # noqa: F401
31
+
32
+ has_fbgemm = True
33
+ except Exception:
34
+ has_fbgemm = False
35
+ pass
36
+
37
+ aten = torch.ops.aten
38
+
39
+ log = logging.getLogger(__name__)
40
+
41
+ MIN_FUSE_SET_SIZE = 5
42
+ MAX_FUSE_SET_SIZE = 300
43
+ MAX_FUSE_SEARCH_DEPTH = 5
44
+ # The maximum tensor size that can go into the fusion group
45
+ MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR = 4096
46
+
47
+ # exclude these nodes from BFS
48
+ # excluding get item improves optimizer compilation time by 60s
49
+ SEARCH_EXCLUSIONS = {operator.getitem}
50
+
51
+
52
+ default_graph_search_options = {
53
+ "min_fuse_set_size": MIN_FUSE_SET_SIZE,
54
+ "max_fuse_set_size": MAX_FUSE_SET_SIZE,
55
+ "max_fuse_search_depth": MAX_FUSE_SEARCH_DEPTH,
56
+ "max_fuse_tensor_size_group_linear": MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR,
57
+ }
58
+
59
+ graph_search_options = default_graph_search_options
60
+
61
+
62
+ def update_stack_example_value(node, metadata, dim=0, op=torch.stack):
63
+ """
64
+ Update the example value of the node in the graph to enable followup split cat opt.
65
+ """
66
+ if node is not None and hasattr(node, "meta"):
67
+ if op == torch.stack:
68
+ example_value = torch.stack(metadata, dim=dim)
69
+ elif op == torch.unbind:
70
+ example_value = torch.unbind(metadata, dim=dim) # type: ignore[assignment]
71
+ else:
72
+ return
73
+ node.meta["example_value"] = example_value
74
+
75
+
76
+ def update_pointwise_example_value(pointwise_node, input, other, op):
77
+ """
78
+ Update the example value of the add node in the graph to enable followup split cat opt.
79
+ """
80
+ if pointwise_node is not None and hasattr(pointwise_node, "meta"):
81
+ if op == torch.add:
82
+ example_value = torch.add(input, other)
83
+ elif op == torch.mul:
84
+ example_value = torch.mul(input, other)
85
+ else:
86
+ return
87
+ pointwise_node.meta["example_value"] = example_value
88
+
89
+
90
+ class GroupBatchFusionBase:
91
+ def __init__(self, **kwargs):
92
+ self.graph_search_options = kwargs.pop(
93
+ "graph_search_options", default_graph_search_options
94
+ )
95
+
96
+ def match(self, node):
97
+ raise NotImplementedError("match called on base")
98
+
99
+ def fuse(self, graph, subset):
100
+ raise NotImplementedError("fuse called on base")
101
+
102
+
103
+ PRE_GRAD_FUSIONS: Dict[str, GroupBatchFusionBase] = dict()
104
+ POST_GRAD_FUSIONS: Dict[str, GroupBatchFusionBase] = dict()
105
+
106
+
107
+ def register_fusion(name: str, pre_grad=True):
108
+ def decorator(fusion_cls: GroupBatchFusionBase):
109
+ if pre_grad:
110
+ PRE_GRAD_FUSIONS[name] = fusion_cls
111
+ else:
112
+ POST_GRAD_FUSIONS[name] = fusion_cls
113
+ return fusion_cls
114
+
115
+ return decorator
116
+
117
+
118
+ def list_group_batch_fusions(pre_grad=True) -> List[str]:
119
+ if pre_grad:
120
+ return list(PRE_GRAD_FUSIONS.keys())
121
+ else:
122
+ return list(POST_GRAD_FUSIONS.keys())
123
+
124
+
125
+ def decompose_stack(graph: torch.fx.GraphModule, input_tensors: List[Any]) -> Any:
126
+ unsqueezed_inputs = []
127
+ for input_tensor in input_tensors:
128
+ unsqueezed_input = graph.call_function(
129
+ aten.unsqueeze, args=(input_tensor,), kwargs={"dim": 0}
130
+ )
131
+ unsqueezed_inputs.append(unsqueezed_input)
132
+ stacked_inputs = graph.call_function(
133
+ aten.cat, args=(unsqueezed_inputs,), kwargs={"dim": 0}
134
+ )
135
+ return stacked_inputs
136
+
137
+
138
+ class GroupFusion(GroupBatchFusionBase):
139
+ """
140
+ Fuse ops in a group way, e.g, fuse mm/addmm of arbitrary input shapes with fbgemm.gmm.
141
+ """
142
+
143
+ pass
144
+
145
+
146
+ class BatchFusion(GroupBatchFusionBase):
147
+ """
148
+ Fuse ops in a batch way, e.g, fuse mm/addmm of same input shapes with bmm.
149
+ """
150
+
151
+ pass
152
+
153
+
154
+ class BatchPointwiseOpsFusionFactory(BatchFusion):
155
+ def __init__(self, op, **kwargs):
156
+ super().__init__(**kwargs)
157
+ self.op = op
158
+
159
+
160
+ @register_fusion("batch_linear_post_grad", pre_grad=False)
161
+ class PostGradBatchLinearFusion(BatchFusion):
162
+ """
163
+ Fuse ops in a batch way in post grad (aten level).
164
+ """
165
+
166
+ def _addmm_node_can_be_fused(self, node: torch.fx.Node) -> bool:
167
+ return (
168
+ node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0 # type: ignore[return-value]
169
+ )
170
+
171
+ def _is_input_2d(self, input: torch.fx.Node) -> bool:
172
+ input_shapes = input.meta["tensor_meta"].shape
173
+ return (
174
+ len(input_shapes) == 2
175
+ and isinstance(input_shapes[0], int)
176
+ and isinstance(input_shapes[1], int)
177
+ )
178
+
179
+ def match(self, node: torch.fx.Node) -> Optional[Tuple[str, int, int, int, bool]]:
180
+ if CallFunctionVarArgs(aten.mm).match(node):
181
+ input_m, weight_m = node.args
182
+ bias_m = None
183
+
184
+ elif CallFunctionVarArgs(aten.addmm.default).match(
185
+ node
186
+ ) and self._addmm_node_can_be_fused(node):
187
+ bias_m, input_m, weight_m = node.args
188
+ else:
189
+ return None
190
+
191
+ # only handle the cases where inputs are 2D tensors
192
+ if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m): # type: ignore[arg-type]
193
+ return None
194
+ m, k = input_m.meta["tensor_meta"].shape # type: ignore[union-attr]
195
+ n = weight_m.meta["tensor_meta"].shape[1] # type: ignore[union-attr]
196
+ batch_key = ("batch_linear", m, k, n, bias_m is not None)
197
+ return batch_key
198
+
199
+ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
200
+ batch_inputs = []
201
+ batch_weights = []
202
+ batch_biases = []
203
+ batch_nodes = []
204
+
205
+ for node in subset:
206
+ if CallFunctionVarArgs(aten.addmm.default).match(node):
207
+ bias, input, weight = node.args
208
+ elif CallFunctionVarArgs(aten.mm.default).match(node):
209
+ input, weight = node.args
210
+ bias = None
211
+ batch_nodes.append(node)
212
+ batch_inputs.append(input) # type: ignore[possibly-undefined]
213
+ batch_weights.append(weight) # type: ignore[possibly-undefined]
214
+ batch_biases.append(bias) # type: ignore[possibly-undefined]
215
+
216
+ with graph.inserting_before(subset[-1]):
217
+ fused_inputs = decompose_stack(graph, batch_inputs)
218
+ fused_weights = decompose_stack(graph, batch_weights)
219
+ fused_bmm = graph.call_function(
220
+ aten.bmm,
221
+ args=(fused_inputs, fused_weights),
222
+ )
223
+
224
+ for i, original_mm in enumerate(batch_nodes):
225
+ has_bias = False
226
+ with graph.inserting_after(fused_bmm):
227
+ new_mm = graph.call_function(aten.select, args=((fused_bmm, 0, i)))
228
+ if batch_biases[i]:
229
+ has_bias = True
230
+ new_bias_add = graph.call_function(
231
+ aten.add, args=((batch_biases[i], new_mm))
232
+ )
233
+ new_mm_cont = new_bias_add if has_bias else new_mm # type: ignore[possibly-undefined]
234
+ original_mm.replace_all_uses_with(new_mm_cont)
235
+ new_mm_cont.meta.update(original_mm.meta)
236
+ graph.erase_node(original_mm)
237
+
238
+
239
+ @register_fusion("group_linear", pre_grad=False)
240
+ class GroupLinearFusion(GroupFusion):
241
+ def _addmm_node_can_be_fused(self, node: torch.fx.Node):
242
+ input_shape = node.args[1].meta["tensor_meta"].shape # type: ignore[union-attr]
243
+ weight_shape = node.args[2].meta["tensor_meta"].shape # type: ignore[union-attr]
244
+ return (
245
+ node.kwargs.get("beta", 1.0) == 1.0
246
+ and node.kwargs.get("alpha", 1.0) == 1.0
247
+ and len(input_shape) == 2
248
+ and len(weight_shape) == 2
249
+ and all(x % 2 == 0 for x in input_shape + weight_shape)
250
+ and all(
251
+ shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"]
252
+ for shape in input_shape + weight_shape
253
+ )
254
+ )
255
+
256
+ def _mm_node_can_be_fused(self, node: torch.fx.Node):
257
+ input_shape = node.args[0].meta["tensor_meta"].shape # type: ignore[union-attr]
258
+ weight_shape = node.args[1].meta["tensor_meta"].shape # type: ignore[union-attr]
259
+ return (
260
+ len(input_shape) == 2
261
+ and len(weight_shape) == 2
262
+ and all(x % 2 == 0 for x in input_shape + weight_shape)
263
+ and all(
264
+ shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"]
265
+ for shape in input_shape + weight_shape
266
+ )
267
+ )
268
+
269
+ def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool]]:
270
+ if CallFunctionVarArgs(aten.mm.default).match(
271
+ node
272
+ ) and self._mm_node_can_be_fused(node):
273
+ group_key = ("group_linear", True)
274
+ elif CallFunctionVarArgs(aten.addmm.default).match(
275
+ node
276
+ ) and self._addmm_node_can_be_fused(node):
277
+ bias = node.args[0]
278
+ group_key = ("group_linear", bias is None)
279
+ else:
280
+ group_key = None
281
+ return group_key
282
+
283
+ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
284
+ group_inputs = []
285
+ group_weights = []
286
+ group_biases = []
287
+ group_nodes = []
288
+ for node in subset:
289
+ if CallFunctionVarArgs(aten.addmm.default).match(node):
290
+ bias, input, weight = node.args
291
+ else:
292
+ assert CallFunctionVarArgs(aten.mm.default).match(node)
293
+ input, weight = node.args
294
+ bias = None
295
+
296
+ group_nodes.append(node)
297
+ group_inputs.append(input)
298
+ group_weights.append(weight)
299
+ group_biases.append(bias)
300
+
301
+ if all(bias is None for bias in group_biases):
302
+ group_biases = None # type: ignore[assignment]
303
+ group_biases: Optional[List[Any]]
304
+
305
+ with graph.inserting_before(subset[0]):
306
+ fused_mm = graph.call_function(
307
+ torch.ops.fbgemm.gmm.default,
308
+ args=(group_inputs, group_weights, group_biases),
309
+ kwargs={"smart_fused": True},
310
+ )
311
+
312
+ for i, original_mm in enumerate(group_nodes):
313
+ with graph.inserting_after(fused_mm):
314
+ new_mm = graph.call_function(operator.getitem, args=(fused_mm, i))
315
+ original_mm.replace_all_uses_with(new_mm)
316
+ new_mm.meta.update(original_mm.meta)
317
+ graph.erase_node(original_mm)
318
+
319
+
320
+ class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
321
+ """
322
+ Batch pointwise operator (e.g., add, mul) in post grad pass.
323
+ """
324
+
325
+ def __init__(self, op, **kwargs):
326
+ super().__init__(op, **kwargs)
327
+ self.op = op
328
+
329
+ def _pointwise_node_can_be_fused(self, node: torch.fx.Node):
330
+ # note: we only consider the case where the inputs are tensors
331
+ # for mixed precision training, we need to make sure the inputs
332
+ # of the aten.cat when do the stack should be the same dtype
333
+ # otherwise, the output of the aten.cat may be not the same as
334
+ # its inputs, and cause dtype not same error in mm or addmm
335
+ input, other = node.args
336
+ return (
337
+ input.meta["tensor_meta"].shape == other.meta["tensor_meta"].shape # type: ignore[union-attr]
338
+ if hasattr(input, "meta")
339
+ and hasattr(other, "meta")
340
+ and "tensor_meta" in input.meta # type: ignore[union-attr]
341
+ and "tensor_meta" in other.meta # type: ignore[union-attr]
342
+ else False
343
+ )
344
+
345
+ def match(self, node: torch.fx.Node):
346
+ if CallFunctionVarArgs(self.op).match(
347
+ node
348
+ ) and self._pointwise_node_can_be_fused(node):
349
+ alpha = node.kwargs.get("alpha", 1.0)
350
+ rounding_mode = node.kwargs.get("rounding_mode", None)
351
+ input, other = node.args
352
+ shape = list(input.meta["tensor_meta"].shape) # type: ignore[union-attr]
353
+ group_key = (
354
+ "batch_" + self.op.__name__.lower() + "_post_grad",
355
+ str(shape),
356
+ str(input.meta["tensor_meta"].dtype), # type: ignore[union-attr]
357
+ str(other.meta["tensor_meta"].dtype), # type: ignore[union-attr]
358
+ str(alpha),
359
+ str(rounding_mode),
360
+ )
361
+ else:
362
+ group_key = None
363
+ return group_key
364
+
365
+ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
366
+ batch_inputs, batch_others = [], []
367
+ alpha = subset[0].kwargs.get("alpha", 1.0)
368
+
369
+ for node in subset:
370
+ input, other = node.args
371
+ batch_inputs.append(input)
372
+ batch_others.append(other)
373
+
374
+ with graph.inserting_before(subset[0]):
375
+ stack_inputs = decompose_stack(graph, batch_inputs)
376
+ stack_others = decompose_stack(graph, batch_others)
377
+
378
+ batch_op = graph.call_function(
379
+ self.op,
380
+ args=(stack_inputs, stack_others),
381
+ kwargs={"alpha": alpha} if self.op == aten.add.Tensor else {},
382
+ )
383
+ for i, original_add in enumerate(subset):
384
+ with graph.inserting_after(batch_op):
385
+ new_add = graph.call_function(
386
+ torch.ops.aten.select, args=((batch_op, 0, i))
387
+ )
388
+ original_add.replace_all_uses_with(new_add)
389
+ new_add.meta.update(original_add.meta)
390
+ graph.erase_node(original_add)
391
+
392
+
393
+ @register_fusion("batch_linear_lhs")
394
+ class BatchLinearLHSFusion(BatchFusion):
395
+ """
396
+ Batch linear left-hand side fusion. This pass tries to fuse the following patterns:
397
+
398
+ torch.nn.functional.linear(x, w1), linear(x, w2),... * linear(x, wn)
399
+ -> torch.mm(x, torch.cat([w1, w2,... * wn]).transpose(0, 1))
400
+
401
+ We have a separate pass to eliminate contiguous transpose in a generic way.
402
+ """
403
+
404
+ def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool, Any]]:
405
+ if CallFunctionVarArgs(torch.nn.functional.linear).match(
406
+ node
407
+ ) and is_linear_node_can_be_fused(node):
408
+ input = get_arg_value(node, 0, "input")
409
+ bias = get_arg_value(node, 2, "bias")
410
+ group_key = ("batch_linear_lhs", bias is None, input)
411
+ else:
412
+ group_key = None
413
+ return group_key
414
+
415
+ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
416
+ batch_nodes = []
417
+ batch_input = None
418
+ batch_weights = []
419
+ batch_biases = []
420
+ split_sections = []
421
+ for node in subset:
422
+ input = get_arg_value(node, 0, "input")
423
+ weight = get_arg_value(node, 1, "weight")
424
+ bias = get_arg_value(node, 2, "bias")
425
+ batch_nodes.append(node)
426
+ if batch_input is None:
427
+ batch_input = input
428
+ else:
429
+ assert batch_input is input
430
+ batch_weights.append(weight)
431
+ if bias:
432
+ batch_biases.append(bias)
433
+ split_sections.append(weight.meta["example_value"].shape[0])
434
+
435
+ with graph.inserting_before(subset[0]):
436
+ cat_weights = graph.call_function(
437
+ torch.cat, args=(batch_weights,), kwargs={"dim": 0}
438
+ )
439
+ transposed_weights = graph.call_function(
440
+ torch.transpose, args=(cat_weights, 0, 1)
441
+ )
442
+ if len(batch_biases) > 0:
443
+ cat_biases = graph.call_function(
444
+ torch.cat, args=(batch_biases,), kwargs={"dim": 0}
445
+ )
446
+ fused_lhs = graph.call_function(
447
+ torch.addmm,
448
+ args=(cat_biases, batch_input, transposed_weights),
449
+ )
450
+ else:
451
+ fused_lhs = graph.call_function(
452
+ torch.mm,
453
+ args=(batch_input, transposed_weights),
454
+ )
455
+ fused_lhs_list = graph.call_function(
456
+ torch.split, args=(fused_lhs, split_sections), kwargs={"dim": 1}
457
+ )
458
+
459
+ for i, node in enumerate(batch_nodes):
460
+ with graph.inserting_after(fused_lhs_list):
461
+ new_node = graph.call_function(
462
+ operator.getitem, args=(fused_lhs_list, i)
463
+ )
464
+ node.replace_all_uses_with(new_node)
465
+ new_node.meta.update(node.meta)
466
+ graph.erase_node(node)
467
+
468
+
469
+ def is_node_meta_valid(node: Optional[torch.fx.Node]):
470
+ if node is None:
471
+ return True
472
+ if "example_value" not in node.meta:
473
+ return False
474
+ return True
475
+
476
+
477
+ def is_linear_node_can_be_fused(node: torch.fx.Node):
478
+ input = get_arg_value(node, 0, "input")
479
+ weight = get_arg_value(node, 1, "weight")
480
+ return (
481
+ is_node_meta_valid(node)
482
+ and is_node_meta_valid(input)
483
+ and is_node_meta_valid(weight)
484
+ and len(input.meta["example_value"].shape) == 2
485
+ and len(weight.meta["example_value"].shape) == 2
486
+ )
487
+
488
+
489
+ @register_fusion("batch_linear")
490
+ class PreGradBatchLinearFusion(BatchFusion):
491
+ """
492
+ Batch linear fusion in pre grad pass.
493
+ Fuse linear with same size with torch.baddmm
494
+ """
495
+
496
+ def _getitem_args(self, getitem_node: torch.fx.Node):
497
+ if getitem_node.target != operator.__getitem__ or (
498
+ getitem_node.op != "call_function"
499
+ ):
500
+ return None
501
+ return getitem_node.args[0]
502
+
503
+ def match(self, node: torch.fx.Node):
504
+ if CallFunctionVarArgs(torch.nn.functional.linear).match(
505
+ node
506
+ ) and is_linear_node_can_be_fused(node):
507
+ input = get_arg_value(node, 0, "input")
508
+ weight = get_arg_value(node, 1, "weight")
509
+ bias = get_arg_value(node, 2, "bias")
510
+ group_key = (
511
+ "batch_linear_pre_grad",
512
+ self._getitem_args(input),
513
+ str(input.meta["example_value"].shape),
514
+ str(weight.meta["example_value"].shape),
515
+ bias is None,
516
+ )
517
+ else:
518
+ group_key = None
519
+ return group_key
520
+
521
+ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
522
+ batch_nodes = []
523
+ batch_inputs = []
524
+ batch_weights = []
525
+ batch_biases = []
526
+ batch_inputs_metadata = []
527
+ batch_weights_metadata = []
528
+ batch_biases_metadata = []
529
+ for node in subset:
530
+ batch_nodes.append(node)
531
+ input = get_arg_value(node, 0, "input")
532
+ batch_inputs.append(input)
533
+ batch_inputs_metadata.append(input.meta["example_value"])
534
+ weight = get_arg_value(node, 1, "weight")
535
+ batch_weights.append(weight)
536
+ batch_weights_metadata.append(weight.meta["example_value"])
537
+ bias = get_arg_value(node, 2, "bias")
538
+ batch_biases.append(bias)
539
+ if bias is not None and hasattr(bias, "meta"):
540
+ batch_biases_metadata.append(bias.meta["example_value"])
541
+
542
+ with graph.inserting_before(subset[0]):
543
+ stack_inputs = graph.call_function(
544
+ torch.stack, args=(batch_inputs,), kwargs={"dim": 0}
545
+ )
546
+ update_stack_example_value(stack_inputs, batch_inputs_metadata)
547
+ stack_weights = graph.call_function(
548
+ torch.stack, args=(batch_weights,), kwargs={"dim": 0}
549
+ )
550
+ update_stack_example_value(stack_weights, batch_weights_metadata)
551
+ transpose_weight = graph.call_function(
552
+ torch.transpose, args=(stack_weights, 1, 2)
553
+ )
554
+ if all(bias is None for bias in batch_biases):
555
+ bmm = graph.call_function(
556
+ torch.bmm,
557
+ args=(stack_inputs, transpose_weight),
558
+ )
559
+ else:
560
+ stack_biases = graph.call_function(
561
+ torch.stack, args=(batch_biases,), kwargs={"dim": 0}
562
+ )
563
+ update_stack_example_value(stack_biases, batch_biases_metadata)
564
+ unsqueeze_biases = graph.call_function(
565
+ torch.unsqueeze, args=(stack_biases, 1)
566
+ )
567
+ bmm = graph.call_function(
568
+ torch.baddbmm,
569
+ args=(unsqueeze_biases, stack_inputs, transpose_weight),
570
+ )
571
+
572
+ bmm = graph.call_function(torch.unbind, args=(bmm,), kwargs={"dim": 0})
573
+ for i, linear in enumerate(batch_nodes):
574
+ with graph.inserting_after(bmm):
575
+ getitem = graph.call_function(operator.getitem, args=(bmm, i))
576
+ linear.replace_all_uses_with(getitem)
577
+ getitem.meta.update(linear.meta)
578
+ graph.erase_node(linear)
579
+
580
+
581
+ @register_fusion("batch_layernorm")
582
+ class BatchLayernormFusion(BatchFusion):
583
+ """
584
+ Batch layer norm fusion in pre grad pass
585
+ """
586
+
587
+ def match(self, node: torch.fx.Node):
588
+ if CallFunctionVarArgs(torch.nn.functional.layer_norm).match(node):
589
+ input = get_arg_value(node, 0, "input")
590
+ weight = get_arg_value(node, 2, "weight")
591
+ bias = get_arg_value(node, 3, "bias")
592
+ group_key = (
593
+ (
594
+ "batch_layernorm",
595
+ str(input.meta["example_value"].shape),
596
+ str(weight.meta["example_value"].shape)
597
+ if weight is not None
598
+ else "",
599
+ str(bias.meta["example_value"].shape) if bias is not None else "",
600
+ str(get_arg_value(node, 1, "normalized_shape")),
601
+ str(get_arg_value(node, 4, "eps")),
602
+ )
603
+ if "example_value" in input.meta
604
+ and is_node_meta_valid(weight)
605
+ and is_node_meta_valid(bias)
606
+ else None
607
+ )
608
+ else:
609
+ group_key = None
610
+ return group_key
611
+
612
+ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
613
+ group_inputs = []
614
+ group_shapes = []
615
+ group_weights = []
616
+ group_biases = []
617
+ group_epss = []
618
+ group_nodes = []
619
+ group_inputs_metadata = []
620
+ group_biases_metadata = []
621
+ group_weights_metadata = []
622
+ for node in subset:
623
+ group_nodes.append(node)
624
+ input = get_arg_value(node, 0, "input")
625
+ group_inputs.append(input)
626
+ group_inputs_metadata.append(input.meta["example_value"])
627
+ group_shapes.append(get_arg_value(node, 1, "normalized_shape"))
628
+ weight = get_arg_value(node, 2, "weight")
629
+ group_weights.append(weight)
630
+ if weight is not None and hasattr(weight, "meta"):
631
+ group_weights_metadata.append(weight.meta["example_value"])
632
+ bias = get_arg_value(node, 3, "bias")
633
+ group_biases.append(bias)
634
+ if bias is not None and hasattr(bias, "meta"):
635
+ group_biases_metadata.append(bias.meta["example_value"])
636
+ eps = get_arg_value(node, 4, "eps")
637
+ if eps is None:
638
+ eps = 1e-5
639
+ group_epss.append(eps)
640
+ stack_dim = -1 - len(group_shapes[-1])
641
+
642
+ if all(bias is None for bias in group_biases):
643
+ group_biases = None # type: ignore[assignment]
644
+ group_biases: Optional[List[Any]]
645
+ if all(weight is None for weight in group_weights):
646
+ group_weights = None # type: ignore[assignment]
647
+ group_weights: Optional[List[Any]]
648
+ assert all(
649
+ eps == group_epss[0] for eps in group_epss
650
+ ), "all epsilon values must be equal"
651
+
652
+ with graph.inserting_before(subset[0]):
653
+ stack_input = graph.call_function(
654
+ torch.stack, args=(group_inputs,), kwargs={"dim": stack_dim}
655
+ )
656
+ update_stack_example_value(stack_input, group_inputs_metadata, stack_dim)
657
+ if group_weights is not None:
658
+ stack_weight = graph.call_function(
659
+ torch.stack, args=(group_weights,), kwargs={"dim": 0}
660
+ )
661
+ update_stack_example_value(stack_weight, group_weights_metadata)
662
+ else:
663
+ stack_weight = None
664
+ if group_biases is not None:
665
+ stack_bias = graph.call_function(
666
+ torch.stack, args=(group_biases,), kwargs={"dim": 0}
667
+ )
668
+ update_stack_example_value(stack_bias, group_biases_metadata)
669
+ else:
670
+ stack_bias = None
671
+
672
+ batch_layer_norm = graph.call_function(
673
+ torch.nn.functional.layer_norm,
674
+ args=(stack_input, group_shapes[-1]),
675
+ kwargs={"eps": group_epss[-1]},
676
+ )
677
+ batch_layer_norm.meta["example_value"] = stack_input.meta["example_value"]
678
+
679
+ if group_weights is not None and group_biases is not None:
680
+ previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
681
+ batch_layer_norm = graph.call_function(
682
+ torch.mul, args=(stack_weight, batch_layer_norm)
683
+ )
684
+ update_pointwise_example_value(
685
+ batch_layer_norm,
686
+ stack_weight.meta["example_value"],
687
+ previous_batch_layer_norm_meta,
688
+ torch.mul,
689
+ )
690
+ previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
691
+ batch_layer_norm = graph.call_function(
692
+ torch.add, args=(stack_bias, batch_layer_norm)
693
+ )
694
+ update_pointwise_example_value(
695
+ batch_layer_norm,
696
+ stack_bias.meta["example_value"],
697
+ previous_batch_layer_norm_meta,
698
+ torch.add,
699
+ )
700
+ elif group_weights is not None and group_biases is None:
701
+ previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
702
+ batch_layer_norm = graph.call_function(
703
+ torch.mul, args=(stack_weight, batch_layer_norm)
704
+ )
705
+ update_pointwise_example_value(
706
+ batch_layer_norm,
707
+ stack_weight.meta["example_value"],
708
+ previous_batch_layer_norm_meta,
709
+ torch.mul,
710
+ )
711
+ elif group_weights is None and group_biases is not None:
712
+ previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
713
+ batch_layer_norm = graph.call_function(
714
+ torch.add, args=(stack_bias, batch_layer_norm)
715
+ )
716
+ update_pointwise_example_value(
717
+ batch_layer_norm,
718
+ stack_bias.meta["example_value"],
719
+ previous_batch_layer_norm_meta,
720
+ torch.add,
721
+ )
722
+
723
+ batch_layer_norm_unbind = graph.call_function(
724
+ torch.unbind,
725
+ args=(batch_layer_norm,),
726
+ kwargs={"dim": stack_dim},
727
+ )
728
+ update_stack_example_value(
729
+ batch_layer_norm_unbind,
730
+ batch_layer_norm.meta["example_value"],
731
+ op=torch.unbind,
732
+ dim=stack_dim,
733
+ )
734
+
735
+ for i, node in enumerate(group_nodes):
736
+ with graph.inserting_after(batch_layer_norm_unbind):
737
+ new_node = graph.call_function(
738
+ operator.getitem, args=(batch_layer_norm_unbind, i)
739
+ )
740
+ node.replace_all_uses_with(new_node)
741
+ new_node.meta.update(node.meta)
742
+ graph.erase_node(node)
743
+
744
+
745
+ class BatchPointwiseOpsPreGradFusion(BatchPointwiseOpsFusionFactory):
746
+ """
747
+ Batch poinwise ops (e.g., sigmoid, relu, tanh) fusion in pre grad pass.
748
+ We fuse it in random place, and the introduced stack node may be merged in split cat.
749
+ """
750
+
751
+ def __init__(self, op, **kwargs):
752
+ super().__init__(op, **kwargs)
753
+ self.op = op
754
+
755
+ def match(self, node: torch.fx.Node):
756
+ input = get_arg_value(node, 0, "input")
757
+ if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node):
758
+ # for relu op, we also use the inplace to construct the key
759
+ group_key = (
760
+ "batch_" + self.op.__name__.lower() + "_pre_grad",
761
+ str(input.meta["example_value"].shape),
762
+ str(node.kwargs.get("inplace", False)),
763
+ )
764
+ else:
765
+ group_key = None
766
+ return group_key
767
+
768
+ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
769
+ batch_nodes = []
770
+ batch_inputs = []
771
+ batch_inputs_metadata = []
772
+
773
+ for node in subset:
774
+ batch_nodes.append(node)
775
+ input = get_arg_value(node, 0, "input")
776
+ batch_inputs.append(input)
777
+ batch_inputs_metadata.append(input.meta["example_value"])
778
+
779
+ with graph.inserting_before(subset[0]):
780
+ stack_inputs = graph.call_function(
781
+ torch.stack, args=(batch_inputs,), kwargs={"dim": 0}
782
+ )
783
+ update_stack_example_value(stack_inputs, batch_inputs_metadata)
784
+ if self.op == torch.nn.functional.relu:
785
+ batch_op = graph.call_function(
786
+ self.op,
787
+ args=(stack_inputs,),
788
+ kwargs={"inplace": subset[0].kwargs.get("inplace", False)},
789
+ )
790
+ else:
791
+ batch_op = graph.call_function(
792
+ self.op,
793
+ args=(stack_inputs,),
794
+ )
795
+ unbind_op = graph.call_function(
796
+ torch.unbind, args=(batch_op,), kwargs={"dim": 0}
797
+ )
798
+ for i, node in enumerate(batch_nodes):
799
+ with graph.inserting_after(unbind_op):
800
+ getitem = graph.call_function(operator.getitem, args=(unbind_op, i))
801
+ node.replace_all_uses_with(getitem)
802
+ getitem.meta.update(node.meta)
803
+ graph.erase_node(node)
804
+
805
+
806
+ @register_fusion("batch_tanh")
807
+ class BatchTanhPreGradFusion(BatchPointwiseOpsPreGradFusion):
808
+ def __init__(self, **kwargs):
809
+ super().__init__(torch.tanh, **kwargs)
810
+
811
+
812
+ @register_fusion("batch_sigmoid")
813
+ class BatchSigmoidPreGradFusion(BatchPointwiseOpsPreGradFusion):
814
+ def __init__(self, **kwargs):
815
+ super().__init__(torch.sigmoid, **kwargs)
816
+
817
+
818
+ @register_fusion("batch_relu")
819
+ class BatchReLuPreGradFusion(BatchPointwiseOpsPreGradFusion):
820
+ def __init__(self, **kwargs):
821
+ super().__init__(torch.nn.functional.relu, **kwargs)
822
+
823
+
824
+ @register_fusion("batch_aten_add", pre_grad=False)
825
+ class BatchAddPostGradFusion(BatchPointwiseOpsPostGradFusion):
826
+ def __init__(self, **kwargs):
827
+ super().__init__(aten.add.Tensor, **kwargs)
828
+
829
+
830
+ @register_fusion("batch_aten_sub", pre_grad=False)
831
+ class BatchSubPostGradFusion(BatchPointwiseOpsPostGradFusion):
832
+ def __init__(self, **kwargs):
833
+ super().__init__(aten.sub.Tensor, **kwargs)
834
+
835
+
836
+ @register_fusion("batch_aten_div", pre_grad=False)
837
+ class BatchDivPostGradFusion(BatchPointwiseOpsPostGradFusion):
838
+ def __init__(self, **kwargs):
839
+ super().__init__(aten.div.Tensor, **kwargs)
840
+
841
+
842
+ @register_fusion("batch_aten_mul", pre_grad=False)
843
+ class BatchMulPostGradFusion(BatchPointwiseOpsPostGradFusion):
844
+ def __init__(self, **kwargs):
845
+ super().__init__(aten.mul.Tensor, **kwargs)
846
+
847
+
848
+ class _OrderedSet:
849
+ def __init__(self, param=None):
850
+ if param:
851
+ self.rep = OrderedDict({k: None for k in param})
852
+ else:
853
+ self.rep = OrderedDict()
854
+
855
+ def __contains__(self, o):
856
+ return o in self.rep
857
+
858
+ def __len__(self):
859
+ return self.rep.__len__()
860
+
861
+ def append(self, o):
862
+ self.rep[o] = None
863
+
864
+ def __iter__(self):
865
+ return self.rep.keys().__iter__()
866
+
867
+
868
+ def find_independent_subset_greedy(
869
+ node_list: Iterable[torch.fx.Node],
870
+ graph_search_options: Dict[str, Any],
871
+ ) -> Iterator[Iterable[torch.fx.Node]]:
872
+ """
873
+ Yields a list of subsets of `node_list` where no element in the subset
874
+ depends on any other element in the subset. This results in a set of
875
+ independent nodes which can be fused together.
876
+
877
+ The order of `node_list` is preserved within each subset so we can benefit
878
+ from split-cat elimination in later passes.
879
+
880
+ During iteration it is only safe to mutate the graph by changing the nodes
881
+ that have been returned.
882
+
883
+ graph_search_options:
884
+ - min_fuse_set_size: Minimum size of the subset to consider. Subsets below
885
+ this size will be ignored.
886
+ - max_fuse_set_size: Maximum size of the subset to consider. Subsets will
887
+ be broken to be at most this size.
888
+ """
889
+
890
+ # Compute all the children of `node` which are members of
891
+ # `interesting_nodes`.
892
+ def find_dependent_nodes(node, interesting_nodes):
893
+ visited_node_set: Set[torch.fx.Node] = {node}
894
+ dep_set: Set[torch.fx.Node] = set()
895
+
896
+ work = [node]
897
+ while work:
898
+ node = work.pop()
899
+ for input_node in node.all_input_nodes:
900
+ if input_node in interesting_nodes:
901
+ dep_set.add(input_node)
902
+
903
+ if input_node not in visited_node_set:
904
+ visited_node_set.add(input_node)
905
+ work.append(input_node)
906
+
907
+ return dep_set
908
+
909
+ min_fuse_set_size = graph_search_options["min_fuse_set_size"]
910
+ max_fuse_set_size = graph_search_options["max_fuse_set_size"]
911
+
912
+ # node_list needs to be a set because we only track the nodes that are left
913
+ # in it (and we want to do the `in` on a set, not a list). But we want to
914
+ # keep the correct order.
915
+ node_list = _OrderedSet(node_list)
916
+
917
+ cache: Dict[torch.fx.Node, Set[torch.fx.Node]] = {}
918
+ while node_list:
919
+ subset: List[torch.fx.Node] = []
920
+ subset_deps: Set[torch.fx.Node] = set()
921
+
922
+ next_round_node_list = _OrderedSet()
923
+ for node in node_list:
924
+ if len(subset) >= max_fuse_set_size or node in subset_deps:
925
+ next_round_node_list.append(node)
926
+ continue
927
+
928
+ dep_set = cache.pop(node, None)
929
+ if dep_set is None:
930
+ dep_set = find_dependent_nodes(node, node_list)
931
+
932
+ if not dep_set.intersection(subset):
933
+ subset.append(node)
934
+ subset_deps.update(dep_set)
935
+ else:
936
+ next_round_node_list.append(node)
937
+ cache[node] = dep_set
938
+
939
+ if len(subset) >= min_fuse_set_size:
940
+ # Careful here - the caller uses the subsets to fuse nodes together
941
+ # so we need to clear any cache entry that contains one of the
942
+ # returned nodes because the dependency list could be different
943
+ # (larger) after the merge.
944
+ cache = {k: v for k, v in cache.items() if v.isdisjoint(subset)}
945
+ yield subset
946
+
947
+ node_list = next_round_node_list
948
+
949
+
950
+ def get_fusion_candidates(
951
+ rule: GroupBatchFusionBase, root_node: torch.fx.Node, fused_set: Set[torch.fx.Node]
952
+ ) -> DefaultDict[Any, List[torch.fx.Node]]:
953
+ """
954
+ Search fusion candidates for a specific rule using BFS starting from the root node.
955
+ We only search the subgraph within graph_search_options["max_fuse_search_depth"].
956
+ """
957
+ q: Deque[Tuple[int, torch.fx.Node]] = collections.deque()
958
+
959
+ candidate_dict: DefaultDict[Any, List[torch.fx.Node]] = collections.defaultdict(
960
+ list
961
+ )
962
+
963
+ if root_node.target in SEARCH_EXCLUSIONS:
964
+ return candidate_dict
965
+
966
+ visited_set: Set[torch.fx.Node] = set()
967
+
968
+ for next_node in root_node.all_input_nodes:
969
+ q.append((1, next_node))
970
+ visited_set.add(next_node)
971
+
972
+ while len(q) > 0:
973
+ depth, node = q.popleft()
974
+
975
+ if node in fused_set:
976
+ continue
977
+
978
+ key = rule.match(node)
979
+ if key is not None:
980
+ candidate_nodes = candidate_dict[key]
981
+ if node not in candidate_nodes:
982
+ candidate_nodes.append(node)
983
+ else:
984
+ if depth < rule.graph_search_options["max_fuse_search_depth"]:
985
+ for next_node in node.all_input_nodes:
986
+ if next_node not in visited_set:
987
+ visited_set.add(next_node)
988
+ q.append((depth + 1, next_node))
989
+
990
+ return candidate_dict
991
+
992
+
993
+ def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusionBase):
994
+ stable_topological_sort(graph) # type: ignore[arg-type]
995
+ fused_set: Set[torch.fx.Node] = set()
996
+
997
+ for node in reversed(graph.nodes):
998
+ candidates = get_fusion_candidates(rule, node, fused_set)
999
+
1000
+ for key, candidate_nodes in candidates.items():
1001
+ if len(candidate_nodes) < rule.graph_search_options["min_fuse_set_size"]:
1002
+ continue
1003
+
1004
+ for subset in find_independent_subset_greedy(
1005
+ candidate_nodes, rule.graph_search_options
1006
+ ):
1007
+ rule.fuse(graph, subset)
1008
+ fused_set.update(subset)
1009
+ if isinstance(rule, GroupFusion):
1010
+ counters["inductor"]["group_fusion"] += 1
1011
+ elif isinstance(rule, BatchFusion):
1012
+ counters["inductor"]["batch_fusion"] += 1
1013
+ else:
1014
+ counters["inductor"]["unknown_group_batch_fusion"] += 1
1015
+
1016
+ log.debug(
1017
+ f"{rule.__class__.__name__}: key = {key}; subset size = {len(list(subset))}" # noqa: G004
1018
+ )
1019
+
1020
+
1021
+ def generate_fusion_from_config(config_options: Dict[str, Any], pre_grad=True):
1022
+ fusions: List[GroupBatchFusionBase] = []
1023
+ for name, options in config_options.items():
1024
+ fusion_cls = PRE_GRAD_FUSIONS[name] if pre_grad else POST_GRAD_FUSIONS[name]
1025
+ _options = graph_search_options.copy()
1026
+ _options.update(options)
1027
+ fusions.append(fusion_cls(graph_search_options=_options)) # type: ignore[operator]
1028
+ return fusions
1029
+
1030
+
1031
+ def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
1032
+ fusions: List[GroupBatchFusionBase] = []
1033
+ # we keep all current pre grad fusions to keep
1034
+ # current implementation, will remove this later
1035
+ if pre_grad:
1036
+ fusions += generate_fusion_from_config(
1037
+ config.pre_grad_fusion_options, pre_grad=True
1038
+ )
1039
+ else:
1040
+ fbgemm_fusion_keys = [
1041
+ x
1042
+ for x in config.post_grad_fusion_options
1043
+ if config.post_grad_fusion_options[x].get("require_fbgemm", False)
1044
+ ]
1045
+ fbgemm_fusions = {
1046
+ fusion: config.post_grad_fusion_options[fusion]
1047
+ for fusion in fbgemm_fusion_keys
1048
+ }
1049
+ non_fbgemm_fusions = {
1050
+ fusion: config.post_grad_fusion_options[fusion]
1051
+ for fusion in config.post_grad_fusion_options.keys()
1052
+ if fusion not in fbgemm_fusion_keys
1053
+ }
1054
+ fusions += generate_fusion_from_config(non_fbgemm_fusions, pre_grad=False)
1055
+ if has_fbgemm:
1056
+ fusions += generate_fusion_from_config(fbgemm_fusions, pre_grad=False)
1057
+
1058
+ for rule in fusions:
1059
+ apply_group_batch_fusion(graph, rule) # type: ignore[arg-type]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/joint_graph.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import typing
3
+ from collections import Counter
4
+ from typing import Dict, List, Set
5
+
6
+ import torch
7
+ import torch._guards
8
+ from torch._inductor.constant_folding import ConstantFolder
9
+ from torch.multiprocessing.reductions import StorageWeakRef
10
+
11
+ from .. import config
12
+ from ..pattern_matcher import (
13
+ CallFunction,
14
+ init_once_fakemode,
15
+ KeywordArg,
16
+ Match,
17
+ PatternMatcherPass,
18
+ register_graph_pattern,
19
+ stable_topological_sort,
20
+ )
21
+ from .replace_random import replace_random_passes
22
+
23
+ log = logging.getLogger(__name__)
24
+ patterns = PatternMatcherPass()
25
+
26
+
27
+ @init_once_fakemode
28
+ def lazy_init():
29
+ from .fuse_attention import _sfdp_init
30
+ from .misc_patterns import _misc_patterns_init
31
+ from .pad_mm import _pad_mm_init
32
+
33
+ _pad_mm_init()
34
+ _sfdp_init()
35
+ _misc_patterns_init()
36
+
37
+
38
+ @torch.utils._python_dispatch._disable_current_modes()
39
+ def remove_no_ops(
40
+ gm: torch.fx.GraphModule, zeros: Set[torch.fx.Node], ones: Set[torch.fx.Node]
41
+ ):
42
+ "Removes no-ops: (+ 0, - 0, * 1, / 1)"
43
+ aten = torch.ops.aten
44
+ graph = gm.graph
45
+
46
+ def fake_tensors_eq(t1, t2, fields=("shape", "dtype", "device")):
47
+ if any(not isinstance(t, torch.Tensor) for t in (t1, t2)):
48
+ return False
49
+ for field in fields:
50
+ if getattr(t1, field) != getattr(t2, field):
51
+ return False
52
+ return True
53
+
54
+ def replace_no_op(node, replace_input_index):
55
+ replacement = node.args[replace_input_index]
56
+
57
+ # https://github.com/pytorch/pytorch/issues/86128 causes
58
+ # non-Tensor inputs even for ops with only Tensor inputs.
59
+ # TODO - decompose/type promote to avoid this
60
+ if not all(isinstance(arg, torch.fx.Node) for arg in node.args):
61
+ return
62
+
63
+ if not fake_tensors_eq(node.meta["val"], replacement.meta["val"]):
64
+ if fake_tensors_eq(
65
+ node.meta["val"],
66
+ replacement.meta["val"],
67
+ ("shape", "device"),
68
+ ):
69
+ with graph.inserting_after(node):
70
+ replacement = graph.call_function(
71
+ torch.ops.prims.convert_element_type.default,
72
+ args=(replacement, node.meta["val"].dtype),
73
+ )
74
+ else:
75
+ return
76
+
77
+ node.replace_all_uses_with(replacement)
78
+ replacement.meta.update(node.meta)
79
+ graph.erase_node(node)
80
+
81
+ for node in graph.nodes:
82
+ if node.op != "call_function":
83
+ continue
84
+
85
+ # TODO handle Tensor-Scalar adds, it's a different schema
86
+ if node.target == aten.add.Tensor and len(node.args) == 2:
87
+ if (
88
+ not any(e in zeros for e in node.args)
89
+ or node.kwargs.get("alpha", 1) != 1
90
+ ):
91
+ continue
92
+
93
+ replace_index = 1 if node.args[0] in zeros else 0
94
+ replace_no_op(node, replace_index)
95
+
96
+ elif node.target == aten.sub.Tensor and len(node.args) == 2:
97
+ if node.args[1] not in zeros or node.kwargs.get("alpha", 1) != 1:
98
+ continue
99
+
100
+ replace_no_op(node, 0)
101
+
102
+ elif node.target == aten.mul.Tensor and len(node.args) == 2:
103
+ if not any(e in ones for e in node.args):
104
+ continue
105
+
106
+ replace_input_index = 1 if node.args[0] in ones else 0
107
+ replace_no_op(node, replace_input_index)
108
+
109
+ elif (
110
+ node.target == aten.div.Tensor
111
+ and len(node.args) == 2
112
+ and node.args[1] in ones
113
+ ):
114
+ replace_no_op(node, 0)
115
+
116
+
117
+ @torch.utils._python_dispatch._disable_current_modes()
118
+ def remove_redundant_views(gm: torch.fx.GraphModule):
119
+ """
120
+ Removes redundant views by reusing existing ones.
121
+ """
122
+
123
+ # A dictionary mapping a tensor to all aliased views.
124
+ views: Dict[torch.fx.Node, Dict[torch.dtype, torch.fx.Node]] = {}
125
+ graph = gm.graph
126
+
127
+ for node in graph.nodes:
128
+ if node.op != "call_function":
129
+ continue
130
+
131
+ if node.target != torch.ops.aten.view.dtype:
132
+ continue
133
+
134
+ src = node.args[0]
135
+ to_type = node.args[1]
136
+ existing_views = views.get(src)
137
+ is_needed = True
138
+
139
+ if existing_views:
140
+ # Replace the view with the an existing view if available.
141
+ alias = existing_views.get(to_type)
142
+ if alias:
143
+ is_needed = False
144
+ node.replace_all_uses_with(alias)
145
+ alias.meta.update(node.meta)
146
+ graph.erase_node(node)
147
+ else:
148
+ from_type = src.meta["val"].dtype
149
+ existing_views = {from_type: src}
150
+ views[src] = existing_views
151
+
152
+ if is_needed:
153
+ # Save the new alias but do not replace existing one.
154
+ existing_views.setdefault(to_type, node)
155
+ views[node] = existing_views
156
+
157
+ # Clean up unused views.
158
+ while True:
159
+ unused_views = [alias for alias in views if not alias.users]
160
+ if len(unused_views) == 0:
161
+ break
162
+ for unused in unused_views:
163
+ views.pop(unused)
164
+ graph.erase_node(unused)
165
+
166
+
167
+ class UniformValueConstantFolder(ConstantFolder):
168
+ """
169
+ Runs constant folding and replaces tensors that have a unifrom value
170
+ with a tensor constructor call: aten.full([shape], value, ...)
171
+ """
172
+
173
+ def __init__(self, gm, skip_constructors=False):
174
+ super().__init__(gm, skip_constructors)
175
+ self.node_storages_ptrs: Dict[torch.fx.Node, int] = {}
176
+ self.constant_data_ptrs: Dict[torch.fx.Node, StorageWeakRef] = {}
177
+ # we may constant fold a tensor which in the graph has a sym size
178
+ # see: [constant folding refining of symints]
179
+ self.node_replacements_shapes: Dict[torch.fx.Node, List[int]] = {}
180
+
181
+ def insertable_tensor_check(self, t: torch.Tensor) -> bool:
182
+ # TODO - we could also Tensors which get replaced with arange here
183
+ return (
184
+ t.numel() != 0
185
+ and bool((t == t.flatten()[0]).all())
186
+ and torch._C._has_storage(t)
187
+ and t.layout == torch.strided
188
+ )
189
+
190
+ def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
191
+ self.node_replacements[node] = tensor.flatten()[0].item()
192
+ self.constant_data_ptrs[node] = StorageWeakRef(tensor.untyped_storage())
193
+ shape = list(tensor.shape)
194
+ assert all(type(dim) is int for dim in shape)
195
+ self.node_replacements_shapes[node] = shape
196
+
197
+
198
+ @torch.utils._python_dispatch._disable_current_modes()
199
+ def constant_fold_uniform_value(gm: torch.fx.GraphModule):
200
+ "Runs constant folding and replaces constants which can be constructed with a single `full` call. Calls into remove_no_ops."
201
+ aten = torch.ops.aten
202
+
203
+ # Constant folding can leak memory, especially with repeated compilation, so we are only going to
204
+ # remove constants which can be replaced with a constructor.
205
+ cf = UniformValueConstantFolder(gm)
206
+ cf.run()
207
+
208
+ node_replacements = cf.node_replacements
209
+
210
+ # note: [constant folding refining of symints]
211
+ # constant folding will partially evaluate a graph such that values which have dependencies which
212
+ # are entirely known at compile time may also become compile time constants. in some cases,
213
+ # this will include symints which we had not yet previously deduced are guaranteed a
214
+ # constant value and is then deduced in constant folding. an example is:
215
+ # unbacked_symint_eq_11 = torch.full((), 11).item()
216
+ # torch.full((unbacked_symint_eq_11,), 0)
217
+ node_replacements_shapes = cf.node_replacements_shapes
218
+
219
+ graph = gm.graph
220
+
221
+ zeros = set()
222
+ ones = set()
223
+
224
+ # Got failures in `test_is_set_to_cuda` if we change aliasing on constants,
225
+ # so just constant-ify if a Tensor is unaliased
226
+ constant_data_ptr_count: typing.Counter[StorageWeakRef] = Counter()
227
+
228
+ for node in cf.node_replacements:
229
+ constant_data_ptr_count[cf.constant_data_ptrs[node]] += 1
230
+
231
+ for node, value in node_replacements.items():
232
+ # we dont have a functional way right now of instantiating a non-contiguous tensor with full/zeros/ones right now
233
+ # hasn't shown up to be important yet
234
+ fake_tensor = node.meta["val"]
235
+ if not fake_tensor.is_contiguous(memory_format=torch.contiguous_format):
236
+ continue
237
+
238
+ if constant_data_ptr_count[cf.constant_data_ptrs[node]] > 1:
239
+ continue
240
+
241
+ with graph.inserting_after(node):
242
+ # the conversion from tensor and back to value can be lossy, just use the original full ctor value
243
+ if (
244
+ node.op == "call_function"
245
+ and node.target == aten.full.default
246
+ and len(node.args) == 2
247
+ ):
248
+ value = node.args[1]
249
+
250
+ # refines symints, see [constant folding refining of symints] above
251
+ for runtime_size, compile_time_size in zip(
252
+ node_replacements_shapes[node], fake_tensor.shape
253
+ ):
254
+ torch._check(runtime_size == compile_time_size)
255
+
256
+ # zeros, and ones just get traced into full, so we insert those
257
+ new_node = graph.call_function(
258
+ aten.full.default,
259
+ args=(node_replacements_shapes[node], value),
260
+ kwargs={
261
+ "dtype": fake_tensor.dtype,
262
+ "layout": torch.strided,
263
+ "device": fake_tensor.device,
264
+ "pin_memory": False,
265
+ },
266
+ )
267
+
268
+ new_node.meta.update(node.meta)
269
+ node.replace_all_uses_with(new_node)
270
+ graph.erase_node(node)
271
+
272
+ if value == 0:
273
+ zeros.add(new_node)
274
+ elif value == 1:
275
+ ones.add(new_node)
276
+
277
+ remove_no_ops(gm, zeros, ones)
278
+ remove_redundant_views(gm)
279
+
280
+
281
+ def joint_graph_passes(graph: torch.fx.GraphModule):
282
+ """
283
+ Run FX transformations on the joint forwards+backwards graph.
284
+ """
285
+ lazy_init()
286
+ count = 0
287
+
288
+ if config.joint_graph_constant_folding:
289
+ constant_fold_uniform_value(graph)
290
+
291
+ if config.pattern_matcher:
292
+ count += patterns.apply(graph.graph) # type: ignore[arg-type]
293
+
294
+ if not config.fallback_random:
295
+ count += replace_random_passes(graph)
296
+
297
+ if count:
298
+ stable_topological_sort(graph.graph)
299
+ graph.graph.lint()
300
+ graph.recompile()
301
+ return graph
302
+
303
+
304
+ @register_graph_pattern(
305
+ CallFunction(
306
+ torch.ops.prims.convert_element_type.default,
307
+ CallFunction(
308
+ torch.ops.prims.convert_element_type.default,
309
+ KeywordArg("arg"),
310
+ KeywordArg("dtype1"),
311
+ ),
312
+ KeywordArg("dtype2"),
313
+ ),
314
+ pass_dict=patterns,
315
+ )
316
+ def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtype):
317
+ """Remove chain of dtype conversions often created by AMP"""
318
+ graph = match.graph
319
+ node = match.output_node()
320
+ allowed = {torch.float16, torch.bfloat16, torch.float32, torch.float64}
321
+ if dtype1 in allowed and dtype2 in allowed:
322
+ repl = graph.call_function(
323
+ torch.ops.prims.convert_element_type.default, (arg, dtype2)
324
+ )
325
+ repl.meta.update(node.meta)
326
+ node.replace_all_uses_with(repl)
327
+ match.erase_nodes(graph)
328
+
329
+
330
+ @register_graph_pattern(
331
+ CallFunction(torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")),
332
+ pass_dict=patterns,
333
+ )
334
+ def pointless_view(match: Match, arg, size):
335
+ """Remove no-op view"""
336
+ graph = match.graph
337
+ node = match.output_node()
338
+ arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr]
339
+ if size == arg_size:
340
+ node.replace_all_uses_with(node.args[0])
341
+ match.erase_nodes(graph)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pre_grad.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ from typing import List, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch._dynamo.utils import counters, detect_fake_mode, optimus_scuba_log
8
+ from torch._utils_internal import upload_graph
9
+ from torch.fx.experimental.optimization import (
10
+ matches_module_pattern,
11
+ replace_node_module,
12
+ )
13
+ from torch.fx.passes.shape_prop import ShapeProp
14
+ from torch.nn import functional as F
15
+ from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights
16
+
17
+ from .. import config
18
+
19
+ from ..fx_utils import matches_module_function_pattern
20
+ from ..pattern_matcher import (
21
+ init_once_fakemode,
22
+ PatternMatcherPass,
23
+ stable_topological_sort,
24
+ )
25
+ from ..utils import is_cpu_device, pass_execution_and_save
26
+ from .group_batch_fusion import group_batch_fusion_passes
27
+ from .misc_patterns import numpy_compat_normalization
28
+
29
+ log = logging.getLogger(__name__)
30
+
31
+ normalization_pass = PatternMatcherPass(
32
+ prevent_match_across_mutations=True, pass_name="normalization_pass"
33
+ )
34
+ merge_splits_pass = PatternMatcherPass(
35
+ prevent_match_across_mutations=True, pass_name="merge_splits_pass"
36
+ )
37
+ split_cat_pass = PatternMatcherPass(
38
+ prevent_match_across_mutations=True, pass_name="split_cat_pass"
39
+ )
40
+ unbind_stack_pass = PatternMatcherPass(
41
+ prevent_match_across_mutations=True, pass_name="unbind_stack_pass"
42
+ )
43
+ efficient_conv_bn_eval_pass = PatternMatcherPass(
44
+ prevent_match_across_mutations=True, pass_name="efficient_conv_bn_eval_pass"
45
+ )
46
+ merge_getitem_cat_pass = PatternMatcherPass(
47
+ prevent_match_across_mutations=True, pass_name="merge_getitem_cat_pass"
48
+ )
49
+
50
+ fuse_split_linear_add_pass = PatternMatcherPass(
51
+ prevent_match_across_mutations=True,
52
+ pass_name="fuse_split_linear_add_pass",
53
+ )
54
+ fuse_chunk_squeeze_cat_pass = PatternMatcherPass(
55
+ prevent_match_across_mutations=True,
56
+ pass_name="fuse_chunk_squeeze_cat_pass",
57
+ )
58
+ remove_reshape_pass = PatternMatcherPass(
59
+ prevent_match_across_mutations=True,
60
+ pass_name="remove_reshape_pass",
61
+ )
62
+
63
+ # based on predispatch aten IR
64
+ normalization_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
65
+ merge_splits_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
66
+ split_cat_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
67
+ unbind_stack_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
68
+ merge_getitem_cat_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
69
+
70
+
71
+ def fuse_parallel_linear_pass(graph):
72
+ return None
73
+
74
+
75
+ def remove_split_ops(graph, shape_prop):
76
+ return None
77
+
78
+
79
+ pattern_matcher_passes: List[PatternMatcherPass] = [
80
+ normalization_pass,
81
+ merge_getitem_cat_pass,
82
+ merge_splits_pass,
83
+ split_cat_pass,
84
+ unbind_stack_pass,
85
+ efficient_conv_bn_eval_pass,
86
+ ]
87
+ pattern_matcher_passes_aten: List[PatternMatcherPass] = [
88
+ merge_getitem_cat_pass_aten,
89
+ merge_splits_pass_aten,
90
+ split_cat_pass_aten,
91
+ unbind_stack_pass_aten,
92
+ ]
93
+
94
+
95
+ @init_once_fakemode
96
+ def lazy_init():
97
+ from . import efficient_conv_bn_eval, split_cat # noqa: F401 # noqa: F401
98
+
99
+ if config.is_fbcode():
100
+ from . import fb # type: ignore[attr-defined] # noqa: F401
101
+
102
+
103
+ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs=None):
104
+ """
105
+ Apply passes on the input FX graph using Torch IR.
106
+
107
+ WARNING:
108
+ The IR before grad is not functional or normalized, so it is harder
109
+ to write passes on this IR. Passes must be safe with respect to
110
+ aliasing and mutation and need to handle all possible arg schemas.
111
+
112
+ Consider adding a new pass to post_grad.py or joint_graph.py which
113
+ are after functionalization and normalization.
114
+ """
115
+ if config.pattern_matcher:
116
+ lazy_init()
117
+ if hasattr(
118
+ config, "fx_passes_numeric_check"
119
+ ) and config.fx_passes_numeric_check.get("pre_grad", False):
120
+ gm_before_fx_passes = gm.__copy__()
121
+ # explicitly run with predispatch atenIR based passes
122
+ if config.is_predispatch:
123
+
124
+ def shape_prop(mod) -> None:
125
+ ShapeProp(
126
+ gm=mod,
127
+ fake_mode=detect_fake_mode(example_inputs),
128
+ ).propagate(*example_inputs)
129
+
130
+ # normalization pass
131
+ pass_execution_and_save(
132
+ normalization_pass_aten.apply,
133
+ gm,
134
+ "[Pre grad(predispatch IR)]Apply normalization pass",
135
+ )
136
+ pass_execution_and_save(
137
+ group_batch_fusion_passes,
138
+ gm,
139
+ "[Pre grad(predispatch IR)] Apply group_batch_fusion",
140
+ )
141
+ pass_execution_and_save(
142
+ fuse_chunk_squeeze_cat_pass.apply,
143
+ gm,
144
+ "[Pre grad(predispatch IR)] Apply fuse_chunk_squeeze_cat_pass",
145
+ )
146
+ pass_execution_and_save(
147
+ fuse_split_linear_add_pass.apply,
148
+ gm,
149
+ "[Pre grad(predispatch IR)] Apply fuse_split_linear_add_pass",
150
+ )
151
+
152
+ log.debug(
153
+ "[Pre grad(predispatch IR)]Before split cat in pre grad pass. graph: %s",
154
+ gm.graph,
155
+ )
156
+ for ind, pattern_matcher_pass_aten in enumerate(
157
+ pattern_matcher_passes_aten
158
+ ):
159
+ pass_execution_and_save(
160
+ pattern_matcher_pass_aten.apply,
161
+ gm,
162
+ f"[Pre grad(predispatch IR)]Apply split_cat, index: {ind}",
163
+ )
164
+ pass_execution_and_save(
165
+ remove_reshape_pass.apply,
166
+ gm,
167
+ "[Pre grad(predispatch IR)] Apply remove_reshape_pass",
168
+ )
169
+ pass_execution_and_save(
170
+ fuse_parallel_linear_pass,
171
+ gm,
172
+ "[Pre grad(predispatch IR)] Apply fuse_parallel_linear_pass",
173
+ )
174
+ pass_execution_and_save(
175
+ lambda graph: remove_split_ops(graph.owning_module, shape_prop),
176
+ gm,
177
+ "[Pre grad(predispatch IR)] Apply remove_split_ops",
178
+ )
179
+ shape_prop(gm)
180
+
181
+ else:
182
+ # We only log the graph with changes to avoid the excessive compilation time
183
+ # https://fb.workplace.com/groups/257735836456307/permalink/633533465543207/
184
+ if example_inputs is not None:
185
+ gm = fuse_fx(gm, example_inputs)
186
+ numpy_compat_normalization(gm.graph)
187
+ inductor_before_change = copy.deepcopy(counters["inductor"])
188
+ group_batch_fusion_passes(gm.graph, pre_grad=True)
189
+ if counters["inductor"] != inductor_before_change:
190
+ optimus_scuba_log["group_batch_fusion_pre_grad"] = upload_graph(
191
+ gm.graph
192
+ )
193
+ for pattern_matcher_pass in pattern_matcher_passes:
194
+ inductor_before_change = copy.deepcopy(counters["inductor"])
195
+ pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
196
+ if counters["inductor"] != inductor_before_change:
197
+ optimus_scuba_log[
198
+ f"split_cat_pattern_{pattern_matcher_pass.pass_name}_pre_grad"
199
+ ] = upload_graph(gm.graph)
200
+
201
+ if config.pre_grad_custom_pass is not None:
202
+ config.pre_grad_custom_pass(gm.graph)
203
+ stable_topological_sort(gm.graph)
204
+ gm.graph.lint()
205
+ gm.recompile()
206
+
207
+ if (
208
+ config.pattern_matcher
209
+ and hasattr(config, "fx_passes_numeric_check")
210
+ and config.fx_passes_numeric_check.get("pre_grad", False)
211
+ and example_inputs is not None
212
+ ):
213
+ from .numeric_utils import numeric_check_if_enabled
214
+
215
+ gm_after_fx_passes = gm.__copy__()
216
+ numeric_check_if_enabled(
217
+ gm_before_fx_passes, # type: ignore[possibly-undefined]
218
+ gm_after_fx_passes,
219
+ example_inputs,
220
+ config.fx_passes_numeric_check.get("num_iterations", 1),
221
+ config.fx_passes_numeric_check.get("precision", 1e-4),
222
+ )
223
+
224
+ return gm
225
+
226
+
227
+ def fuse_fx(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule:
228
+ is_cpu = is_cpu_device(example_inputs)
229
+
230
+ fake_mode = detect_fake_mode(example_inputs)
231
+
232
+ gm = sink_cat_after_pointwise(gm)
233
+ if config.permute_fusion and not is_cpu:
234
+ # For linear permute fusion, we need to check input info to identify
235
+ # and perform proper permutation/transpose
236
+ ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs)
237
+ gm = linear_permute_fusion(gm)
238
+ gm = permute_linear_fusion(gm)
239
+ gm = permute_matmul_fusion(gm)
240
+
241
+ # make sure the autograd is disabled.
242
+ if torch.is_grad_enabled() or not is_cpu:
243
+ return gm
244
+ if config.freezing:
245
+ gm = remove_identity(gm)
246
+ gm = fuse_conv_bn(gm)
247
+ return gm
248
+
249
+
250
+ def fetch_attr(target: str, mod):
251
+ target_atoms = target.split(".")
252
+ attr_itr = mod
253
+ for i, atom in enumerate(target_atoms):
254
+ if not hasattr(attr_itr, atom):
255
+ raise RuntimeError(
256
+ f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
257
+ )
258
+ attr_itr = getattr(attr_itr, atom)
259
+ return attr_itr
260
+
261
+
262
+ def remove_identity(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
263
+ """
264
+ Removes all identity layers from the module.
265
+ """
266
+
267
+ class IdentityRemover(torch.fx.Transformer):
268
+ def call_module(self, target, args, kwargs):
269
+ if isinstance(self.submodules[target], nn.Identity):
270
+ assert len(args) == 1
271
+ return args[0]
272
+ else:
273
+ return super().call_module(target, args, kwargs)
274
+
275
+ return IdentityRemover(gm).transform()
276
+
277
+
278
+ def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False) -> torch.fx.GraphModule:
279
+ """
280
+ Fuses Convolution/BN layers for inference purposes.
281
+ """
282
+ modules_patterns = [
283
+ (torch.nn.Conv1d, torch.nn.BatchNorm1d),
284
+ (torch.nn.Conv2d, torch.nn.BatchNorm2d),
285
+ (torch.nn.Conv3d, torch.nn.BatchNorm3d),
286
+ ]
287
+ module_function_patterns = [
288
+ (torch.nn.Conv1d, F.batch_norm),
289
+ (torch.nn.Conv2d, F.batch_norm),
290
+ (torch.nn.Conv3d, F.batch_norm),
291
+ ]
292
+ modules = dict(gm.named_modules())
293
+ for pattern in modules_patterns:
294
+ for node in gm.graph.nodes:
295
+ if matches_module_pattern(pattern, node, modules):
296
+ if len(node.args[0].users) > 1: # Output of conv is used by other nodes
297
+ continue
298
+ conv = modules[node.args[0].target]
299
+ bn = modules[node.target]
300
+ eval_mode = all(not n.training for n in [conv, bn])
301
+ if not eval_mode:
302
+ continue
303
+ if not bn.track_running_stats:
304
+ continue
305
+ fused_conv = fuse_conv_bn_eval(conv, bn)
306
+ replace_node_module(node.args[0], modules, fused_conv)
307
+ node.replace_all_uses_with(node.args[0])
308
+ gm.graph.erase_node(node)
309
+ gm.graph.lint()
310
+ for pattern in module_function_patterns:
311
+ for node in gm.graph.nodes:
312
+ if matches_module_function_pattern(pattern, node, modules):
313
+ # TODO: support kwargs.
314
+ if len(node.args) != 8:
315
+ continue
316
+ conv = modules[node.args[0].target]
317
+ bn_training = node.args[5]
318
+ bn_eps = node.args[7]
319
+ if conv.training or bn_training:
320
+ continue
321
+ if type(bn_eps) is not float:
322
+ continue
323
+ bn_args_is_constant = all(
324
+ n.op == "get_attr" and len(n.users) == 1 for n in node.args[1:5]
325
+ )
326
+ if not bn_args_is_constant:
327
+ continue
328
+ bn_running_mean = fetch_attr(node.args[1].target, gm)
329
+ bn_running_var = fetch_attr(node.args[2].target, gm)
330
+ bn_weight = fetch_attr(node.args[3].target, gm)
331
+ bn_bias = fetch_attr(node.args[4].target, gm)
332
+ if bn_running_mean is None or bn_running_var is None:
333
+ continue
334
+ fused_conv = copy.deepcopy(conv)
335
+ fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights(
336
+ fused_conv.weight,
337
+ fused_conv.bias,
338
+ bn_running_mean,
339
+ bn_running_var,
340
+ bn_eps,
341
+ bn_weight,
342
+ bn_bias,
343
+ )
344
+ replace_node_module(node.args[0], modules, fused_conv)
345
+ node.replace_all_uses_with(node.args[0])
346
+ gm.graph.erase_node(node)
347
+ gm.graph.lint()
348
+ gm.recompile()
349
+
350
+ return gm
351
+
352
+
353
+ class NormalizedLinearNode:
354
+ def __init__(self, node: torch.fx.Node) -> None:
355
+ assert node.op == "call_function"
356
+ assert node.target in [torch.nn.functional.linear]
357
+ self.node: torch.fx.Node = node
358
+
359
+ def get_input(self) -> torch.fx.Node:
360
+ if len(self.node.args) > 0:
361
+ return self.node.args[0] # type: ignore[return-value]
362
+ else:
363
+ return self.node.kwargs["input"] # type: ignore[return-value]
364
+
365
+ def get_weight(self) -> torch.fx.Node:
366
+ if len(self.node.args) > 1:
367
+ return self.node.args[1] # type: ignore[return-value]
368
+ else:
369
+ return self.node.kwargs["weight"] # type: ignore[return-value]
370
+
371
+ def get_bias(self) -> torch.fx.Node:
372
+ if len(self.node.args) > 2:
373
+ return self.node.args[2] # type: ignore[return-value]
374
+ else:
375
+ return self.node.kwargs["bias"] if "bias" in self.node.kwargs else None # type: ignore[return-value]
376
+
377
+
378
+ class NormalizedMatmulNode:
379
+ def __init__(self, node: torch.fx.Node) -> None:
380
+ assert node.op == "call_function"
381
+ assert node.target in [torch.bmm, torch.matmul]
382
+ self.node: torch.fx.Node = node
383
+
384
+ def get_input(self) -> torch.fx.Node:
385
+ if len(self.node.args) > 0:
386
+ return self.node.args[0] # type: ignore[return-value]
387
+ else:
388
+ return self.node.kwargs["input"] # type: ignore[return-value]
389
+
390
+ def get_other(self) -> torch.fx.Node:
391
+ if len(self.node.args) > 1:
392
+ return self.node.args[1] # type: ignore[return-value]
393
+ else:
394
+ return self.node.kwargs["other"] # type: ignore[return-value]
395
+
396
+
397
+ def check_permute(node: torch.fx.Node) -> bool:
398
+ ranks = len(node.meta["tensor_meta"].shape)
399
+ if len(node.args) > 3:
400
+ permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] # type: ignore[operator]
401
+ elif (
402
+ "permutation" in node.kwargs
403
+ and node.kwargs["permutation"] is not None
404
+ and len(node.kwargs["permutation"]) > 2 # type: ignore[arg-type]
405
+ ):
406
+ permutation = [i % ranks for i in node.kwargs["permutation"]] # type: ignore[union-attr]
407
+ else:
408
+ return False
409
+ allowed_permutation = list(range(ranks))
410
+ allowed_permutation[-1] = ranks - 2
411
+ allowed_permutation[-2] = ranks - 1
412
+ return permutation == allowed_permutation
413
+
414
+
415
+ def sink_cat_after_pointwise(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
416
+ def one_user(node):
417
+ users = list(node.users)
418
+ return users[0] if len(users) == 1 else None
419
+
420
+ def is_view(node):
421
+ view = {"view"}
422
+ return node.op == "call_method" and node.target in view
423
+
424
+ def is_pointwise_unary(node):
425
+ pointwise = {torch.relu, torch.tanh, "relu", "tanh"}
426
+ return node.op in {"call_function", "call_method"} and node.target in pointwise
427
+
428
+ g = module.graph
429
+ for node in g.nodes:
430
+ if node.op != "call_function" or node.target != torch.cat:
431
+ continue
432
+
433
+ cat_or_view = node
434
+ while True:
435
+ user = one_user(cat_or_view)
436
+ if not user or not is_view(user):
437
+ break
438
+ cat_or_view = user
439
+
440
+ if user and is_pointwise_unary(user):
441
+ with g.inserting_before(node):
442
+
443
+ def cat_args(tensors, dim=0):
444
+ return tensors, dim
445
+
446
+ tensors, dim = cat_args(*node.args, **node.kwargs)
447
+ new_tensors = [
448
+ g.create_node(user.op, user.target, args=(arg,), kwargs=user.kwargs)
449
+ for arg in tensors
450
+ ]
451
+ new_cat = g.create_node(
452
+ "call_function", torch.cat, args=(new_tensors, dim)
453
+ )
454
+ user.replace_all_uses_with(cat_or_view)
455
+ node.replace_all_uses_with(new_cat)
456
+ g.erase_node(user)
457
+ g.erase_node(node)
458
+ g.lint()
459
+ module.recompile()
460
+ return module
461
+
462
+
463
+ def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
464
+ for node in module.graph.nodes:
465
+ if (
466
+ node.op == "call_method"
467
+ and node.target == "permute"
468
+ and check_permute(node)
469
+ ):
470
+ if len(node.args) > 0:
471
+ input_node = node.args[0]
472
+ else:
473
+ input_node = node.kwargs["input"]
474
+ if (
475
+ input_node.op == "call_function"
476
+ and input_node.target == torch.nn.functional.linear
477
+ ):
478
+ normalized = NormalizedLinearNode(input_node)
479
+ input = normalized.get_input()
480
+ weight = normalized.get_weight()
481
+ bias = normalized.get_bias()
482
+ with module.graph.inserting_before(node):
483
+ fused_node = module.graph.call_function(
484
+ linear_transpose, args=(input, weight, bias)
485
+ )
486
+ node.replace_all_uses_with(fused_node)
487
+ module.graph.erase_node(node)
488
+ if len(input_node.users) == 0:
489
+ module.graph.erase_node(input_node)
490
+
491
+ module.graph.lint()
492
+ module.recompile()
493
+ return module
494
+
495
+
496
+ # Y1 = X * W^T + bias
497
+ # Y2 = Y1.permute(0, 2, 1)
498
+ # ---->
499
+ # Y2 = (W * X^T + bias.unsqueeze(-1))^T
500
+ def linear_transpose(
501
+ input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
502
+ ) -> torch.Tensor:
503
+ if bias is None:
504
+ return torch.matmul(weight, input.transpose(-1, -2))
505
+ return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1)
506
+
507
+
508
+ def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
509
+ for node in module.graph.nodes:
510
+ if node.op == "call_function" and node.target == torch.nn.functional.linear:
511
+ if len(node.args) > 0:
512
+ input_node = node.args[0]
513
+ else:
514
+ input_node = node.kwargs["input"]
515
+ if (
516
+ input_node.op == "call_method"
517
+ and input_node.target == "permute"
518
+ and check_permute(input_node)
519
+ ):
520
+ normalized = NormalizedLinearNode(node)
521
+ if len(input_node.args) > 0:
522
+ input = input_node.args[0]
523
+ else:
524
+ input = input_node.kwargs["input"]
525
+ weight = normalized.get_weight()
526
+ bias = normalized.get_bias()
527
+ with module.graph.inserting_before(node):
528
+ fused_node = module.graph.call_function(
529
+ transpose_linear, args=(input, weight, bias)
530
+ )
531
+ node.replace_all_uses_with(fused_node)
532
+ module.graph.erase_node(node)
533
+ if len(input_node.users) == 0:
534
+ module.graph.erase_node(input_node)
535
+
536
+ module.graph.lint()
537
+ module.recompile()
538
+ return module
539
+
540
+
541
+ def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
542
+ for node in module.graph.nodes:
543
+ if node.op == "call_function" and (
544
+ node.target == torch.bmm or node.target == torch.matmul
545
+ ):
546
+ normalized = NormalizedMatmulNode(node)
547
+ input_A_node = normalized.get_input()
548
+ input_B_node = normalized.get_other()
549
+ input_A = input_A_node
550
+ input_B = input_B_node
551
+ Atrans = Btrans = False
552
+ if (
553
+ input_A_node.op == "call_method"
554
+ and input_A_node.target == "permute"
555
+ and check_permute(input_A_node)
556
+ ):
557
+ Atrans = True
558
+ if len(input_A_node.args) > 0:
559
+ input_A = input_A_node.args[0] # type: ignore[assignment]
560
+ else:
561
+ input_A = input_A_node.kwargs["input"] # type: ignore[assignment]
562
+
563
+ if (
564
+ input_B_node.op == "call_method"
565
+ and input_B_node.target == "permute"
566
+ and check_permute(input_B_node)
567
+ ):
568
+ Btrans = True
569
+ if len(input_B_node.args) > 0:
570
+ input_B = input_B_node.args[0] # type: ignore[assignment]
571
+ else:
572
+ input_B = input_B_node.kwargs["input"] # type: ignore[assignment]
573
+
574
+ if Atrans or Btrans:
575
+ with module.graph.inserting_before(node):
576
+ fused_node = module.graph.call_function(
577
+ transpose_matmul,
578
+ args=(input_A, input_B, Atrans, Btrans),
579
+ )
580
+ node.replace_all_uses_with(fused_node)
581
+ module.graph.erase_node(node)
582
+ if Atrans and len(input_A_node.users) == 0:
583
+ module.graph.erase_node(input_A_node)
584
+ if Btrans and len(input_B_node.users) == 0:
585
+ module.graph.erase_node(input_B_node)
586
+
587
+ module.graph.lint()
588
+ module.recompile()
589
+ return module
590
+
591
+
592
+ # X1 = X.permute(0, 2, 1)
593
+ # Y1 = X1 * W1^T + bias1
594
+ # ---->
595
+ # Y2 = X1.transpose(-1, -2) * W1^T + bias1
596
+ def transpose_linear(
597
+ input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
598
+ ) -> torch.Tensor:
599
+ if bias is None:
600
+ return torch.matmul(input.transpose(-1, -2), weight.t())
601
+ return torch.matmul(input.transpose(-1, -2), weight.t()) + bias
602
+
603
+
604
+ def transpose_matmul(
605
+ A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool
606
+ ) -> torch.Tensor:
607
+ if Atrans:
608
+ A = A.transpose(-1, -2)
609
+ if Btrans:
610
+ B = B.transpose(-1, -2)
611
+ return torch.matmul(A, B)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/debug.cpython-311.pyc ADDED
Binary file (1.72 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-311.pyc ADDED
Binary file (7.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-311.pyc ADDED
Binary file (54.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-311.pyc ADDED
Binary file (430 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/match.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .core import unify, reify # type: ignore[attr-defined]
2
+ from .variable import isvar
3
+ from .utils import _toposort, freeze
4
+ from .unification_tools import groupby, first # type: ignore[import]
5
+
6
+
7
+ class Dispatcher:
8
+ def __init__(self, name):
9
+ self.name = name
10
+ self.funcs = {}
11
+ self.ordering = []
12
+
13
+ def add(self, signature, func):
14
+ self.funcs[freeze(signature)] = func
15
+ self.ordering = ordering(self.funcs)
16
+
17
+ def __call__(self, *args, **kwargs):
18
+ func, s = self.resolve(args)
19
+ return func(*args, **kwargs)
20
+
21
+ def resolve(self, args):
22
+ n = len(args)
23
+ for signature in self.ordering:
24
+ if len(signature) != n:
25
+ continue
26
+ s = unify(freeze(args), signature)
27
+ if s is not False:
28
+ result = self.funcs[signature]
29
+ return result, s
30
+ raise NotImplementedError("No match found. \nKnown matches: "
31
+ + str(self.ordering) + "\nInput: " + str(args))
32
+
33
+ def register(self, *signature):
34
+ def _(func):
35
+ self.add(signature, func)
36
+ return self
37
+ return _
38
+
39
+
40
+ class VarDispatcher(Dispatcher):
41
+ """ A dispatcher that calls functions with variable names
42
+ >>> # xdoctest: +SKIP
43
+ >>> d = VarDispatcher('d')
44
+ >>> x = var('x')
45
+ >>> @d.register('inc', x)
46
+ ... def f(x):
47
+ ... return x + 1
48
+ >>> @d.register('double', x)
49
+ ... def f(x):
50
+ ... return x * 2
51
+ >>> d('inc', 10)
52
+ 11
53
+ >>> d('double', 10)
54
+ 20
55
+ """
56
+ def __call__(self, *args, **kwargs):
57
+ func, s = self.resolve(args)
58
+ d = {k.token: v for k, v in s.items()}
59
+ return func(**d)
60
+
61
+
62
+ global_namespace = {} # type: ignore[var-annotated]
63
+
64
+
65
+ def match(*signature, **kwargs):
66
+ namespace = kwargs.get('namespace', global_namespace)
67
+ dispatcher = kwargs.get('Dispatcher', Dispatcher)
68
+
69
+ def _(func):
70
+ name = func.__name__
71
+
72
+ if name not in namespace:
73
+ namespace[name] = dispatcher(name)
74
+ d = namespace[name]
75
+
76
+ d.add(signature, func)
77
+
78
+ return d
79
+ return _
80
+
81
+
82
+ def supercedes(a, b):
83
+ """ ``a`` is a more specific match than ``b`` """
84
+ if isvar(b) and not isvar(a):
85
+ return True
86
+ s = unify(a, b)
87
+ if s is False:
88
+ return False
89
+ s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)}
90
+ if reify(a, s) == a:
91
+ return True
92
+ if reify(b, s) == b:
93
+ return False
94
+
95
+
96
+ # Taken from multipledispatch
97
+ def edge(a, b, tie_breaker=hash):
98
+ """ A should be checked before B
99
+ Tie broken by tie_breaker, defaults to ``hash``
100
+ """
101
+ if supercedes(a, b):
102
+ if supercedes(b, a):
103
+ return tie_breaker(a) > tie_breaker(b)
104
+ else:
105
+ return True
106
+ return False
107
+
108
+
109
+ # Taken from multipledispatch
110
+ def ordering(signatures):
111
+ """ A sane ordering of signatures to check, first to last
112
+ Topological sort of edges as given by ``edge`` and ``supercedes``
113
+ """
114
+ signatures = list(map(tuple, signatures))
115
+ edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
116
+ edges = groupby(first, edges)
117
+ for s in signatures:
118
+ if s not in edges:
119
+ edges[s] = []
120
+ edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment]
121
+ return _toposort(edges)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-311.pyc ADDED
Binary file (4.77 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import typename
2
+
3
+ __all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"]
4
+
5
+ class VariadicSignatureType(type):
6
+ # checking if subclass is a subclass of self
7
+ def __subclasscheck__(cls, subclass):
8
+ other_type = (subclass.variadic_type if isvariadic(subclass)
9
+ else (subclass,))
10
+ return subclass is cls or all(
11
+ issubclass(other, cls.variadic_type) for other in other_type # type: ignore[attr-defined]
12
+ )
13
+
14
+ def __eq__(cls, other):
15
+ """
16
+ Return True if other has the same variadic type
17
+ Parameters
18
+ ----------
19
+ other : object (type)
20
+ The object (type) to check
21
+ Returns
22
+ -------
23
+ bool
24
+ Whether or not `other` is equal to `self`
25
+ """
26
+ return (isvariadic(other) and
27
+ set(cls.variadic_type) == set(other.variadic_type)) # type: ignore[attr-defined]
28
+
29
+ def __hash__(cls):
30
+ return hash((type(cls), frozenset(cls.variadic_type))) # type: ignore[attr-defined]
31
+
32
+
33
+ def isvariadic(obj):
34
+ """Check whether the type `obj` is variadic.
35
+ Parameters
36
+ ----------
37
+ obj : type
38
+ The type to check
39
+ Returns
40
+ -------
41
+ bool
42
+ Whether or not `obj` is variadic
43
+ Examples
44
+ --------
45
+ >>> # xdoctest: +SKIP
46
+ >>> isvariadic(int)
47
+ False
48
+ >>> isvariadic(Variadic[int])
49
+ True
50
+ """
51
+ return isinstance(obj, VariadicSignatureType)
52
+
53
+
54
+ class VariadicSignatureMeta(type):
55
+ """A metaclass that overrides ``__getitem__`` on the class. This is used to
56
+ generate a new type for Variadic signatures. See the Variadic class for
57
+ examples of how this behaves.
58
+ """
59
+ def __getitem__(cls, variadic_type):
60
+ if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)):
61
+ raise ValueError("Variadic types must be type or tuple of types"
62
+ " (Variadic[int] or Variadic[(int, float)]")
63
+
64
+ if not isinstance(variadic_type, tuple):
65
+ variadic_type = variadic_type,
66
+ return VariadicSignatureType(
67
+ f'Variadic[{typename(variadic_type)}]',
68
+ (),
69
+ dict(variadic_type=variadic_type, __slots__=())
70
+ )
71
+
72
+
73
+ class Variadic(metaclass=VariadicSignatureMeta):
74
+ """A class whose getitem method can be used to generate a new type
75
+ representing a specific variadic signature.
76
+ Examples
77
+ --------
78
+ >>> # xdoctest: +SKIP
79
+ >>> Variadic[int] # any number of int arguments
80
+ <class 'multipledispatch.variadic.Variadic[int]'>
81
+ >>> Variadic[(int, str)] # any number of one of int or str arguments
82
+ <class 'multipledispatch.variadic.Variadic[(int, str)]'>
83
+ >>> issubclass(int, Variadic[int])
84
+ True
85
+ >>> issubclass(int, Variadic[(int, str)])
86
+ True
87
+ >>> issubclass(str, Variadic[(int, str)])
88
+ True
89
+ >>> issubclass(float, Variadic[(int, str)])
90
+ False
91
+ """
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/unification_tools.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import operator
3
+ from functools import reduce
4
+ from collections.abc import Mapping
5
+
6
+ __all__ = ('merge', 'merge_with', 'valmap', 'keymap', 'itemmap',
7
+ 'valfilter', 'keyfilter', 'itemfilter',
8
+ 'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in')
9
+
10
+
11
+ def _get_factory(f, kwargs):
12
+ factory = kwargs.pop('factory', dict)
13
+ if kwargs:
14
+ raise TypeError(f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'")
15
+ return factory
16
+
17
+
18
+ def merge(*dicts, **kwargs):
19
+ """ Merge a collection of dictionaries
20
+
21
+ >>> merge({1: 'one'}, {2: 'two'})
22
+ {1: 'one', 2: 'two'}
23
+
24
+ Later dictionaries have precedence
25
+
26
+ >>> merge({1: 2, 3: 4}, {3: 3, 4: 4})
27
+ {1: 2, 3: 3, 4: 4}
28
+
29
+ See Also:
30
+ merge_with
31
+ """
32
+ if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
33
+ dicts = dicts[0]
34
+ factory = _get_factory(merge, kwargs)
35
+
36
+ rv = factory()
37
+ for d in dicts:
38
+ rv.update(d)
39
+ return rv
40
+
41
+
42
+ def merge_with(func, *dicts, **kwargs):
43
+ """ Merge dictionaries and apply function to combined values
44
+
45
+ A key may occur in more than one dict, and all values mapped from the key
46
+ will be passed to the function as a list, such as func([val1, val2, ...]).
47
+
48
+ >>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20})
49
+ {1: 11, 2: 22}
50
+
51
+ >>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30}) # doctest: +SKIP
52
+ {1: 1, 2: 2, 3: 30}
53
+
54
+ See Also:
55
+ merge
56
+ """
57
+ if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
58
+ dicts = dicts[0]
59
+ factory = _get_factory(merge_with, kwargs)
60
+
61
+ result = factory()
62
+ for d in dicts:
63
+ for k, v in d.items():
64
+ if k not in result:
65
+ result[k] = [v]
66
+ else:
67
+ result[k].append(v)
68
+ return valmap(func, result, factory)
69
+
70
+
71
+ def valmap(func, d, factory=dict):
72
+ """ Apply function to values of dictionary
73
+
74
+ >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
75
+ >>> valmap(sum, bills) # doctest: +SKIP
76
+ {'Alice': 65, 'Bob': 45}
77
+
78
+ See Also:
79
+ keymap
80
+ itemmap
81
+ """
82
+ rv = factory()
83
+ rv.update(zip(d.keys(), map(func, d.values())))
84
+ return rv
85
+
86
+
87
+ def keymap(func, d, factory=dict):
88
+ """ Apply function to keys of dictionary
89
+
90
+ >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
91
+ >>> keymap(str.lower, bills) # doctest: +SKIP
92
+ {'alice': [20, 15, 30], 'bob': [10, 35]}
93
+
94
+ See Also:
95
+ valmap
96
+ itemmap
97
+ """
98
+ rv = factory()
99
+ rv.update(zip(map(func, d.keys()), d.values()))
100
+ return rv
101
+
102
+
103
+ def itemmap(func, d, factory=dict):
104
+ """ Apply function to items of dictionary
105
+
106
+ >>> accountids = {"Alice": 10, "Bob": 20}
107
+ >>> itemmap(reversed, accountids) # doctest: +SKIP
108
+ {10: "Alice", 20: "Bob"}
109
+
110
+ See Also:
111
+ keymap
112
+ valmap
113
+ """
114
+ rv = factory()
115
+ rv.update(map(func, d.items()))
116
+ return rv
117
+
118
+
119
+ def valfilter(predicate, d, factory=dict):
120
+ """ Filter items in dictionary by value
121
+
122
+ >>> iseven = lambda x: x % 2 == 0
123
+ >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
124
+ >>> valfilter(iseven, d)
125
+ {1: 2, 3: 4}
126
+
127
+ See Also:
128
+ keyfilter
129
+ itemfilter
130
+ valmap
131
+ """
132
+ rv = factory()
133
+ for k, v in d.items():
134
+ if predicate(v):
135
+ rv[k] = v
136
+ return rv
137
+
138
+
139
+ def keyfilter(predicate, d, factory=dict):
140
+ """ Filter items in dictionary by key
141
+
142
+ >>> iseven = lambda x: x % 2 == 0
143
+ >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
144
+ >>> keyfilter(iseven, d)
145
+ {2: 3, 4: 5}
146
+
147
+ See Also:
148
+ valfilter
149
+ itemfilter
150
+ keymap
151
+ """
152
+ rv = factory()
153
+ for k, v in d.items():
154
+ if predicate(k):
155
+ rv[k] = v
156
+ return rv
157
+
158
+
159
+ def itemfilter(predicate, d, factory=dict):
160
+ """ Filter items in dictionary by item
161
+
162
+ >>> def isvalid(item):
163
+ ... k, v = item
164
+ ... return k % 2 == 0 and v < 4
165
+
166
+ >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
167
+ >>> itemfilter(isvalid, d)
168
+ {2: 3}
169
+
170
+ See Also:
171
+ keyfilter
172
+ valfilter
173
+ itemmap
174
+ """
175
+ rv = factory()
176
+ for item in d.items():
177
+ if predicate(item):
178
+ k, v = item
179
+ rv[k] = v
180
+ return rv
181
+
182
+
183
+ def assoc(d, key, value, factory=dict):
184
+ """ Return a new dict with new key value pair
185
+
186
+ New dict has d[key] set to value. Does not modify the initial dictionary.
187
+
188
+ >>> assoc({'x': 1}, 'x', 2)
189
+ {'x': 2}
190
+ >>> assoc({'x': 1}, 'y', 3) # doctest: +SKIP
191
+ {'x': 1, 'y': 3}
192
+ """
193
+ d2 = factory()
194
+ d2.update(d)
195
+ d2[key] = value
196
+ return d2
197
+
198
+
199
+ def dissoc(d, *keys, **kwargs):
200
+ """ Return a new dict with the given key(s) removed.
201
+
202
+ New dict has d[key] deleted for each supplied key.
203
+ Does not modify the initial dictionary.
204
+
205
+ >>> dissoc({'x': 1, 'y': 2}, 'y')
206
+ {'x': 1}
207
+ >>> dissoc({'x': 1, 'y': 2}, 'y', 'x')
208
+ {}
209
+ >>> dissoc({'x': 1}, 'y') # Ignores missing keys
210
+ {'x': 1}
211
+ """
212
+ factory = _get_factory(dissoc, kwargs)
213
+ d2 = factory()
214
+
215
+ if len(keys) < len(d) * .6:
216
+ d2.update(d)
217
+ for key in keys:
218
+ if key in d2:
219
+ del d2[key]
220
+ else:
221
+ remaining = set(d)
222
+ remaining.difference_update(keys)
223
+ for k in remaining:
224
+ d2[k] = d[k]
225
+ return d2
226
+
227
+
228
+ def assoc_in(d, keys, value, factory=dict):
229
+ """ Return a new dict with new, potentially nested, key value pair
230
+
231
+ >>> purchase = {'name': 'Alice',
232
+ ... 'order': {'items': ['Apple', 'Orange'],
233
+ ... 'costs': [0.50, 1.25]},
234
+ ... 'credit card': '5555-1234-1234-1234'}
235
+ >>> assoc_in(purchase, ['order', 'costs'], [0.25, 1.00]) # doctest: +SKIP
236
+ {'credit card': '5555-1234-1234-1234',
237
+ 'name': 'Alice',
238
+ 'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}}
239
+ """
240
+ return update_in(d, keys, lambda x: value, value, factory)
241
+
242
+
243
+ def update_in(d, keys, func, default=None, factory=dict):
244
+ """ Update value in a (potentially) nested dictionary
245
+
246
+ inputs:
247
+ d - dictionary on which to operate
248
+ keys - list or tuple giving the location of the value to be changed in d
249
+ func - function to operate on that value
250
+
251
+ If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the
252
+ original dictionary with v replaced by func(v), but does not mutate the
253
+ original dictionary.
254
+
255
+ If k0 is not a key in d, update_in creates nested dictionaries to the depth
256
+ specified by the keys, with the innermost value set to func(default).
257
+
258
+ >>> inc = lambda x: x + 1
259
+ >>> update_in({'a': 0}, ['a'], inc)
260
+ {'a': 1}
261
+
262
+ >>> transaction = {'name': 'Alice',
263
+ ... 'purchase': {'items': ['Apple', 'Orange'],
264
+ ... 'costs': [0.50, 1.25]},
265
+ ... 'credit card': '5555-1234-1234-1234'}
266
+ >>> update_in(transaction, ['purchase', 'costs'], sum) # doctest: +SKIP
267
+ {'credit card': '5555-1234-1234-1234',
268
+ 'name': 'Alice',
269
+ 'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}}
270
+
271
+ >>> # updating a value when k0 is not in d
272
+ >>> update_in({}, [1, 2, 3], str, default="bar")
273
+ {1: {2: {3: 'bar'}}}
274
+ >>> update_in({1: 'foo'}, [2, 3, 4], inc, 0)
275
+ {1: 'foo', 2: {3: {4: 1}}}
276
+ """
277
+ ks = iter(keys)
278
+ k = next(ks)
279
+
280
+ rv = inner = factory()
281
+ rv.update(d)
282
+
283
+ for key in ks:
284
+ if k in d:
285
+ d = d[k]
286
+ dtemp = factory()
287
+ dtemp.update(d)
288
+ else:
289
+ d = dtemp = factory()
290
+
291
+ inner[k] = inner = dtemp
292
+ k = key
293
+
294
+ if k in d:
295
+ inner[k] = func(d[k])
296
+ else:
297
+ inner[k] = func(default)
298
+ return rv
299
+
300
+
301
+ def get_in(keys, coll, default=None, no_default=False):
302
+ """ Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys.
303
+
304
+ If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless
305
+ ``no_default`` is specified, then it raises KeyError or IndexError.
306
+
307
+ ``get_in`` is a generalization of ``operator.getitem`` for nested data
308
+ structures such as dictionaries and lists.
309
+
310
+ >>> transaction = {'name': 'Alice',
311
+ ... 'purchase': {'items': ['Apple', 'Orange'],
312
+ ... 'costs': [0.50, 1.25]},
313
+ ... 'credit card': '5555-1234-1234-1234'}
314
+ >>> get_in(['purchase', 'items', 0], transaction)
315
+ 'Apple'
316
+ >>> get_in(['name'], transaction)
317
+ 'Alice'
318
+ >>> get_in(['purchase', 'total'], transaction)
319
+ >>> get_in(['purchase', 'items', 'apple'], transaction)
320
+ >>> get_in(['purchase', 'items', 10], transaction)
321
+ >>> get_in(['purchase', 'total'], transaction, 0)
322
+ 0
323
+ >>> get_in(['y'], {}, no_default=True)
324
+ Traceback (most recent call last):
325
+ ...
326
+ KeyError: 'y'
327
+
328
+ See Also:
329
+ itertoolz.get
330
+ operator.getitem
331
+ """
332
+ try:
333
+ return reduce(operator.getitem, keys, coll)
334
+ except (KeyError, IndexError, TypeError):
335
+ if no_default:
336
+ raise
337
+ return default
338
+
339
+
340
+ def getter(index):
341
+ if isinstance(index, list):
342
+ if len(index) == 1:
343
+ index = index[0]
344
+ return lambda x: (x[index],)
345
+ elif index:
346
+ return operator.itemgetter(*index)
347
+ else:
348
+ return lambda x: ()
349
+ else:
350
+ return operator.itemgetter(index)
351
+
352
+
353
+ def groupby(key, seq):
354
+ """ Group a collection by a key function
355
+
356
+ >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
357
+ >>> groupby(len, names) # doctest: +SKIP
358
+ {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
359
+
360
+ >>> iseven = lambda x: x % 2 == 0
361
+ >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
362
+ {False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
363
+
364
+ Non-callable keys imply grouping on a member.
365
+
366
+ >>> groupby('gender', [{'name': 'Alice', 'gender': 'F'},
367
+ ... {'name': 'Bob', 'gender': 'M'},
368
+ ... {'name': 'Charlie', 'gender': 'M'}]) # doctest:+SKIP
369
+ {'F': [{'gender': 'F', 'name': 'Alice'}],
370
+ 'M': [{'gender': 'M', 'name': 'Bob'},
371
+ {'gender': 'M', 'name': 'Charlie'}]}
372
+
373
+ Not to be confused with ``itertools.groupby``
374
+
375
+ See Also:
376
+ countby
377
+ """
378
+ if not callable(key):
379
+ key = getter(key)
380
+ d = collections.defaultdict(lambda: [].append) # type: ignore[var-annotated]
381
+ for item in seq:
382
+ d[key(item)](item)
383
+ rv = {}
384
+ for k, v in d.items():
385
+ rv[k] = v.__self__ # type: ignore[var-annotated, attr-defined]
386
+ return rv
387
+
388
+
389
+ def first(seq):
390
+ """ The first element in a sequence
391
+
392
+ >>> first('ABC')
393
+ 'A'
394
+ """
395
+ return next(iter(seq))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"]
2
+ def hashable(x):
3
+ try:
4
+ hash(x)
5
+ return True
6
+ except TypeError:
7
+ return False
8
+
9
+
10
+ def transitive_get(key, d):
11
+ """ Transitive dict.get
12
+ >>> d = {1: 2, 2: 3, 3: 4}
13
+ >>> d.get(1)
14
+ 2
15
+ >>> transitive_get(1, d)
16
+ 4
17
+ """
18
+ while hashable(key) and key in d:
19
+ key = d[key]
20
+ return key
21
+
22
+
23
+ def raises(err, lamda):
24
+ try:
25
+ lamda()
26
+ return False
27
+ except err:
28
+ return True
29
+
30
+
31
+ # Taken from theano/theano/gof/sched.py
32
+ # Avoids licensing issues because this was written by Matthew Rocklin
33
+ def _toposort(edges):
34
+ """ Topological sort algorithm by Kahn [1] - O(nodes + vertices)
35
+ inputs:
36
+ edges - a dict of the form {a: {b, c}} where b and c depend on a
37
+ outputs:
38
+ L - an ordered list of nodes that satisfy the dependencies of edges
39
+ >>> # xdoctest: +SKIP
40
+ >>> _toposort({1: (2, 3), 2: (3, )})
41
+ [1, 2, 3]
42
+ Closely follows the wikipedia page [2]
43
+ [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
44
+ Communications of the ACM
45
+ [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
46
+ """
47
+ incoming_edges = reverse_dict(edges)
48
+ incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
49
+ S = ({v for v in edges if v not in incoming_edges})
50
+ L = []
51
+
52
+ while S:
53
+ n = S.pop()
54
+ L.append(n)
55
+ for m in edges.get(n, ()):
56
+ assert n in incoming_edges[m]
57
+ incoming_edges[m].remove(n)
58
+ if not incoming_edges[m]:
59
+ S.add(m)
60
+ if any(incoming_edges.get(v, None) for v in edges):
61
+ raise ValueError("Input has cycles")
62
+ return L
63
+
64
+
65
+ def reverse_dict(d):
66
+ """Reverses direction of dependence dict
67
+ >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
68
+ >>> reverse_dict(d) # doctest: +SKIP
69
+ {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
70
+ :note: dict order are not deterministic. As we iterate on the
71
+ input dict, it make the output of this function depend on the
72
+ dict order. So this function output order should be considered
73
+ as undeterministic.
74
+ """
75
+ result = {} # type: ignore[var-annotated]
76
+ for key in d:
77
+ for val in d[key]:
78
+ result[val] = result.get(val, tuple()) + (key, )
79
+ return result
80
+
81
+
82
+ def xfail(func):
83
+ try:
84
+ func()
85
+ raise Exception("XFailed test passed") # pragma:nocover
86
+ except Exception:
87
+ pass
88
+
89
+
90
+ def freeze(d):
91
+ """ Freeze container to hashable form
92
+ >>> freeze(1)
93
+ 1
94
+ >>> freeze([1, 2])
95
+ (1, 2)
96
+ >>> freeze({1: 2}) # doctest: +SKIP
97
+ frozenset([(1, 2)])
98
+ """
99
+ if isinstance(d, dict):
100
+ return frozenset(map(freeze, d.items()))
101
+ if isinstance(d, set):
102
+ return frozenset(map(freeze, d))
103
+ if isinstance(d, (tuple, list)):
104
+ return tuple(map(freeze, d))
105
+ return d
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/variable.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from .utils import hashable
3
+ from .dispatch import dispatch
4
+
5
+ _global_logic_variables = set() # type: ignore[var-annotated]
6
+ _glv = _global_logic_variables
7
+
8
+
9
+ class Var:
10
+ """ Logic Variable """
11
+
12
+ _id = 1
13
+
14
+ def __new__(cls, *token):
15
+ if len(token) == 0:
16
+ token = f"_{Var._id}" # type: ignore[assignment]
17
+ Var._id += 1
18
+ elif len(token) == 1:
19
+ token = token[0]
20
+
21
+ obj = object.__new__(cls)
22
+ obj.token = token # type: ignore[attr-defined]
23
+ return obj
24
+
25
+ def __str__(self):
26
+ return "~" + str(self.token) # type: ignore[attr-defined]
27
+ __repr__ = __str__
28
+
29
+ def __eq__(self, other):
30
+ return type(self) == type(other) and self.token == other.token # type: ignore[attr-defined]
31
+
32
+ def __hash__(self):
33
+ return hash((type(self), self.token)) # type: ignore[attr-defined]
34
+
35
+
36
+ def var():
37
+ return lambda *args: Var(*args)
38
+
39
+
40
+ def vars():
41
+ return lambda n: [var() for i in range(n)]
42
+
43
+
44
+ @dispatch(Var)
45
+ def isvar(v):
46
+ return True
47
+
48
+ isvar
49
+
50
+
51
+ @dispatch(object) # type: ignore[no-redef]
52
+ def isvar(o):
53
+ return not not _glv and hashable(o) and o in _glv
54
+
55
+
56
+ @contextmanager
57
+ def variables(*variables):
58
+ """
59
+ Context manager for logic variables
60
+
61
+ Example:
62
+ >>> # xdoctest: +SKIP("undefined vars")
63
+ >>> from __future__ import with_statement
64
+ >>> with variables(1):
65
+ ... print(isvar(1))
66
+ True
67
+ >>> print(isvar(1))
68
+ False
69
+ >>> # Normal approach
70
+ >>> from unification import unify
71
+ >>> x = var('x')
72
+ >>> unify(x, 1)
73
+ {~x: 1}
74
+ >>> # Context Manager approach
75
+ >>> with variables('x'):
76
+ ... print(unify('x', 1))
77
+ {'x': 1}
78
+ """
79
+ old_global_logic_variables = _global_logic_variables.copy()
80
+ _global_logic_variables.update(set(variables))
81
+ try:
82
+ yield
83
+ finally:
84
+ _global_logic_variables.clear()
85
+ _global_logic_variables.update(old_global_logic_variables)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSGeneratorImpl.h ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2022 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <ATen/core/Generator.h>
6
+ #include <ATen/core/PhiloxRNGEngine.h>
7
+ #include <c10/core/GeneratorImpl.h>
8
+ #include <c10/util/Optional.h>
9
+
10
+ namespace at {
11
+ namespace mps::detail {
12
+
13
+ static const uint32_t PHILOX_STATE_N = 7;
14
+ struct rng_data_pod {
15
+ std::array<uint32_t, PHILOX_STATE_N> state{1};
16
+ uint64_t seed = default_rng_seed_val;
17
+ };
18
+
19
+ TORCH_API const Generator& getDefaultMPSGenerator();
20
+ TORCH_API Generator createMPSGenerator(uint64_t seed_val = default_rng_seed_val);
21
+
22
+ } // namespace mps::detail
23
+
24
+ struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl {
25
+ // Constructors
26
+ MPSGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
27
+ ~MPSGeneratorImpl() override = default;
28
+
29
+ // MPSGeneratorImpl methods
30
+ std::shared_ptr<MPSGeneratorImpl> clone() const;
31
+ void set_current_seed(uint64_t seed) override;
32
+ void set_offset(uint64_t offset) override;
33
+ uint64_t get_offset() const override;
34
+ uint64_t current_seed() const override;
35
+ uint64_t seed() override;
36
+ void set_state(const c10::TensorImpl& new_state) override;
37
+ c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
38
+ void update_philox_counters();
39
+
40
+ void set_engine(at::Philox4_32 engine) { engine_ = engine; };
41
+ at::Philox4_32 engine() { return engine_; };
42
+ uint32_t* state_data() { return data_.state.data(); }
43
+ static DeviceType device_type() { return DeviceType::MPS; };
44
+
45
+ private:
46
+ mps::detail::rng_data_pod data_;
47
+ at::Philox4_32 engine_;
48
+
49
+ MPSGeneratorImpl* clone_impl() const override;
50
+ };
51
+
52
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSHooks.h ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2022 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <ATen/detail/MPSHooksInterface.h>
6
+ #include <ATen/Generator.h>
7
+ #include <ATen/mps/MPSEvent.h>
8
+ #include <c10/util/Optional.h>
9
+
10
+ namespace at::mps {
11
+
12
+ // The real implementation of MPSHooksInterface
13
+ struct MPSHooks : public at::MPSHooksInterface {
14
+ MPSHooks(at::MPSHooksArgs) {}
15
+ void initMPS() const override;
16
+
17
+ // MPSDevice interface
18
+ bool hasMPS() const override;
19
+ bool isOnMacOSorNewer(unsigned major, unsigned minor) const override;
20
+
21
+ // MPSGeneratorImpl interface
22
+ const Generator& getDefaultMPSGenerator() const override;
23
+
24
+ // MPSStream interface
25
+ void deviceSynchronize() const override;
26
+ void commitStream() const override;
27
+ void* getCommandBuffer() const override;
28
+ void* getDispatchQueue() const override;
29
+
30
+ // MPSAllocator interface
31
+ Allocator* getMPSDeviceAllocator() const override;
32
+ void emptyCache() const override;
33
+ size_t getCurrentAllocatedMemory() const override;
34
+ size_t getDriverAllocatedMemory() const override;
35
+ void setMemoryFraction(double ratio) const override;
36
+
37
+ // MPSProfiler interface
38
+ void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const override;
39
+ void profilerStopTrace() const override;
40
+
41
+ // MPSEvent interface
42
+ uint32_t acquireEvent(bool enable_timing) const override;
43
+ void releaseEvent(uint32_t event_id) const override;
44
+ void recordEvent(uint32_t event_id) const override;
45
+ void waitForEvent(uint32_t event_id) const override;
46
+ void synchronizeEvent(uint32_t event_id) const override;
47
+ bool queryEvent(uint32_t event_id) const override;
48
+ double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const override;
49
+
50
+ // Compatibility with Accelerator API
51
+ bool hasPrimaryContext(DeviceIndex device_index) const override {
52
+ // When MPS is available, it is always in use for the one device.
53
+ return true;
54
+ }
55
+ };
56
+
57
+ } // namespace at::mps
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSStream.h ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2022 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <cstdint>
6
+ #include <utility>
7
+
8
+ #include <c10/core/DeviceGuard.h>
9
+ #include <c10/util/Exception.h>
10
+ #include <c10/core/Stream.h>
11
+ #include <ATen/mps/MPSDevice.h>
12
+
13
+ #ifdef __OBJC__
14
+ #include <Foundation/Foundation.h>
15
+ #include <Metal/Metal.h>
16
+ #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
17
+ #include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
18
+ typedef id<MTLCommandQueue> MTLCommandQueue_t;
19
+ typedef id<MTLCommandBuffer> MTLCommandBuffer_t;
20
+ typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t;
21
+ typedef id<MTLSharedEvent> MTLSharedEvent_t;
22
+ typedef id<MTLDevice> MTLDevice_t;
23
+ #else
24
+ typedef void* MTLCommandQueue_t;
25
+ typedef void* MTLCommandQueue;
26
+ typedef void* MTLCommandBuffer_t;
27
+ typedef void* MTLCommandBuffer;
28
+ typedef void* MTLComputeCommandEncoder_t;
29
+ typedef void* MTLSharedEvent_t;
30
+ typedef void* dispatch_queue_t;
31
+ typedef void* MTLDevice_t;
32
+ #define nil NULL;
33
+ #endif
34
+
35
+
36
+ namespace at::mps {
37
+
38
+ //-----------------------------------------------------------------
39
+ // MPSStream
40
+ //-----------------------------------------------------------------
41
+
42
+ enum class SyncType {
43
+ NONE, // no commit to command buffer
44
+ COMMIT, // commit and flush the command buffer
45
+ COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish
46
+ COMMIT_AND_CONTINUE,// commit and continue with a new underlying command buffer
47
+ COMMIT_ADAPTIVE, // commit adaptively based on available memory
48
+ };
49
+
50
+ class TORCH_API MPSStream
51
+ {
52
+ public:
53
+ enum Unchecked { UNCHECKED };
54
+
55
+ /// Construct a MPSStream from a Stream. This construction is checked,
56
+ /// and will raise an error if the Stream is not, in fact, a MPS stream.
57
+ explicit MPSStream(Stream stream);
58
+
59
+ ~MPSStream();
60
+ MTLCommandQueue_t commandQueue() const { return _commandQueue; };
61
+ dispatch_queue_t queue() const { return _serialQueue; }
62
+
63
+ MPSCommandBuffer* commandBuffer();
64
+ MTLComputeCommandEncoder_t commandEncoder();
65
+ void endKernelCoalescing();
66
+ void synchronize(SyncType syncType);
67
+ void fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
68
+ void copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
69
+ size_t length, size_t srcOffset, size_t dstOffset,
70
+ uint64_t profileId, SyncType syncType = SyncType::NONE);
71
+ void copy_and_sync(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
72
+ size_t length, size_t srcOffset, size_t dstOffset,
73
+ bool non_blocking, uint64_t profileId);
74
+ void executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType = SyncType::NONE);
75
+ void addCompletedHandler(MTLCommandBufferHandler block);
76
+
77
+ /// Get the MPS device index that this stream is associated with.
78
+ c10::DeviceIndex device_index() const { return _stream.device_index(); }
79
+
80
+ MTLCommandQueue_t stream() const { return _commandQueue; };
81
+
82
+ MTLDevice_t device() const { return [_commandQueue device];}
83
+
84
+ /// Explicit conversion to Stream.
85
+ Stream unwrap() const { return _stream; }
86
+
87
+ private:
88
+ Stream _stream;
89
+ MTLCommandQueue_t _commandQueue = nil;
90
+ MPSCommandBuffer* _commandBuffer = nil;
91
+ MPSCommandBuffer* _prevCommandBuffer = nil;
92
+ MTLComputeCommandEncoder_t _commandEncoder = nil;
93
+ MPSGraphExecutionDescriptor *_executionDescriptor = nil;
94
+ MPSGraphCompilationDescriptor *_compilationDescriptor = nil;
95
+ dispatch_queue_t _serialQueue = nullptr;
96
+ // CommitAndContinue is enabled by default
97
+ bool _enableCommitAndContinue = true;
98
+
99
+ // use synchronize() to access any of these commit functions outside MPSStream
100
+ void commit();
101
+ void commitAndWait();
102
+ void commitAndContinue();
103
+ void flush();
104
+ };
105
+
106
+ /**
107
+ * Get the current MPS stream
108
+ */
109
+ TORCH_API MPSStream* getCurrentMPSStream();
110
+
111
+ /**
112
+ * Get the default MPS stream
113
+ */
114
+ TORCH_API MPSStream* getDefaultMPSStream();
115
+
116
+ //-----------------------------------------------------------------
117
+ // MPSStreamImpl
118
+ //-----------------------------------------------------------------
119
+
120
+ class TORCH_API MPSStreamImpl
121
+ {
122
+ public:
123
+ /**
124
+ * Gets single instance of the MPSStream.
125
+ */
126
+ static MPSStream* getInstance();
127
+
128
+ private:
129
+ static MPSStream* _stream;
130
+ MPSStreamImpl();
131
+ };
132
+
133
+ } // namespace at::mps
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_amp_foreach_non_finite_check_and_unscale_ops.h ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API _amp_foreach_non_finite_check_and_unscale_ {
18
+ using schema = void (at::TensorList, at::Tensor &, const at::Tensor &);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_amp_foreach_non_finite_check_and_unscale_")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_amp_foreach_non_finite_check_and_unscale_(Tensor(a!)[] self, Tensor(b!) found_inf, Tensor inv_scale) -> ()")
24
+ static void call(at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale);
25
+ static void redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale);
26
+ };
27
+
28
+ struct TORCH_API _amp_foreach_non_finite_check_and_unscale_out {
29
+ using schema = void (at::TensorList, at::Tensor &, const at::Tensor &, at::TensorList);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_amp_foreach_non_finite_check_and_unscale")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_amp_foreach_non_finite_check_and_unscale.out(Tensor[] self, Tensor(b!) found_inf, Tensor inv_scale, *, Tensor(a!)[] out) -> ()")
35
+ static void call(at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale, at::TensorList out);
36
+ static void redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale, at::TensorList out);
37
+ };
38
+
39
+ struct TORCH_API _amp_foreach_non_finite_check_and_unscale {
40
+ using schema = ::std::tuple<::std::vector<at::Tensor>,at::Tensor> (at::TensorList, const at::Tensor &, const at::Tensor &);
41
+ using ptr_schema = schema*;
42
+ // See Note [static constexpr char* members for windows NVCC]
43
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_amp_foreach_non_finite_check_and_unscale")
44
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
45
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_amp_foreach_non_finite_check_and_unscale(Tensor[] self, Tensor found_inf, Tensor inv_scale) -> (Tensor[] self_out, Tensor found_inf_out)")
46
+ static ::std::tuple<::std::vector<at::Tensor>,at::Tensor> call(at::TensorList self, const at::Tensor & found_inf, const at::Tensor & inv_scale);
47
+ static ::std::tuple<::std::vector<at::Tensor>,at::Tensor> redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & found_inf, const at::Tensor & inv_scale);
48
+ };
49
+
50
+ }} // namespace at::_ops
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_copy_from_native.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <c10/util/Optional.h>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API at::Tensor & _copy_from_out(const at::Tensor & self, const at::Tensor & dst, bool non_blocking, at::Tensor & out);
20
+ } // namespace native
21
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_empty_per_channel_affine_quantized.h ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/_empty_per_channel_affine_quantized_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
26
+ inline at::Tensor _empty_per_channel_affine_quantized(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, c10::optional<at::MemoryFormat> memory_format=MemoryFormat::Contiguous) {
27
+ return at::_ops::_empty_per_channel_affine_quantized::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
28
+ }
29
+ namespace symint {
30
+ template <typename T, typename = std::enable_if_t<std::is_same<T, int64_t>::value>>
31
+ at::Tensor _empty_per_channel_affine_quantized(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, c10::optional<at::MemoryFormat> memory_format=MemoryFormat::Contiguous) {
32
+ return at::_ops::_empty_per_channel_affine_quantized::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
33
+ }
34
+ }
35
+
36
+ // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
37
+ inline at::Tensor _empty_per_channel_affine_quantized(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format) {
38
+ return at::_ops::_empty_per_channel_affine_quantized::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format);
39
+ }
40
+ namespace symint {
41
+ template <typename T, typename = std::enable_if_t<std::is_same<T, int64_t>::value>>
42
+ at::Tensor _empty_per_channel_affine_quantized(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format) {
43
+ return at::_ops::_empty_per_channel_affine_quantized::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format);
44
+ }
45
+ }
46
+
47
+ // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
48
+ inline at::Tensor _empty_per_channel_affine_quantized_symint(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, c10::optional<at::MemoryFormat> memory_format=MemoryFormat::Contiguous) {
49
+ return at::_ops::_empty_per_channel_affine_quantized::call(size, scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
50
+ }
51
+ namespace symint {
52
+ template <typename T, typename = std::enable_if_t<std::is_same<T, c10::SymInt>::value>>
53
+ at::Tensor _empty_per_channel_affine_quantized(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, c10::optional<at::MemoryFormat> memory_format=MemoryFormat::Contiguous) {
54
+ return at::_ops::_empty_per_channel_affine_quantized::call(size, scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
55
+ }
56
+ }
57
+
58
+ // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
59
+ inline at::Tensor _empty_per_channel_affine_quantized_symint(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format) {
60
+ return at::_ops::_empty_per_channel_affine_quantized::call(size, scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format);
61
+ }
62
+ namespace symint {
63
+ template <typename T, typename = std::enable_if_t<std::is_same<T, c10::SymInt>::value>>
64
+ at::Tensor _empty_per_channel_affine_quantized(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format) {
65
+ return at::_ops::_empty_per_channel_affine_quantized::call(size, scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format);
66
+ }
67
+ }
68
+
69
+ // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
70
+ inline at::Tensor & _empty_per_channel_affine_quantized_out(at::Tensor & out, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::MemoryFormat> memory_format=MemoryFormat::Contiguous) {
71
+ return at::_ops::_empty_per_channel_affine_quantized_out::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out);
72
+ }
73
+ namespace symint {
74
+ template <typename T, typename = std::enable_if_t<std::is_same<T, int64_t>::value>>
75
+ at::Tensor & _empty_per_channel_affine_quantized_out(at::Tensor & out, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::MemoryFormat> memory_format=MemoryFormat::Contiguous) {
76
+ return at::_ops::_empty_per_channel_affine_quantized_out::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out);
77
+ }
78
+ }
79
+
80
+ // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
81
+ inline at::Tensor & _empty_per_channel_affine_quantized_outf(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::MemoryFormat> memory_format, at::Tensor & out) {
82
+ return at::_ops::_empty_per_channel_affine_quantized_out::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out);
83
+ }
84
+ namespace symint {
85
+ template <typename T, typename = std::enable_if_t<std::is_same<T, int64_t>::value>>
86
+ at::Tensor & _empty_per_channel_affine_quantized_outf(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::MemoryFormat> memory_format, at::Tensor & out) {
87
+ return at::_ops::_empty_per_channel_affine_quantized_out::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out);
88
+ }
89
+ }
90
+
91
+ // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
92
+ inline at::Tensor & _empty_per_channel_affine_quantized_symint_out(at::Tensor & out, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::MemoryFormat> memory_format=MemoryFormat::Contiguous) {
93
+ return at::_ops::_empty_per_channel_affine_quantized_out::call(size, scales, zero_points, axis, memory_format, out);
94
+ }
95
+ namespace symint {
96
+ template <typename T, typename = std::enable_if_t<std::is_same<T, c10::SymInt>::value>>
97
+ at::Tensor & _empty_per_channel_affine_quantized_out(at::Tensor & out, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::MemoryFormat> memory_format=MemoryFormat::Contiguous) {
98
+ return at::_ops::_empty_per_channel_affine_quantized_out::call(size, scales, zero_points, axis, memory_format, out);
99
+ }
100
+ }
101
+
102
+ // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
103
+ inline at::Tensor & _empty_per_channel_affine_quantized_symint_outf(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::MemoryFormat> memory_format, at::Tensor & out) {
104
+ return at::_ops::_empty_per_channel_affine_quantized_out::call(size, scales, zero_points, axis, memory_format, out);
105
+ }
106
+ namespace symint {
107
+ template <typename T, typename = std::enable_if_t<std::is_same<T, c10::SymInt>::value>>
108
+ at::Tensor & _empty_per_channel_affine_quantized_outf(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::MemoryFormat> memory_format, at::Tensor & out) {
109
+ return at::_ops::_empty_per_channel_affine_quantized_out::call(size, scales, zero_points, axis, memory_format, out);
110
+ }
111
+ }
112
+
113
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_fft_c2r_cuda_dispatch.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cuda {
19
+
20
+ TORCH_API at::Tensor _fft_c2r(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, int64_t last_dim_size);
21
+ TORCH_API at::Tensor _fft_c2r_symint(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size);
22
+ TORCH_API at::Tensor & _fft_c2r_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, int64_t last_dim_size);
23
+ TORCH_API at::Tensor & _fft_c2r_outf(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, int64_t last_dim_size, at::Tensor & out);
24
+ TORCH_API at::Tensor & _fft_c2r_symint_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size);
25
+ TORCH_API at::Tensor & _fft_c2r_symint_outf(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size, at::Tensor & out);
26
+
27
+ } // namespace cuda
28
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_flash_attention_forward.h ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/_flash_attention_forward_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
26
+ inline ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> _flash_attention_forward(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const c10::optional<at::Tensor> & cum_seq_q, const c10::optional<at::Tensor> & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, bool return_debug_mask, c10::optional<double> scale=c10::nullopt) {
27
+ return at::_ops::_flash_attention_forward::call(query, key, value, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, return_debug_mask, scale);
28
+ }
29
+ namespace symint {
30
+ template <typename T, typename = std::enable_if_t<std::is_same<T, int64_t>::value>>
31
+ ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> _flash_attention_forward(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const c10::optional<at::Tensor> & cum_seq_q, const c10::optional<at::Tensor> & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, bool return_debug_mask, c10::optional<double> scale=c10::nullopt) {
32
+ return at::_ops::_flash_attention_forward::call(query, key, value, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, return_debug_mask, scale);
33
+ }
34
+ }
35
+
36
+ // aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
37
+ inline ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> _flash_attention_forward_symint(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const c10::optional<at::Tensor> & cum_seq_q, const c10::optional<at::Tensor> & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, bool return_debug_mask, c10::optional<double> scale=c10::nullopt) {
38
+ return at::_ops::_flash_attention_forward::call(query, key, value, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, return_debug_mask, scale);
39
+ }
40
+ namespace symint {
41
+ template <typename T, typename = std::enable_if_t<std::is_same<T, c10::SymInt>::value>>
42
+ ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> _flash_attention_forward(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const c10::optional<at::Tensor> & cum_seq_q, const c10::optional<at::Tensor> & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, bool return_debug_mask, c10::optional<double> scale=c10::nullopt) {
43
+ return at::_ops::_flash_attention_forward::call(query, key, value, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, return_debug_mask, scale);
44
+ }
45
+ }
46
+
47
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_addcmul_native.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <c10/util/Optional.h>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API void _foreach_addcmul_Scalar_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value, at::TensorList out);
20
+ TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_scalar_slow(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1);
21
+ TORCH_API void foreach_tensor_addcmul_scalar_slow_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1);
22
+ TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_scalar_cuda(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1);
23
+ TORCH_API void foreach_tensor_addcmul_scalar_cuda_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1);
24
+ TORCH_API void _foreach_addcmul_ScalarList_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars, at::TensorList out);
25
+ TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_scalarlist_slow(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars);
26
+ TORCH_API void foreach_tensor_addcmul_scalarlist_slow_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars);
27
+ TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_scalarlist_cuda(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars);
28
+ TORCH_API void foreach_tensor_addcmul_scalarlist_cuda_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars);
29
+ TORCH_API void _foreach_addcmul_Tensor_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars, at::TensorList out);
30
+ TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_tensor_slow(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars);
31
+ TORCH_API void foreach_tensor_addcmul_tensor_slow_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars);
32
+ TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_tensor_cuda(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars);
33
+ TORCH_API void foreach_tensor_addcmul_tensor_cuda_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars);
34
+ } // namespace native
35
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_erfc_cuda_dispatch.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cuda {
19
+
20
+ TORCH_API ::std::vector<at::Tensor> _foreach_erfc(at::TensorList self);
21
+ TORCH_API void _foreach_erfc_(at::TensorList self);
22
+
23
+ } // namespace cuda
24
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_round.h ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/_foreach_round_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::_foreach_round(Tensor[] self) -> Tensor[]
26
+ inline ::std::vector<at::Tensor> _foreach_round(at::TensorList self) {
27
+ return at::_ops::_foreach_round::call(self);
28
+ }
29
+
30
+ // aten::_foreach_round_(Tensor(a!)[] self) -> ()
31
+ inline void _foreach_round_(at::TensorList self) {
32
+ return at::_ops::_foreach_round_::call(self);
33
+ }
34
+
35
+ // aten::_foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
36
+ inline void _foreach_round_out(at::TensorList out, at::TensorList self) {
37
+ return at::_ops::_foreach_round_out::call(self, out);
38
+ }
39
+ // aten::_foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
40
+ inline void _foreach_round_outf(at::TensorList self, at::TensorList out) {
41
+ return at::_ops::_foreach_round_out::call(self, out);
42
+ }
43
+
44
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_sin_native.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <c10/util/Optional.h>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API void _foreach_sin_out(at::TensorList self, at::TensorList out);
20
+ TORCH_API ::std::vector<at::Tensor> foreach_tensor_sin_slow(at::TensorList self);
21
+ TORCH_API void foreach_tensor_sin_slow_(at::TensorList self);
22
+ TORCH_API ::std::vector<at::Tensor> foreach_tensor_sin_cuda(at::TensorList self);
23
+ TORCH_API void foreach_tensor_sin_cuda_(at::TensorList self);
24
+ } // namespace native
25
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_tanh_cpu_dispatch.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cpu {
19
+
20
+ TORCH_API ::std::vector<at::Tensor> _foreach_tanh(at::TensorList self);
21
+ TORCH_API void _foreach_tanh_(at::TensorList self);
22
+
23
+ } // namespace cpu
24
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_functional_assert_scalar_native.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <c10/util/Optional.h>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API at::Tensor _functional_assert_scalar(const at::Scalar & self, c10::string_view assert_msg, const at::Tensor & dep_token);
20
+ } // namespace native
21
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_linalg_svd_ops.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API _linalg_svd {
18
+ using schema = ::std::tuple<at::Tensor,at::Tensor,at::Tensor> (const at::Tensor &, bool, bool, c10::optional<c10::string_view>);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_linalg_svd")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)")
24
+ static ::std::tuple<at::Tensor,at::Tensor,at::Tensor> call(const at::Tensor & A, bool full_matrices, bool compute_uv, c10::optional<c10::string_view> driver);
25
+ static ::std::tuple<at::Tensor,at::Tensor,at::Tensor> redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool full_matrices, bool compute_uv, c10::optional<c10::string_view> driver);
26
+ };
27
+
28
+ struct TORCH_API _linalg_svd_U {
29
+ using schema = ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &> (const at::Tensor &, bool, bool, c10::optional<c10::string_view>, at::Tensor &, at::Tensor &, at::Tensor &);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_linalg_svd")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "U")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_linalg_svd.U(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)")
35
+ static ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &> call(const at::Tensor & A, bool full_matrices, bool compute_uv, c10::optional<c10::string_view> driver, at::Tensor & U, at::Tensor & S, at::Tensor & Vh);
36
+ static ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &> redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool full_matrices, bool compute_uv, c10::optional<c10::string_view> driver, at::Tensor & U, at::Tensor & S, at::Tensor & Vh);
37
+ };
38
+
39
+ }} // namespace at::_ops
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nested_from_padded.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/_nested_from_padded_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::_nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor
26
+ inline at::Tensor _nested_from_padded(const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213=false) {
27
+ return at::_ops::_nested_from_padded::call(padded, cpu_nested_shape_example, fuse_transform_0213);
28
+ }
29
+
30
+ // aten::_nested_from_padded.out(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False, *, Tensor(a!) out) -> Tensor(a!)
31
+ inline at::Tensor & _nested_from_padded_out(at::Tensor & out, const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213=false) {
32
+ return at::_ops::_nested_from_padded_out::call(padded, cpu_nested_shape_example, fuse_transform_0213, out);
33
+ }
34
+ // aten::_nested_from_padded.out(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False, *, Tensor(a!) out) -> Tensor(a!)
35
+ inline at::Tensor & _nested_from_padded_outf(const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213, at::Tensor & out) {
36
+ return at::_ops::_nested_from_padded_out::call(padded, cpu_nested_shape_example, fuse_transform_0213, out);
37
+ }
38
+
39
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nnpack_spatial_convolution_compositeexplicitautograd_dispatch.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautograd {
19
+
20
+ TORCH_API at::Tensor _nnpack_spatial_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef padding, at::IntArrayRef stride=1);
21
+ TORCH_API at::Tensor _nnpack_spatial_convolution_symint(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride=c10::SymInt(1));
22
+ TORCH_API at::Tensor & _nnpack_spatial_convolution_out(at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef padding, at::IntArrayRef stride=1);
23
+ TORCH_API at::Tensor & _nnpack_spatial_convolution_outf(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::Tensor & out);
24
+ TORCH_API at::Tensor & _nnpack_spatial_convolution_symint_out(at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride=c10::SymInt(1));
25
+ TORCH_API at::Tensor & _nnpack_spatial_convolution_symint_outf(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, at::Tensor & out);
26
+
27
+ } // namespace compositeexplicitautograd
28
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_prelu_kernel_backward_cuda_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cuda {
19
+
20
+ TORCH_API ::std::tuple<at::Tensor,at::Tensor> _prelu_kernel_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight);
21
+
22
+ } // namespace cuda
23
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_transform_bias_rescale_qkv.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/_transform_bias_rescale_qkv_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::_transform_bias_rescale_qkv(Tensor qkv, Tensor qkv_bias, int num_heads) -> (Tensor, Tensor, Tensor)
26
+ inline ::std::tuple<at::Tensor,at::Tensor,at::Tensor> _transform_bias_rescale_qkv(const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads) {
27
+ return at::_ops::_transform_bias_rescale_qkv::call(qkv, qkv_bias, num_heads);
28
+ }
29
+
30
+ // aten::_transform_bias_rescale_qkv.out(Tensor qkv, Tensor qkv_bias, int num_heads, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
31
+ inline ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &> _transform_bias_rescale_qkv_out(at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads) {
32
+ return at::_ops::_transform_bias_rescale_qkv_out::call(qkv, qkv_bias, num_heads, out0, out1, out2);
33
+ }
34
+ // aten::_transform_bias_rescale_qkv.out(Tensor qkv, Tensor qkv_bias, int num_heads, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
35
+ inline ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &> _transform_bias_rescale_qkv_outf(const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
36
+ return at::_ops::_transform_bias_rescale_qkv_out::call(qkv, qkv_bias, num_heads, out0, out1, out2);
37
+ }
38
+
39
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/acos_meta_dispatch.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace meta {
19
+
20
+ TORCH_API at::Tensor acos(const at::Tensor & self);
21
+ TORCH_API at::Tensor & acos_out(at::Tensor & out, const at::Tensor & self);
22
+ TORCH_API at::Tensor & acos_outf(const at::Tensor & self, at::Tensor & out);
23
+ TORCH_API at::Tensor & acos_(at::Tensor & self);
24
+
25
+ } // namespace meta
26
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/alias_ops.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API alias {
18
+ using schema = at::Tensor (const at::Tensor &);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::alias")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "alias(Tensor(a) self) -> Tensor(a)")
24
+ static at::Tensor call(const at::Tensor & self);
25
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self);
26
+ };
27
+
28
+ }} // namespace at::_ops
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/any_meta_dispatch.h ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace meta {
19
+
20
+ TORCH_API at::Tensor any(const at::Tensor & self, int64_t dim, bool keepdim=false);
21
+ TORCH_API at::Tensor & any_out(at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false);
22
+ TORCH_API at::Tensor & any_outf(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out);
23
+ TORCH_API at::Tensor any(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false);
24
+ TORCH_API at::Tensor & any_out(at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false);
25
+ TORCH_API at::Tensor & any_outf(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out);
26
+ TORCH_API at::Tensor any(const at::Tensor & self);
27
+ TORCH_API at::Tensor & any_out(at::Tensor & out, const at::Tensor & self);
28
+ TORCH_API at::Tensor & any_outf(const at::Tensor & self, at::Tensor & out);
29
+
30
+ } // namespace meta
31
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/batch_norm_update_stats_cuda_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cuda {
19
+
20
+ TORCH_API ::std::tuple<at::Tensor,at::Tensor> batch_norm_update_stats(const at::Tensor & input, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, double momentum);
21
+
22
+ } // namespace cuda
23
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/ccol_indices_compositeexplicitautograd_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautograd {
19
+
20
+ TORCH_API at::Tensor ccol_indices(const at::Tensor & self);
21
+
22
+ } // namespace compositeexplicitautograd
23
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/column_stack_native.h ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <c10/util/Optional.h>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API at::Tensor column_stack(at::TensorList tensors);
20
+ TORCH_API at::Tensor & column_stack_out(at::TensorList tensors, at::Tensor & out);
21
+ } // namespace native
22
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/concatenate.h ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/concatenate_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::concatenate(Tensor[] tensors, int dim=0) -> Tensor
26
+ inline at::Tensor concatenate(at::TensorList tensors, int64_t dim=0) {
27
+ return at::_ops::concatenate::call(tensors, dim);
28
+ }
29
+
30
+ // aten::concatenate.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
31
+ inline at::Tensor & concatenate_out(at::Tensor & out, at::TensorList tensors, int64_t dim=0) {
32
+ return at::_ops::concatenate_out::call(tensors, dim, out);
33
+ }
34
+ // aten::concatenate.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
35
+ inline at::Tensor & concatenate_outf(at::TensorList tensors, int64_t dim, at::Tensor & out) {
36
+ return at::_ops::concatenate_out::call(tensors, dim, out);
37
+ }
38
+
39
+ // aten::concatenate.names(Tensor[] tensors, Dimname dim) -> Tensor
40
+ inline at::Tensor concatenate(at::TensorList tensors, at::Dimname dim) {
41
+ return at::_ops::concatenate_names::call(tensors, dim);
42
+ }
43
+
44
+ // aten::concatenate.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)
45
+ inline at::Tensor & concatenate_out(at::Tensor & out, at::TensorList tensors, at::Dimname dim) {
46
+ return at::_ops::concatenate_names_out::call(tensors, dim, out);
47
+ }
48
+ // aten::concatenate.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)
49
+ inline at::Tensor & concatenate_outf(at::TensorList tensors, at::Dimname dim, at::Tensor & out) {
50
+ return at::_ops::concatenate_names_out::call(tensors, dim, out);
51
+ }
52
+
53
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cudnn_convolution_add_relu_ops.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API cudnn_convolution_add_relu {
18
+ using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Tensor &, const c10::optional<at::Scalar> &, const c10::optional<at::Tensor> &, c10::SymIntArrayRef, c10::SymIntArrayRef, c10::SymIntArrayRef, c10::SymInt);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::cudnn_convolution_add_relu")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "cudnn_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor")
24
+ static at::Tensor call(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const c10::optional<at::Scalar> & alpha, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups);
25
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const c10::optional<at::Scalar> & alpha, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups);
26
+ };
27
+
28
+ struct TORCH_API cudnn_convolution_add_relu_out {
29
+ using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Tensor &, const c10::optional<at::Scalar> &, const c10::optional<at::Tensor> &, c10::SymIntArrayRef, c10::SymIntArrayRef, c10::SymIntArrayRef, c10::SymInt, at::Tensor &);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::cudnn_convolution_add_relu")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "cudnn_convolution_add_relu.out(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)")
35
+ static at::Tensor & call(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const c10::optional<at::Scalar> & alpha, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out);
36
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const c10::optional<at::Scalar> & alpha, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out);
37
+ };
38
+
39
+ }} // namespace at::_ops
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cudnn_grid_sampler_backward_cuda_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cuda {
19
+
20
+ TORCH_API ::std::tuple<at::Tensor,at::Tensor> cudnn_grid_sampler_backward(const at::Tensor & self, const at::Tensor & grid, const at::Tensor & grad_output);
21
+
22
+ } // namespace cuda
23
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/diag_embed_compositeexplicitautograd_dispatch.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautograd {
19
+
20
+ TORCH_API at::Tensor & diag_embed_out(at::Tensor & out, const at::Tensor & self, int64_t offset=0, int64_t dim1=-2, int64_t dim2=-1);
21
+ TORCH_API at::Tensor & diag_embed_outf(const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out);
22
+
23
+ } // namespace compositeexplicitautograd
24
+ } // namespace at