Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_foreach.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/implementation.cpp +87 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp +354 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py +78 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/fuse_attention.py +786 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py +1059 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/joint_graph.py +341 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pre_grad.py +611 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/debug.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/match.py +121 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py +91 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/unification_tools.py +395 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/utils.py +105 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/variable.py +85 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSGeneratorImpl.h +52 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSHooks.h +57 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSStream.h +133 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_amp_foreach_non_finite_check_and_unscale_ops.h +50 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_copy_from_native.h +21 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_empty_per_channel_affine_quantized.h +113 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_fft_c2r_cuda_dispatch.h +28 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_flash_attention_forward.h +47 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_addcmul_native.h +35 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_erfc_cuda_dispatch.h +24 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_round.h +44 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_sin_native.h +25 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_tanh_cpu_dispatch.h +24 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_functional_assert_scalar_native.h +21 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_linalg_svd_ops.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nested_from_padded.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nnpack_spatial_convolution_compositeexplicitautograd_dispatch.h +28 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_prelu_kernel_backward_cuda_dispatch.h +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_transform_bias_rescale_qkv.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/acos_meta_dispatch.h +26 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/alias_ops.h +28 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/any_meta_dispatch.h +31 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/batch_norm_update_stats_cuda_dispatch.h +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/ccol_indices_compositeexplicitautograd_dispatch.h +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/column_stack_native.h +22 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/concatenate.h +53 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cudnn_convolution_add_relu_ops.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cudnn_grid_sampler_backward_cuda_dispatch.h +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/diag_embed_compositeexplicitautograd_dispatch.h +24 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-311.pyc
ADDED
|
Binary file (91.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-311.pyc
ADDED
|
Binary file (5.26 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_foreach.cpython-311.pyc
ADDED
|
Binary file (15.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/implementation.cpp
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// NOTE: Like interface.cpp, this file will be copied into AOTInductor
|
| 2 |
+
// generated output. This file is intended to keep implementation
|
| 3 |
+
// details separate from the implementation of the AOTI public
|
| 4 |
+
// interface. Note also that #includes should go into interface.cpp
|
| 5 |
+
// for simplicity of maintenance.
|
| 6 |
+
|
| 7 |
+
namespace torch {
|
| 8 |
+
namespace aot_inductor {
|
| 9 |
+
template <typename T>
|
| 10 |
+
void convert_output_to_handle(
|
| 11 |
+
const ArrayRefTensor<T>& output,
|
| 12 |
+
AtenTensorHandle& handle) {
|
| 13 |
+
handle = output.expensiveCopyToTensor();
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
template <typename... Ts, std::size_t... Is>
|
| 17 |
+
void convert_outputs_to_handles_helper(
|
| 18 |
+
const std::tuple<ArrayRefTensor<Ts>...>& outputs,
|
| 19 |
+
AtenTensorHandle* output_handles,
|
| 20 |
+
std::index_sequence<Is...>) {
|
| 21 |
+
(convert_output_to_handle(std::get<Is>(outputs), output_handles[Is]), ...);
|
| 22 |
+
}
|
| 23 |
+
template <typename... Ts>
|
| 24 |
+
void convert_outputs_to_handles(
|
| 25 |
+
const std::tuple<ArrayRefTensor<Ts>...>& outputs,
|
| 26 |
+
AtenTensorHandle* output_handles) {
|
| 27 |
+
convert_outputs_to_handles_helper(
|
| 28 |
+
outputs, output_handles, std::make_index_sequence<sizeof...(Ts)>());
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
template <typename T>
|
| 32 |
+
void convert_handle_to_arrayref_tensor(
|
| 33 |
+
AtenTensorHandle handle,
|
| 34 |
+
ArrayRefTensor<T>& input) {
|
| 35 |
+
void* data_ptr;
|
| 36 |
+
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle, &data_ptr));
|
| 37 |
+
int64_t dim;
|
| 38 |
+
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(handle, &dim));
|
| 39 |
+
int64_t numel;
|
| 40 |
+
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(handle, &numel));
|
| 41 |
+
int64_t* sizes;
|
| 42 |
+
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle, &sizes));
|
| 43 |
+
int64_t* strides;
|
| 44 |
+
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle, &strides));
|
| 45 |
+
int32_t dtype;
|
| 46 |
+
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(handle, &dtype));
|
| 47 |
+
int32_t device_type;
|
| 48 |
+
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(handle, &device_type));
|
| 49 |
+
int32_t device_index;
|
| 50 |
+
AOTI_TORCH_ERROR_CODE_CHECK(
|
| 51 |
+
aoti_torch_get_device_index(handle, &device_index));
|
| 52 |
+
|
| 53 |
+
input = ArrayRefTensor<T>(
|
| 54 |
+
MiniArrayRef<T>(reinterpret_cast<T*>(data_ptr), numel),
|
| 55 |
+
MiniArrayRef<const int64_t>(sizes, dim),
|
| 56 |
+
MiniArrayRef<const int64_t>(strides, dim),
|
| 57 |
+
device_type,
|
| 58 |
+
device_index);
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
template <typename... Ts, std::size_t... Is>
|
| 62 |
+
void convert_handles_to_inputs_helper(
|
| 63 |
+
AtenTensorHandle* input_handles,
|
| 64 |
+
std::tuple<ArrayRefTensor<Ts>...>& inputs,
|
| 65 |
+
std::index_sequence<Is...>) {
|
| 66 |
+
(convert_handle_to_arrayref_tensor(input_handles[Is], std::get<Is>(inputs)),
|
| 67 |
+
...);
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
template <typename... Ts>
|
| 71 |
+
void convert_handles_to_inputs(
|
| 72 |
+
AtenTensorHandle* input_handles,
|
| 73 |
+
std::tuple<ArrayRefTensor<Ts>...>& inputs) {
|
| 74 |
+
convert_handles_to_inputs_helper(
|
| 75 |
+
input_handles, inputs, std::make_index_sequence<sizeof...(Ts)>());
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
template <typename T>
|
| 79 |
+
void assert_numel(const ArrayRefTensor<T>& tensor, int64_t numel) {
|
| 80 |
+
if (tensor.numel() != numel) {
|
| 81 |
+
std::stringstream err;
|
| 82 |
+
err << "incorrect numel for input tensor. expected " << numel << ", got " << tensor.numel();
|
| 83 |
+
throw std::runtime_error(err.str());
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
} // namespace aot_inductor
|
| 87 |
+
} // namespace torch
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h>
|
| 2 |
+
#include <torch/csrc/inductor/aoti_runtime/interface.h>
|
| 3 |
+
#include <torch/csrc/inductor/aoti_runtime/model_container.h>
|
| 4 |
+
#include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h>
|
| 5 |
+
#include <torch/csrc/inductor/aoti_runtime/thread_local.h>
|
| 6 |
+
|
| 7 |
+
#include <iostream>
|
| 8 |
+
#include <sstream>
|
| 9 |
+
#include <stdexcept>
|
| 10 |
+
#include <vector>
|
| 11 |
+
|
| 12 |
+
#define CONVERT_EXCEPTION_TO_ERROR_CODE(...) \
|
| 13 |
+
try { \
|
| 14 |
+
__VA_ARGS__ \
|
| 15 |
+
} catch (const std::exception& e) { \
|
| 16 |
+
std::cerr << "Error: " << e.what() << std::endl; \
|
| 17 |
+
return AOTI_RUNTIME_FAILURE; \
|
| 18 |
+
} catch (...) { \
|
| 19 |
+
std::cerr << "Unknown exception occurred." << std::endl; \
|
| 20 |
+
return AOTI_RUNTIME_FAILURE; \
|
| 21 |
+
} \
|
| 22 |
+
return AOTI_RUNTIME_SUCCESS;
|
| 23 |
+
|
| 24 |
+
#define AOTI_VECTOR_SIZE_CHECK(actual_size, expected_size, name) \
|
| 25 |
+
do { \
|
| 26 |
+
AOTI_RUNTIME_CHECK( \
|
| 27 |
+
actual_size == expected_size, \
|
| 28 |
+
"expected " + std::string(name) + " vector size to be " + \
|
| 29 |
+
std::to_string(expected_size) + ", but got " + \
|
| 30 |
+
std::to_string(actual_size)); \
|
| 31 |
+
} while (0)
|
| 32 |
+
|
| 33 |
+
// AOTInductor uses at::addmm_out, which doesn't supports
|
| 34 |
+
// arguments that requires gradient. For this reason, we
|
| 35 |
+
// enforce no_grad context for run APIs.
|
| 36 |
+
//
|
| 37 |
+
// A RAII, thread local (!) guard that enables or disables grad mode upon
|
| 38 |
+
// construction, and sets it back to the original value upon destruction.
|
| 39 |
+
struct AOTINoGradGuard {
|
| 40 |
+
AOTINoGradGuard() : prev_mode(aoti_torch_grad_mode_is_enabled()) {
|
| 41 |
+
aoti_torch_grad_mode_set_enabled(false);
|
| 42 |
+
}
|
| 43 |
+
~AOTINoGradGuard() {
|
| 44 |
+
aoti_torch_grad_mode_set_enabled(prev_mode);
|
| 45 |
+
}
|
| 46 |
+
bool prev_mode;
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
extern "C" {
|
| 50 |
+
|
| 51 |
+
AOTIRuntimeError AOTInductorModelContainerCreate(
|
| 52 |
+
AOTInductorModelContainerHandle* container_handle,
|
| 53 |
+
size_t num_models,
|
| 54 |
+
bool is_cpu,
|
| 55 |
+
const char* cubin_dir) {
|
| 56 |
+
return AOTInductorModelContainerCreateWithDevice(
|
| 57 |
+
container_handle,
|
| 58 |
+
num_models,
|
| 59 |
+
is_cpu ? "cpu" : "cuda",
|
| 60 |
+
cubin_dir);
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
AOTIRuntimeError AOTInductorModelContainerCreateWithDevice(
|
| 64 |
+
AOTInductorModelContainerHandle* container_handle,
|
| 65 |
+
size_t num_models,
|
| 66 |
+
const char* device_str,
|
| 67 |
+
const char* cubin_dir) {
|
| 68 |
+
if (num_models == 0) {
|
| 69 |
+
std::cerr << "Error: num_models must be positive, but got 0" << std::endl;
|
| 70 |
+
return AOTI_RUNTIME_FAILURE;
|
| 71 |
+
}
|
| 72 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 73 |
+
std::optional<std::string> cubin_dir_opt;
|
| 74 |
+
if (cubin_dir != nullptr) {
|
| 75 |
+
cubin_dir_opt.emplace(cubin_dir);
|
| 76 |
+
}
|
| 77 |
+
auto* container = new torch::aot_inductor::AOTInductorModelContainer(
|
| 78 |
+
num_models, std::string(device_str), cubin_dir_opt);
|
| 79 |
+
*container_handle =
|
| 80 |
+
reinterpret_cast<AOTInductorModelContainerHandle>(container);
|
| 81 |
+
})
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
AOTIRuntimeError AOTInductorModelContainerDelete(
|
| 85 |
+
AOTInductorModelContainerHandle container_handle) {
|
| 86 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 87 |
+
auto* container =
|
| 88 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 89 |
+
container_handle);
|
| 90 |
+
delete container;
|
| 91 |
+
});
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
AOTIRuntimeError AOTInductorModelContainerRun(
|
| 95 |
+
AOTInductorModelContainerHandle container_handle,
|
| 96 |
+
AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles
|
| 97 |
+
// are stolen; the array itself is borrowed
|
| 98 |
+
size_t num_inputs,
|
| 99 |
+
AtenTensorHandle*
|
| 100 |
+
output_handles, // array for writing output AtenTensorHandle; handles
|
| 101 |
+
// will be stolen by the caller; the array itself is
|
| 102 |
+
// borrowed
|
| 103 |
+
size_t num_outputs,
|
| 104 |
+
AOTInductorStreamHandle stream_handle,
|
| 105 |
+
AOTIProxyExecutorHandle proxy_executor_handle) {
|
| 106 |
+
auto* container =
|
| 107 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 108 |
+
container_handle);
|
| 109 |
+
AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs");
|
| 110 |
+
AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs");
|
| 111 |
+
|
| 112 |
+
auto stream =
|
| 113 |
+
reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
|
| 114 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 115 |
+
AOTINoGradGuard guard;
|
| 116 |
+
container->run(
|
| 117 |
+
input_handles, output_handles, stream, proxy_executor_handle);
|
| 118 |
+
})
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
AOTIRuntimeError AOTInductorModelContainerGetNumConstants(
|
| 122 |
+
AOTInductorModelContainerHandle container_handle,
|
| 123 |
+
size_t* num_constants) {
|
| 124 |
+
auto* container =
|
| 125 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 126 |
+
container_handle);
|
| 127 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 128 |
+
{ *num_constants = container->num_constants(); })
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
AOTIRuntimeError AOTInductorModelContainerGetConstantName(
|
| 132 |
+
AOTInductorModelContainerHandle container_handle,
|
| 133 |
+
size_t idx,
|
| 134 |
+
const char** name) {
|
| 135 |
+
auto* container =
|
| 136 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 137 |
+
container_handle);
|
| 138 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 139 |
+
{ *name = container->constant_name(idx); })
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN(
|
| 143 |
+
AOTInductorModelContainerHandle container_handle,
|
| 144 |
+
size_t idx,
|
| 145 |
+
const char** original_fqn) {
|
| 146 |
+
auto* container =
|
| 147 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 148 |
+
container_handle);
|
| 149 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 150 |
+
{ *original_fqn = container->constant_original_fqn(idx); })
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded(
|
| 154 |
+
AOTInductorModelContainerHandle container_handle,
|
| 155 |
+
size_t idx,
|
| 156 |
+
bool* from_folded) {
|
| 157 |
+
auto* container =
|
| 158 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(container_handle);
|
| 159 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({ *from_folded = container->constant_from_folded(idx); })
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
AOTIRuntimeError AOTInductorModelContainerGetConstantDtype(
|
| 163 |
+
AOTInductorModelContainerHandle container_handle,
|
| 164 |
+
size_t idx,
|
| 165 |
+
int32_t* dtype) {
|
| 166 |
+
auto* container =
|
| 167 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 168 |
+
container_handle);
|
| 169 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 170 |
+
{ *dtype = container->constant_dtype(idx); })
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer(
|
| 174 |
+
AOTInductorModelContainerHandle container_handle,
|
| 175 |
+
AOTInductorConstantMapHandle constant_map_handle,
|
| 176 |
+
bool use_inactive,
|
| 177 |
+
bool validate_full_update) {
|
| 178 |
+
auto* container =
|
| 179 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 180 |
+
container_handle);
|
| 181 |
+
auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
|
| 182 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 183 |
+
container->update_constant_buffer(
|
| 184 |
+
*input_map, use_inactive, validate_full_update);
|
| 185 |
+
})
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer(
|
| 189 |
+
AOTInductorModelContainerHandle container_handle,
|
| 190 |
+
AOTInductorConstantMapHandle constant_map_handle) {
|
| 191 |
+
return AOTInductorModelContainerUpdateConstantBuffer(container_handle,
|
| 192 |
+
constant_map_handle,
|
| 193 |
+
/*use_inactive*/ true,
|
| 194 |
+
/*validate_full_update*/ true);
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
AOTIRuntimeError AOTInductorModelContainerRunConstantFolding(
|
| 198 |
+
AOTInductorModelContainerHandle container_handle,
|
| 199 |
+
bool use_inactive,
|
| 200 |
+
AOTInductorStreamHandle stream_handle,
|
| 201 |
+
AOTIProxyExecutorHandle proxy_executor_handle) {
|
| 202 |
+
auto* container =
|
| 203 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 204 |
+
container_handle);
|
| 205 |
+
auto stream =
|
| 206 |
+
reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
|
| 207 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 208 |
+
AOTINoGradGuard guard;
|
| 209 |
+
container->run_const_fold(use_inactive, stream, proxy_executor_handle);
|
| 210 |
+
})
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer(
|
| 214 |
+
AOTInductorModelContainerHandle container_handle) {
|
| 215 |
+
auto* container =
|
| 216 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 217 |
+
container_handle);
|
| 218 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 219 |
+
container->swap_constant_buffer();
|
| 220 |
+
})
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
AOTIRuntimeError AOTInductorModelContainerGetNumInputs(
|
| 224 |
+
AOTInductorModelContainerHandle container_handle,
|
| 225 |
+
size_t* ret_num_inputs) {
|
| 226 |
+
auto* container =
|
| 227 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 228 |
+
container_handle);
|
| 229 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 230 |
+
{ *ret_num_inputs = container->num_inputs(); })
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
AOTIRuntimeError AOTInductorModelContainerGetInputName(
|
| 234 |
+
AOTInductorModelContainerHandle container_handle,
|
| 235 |
+
size_t input_idx,
|
| 236 |
+
const char** ret_input_names) {
|
| 237 |
+
auto* container =
|
| 238 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 239 |
+
container_handle);
|
| 240 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 241 |
+
{ *ret_input_names = container->input_name(input_idx); })
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
AOTIRuntimeError AOTInductorModelContainerGetNumOutputs(
|
| 245 |
+
AOTInductorModelContainerHandle container_handle,
|
| 246 |
+
size_t* ret_num_outputs) {
|
| 247 |
+
auto* container =
|
| 248 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 249 |
+
container_handle);
|
| 250 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 251 |
+
{ *ret_num_outputs = container->num_outputs(); })
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
AOTIRuntimeError AOTInductorModelContainerGetOutputName(
|
| 255 |
+
AOTInductorModelContainerHandle container_handle,
|
| 256 |
+
size_t output_idx,
|
| 257 |
+
const char** ret_output_names) {
|
| 258 |
+
auto* container =
|
| 259 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 260 |
+
container_handle);
|
| 261 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 262 |
+
{ *ret_output_names = container->output_name(output_idx); })
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
AOTIRuntimeError AOTInductorModelContainerGetCallSpec(
|
| 266 |
+
AOTInductorModelContainerHandle container_handle,
|
| 267 |
+
const char** in_spec,
|
| 268 |
+
const char** out_spec) {
|
| 269 |
+
auto* container =
|
| 270 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 271 |
+
container_handle);
|
| 272 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 273 |
+
*in_spec = container->get_in_spec();
|
| 274 |
+
*out_spec = container->get_out_spec();
|
| 275 |
+
})
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
AOTIRuntimeError AOTInductorModelCreate(
|
| 279 |
+
AOTInductorModelHandle* model_handle,
|
| 280 |
+
AOTInductorConstantMapHandle constant_map_handle){
|
| 281 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 282 |
+
auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
|
| 283 |
+
auto constant_array = std::make_shared<std::vector<torch::aot_inductor::ConstantHandle>>();
|
| 284 |
+
auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
|
| 285 |
+
|
| 286 |
+
auto model = new torch::aot_inductor::AOTInductorModel(
|
| 287 |
+
constant_map,
|
| 288 |
+
constant_array,
|
| 289 |
+
"cpu", // device_str is hardcoded, as AOTInductorModelCreate is only use for CPU models
|
| 290 |
+
""
|
| 291 |
+
);
|
| 292 |
+
|
| 293 |
+
if (input_map) {
|
| 294 |
+
for (auto const& kv : *input_map) {
|
| 295 |
+
constant_map->emplace(kv.first, kv.second);
|
| 296 |
+
}
|
| 297 |
+
} else {
|
| 298 |
+
model->load_constants();
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
*model_handle = reinterpret_cast<AOTInductorModelHandle>(model);
|
| 302 |
+
})}
|
| 303 |
+
|
| 304 |
+
AOTIRuntimeError AOTInductorModelRun(
|
| 305 |
+
AOTInductorModelHandle model_handle,
|
| 306 |
+
AtenTensorHandle* input_handles,
|
| 307 |
+
AtenTensorHandle* output_handles) {
|
| 308 |
+
auto model =
|
| 309 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
|
| 310 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 311 |
+
AOTINoGradGuard guard;
|
| 312 |
+
model->run_impl(
|
| 313 |
+
input_handles,
|
| 314 |
+
output_handles,
|
| 315 |
+
(torch::aot_inductor::DeviceStreamType) nullptr,
|
| 316 |
+
nullptr);
|
| 317 |
+
})
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle){
|
| 321 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 322 |
+
auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(
|
| 323 |
+
model_handle);
|
| 324 |
+
delete model;
|
| 325 |
+
})}
|
| 326 |
+
|
| 327 |
+
AOTIRuntimeError AOTInductorModelGetNumOutputs(
|
| 328 |
+
AOTInductorModelHandle model_handle,
|
| 329 |
+
size_t* ret_num_outputs) {
|
| 330 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 331 |
+
auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
|
| 332 |
+
*ret_num_outputs = model->num_outputs();
|
| 333 |
+
})
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
AOTIRuntimeError AOTInductorModelUpdateConstantsMap(
|
| 337 |
+
AOTInductorModelHandle model_handle,
|
| 338 |
+
AOTInductorConstantMapHandle constant_map_handle) {
|
| 339 |
+
auto model =
|
| 340 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
|
| 341 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 342 |
+
auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
|
| 343 |
+
auto input_map =
|
| 344 |
+
reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(
|
| 345 |
+
constant_map_handle);
|
| 346 |
+
|
| 347 |
+
for (auto const& kv : *input_map) {
|
| 348 |
+
constant_map->emplace(kv.first, kv.second);
|
| 349 |
+
}
|
| 350 |
+
model->update_constants_map(std::move(constant_map));
|
| 351 |
+
})
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
} // extern "C"
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-311.pyc
ADDED
|
Binary file (13.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.fx.experimental.proxy_tensor import py_sym_types, SymBool, SymFloat, SymInt
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class _SymExprHash:
|
| 10 |
+
"""
|
| 11 |
+
Hash for a py_sym_types that will use the underlying sympy expression
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
sym_obj: Union[SymInt, SymFloat, SymBool]
|
| 15 |
+
|
| 16 |
+
def __hash__(self) -> int:
|
| 17 |
+
return hash((type(self.sym_obj), self.sym_obj.node.expr))
|
| 18 |
+
|
| 19 |
+
def __eq__(self, value) -> bool:
|
| 20 |
+
if not isinstance(value, _SymExprHash):
|
| 21 |
+
return False
|
| 22 |
+
return self.sym_obj.node.expr == value.sym_obj.node.expr
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class _SymHashingDict:
|
| 26 |
+
"""
|
| 27 |
+
Wrapper around a dictionary that will convert sym types to hash with _SymExprHash and reuse
|
| 28 |
+
existing sym proxies.
|
| 29 |
+
|
| 30 |
+
SymPy hash is not always reliable so optimistically hash sympy expression, and if those fail,
|
| 31 |
+
fallback to symnodes.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self):
|
| 35 |
+
self.sym_hash_dict = {}
|
| 36 |
+
|
| 37 |
+
def __setitem__(self, key, value):
|
| 38 |
+
self.sym_hash_dict.__setitem__(self._wrap_to_sym_expr_hash(key), value)
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, key):
|
| 41 |
+
return self.sym_hash_dict[self._wrap_to_sym_expr_hash(key)]
|
| 42 |
+
|
| 43 |
+
def __contains__(self, key):
|
| 44 |
+
return self._wrap_to_sym_expr_hash(key) in self.sym_hash_dict
|
| 45 |
+
|
| 46 |
+
def get(self, key, default=None):
|
| 47 |
+
return self.sym_hash_dict.get(self._wrap_to_sym_expr_hash(key), default)
|
| 48 |
+
|
| 49 |
+
def _wrap_to_sym_expr_hash(self, key):
|
| 50 |
+
return _SymExprHash(key) if isinstance(key, py_sym_types) else key
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def dedupe_symints(graph: torch.fx.Graph):
|
| 54 |
+
"""
|
| 55 |
+
Dedupes sym ints in the graph to nodes are resolvable to symint graph inputs.
|
| 56 |
+
|
| 57 |
+
We only dedupe from graph inputs to avoid adding a potential dependency in the forward
|
| 58 |
+
from the backward.
|
| 59 |
+
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
sym_dict = _SymHashingDict()
|
| 63 |
+
resolvable_from_input_symints = set()
|
| 64 |
+
|
| 65 |
+
for node in graph.nodes:
|
| 66 |
+
val = node.meta.get("val", None)
|
| 67 |
+
if val is None or not isinstance(val, py_sym_types):
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
if node.op == "placeholder":
|
| 71 |
+
resolvable_from_input_symints.add(node)
|
| 72 |
+
sym_dict[val] = node
|
| 73 |
+
elif existing_node := sym_dict.get(val):
|
| 74 |
+
node.replace_all_uses_with(existing_node)
|
| 75 |
+
graph.erase_node(node)
|
| 76 |
+
elif all(n in resolvable_from_input_symints for n in node.all_input_nodes):
|
| 77 |
+
sym_dict[val] = node
|
| 78 |
+
resolvable_from_input_symints.add(node)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/fuse_attention.py
ADDED
|
@@ -0,0 +1,786 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import inspect
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from ..._dynamo.utils import counters
|
| 8 |
+
from ..pattern_matcher import (
|
| 9 |
+
filter_nodes,
|
| 10 |
+
fwd_only,
|
| 11 |
+
joint_fwd_bwd,
|
| 12 |
+
register_replacement,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
log = logging.getLogger(__name__)
|
| 16 |
+
aten = torch.ops.aten
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _sfdp_pattern_1(query, key, value, inv_scale):
|
| 20 |
+
return (
|
| 21 |
+
torch.matmul(query, key.transpose(-2, -1))
|
| 22 |
+
.div(inv_scale)
|
| 23 |
+
.softmax(dim=-1)
|
| 24 |
+
.matmul(value)
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _sfdp_replacement_1(query, key, value, inv_scale):
|
| 29 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 30 |
+
return aten.scaled_dot_product_attention(
|
| 31 |
+
query.contiguous(),
|
| 32 |
+
key.contiguous(),
|
| 33 |
+
value.contiguous(),
|
| 34 |
+
attn_mask=None,
|
| 35 |
+
dropout_p=0.0,
|
| 36 |
+
is_causal=False,
|
| 37 |
+
scale=1.0 / inv_scale,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _sfdp_pattern_2(query, key, value, scale_factor):
|
| 42 |
+
return (
|
| 43 |
+
torch.matmul(query, key.transpose(-2, -1))
|
| 44 |
+
.mul(scale_factor)
|
| 45 |
+
.softmax(dim=-1)
|
| 46 |
+
.matmul(value)
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _sfdp_replacement_2(query, key, value, scale_factor):
|
| 51 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 52 |
+
return aten.scaled_dot_product_attention(
|
| 53 |
+
query.contiguous(),
|
| 54 |
+
key.contiguous(),
|
| 55 |
+
value.contiguous(),
|
| 56 |
+
attn_mask=None,
|
| 57 |
+
dropout_p=0.0,
|
| 58 |
+
is_causal=False,
|
| 59 |
+
scale=scale_factor,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _sfdp_pattern_3(query, key, value, inv_scale_factor, dropout_p):
|
| 64 |
+
return torch.nn.functional.dropout(
|
| 65 |
+
torch.matmul(query, key.transpose(-2, -1))
|
| 66 |
+
.div(inv_scale_factor)
|
| 67 |
+
.softmax(dim=-1),
|
| 68 |
+
p=dropout_p,
|
| 69 |
+
).matmul(value)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _sfdp_replacement_3(query, key, value, inv_scale_factor, dropout_p):
|
| 73 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 74 |
+
return aten.scaled_dot_product_attention(
|
| 75 |
+
query.contiguous(),
|
| 76 |
+
key.contiguous(),
|
| 77 |
+
value.contiguous(),
|
| 78 |
+
attn_mask=None,
|
| 79 |
+
dropout_p=dropout_p,
|
| 80 |
+
is_causal=False,
|
| 81 |
+
scale=1.0 / inv_scale_factor,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _sfdp_pattern_4(query, key, value, scale_factor, dropout_p):
|
| 86 |
+
return torch.nn.functional.dropout(
|
| 87 |
+
torch.matmul(query, key.transpose(-2, -1)).mul(scale_factor).softmax(dim=-1),
|
| 88 |
+
p=dropout_p,
|
| 89 |
+
).matmul(value)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _sfdp_replacement_4(query, key, value, scale_factor, dropout_p):
|
| 93 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 94 |
+
return aten.scaled_dot_product_attention(
|
| 95 |
+
query.contiguous(),
|
| 96 |
+
key.contiguous(),
|
| 97 |
+
value.contiguous(),
|
| 98 |
+
attn_mask=None,
|
| 99 |
+
dropout_p=dropout_p,
|
| 100 |
+
is_causal=False,
|
| 101 |
+
scale=scale_factor,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _sfdp_pattern_5(query, key, value, attn_mask):
|
| 106 |
+
attn_weight = torch.softmax(
|
| 107 |
+
(query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1
|
| 108 |
+
)
|
| 109 |
+
# attn_weight = torch.dropout(attn_weight, dropout_p)
|
| 110 |
+
return attn_weight @ value
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _sfdp_replacement_5(query, key, value, attn_mask):
|
| 114 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 115 |
+
return aten.scaled_dot_product_attention(
|
| 116 |
+
query.contiguous(),
|
| 117 |
+
key.contiguous(),
|
| 118 |
+
value.contiguous(),
|
| 119 |
+
attn_mask=attn_mask.to(dtype=query.dtype),
|
| 120 |
+
dropout_p=0.0,
|
| 121 |
+
is_causal=False,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _sfdp_pattern_6(query, key, value, attn_mask, dropout_p):
|
| 126 |
+
attn_weight = torch.softmax(
|
| 127 |
+
(query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1
|
| 128 |
+
)
|
| 129 |
+
attn_weight = torch.dropout(attn_weight, dropout_p, True)
|
| 130 |
+
return attn_weight @ value
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p):
|
| 134 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 135 |
+
return aten.scaled_dot_product_attention(
|
| 136 |
+
query.contiguous(),
|
| 137 |
+
key.contiguous(),
|
| 138 |
+
value.contiguous(),
|
| 139 |
+
attn_mask=attn_mask.to(dtype=query.dtype),
|
| 140 |
+
dropout_p=dropout_p,
|
| 141 |
+
is_causal=False,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _sfdp_pattern_7(query, key, value, dropout_p):
|
| 146 |
+
# in real workloads inputs to matmul are permuted
|
| 147 |
+
# causing matmul to expand to a series of expand and clone calls
|
| 148 |
+
# we want the same to happen during pattern tracing
|
| 149 |
+
q = query.permute(0, 2, 1, 3)
|
| 150 |
+
k = key.permute(0, 2, 1, 3)
|
| 151 |
+
v = value.permute(0, 2, 1, 3)
|
| 152 |
+
div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
|
| 153 |
+
div = div.to(torch.float32)
|
| 154 |
+
attn_weight = torch.softmax(div, dim=-1)
|
| 155 |
+
attn_weight = torch.dropout(attn_weight, dropout_p, True)
|
| 156 |
+
attn_weight = attn_weight.to(torch.float16)
|
| 157 |
+
return attn_weight @ v
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _sfdp_replacement_7(query, key, value, dropout_p):
|
| 161 |
+
# sdpa prefers inputs in permuted format
|
| 162 |
+
# it makes a copy to put them in this format
|
| 163 |
+
# if they aren't already
|
| 164 |
+
# to make replacement efficient ensure that inputs to sdpa
|
| 165 |
+
# are in required order
|
| 166 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 167 |
+
q = query.permute(0, 2, 1, 3)
|
| 168 |
+
k = key.permute(0, 2, 1, 3)
|
| 169 |
+
v = value.permute(0, 2, 1, 3)
|
| 170 |
+
return aten.scaled_dot_product_attention(
|
| 171 |
+
q,
|
| 172 |
+
k,
|
| 173 |
+
v,
|
| 174 |
+
attn_mask=None, # attn_mask,
|
| 175 |
+
dropout_p=dropout_p,
|
| 176 |
+
is_causal=False,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _sfdp_pattern_8(query, key, value):
|
| 181 |
+
# no dropout version of pattern 7
|
| 182 |
+
q = query.permute(0, 2, 1, 3)
|
| 183 |
+
k = key.permute(0, 2, 1, 3)
|
| 184 |
+
v = value.permute(0, 2, 1, 3)
|
| 185 |
+
div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
|
| 186 |
+
div = div.to(torch.float32)
|
| 187 |
+
attn_weight = torch.softmax(div, dim=-1)
|
| 188 |
+
attn_weight = attn_weight.to(torch.float16)
|
| 189 |
+
return attn_weight @ v
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _sfdp_replacement_8(query, key, value):
|
| 193 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 194 |
+
q = query.permute(0, 2, 1, 3)
|
| 195 |
+
k = key.permute(0, 2, 1, 3)
|
| 196 |
+
v = value.permute(0, 2, 1, 3)
|
| 197 |
+
return aten.scaled_dot_product_attention(
|
| 198 |
+
q,
|
| 199 |
+
k,
|
| 200 |
+
v,
|
| 201 |
+
attn_mask=None, # attn_mask,
|
| 202 |
+
dropout_p=0.0,
|
| 203 |
+
is_causal=False,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _sfdp_pattern_9(query, key, value, dropout_p):
|
| 208 |
+
q = query.permute(0, 2, 1, 3)
|
| 209 |
+
k = key.permute(0, 2, 1, 3)
|
| 210 |
+
v = value.permute(0, 2, 1, 3)
|
| 211 |
+
q = q / math.sqrt(q.size(-1))
|
| 212 |
+
div = q @ k.transpose(-2, -1)
|
| 213 |
+
div = div.to(torch.float32)
|
| 214 |
+
attn_weight = torch.softmax(div, dim=-1)
|
| 215 |
+
attn_weight = torch.dropout(attn_weight, dropout_p, True)
|
| 216 |
+
attn_weight = attn_weight.to(torch.float16)
|
| 217 |
+
return attn_weight @ v
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _sfdp_replacement_9(query, key, value, dropout_p):
|
| 221 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 222 |
+
q = query.permute(0, 2, 1, 3)
|
| 223 |
+
k = key.permute(0, 2, 1, 3)
|
| 224 |
+
v = value.permute(0, 2, 1, 3)
|
| 225 |
+
return aten.scaled_dot_product_attention(
|
| 226 |
+
q,
|
| 227 |
+
k,
|
| 228 |
+
v,
|
| 229 |
+
attn_mask=None, # attn_mask,
|
| 230 |
+
dropout_p=dropout_p,
|
| 231 |
+
is_causal=False,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def _sfdp_pattern_10(query, key, value):
|
| 236 |
+
# no dropout version of 9
|
| 237 |
+
q = query.permute(0, 2, 1, 3)
|
| 238 |
+
k = key.permute(0, 2, 1, 3)
|
| 239 |
+
v = value.permute(0, 2, 1, 3)
|
| 240 |
+
q = q / math.sqrt(q.size(-1))
|
| 241 |
+
div = q @ k.transpose(-2, -1)
|
| 242 |
+
div = div.to(torch.float32)
|
| 243 |
+
attn_weight = torch.softmax(div, dim=-1)
|
| 244 |
+
attn_weight = attn_weight.to(torch.float16)
|
| 245 |
+
return attn_weight @ v
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _sfdp_replacement_10(query, key, value):
|
| 249 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 250 |
+
q = query.permute(0, 2, 1, 3)
|
| 251 |
+
k = key.permute(0, 2, 1, 3)
|
| 252 |
+
v = value.permute(0, 2, 1, 3)
|
| 253 |
+
return aten.scaled_dot_product_attention(
|
| 254 |
+
q,
|
| 255 |
+
k,
|
| 256 |
+
v,
|
| 257 |
+
attn_mask=None, # attn_mask,
|
| 258 |
+
dropout_p=0.0,
|
| 259 |
+
is_causal=False,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def _sfdp_pattern_11(query, key, value, inv_scale):
|
| 264 |
+
# Mainly for huggingface models
|
| 265 |
+
q = query.permute(0, 2, 1, 3)
|
| 266 |
+
k = key.permute(0, 2, 1, 3)
|
| 267 |
+
v = value.permute(0, 2, 1, 3)
|
| 268 |
+
return torch.matmul(q, k.transpose(-2, -1)).div(inv_scale).softmax(dim=-1).matmul(v)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def _sfdp_replacement_11(query, key, value, inv_scale):
|
| 272 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 273 |
+
return aten.scaled_dot_product_attention(
|
| 274 |
+
query.transpose(1, 2),
|
| 275 |
+
key.transpose(1, 2),
|
| 276 |
+
value.transpose(1, 2),
|
| 277 |
+
attn_mask=None,
|
| 278 |
+
dropout_p=0.0,
|
| 279 |
+
is_causal=False,
|
| 280 |
+
scale=1.0 / inv_scale,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def _sfdp_pattern_12(query, key, value, inv_scale_factor, dropout_p):
|
| 285 |
+
q = query.permute(0, 2, 1, 3)
|
| 286 |
+
k = key.permute(0, 2, 1, 3)
|
| 287 |
+
v = value.permute(0, 2, 1, 3)
|
| 288 |
+
return torch.nn.functional.dropout(
|
| 289 |
+
torch.matmul(q, k.transpose(-2, -1)).div(inv_scale_factor).softmax(dim=-1),
|
| 290 |
+
p=dropout_p,
|
| 291 |
+
).matmul(v)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def _sfdp_replacement_12(query, key, value, inv_scale_factor, dropout_p):
|
| 295 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 296 |
+
return aten.scaled_dot_product_attention(
|
| 297 |
+
query.transpose(1, 2),
|
| 298 |
+
key.transpose(1, 2),
|
| 299 |
+
value.transpose(1, 2),
|
| 300 |
+
attn_mask=None,
|
| 301 |
+
dropout_p=dropout_p,
|
| 302 |
+
is_causal=False,
|
| 303 |
+
scale=1.0 / inv_scale_factor,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def _sfdp_pattern_13(query, key, value, dropout_p):
|
| 308 |
+
attn_weight = torch.bmm(query, key.transpose(1, 2)).softmax(dim=-1)
|
| 309 |
+
attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p)
|
| 310 |
+
return torch.bmm(attn_weight, value)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def _sfdp_replacement_13(query, key, value, dropout_p):
|
| 314 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 315 |
+
return aten.scaled_dot_product_attention(
|
| 316 |
+
query.unsqueeze(0),
|
| 317 |
+
key.unsqueeze(0),
|
| 318 |
+
value.unsqueeze(0),
|
| 319 |
+
dropout_p=dropout_p,
|
| 320 |
+
scale=1.0,
|
| 321 |
+
).squeeze(0)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def _sfdp_pattern_14(query, key, value, attn_mask, inv_scale):
|
| 325 |
+
# for BertLarge
|
| 326 |
+
# Permutations are needed to create clones in graph.
|
| 327 |
+
q = query.permute([0, 2, 1, 3])
|
| 328 |
+
k = key.permute([0, 2, 1, 3])
|
| 329 |
+
v = value.permute([0, 2, 1, 3])
|
| 330 |
+
return (
|
| 331 |
+
(torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask)
|
| 332 |
+
.softmax(dim=-1)
|
| 333 |
+
.matmul(v)
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def _sfdp_replacement_14(query, key, value, attn_mask, inv_scale):
|
| 338 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 339 |
+
return aten.scaled_dot_product_attention(
|
| 340 |
+
query.transpose(1, 2),
|
| 341 |
+
key.transpose(1, 2),
|
| 342 |
+
value.transpose(1, 2),
|
| 343 |
+
attn_mask=attn_mask.to(dtype=query.dtype),
|
| 344 |
+
dropout_p=0.0,
|
| 345 |
+
is_causal=False,
|
| 346 |
+
scale=1.0 / inv_scale,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def _sfdp_pattern_15(query, key, value, attn_mask, inv_scale):
|
| 351 |
+
# for DistilBert
|
| 352 |
+
# Permutations are needed to create clones in graph.
|
| 353 |
+
q = query.permute([0, 2, 1, 3])
|
| 354 |
+
k = key.permute([0, 2, 1, 3])
|
| 355 |
+
v = value.permute([0, 2, 1, 3])
|
| 356 |
+
bs = q.size(0)
|
| 357 |
+
k_len = k.size(-2)
|
| 358 |
+
scores = q @ k.transpose(-2, -1)
|
| 359 |
+
scores = scores.div(inv_scale)
|
| 360 |
+
fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
|
| 361 |
+
attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores)
|
| 362 |
+
return torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1) @ v
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def _sfdp_replacement_15(query, key, value, attn_mask, inv_scale):
|
| 366 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 367 |
+
bs = query.size(0)
|
| 368 |
+
n_head = query.size(2)
|
| 369 |
+
q_len = query.size(1)
|
| 370 |
+
k_len = key.size(1)
|
| 371 |
+
# do attn_mask->logical_not() in aten.scaled_dot_product_attention
|
| 372 |
+
attn_mask = (
|
| 373 |
+
(attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len))
|
| 374 |
+
)
|
| 375 |
+
return aten.scaled_dot_product_attention(
|
| 376 |
+
query.transpose(1, 2),
|
| 377 |
+
key.transpose(1, 2),
|
| 378 |
+
value.transpose(1, 2),
|
| 379 |
+
attn_mask=attn_mask.to(dtype=torch.bool),
|
| 380 |
+
dropout_p=0.0,
|
| 381 |
+
is_causal=False,
|
| 382 |
+
scale=1.0 / inv_scale,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def _sfdp_pattern_16(query, key, value, attn_mask, inv_scale, dropout_p):
|
| 387 |
+
# for BertLarge with dropout
|
| 388 |
+
q = query.permute([0, 2, 1, 3])
|
| 389 |
+
k = key.permute([0, 2, 1, 3])
|
| 390 |
+
v = value.permute([0, 2, 1, 3])
|
| 391 |
+
return (
|
| 392 |
+
torch.nn.functional.dropout(
|
| 393 |
+
(torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask).softmax(
|
| 394 |
+
dim=-1
|
| 395 |
+
),
|
| 396 |
+
dropout_p,
|
| 397 |
+
)
|
| 398 |
+
.to(dtype=query.dtype)
|
| 399 |
+
.matmul(v)
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def _sfdp_replacement_16(query, key, value, attn_mask, inv_scale, dropout_p):
|
| 404 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 405 |
+
return aten.scaled_dot_product_attention(
|
| 406 |
+
query.transpose(1, 2),
|
| 407 |
+
key.transpose(1, 2),
|
| 408 |
+
value.transpose(1, 2),
|
| 409 |
+
attn_mask=attn_mask.to(dtype=query.dtype),
|
| 410 |
+
dropout_p=dropout_p,
|
| 411 |
+
is_causal=False,
|
| 412 |
+
scale=1.0 / inv_scale,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def _sfdp_pattern_17(query, key, value, attn_mask, inv_scale, dropout_p):
|
| 417 |
+
# for DistilBert with dropout
|
| 418 |
+
q = query.permute([0, 2, 1, 3])
|
| 419 |
+
k = key.permute([0, 2, 1, 3])
|
| 420 |
+
v = value.permute([0, 2, 1, 3])
|
| 421 |
+
bs = q.size(0)
|
| 422 |
+
k_len = k.size(-2)
|
| 423 |
+
scores = q @ k.transpose(-2, -1)
|
| 424 |
+
scores = scores.div(inv_scale)
|
| 425 |
+
fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
|
| 426 |
+
attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores)
|
| 427 |
+
return (
|
| 428 |
+
torch.nn.functional.dropout(
|
| 429 |
+
torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1), dropout_p
|
| 430 |
+
)
|
| 431 |
+
@ v
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def _sfdp_replacement_17(query, key, value, attn_mask, inv_scale, dropout_p):
|
| 436 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 437 |
+
bs = query.size(0)
|
| 438 |
+
n_head = query.size(2)
|
| 439 |
+
q_len = query.size(1)
|
| 440 |
+
k_len = key.size(1)
|
| 441 |
+
# do attn_mask->logical_not() in aten.scaled_dot_product_attention
|
| 442 |
+
attn_mask = (
|
| 443 |
+
(attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len))
|
| 444 |
+
)
|
| 445 |
+
return aten.scaled_dot_product_attention(
|
| 446 |
+
query.transpose(1, 2),
|
| 447 |
+
key.transpose(1, 2),
|
| 448 |
+
value.transpose(1, 2),
|
| 449 |
+
attn_mask=attn_mask.to(dtype=torch.bool),
|
| 450 |
+
dropout_p=dropout_p,
|
| 451 |
+
is_causal=False,
|
| 452 |
+
scale=1.0 / inv_scale,
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def _sfdp_params_check(match):
|
| 457 |
+
assert all(k in match.kwargs for k in ("query", "key", "value"))
|
| 458 |
+
query = match.kwargs["query"].meta["val"]
|
| 459 |
+
key = match.kwargs["key"].meta["val"]
|
| 460 |
+
value = match.kwargs["value"].meta["val"]
|
| 461 |
+
if not (query.dtype == key.dtype == value.dtype) or not (
|
| 462 |
+
query.device == key.device == value.device
|
| 463 |
+
):
|
| 464 |
+
return False
|
| 465 |
+
add_mask_node = filter_nodes(match.nodes, aten.add.Tensor)
|
| 466 |
+
# Has attn_mask add.
|
| 467 |
+
if len(add_mask_node) > 0:
|
| 468 |
+
attn_mask_node = add_mask_node[0].args[1]
|
| 469 |
+
# attn_mask_node may be a float/int number.
|
| 470 |
+
if not hasattr(attn_mask_node, "meta"):
|
| 471 |
+
return False
|
| 472 |
+
attn_mask = attn_mask_node.meta["val"] # type: ignore[union-attr]
|
| 473 |
+
# Make sure attn_mask.dtype == query.dtype or attn_mask.dtype == torch.bool
|
| 474 |
+
# attn_mask.dtype == torch.float for models like albert.
|
| 475 |
+
if (
|
| 476 |
+
not isinstance(attn_mask, torch.Tensor)
|
| 477 |
+
or not (
|
| 478 |
+
attn_mask.dtype == query.dtype
|
| 479 |
+
or attn_mask.dtype == torch.bool
|
| 480 |
+
or attn_mask.dtype == torch.float
|
| 481 |
+
)
|
| 482 |
+
or query.device != attn_mask.device
|
| 483 |
+
):
|
| 484 |
+
return False
|
| 485 |
+
return True
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def _sfdp_extra_check(scale_factor_op, disable_cuda=False):
|
| 489 |
+
def fn(match):
|
| 490 |
+
scale_factor_node = filter_nodes(match.nodes, scale_factor_op)[0]
|
| 491 |
+
# Note: args[1] of the scale_factor_node is always the scale_factor for the current patterns.
|
| 492 |
+
scale_factor = scale_factor_node.args[1]
|
| 493 |
+
# make sure the scale_factor a float/int. SymInt?
|
| 494 |
+
if not isinstance(scale_factor, (float, int)):
|
| 495 |
+
return False
|
| 496 |
+
if (
|
| 497 |
+
disable_cuda
|
| 498 |
+
and "query" in match.kwargs
|
| 499 |
+
and "cuda" in str(match.kwargs["query"].meta["val"].device)
|
| 500 |
+
):
|
| 501 |
+
return False
|
| 502 |
+
return _sfdp_params_check(match)
|
| 503 |
+
|
| 504 |
+
return fn
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def partialize_and_update_signature(func, **kwargs):
|
| 508 |
+
"""
|
| 509 |
+
Equivalent to functools.partial but also updates the signature on returned function
|
| 510 |
+
"""
|
| 511 |
+
original_sig = inspect.signature(func)
|
| 512 |
+
parameters = original_sig.parameters
|
| 513 |
+
|
| 514 |
+
new_parameters = {
|
| 515 |
+
key: value for key, value in parameters.items() if key not in kwargs
|
| 516 |
+
}
|
| 517 |
+
new_sig = inspect.Signature(parameters=list(new_parameters.values()))
|
| 518 |
+
|
| 519 |
+
partial_func = functools.partial(func, **kwargs)
|
| 520 |
+
|
| 521 |
+
def wrapper(*args, **kwargs):
|
| 522 |
+
return partial_func(*args, **kwargs)
|
| 523 |
+
|
| 524 |
+
wrapper.__signature__ = new_sig # type: ignore[attr-defined]
|
| 525 |
+
wrapper.__name__ = func.__name__
|
| 526 |
+
|
| 527 |
+
return wrapper
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def _get_sfdp_patterns():
|
| 531 |
+
from .joint_graph import patterns
|
| 532 |
+
|
| 533 |
+
if torch.cuda.is_available():
|
| 534 |
+
# workaround https://github.com/pytorch/pytorch/issues/97894
|
| 535 |
+
device = "cuda"
|
| 536 |
+
else:
|
| 537 |
+
device = "cpu"
|
| 538 |
+
|
| 539 |
+
# sizes/values don't actually matter for initial trace
|
| 540 |
+
# once we get a possible match we re-trace with the actual values and verify the match still holds
|
| 541 |
+
g_inp = functools.partial(
|
| 542 |
+
torch.empty, (2, 4, 8, 16), device=device, requires_grad=True
|
| 543 |
+
)
|
| 544 |
+
# attn_mask
|
| 545 |
+
b_inp = functools.partial(torch.empty, (1, 1, 8, 8), device=device)
|
| 546 |
+
m_inp = functools.partial(torch.empty, (2, 1, 1, 4), device=device)
|
| 547 |
+
# inv_scale
|
| 548 |
+
c_inp = functools.partial(torch.tensor, 2.0, device=device)
|
| 549 |
+
# workaround https://github.com/pytorch/pytorch/issues/97894
|
| 550 |
+
# 0.113377 is a "magic" value that lets us recover the lost input arg relationship
|
| 551 |
+
d = {"dropout_p": 0.113377}
|
| 552 |
+
|
| 553 |
+
# we could also generate all these patterns in 3d.. TODO
|
| 554 |
+
g_3d_inp = functools.partial(
|
| 555 |
+
torch.empty, (1024, 128, 128), device=device, requires_grad=True
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
# reshape in matmul decomposition generates a clone when batch_size>1 due to the memory layout change.
|
| 559 |
+
# however when batch_size=1, reshape does not change the memory layout, so clone would not be generated.
|
| 560 |
+
# here we need to trace with input of batch_size=1 to generate a pattern graph without clone.
|
| 561 |
+
g_bs1_inp = functools.partial(
|
| 562 |
+
torch.empty, (1, 4, 8, 16), device=device, requires_grad=True
|
| 563 |
+
)
|
| 564 |
+
m_bs1_inp = functools.partial(torch.empty, (1, 1, 1, 4), device=device)
|
| 565 |
+
|
| 566 |
+
# softmax will generate a dtype conversion on inputs if they are in half,
|
| 567 |
+
# but will not in float, so we generate a pattern for both
|
| 568 |
+
for dtype in [torch.float, torch.half]:
|
| 569 |
+
g = functools.partial(g_inp, dtype=dtype)
|
| 570 |
+
b = functools.partial(b_inp, dtype=dtype)
|
| 571 |
+
m = functools.partial(m_inp, dtype=dtype)
|
| 572 |
+
m_float = functools.partial(m_inp, dtype=torch.float)
|
| 573 |
+
c = functools.partial(c_inp, dtype=dtype)
|
| 574 |
+
g_3d = functools.partial(g_3d_inp, dtype=dtype)
|
| 575 |
+
g_bs1 = functools.partial(g_bs1_inp, dtype=dtype)
|
| 576 |
+
m_bs1 = functools.partial(m_bs1_inp, dtype=dtype)
|
| 577 |
+
m_bs1_float = functools.partial(m_bs1_inp, dtype=torch.float)
|
| 578 |
+
|
| 579 |
+
candidates = [
|
| 580 |
+
(
|
| 581 |
+
_sfdp_pattern_1,
|
| 582 |
+
_sfdp_replacement_1,
|
| 583 |
+
[g(), g(), g(), c()],
|
| 584 |
+
{},
|
| 585 |
+
_sfdp_extra_check(aten.div.Tensor),
|
| 586 |
+
),
|
| 587 |
+
(
|
| 588 |
+
_sfdp_pattern_2,
|
| 589 |
+
_sfdp_replacement_2,
|
| 590 |
+
[g(), g(), g(), c()],
|
| 591 |
+
{},
|
| 592 |
+
_sfdp_extra_check(aten.mul.Tensor),
|
| 593 |
+
),
|
| 594 |
+
(
|
| 595 |
+
_sfdp_pattern_3,
|
| 596 |
+
_sfdp_replacement_3,
|
| 597 |
+
[g(), g(), g(), c()],
|
| 598 |
+
d,
|
| 599 |
+
_sfdp_extra_check(aten.div.Tensor),
|
| 600 |
+
),
|
| 601 |
+
(
|
| 602 |
+
_sfdp_pattern_4,
|
| 603 |
+
_sfdp_replacement_4,
|
| 604 |
+
[g(), g(), g(), c()],
|
| 605 |
+
d,
|
| 606 |
+
_sfdp_extra_check(aten.mul.Tensor),
|
| 607 |
+
),
|
| 608 |
+
(
|
| 609 |
+
_sfdp_pattern_5,
|
| 610 |
+
_sfdp_replacement_5,
|
| 611 |
+
[g(), g(), g(), b()],
|
| 612 |
+
{},
|
| 613 |
+
_sfdp_params_check,
|
| 614 |
+
),
|
| 615 |
+
(
|
| 616 |
+
_sfdp_pattern_6,
|
| 617 |
+
_sfdp_replacement_6,
|
| 618 |
+
[g(), g(), g(), b()],
|
| 619 |
+
d,
|
| 620 |
+
_sfdp_params_check,
|
| 621 |
+
),
|
| 622 |
+
(
|
| 623 |
+
_sfdp_pattern_7,
|
| 624 |
+
_sfdp_replacement_7,
|
| 625 |
+
[g(), g(), g()],
|
| 626 |
+
d,
|
| 627 |
+
_sfdp_params_check,
|
| 628 |
+
),
|
| 629 |
+
(
|
| 630 |
+
_sfdp_pattern_8,
|
| 631 |
+
_sfdp_replacement_8,
|
| 632 |
+
[g(), g(), g()],
|
| 633 |
+
{},
|
| 634 |
+
_sfdp_params_check,
|
| 635 |
+
),
|
| 636 |
+
(
|
| 637 |
+
_sfdp_pattern_9,
|
| 638 |
+
_sfdp_replacement_9,
|
| 639 |
+
[g(), g(), g()],
|
| 640 |
+
d,
|
| 641 |
+
_sfdp_params_check,
|
| 642 |
+
),
|
| 643 |
+
(
|
| 644 |
+
_sfdp_pattern_10,
|
| 645 |
+
_sfdp_replacement_10,
|
| 646 |
+
[g(), g(), g()],
|
| 647 |
+
{},
|
| 648 |
+
_sfdp_params_check,
|
| 649 |
+
),
|
| 650 |
+
(
|
| 651 |
+
_sfdp_pattern_11,
|
| 652 |
+
_sfdp_replacement_11,
|
| 653 |
+
[g(), g(), g(), c()],
|
| 654 |
+
{},
|
| 655 |
+
_sfdp_extra_check(aten.div.Tensor),
|
| 656 |
+
),
|
| 657 |
+
(
|
| 658 |
+
_sfdp_pattern_12,
|
| 659 |
+
_sfdp_replacement_12,
|
| 660 |
+
[g(), g(), g(), c()],
|
| 661 |
+
d,
|
| 662 |
+
_sfdp_extra_check(aten.div.Tensor),
|
| 663 |
+
),
|
| 664 |
+
(
|
| 665 |
+
_sfdp_pattern_13,
|
| 666 |
+
_sfdp_replacement_13,
|
| 667 |
+
[g_3d(), g_3d(), g_3d()],
|
| 668 |
+
d,
|
| 669 |
+
_sfdp_params_check,
|
| 670 |
+
),
|
| 671 |
+
(
|
| 672 |
+
_sfdp_pattern_14,
|
| 673 |
+
_sfdp_replacement_14,
|
| 674 |
+
[g(), g(), g(), m(), c()],
|
| 675 |
+
{},
|
| 676 |
+
_sfdp_extra_check(aten.div.Tensor),
|
| 677 |
+
),
|
| 678 |
+
(
|
| 679 |
+
_sfdp_pattern_15,
|
| 680 |
+
_sfdp_replacement_15,
|
| 681 |
+
[g(), g(), g(), m(), c()],
|
| 682 |
+
{},
|
| 683 |
+
_sfdp_extra_check(aten.div.Tensor),
|
| 684 |
+
),
|
| 685 |
+
# TODO: Enable CUDA after solving Bert accuracy issue of calling efficient attention
|
| 686 |
+
(
|
| 687 |
+
_sfdp_pattern_16,
|
| 688 |
+
_sfdp_replacement_16,
|
| 689 |
+
[g(), g(), g(), m(), c()],
|
| 690 |
+
d,
|
| 691 |
+
_sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
|
| 692 |
+
),
|
| 693 |
+
(
|
| 694 |
+
_sfdp_pattern_16,
|
| 695 |
+
_sfdp_replacement_16,
|
| 696 |
+
[g_bs1(), g_bs1(), g_bs1(), m_bs1(), c()],
|
| 697 |
+
d,
|
| 698 |
+
_sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
|
| 699 |
+
),
|
| 700 |
+
(
|
| 701 |
+
_sfdp_pattern_17,
|
| 702 |
+
_sfdp_replacement_17,
|
| 703 |
+
[g(), g(), g(), m(), c()],
|
| 704 |
+
d,
|
| 705 |
+
_sfdp_extra_check(aten.div.Tensor),
|
| 706 |
+
),
|
| 707 |
+
]
|
| 708 |
+
mask_fp32_patterns = ["pattern_16"]
|
| 709 |
+
if dtype == torch.half:
|
| 710 |
+
# Add inputs of bf16 q/k/v and fp32 mask, for models like albert.
|
| 711 |
+
candidates.append(
|
| 712 |
+
(
|
| 713 |
+
_sfdp_pattern_16,
|
| 714 |
+
_sfdp_replacement_16,
|
| 715 |
+
[g(), g(), g(), m_float(), c()],
|
| 716 |
+
d,
|
| 717 |
+
_sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
|
| 718 |
+
)
|
| 719 |
+
)
|
| 720 |
+
candidates.append(
|
| 721 |
+
(
|
| 722 |
+
_sfdp_pattern_16,
|
| 723 |
+
_sfdp_replacement_16,
|
| 724 |
+
[g_bs1(), g_bs1(), g_bs1(), m_bs1_float(), c()],
|
| 725 |
+
d,
|
| 726 |
+
_sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
|
| 727 |
+
)
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
for pattern, replacement, args, workaround, extra_check in candidates:
|
| 731 |
+
# XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern
|
| 732 |
+
# gets serialized to a python file and does not require tracing at runtime.
|
| 733 |
+
assert isinstance(workaround, dict)
|
| 734 |
+
name = pattern.__name__
|
| 735 |
+
|
| 736 |
+
if dtype != torch.float:
|
| 737 |
+
name += "_half"
|
| 738 |
+
if (
|
| 739 |
+
any(p in name for p in mask_fp32_patterns)
|
| 740 |
+
and args[3].dtype == torch.float32
|
| 741 |
+
):
|
| 742 |
+
name += "_mask_fp32"
|
| 743 |
+
if args[0].size(0) == 1:
|
| 744 |
+
name += "_bs1"
|
| 745 |
+
|
| 746 |
+
training_name = name + "_training"
|
| 747 |
+
yield training_name, {
|
| 748 |
+
"search_fn": pattern,
|
| 749 |
+
"replace_fn": replacement,
|
| 750 |
+
"example_inputs": args,
|
| 751 |
+
"trace_fn": joint_fwd_bwd,
|
| 752 |
+
"pass_dicts": patterns,
|
| 753 |
+
"extra_check": extra_check,
|
| 754 |
+
"scalar_workaround": workaround,
|
| 755 |
+
}
|
| 756 |
+
|
| 757 |
+
if workaround:
|
| 758 |
+
assert len(workaround) == 1 and "dropout_p" in workaround
|
| 759 |
+
# functools.partial insufficient because we look at signature downstream
|
| 760 |
+
pattern = partialize_and_update_signature(pattern, dropout_p=0.0)
|
| 761 |
+
replacement = partialize_and_update_signature(
|
| 762 |
+
replacement, dropout_p=0.0
|
| 763 |
+
)
|
| 764 |
+
workaround = {}
|
| 765 |
+
|
| 766 |
+
inference_name = name + "_inference"
|
| 767 |
+
yield inference_name, {
|
| 768 |
+
"search_fn": pattern,
|
| 769 |
+
"replace_fn": replacement,
|
| 770 |
+
"example_inputs": args,
|
| 771 |
+
"trace_fn": fwd_only,
|
| 772 |
+
"pass_dicts": patterns,
|
| 773 |
+
"extra_check": extra_check,
|
| 774 |
+
"scalar_workaround": workaround,
|
| 775 |
+
}
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
@functools.lru_cache(None)
|
| 779 |
+
def _sfdp_init():
|
| 780 |
+
from .serialized_patterns.central_index import get_serialized_pattern
|
| 781 |
+
|
| 782 |
+
for key, register_replacement_kwargs in _get_sfdp_patterns():
|
| 783 |
+
search_fn_pattern = get_serialized_pattern(key)
|
| 784 |
+
register_replacement(
|
| 785 |
+
**register_replacement_kwargs, search_fn_pattern=search_fn_pattern
|
| 786 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py
ADDED
|
@@ -0,0 +1,1059 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import logging
|
| 3 |
+
import operator
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
from typing import (
|
| 6 |
+
Any,
|
| 7 |
+
DefaultDict,
|
| 8 |
+
Deque,
|
| 9 |
+
Dict,
|
| 10 |
+
Iterable,
|
| 11 |
+
Iterator,
|
| 12 |
+
List,
|
| 13 |
+
Optional,
|
| 14 |
+
Set,
|
| 15 |
+
Tuple,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from torch._dynamo.utils import counters
|
| 20 |
+
|
| 21 |
+
from .. import config
|
| 22 |
+
from ..pattern_matcher import (
|
| 23 |
+
CallFunctionVarArgs,
|
| 24 |
+
get_arg_value,
|
| 25 |
+
stable_topological_sort,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
# importing this will register fbgemm lowerings for inductor
|
| 30 |
+
import deeplearning.fbgemm.fbgemm_gpu.fb.inductor_lowerings # noqa: F401
|
| 31 |
+
|
| 32 |
+
has_fbgemm = True
|
| 33 |
+
except Exception:
|
| 34 |
+
has_fbgemm = False
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
aten = torch.ops.aten
|
| 38 |
+
|
| 39 |
+
log = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
MIN_FUSE_SET_SIZE = 5
|
| 42 |
+
MAX_FUSE_SET_SIZE = 300
|
| 43 |
+
MAX_FUSE_SEARCH_DEPTH = 5
|
| 44 |
+
# The maximum tensor size that can go into the fusion group
|
| 45 |
+
MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR = 4096
|
| 46 |
+
|
| 47 |
+
# exclude these nodes from BFS
|
| 48 |
+
# excluding get item improves optimizer compilation time by 60s
|
| 49 |
+
SEARCH_EXCLUSIONS = {operator.getitem}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
default_graph_search_options = {
|
| 53 |
+
"min_fuse_set_size": MIN_FUSE_SET_SIZE,
|
| 54 |
+
"max_fuse_set_size": MAX_FUSE_SET_SIZE,
|
| 55 |
+
"max_fuse_search_depth": MAX_FUSE_SEARCH_DEPTH,
|
| 56 |
+
"max_fuse_tensor_size_group_linear": MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
graph_search_options = default_graph_search_options
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def update_stack_example_value(node, metadata, dim=0, op=torch.stack):
|
| 63 |
+
"""
|
| 64 |
+
Update the example value of the node in the graph to enable followup split cat opt.
|
| 65 |
+
"""
|
| 66 |
+
if node is not None and hasattr(node, "meta"):
|
| 67 |
+
if op == torch.stack:
|
| 68 |
+
example_value = torch.stack(metadata, dim=dim)
|
| 69 |
+
elif op == torch.unbind:
|
| 70 |
+
example_value = torch.unbind(metadata, dim=dim) # type: ignore[assignment]
|
| 71 |
+
else:
|
| 72 |
+
return
|
| 73 |
+
node.meta["example_value"] = example_value
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def update_pointwise_example_value(pointwise_node, input, other, op):
|
| 77 |
+
"""
|
| 78 |
+
Update the example value of the add node in the graph to enable followup split cat opt.
|
| 79 |
+
"""
|
| 80 |
+
if pointwise_node is not None and hasattr(pointwise_node, "meta"):
|
| 81 |
+
if op == torch.add:
|
| 82 |
+
example_value = torch.add(input, other)
|
| 83 |
+
elif op == torch.mul:
|
| 84 |
+
example_value = torch.mul(input, other)
|
| 85 |
+
else:
|
| 86 |
+
return
|
| 87 |
+
pointwise_node.meta["example_value"] = example_value
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class GroupBatchFusionBase:
|
| 91 |
+
def __init__(self, **kwargs):
|
| 92 |
+
self.graph_search_options = kwargs.pop(
|
| 93 |
+
"graph_search_options", default_graph_search_options
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def match(self, node):
|
| 97 |
+
raise NotImplementedError("match called on base")
|
| 98 |
+
|
| 99 |
+
def fuse(self, graph, subset):
|
| 100 |
+
raise NotImplementedError("fuse called on base")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
PRE_GRAD_FUSIONS: Dict[str, GroupBatchFusionBase] = dict()
|
| 104 |
+
POST_GRAD_FUSIONS: Dict[str, GroupBatchFusionBase] = dict()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def register_fusion(name: str, pre_grad=True):
|
| 108 |
+
def decorator(fusion_cls: GroupBatchFusionBase):
|
| 109 |
+
if pre_grad:
|
| 110 |
+
PRE_GRAD_FUSIONS[name] = fusion_cls
|
| 111 |
+
else:
|
| 112 |
+
POST_GRAD_FUSIONS[name] = fusion_cls
|
| 113 |
+
return fusion_cls
|
| 114 |
+
|
| 115 |
+
return decorator
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def list_group_batch_fusions(pre_grad=True) -> List[str]:
|
| 119 |
+
if pre_grad:
|
| 120 |
+
return list(PRE_GRAD_FUSIONS.keys())
|
| 121 |
+
else:
|
| 122 |
+
return list(POST_GRAD_FUSIONS.keys())
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def decompose_stack(graph: torch.fx.GraphModule, input_tensors: List[Any]) -> Any:
|
| 126 |
+
unsqueezed_inputs = []
|
| 127 |
+
for input_tensor in input_tensors:
|
| 128 |
+
unsqueezed_input = graph.call_function(
|
| 129 |
+
aten.unsqueeze, args=(input_tensor,), kwargs={"dim": 0}
|
| 130 |
+
)
|
| 131 |
+
unsqueezed_inputs.append(unsqueezed_input)
|
| 132 |
+
stacked_inputs = graph.call_function(
|
| 133 |
+
aten.cat, args=(unsqueezed_inputs,), kwargs={"dim": 0}
|
| 134 |
+
)
|
| 135 |
+
return stacked_inputs
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class GroupFusion(GroupBatchFusionBase):
|
| 139 |
+
"""
|
| 140 |
+
Fuse ops in a group way, e.g, fuse mm/addmm of arbitrary input shapes with fbgemm.gmm.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
pass
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class BatchFusion(GroupBatchFusionBase):
|
| 147 |
+
"""
|
| 148 |
+
Fuse ops in a batch way, e.g, fuse mm/addmm of same input shapes with bmm.
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
pass
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class BatchPointwiseOpsFusionFactory(BatchFusion):
|
| 155 |
+
def __init__(self, op, **kwargs):
|
| 156 |
+
super().__init__(**kwargs)
|
| 157 |
+
self.op = op
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@register_fusion("batch_linear_post_grad", pre_grad=False)
|
| 161 |
+
class PostGradBatchLinearFusion(BatchFusion):
|
| 162 |
+
"""
|
| 163 |
+
Fuse ops in a batch way in post grad (aten level).
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def _addmm_node_can_be_fused(self, node: torch.fx.Node) -> bool:
|
| 167 |
+
return (
|
| 168 |
+
node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0 # type: ignore[return-value]
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
def _is_input_2d(self, input: torch.fx.Node) -> bool:
|
| 172 |
+
input_shapes = input.meta["tensor_meta"].shape
|
| 173 |
+
return (
|
| 174 |
+
len(input_shapes) == 2
|
| 175 |
+
and isinstance(input_shapes[0], int)
|
| 176 |
+
and isinstance(input_shapes[1], int)
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
def match(self, node: torch.fx.Node) -> Optional[Tuple[str, int, int, int, bool]]:
|
| 180 |
+
if CallFunctionVarArgs(aten.mm).match(node):
|
| 181 |
+
input_m, weight_m = node.args
|
| 182 |
+
bias_m = None
|
| 183 |
+
|
| 184 |
+
elif CallFunctionVarArgs(aten.addmm.default).match(
|
| 185 |
+
node
|
| 186 |
+
) and self._addmm_node_can_be_fused(node):
|
| 187 |
+
bias_m, input_m, weight_m = node.args
|
| 188 |
+
else:
|
| 189 |
+
return None
|
| 190 |
+
|
| 191 |
+
# only handle the cases where inputs are 2D tensors
|
| 192 |
+
if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m): # type: ignore[arg-type]
|
| 193 |
+
return None
|
| 194 |
+
m, k = input_m.meta["tensor_meta"].shape # type: ignore[union-attr]
|
| 195 |
+
n = weight_m.meta["tensor_meta"].shape[1] # type: ignore[union-attr]
|
| 196 |
+
batch_key = ("batch_linear", m, k, n, bias_m is not None)
|
| 197 |
+
return batch_key
|
| 198 |
+
|
| 199 |
+
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
| 200 |
+
batch_inputs = []
|
| 201 |
+
batch_weights = []
|
| 202 |
+
batch_biases = []
|
| 203 |
+
batch_nodes = []
|
| 204 |
+
|
| 205 |
+
for node in subset:
|
| 206 |
+
if CallFunctionVarArgs(aten.addmm.default).match(node):
|
| 207 |
+
bias, input, weight = node.args
|
| 208 |
+
elif CallFunctionVarArgs(aten.mm.default).match(node):
|
| 209 |
+
input, weight = node.args
|
| 210 |
+
bias = None
|
| 211 |
+
batch_nodes.append(node)
|
| 212 |
+
batch_inputs.append(input) # type: ignore[possibly-undefined]
|
| 213 |
+
batch_weights.append(weight) # type: ignore[possibly-undefined]
|
| 214 |
+
batch_biases.append(bias) # type: ignore[possibly-undefined]
|
| 215 |
+
|
| 216 |
+
with graph.inserting_before(subset[-1]):
|
| 217 |
+
fused_inputs = decompose_stack(graph, batch_inputs)
|
| 218 |
+
fused_weights = decompose_stack(graph, batch_weights)
|
| 219 |
+
fused_bmm = graph.call_function(
|
| 220 |
+
aten.bmm,
|
| 221 |
+
args=(fused_inputs, fused_weights),
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
for i, original_mm in enumerate(batch_nodes):
|
| 225 |
+
has_bias = False
|
| 226 |
+
with graph.inserting_after(fused_bmm):
|
| 227 |
+
new_mm = graph.call_function(aten.select, args=((fused_bmm, 0, i)))
|
| 228 |
+
if batch_biases[i]:
|
| 229 |
+
has_bias = True
|
| 230 |
+
new_bias_add = graph.call_function(
|
| 231 |
+
aten.add, args=((batch_biases[i], new_mm))
|
| 232 |
+
)
|
| 233 |
+
new_mm_cont = new_bias_add if has_bias else new_mm # type: ignore[possibly-undefined]
|
| 234 |
+
original_mm.replace_all_uses_with(new_mm_cont)
|
| 235 |
+
new_mm_cont.meta.update(original_mm.meta)
|
| 236 |
+
graph.erase_node(original_mm)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@register_fusion("group_linear", pre_grad=False)
|
| 240 |
+
class GroupLinearFusion(GroupFusion):
|
| 241 |
+
def _addmm_node_can_be_fused(self, node: torch.fx.Node):
|
| 242 |
+
input_shape = node.args[1].meta["tensor_meta"].shape # type: ignore[union-attr]
|
| 243 |
+
weight_shape = node.args[2].meta["tensor_meta"].shape # type: ignore[union-attr]
|
| 244 |
+
return (
|
| 245 |
+
node.kwargs.get("beta", 1.0) == 1.0
|
| 246 |
+
and node.kwargs.get("alpha", 1.0) == 1.0
|
| 247 |
+
and len(input_shape) == 2
|
| 248 |
+
and len(weight_shape) == 2
|
| 249 |
+
and all(x % 2 == 0 for x in input_shape + weight_shape)
|
| 250 |
+
and all(
|
| 251 |
+
shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"]
|
| 252 |
+
for shape in input_shape + weight_shape
|
| 253 |
+
)
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
def _mm_node_can_be_fused(self, node: torch.fx.Node):
|
| 257 |
+
input_shape = node.args[0].meta["tensor_meta"].shape # type: ignore[union-attr]
|
| 258 |
+
weight_shape = node.args[1].meta["tensor_meta"].shape # type: ignore[union-attr]
|
| 259 |
+
return (
|
| 260 |
+
len(input_shape) == 2
|
| 261 |
+
and len(weight_shape) == 2
|
| 262 |
+
and all(x % 2 == 0 for x in input_shape + weight_shape)
|
| 263 |
+
and all(
|
| 264 |
+
shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"]
|
| 265 |
+
for shape in input_shape + weight_shape
|
| 266 |
+
)
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool]]:
|
| 270 |
+
if CallFunctionVarArgs(aten.mm.default).match(
|
| 271 |
+
node
|
| 272 |
+
) and self._mm_node_can_be_fused(node):
|
| 273 |
+
group_key = ("group_linear", True)
|
| 274 |
+
elif CallFunctionVarArgs(aten.addmm.default).match(
|
| 275 |
+
node
|
| 276 |
+
) and self._addmm_node_can_be_fused(node):
|
| 277 |
+
bias = node.args[0]
|
| 278 |
+
group_key = ("group_linear", bias is None)
|
| 279 |
+
else:
|
| 280 |
+
group_key = None
|
| 281 |
+
return group_key
|
| 282 |
+
|
| 283 |
+
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
| 284 |
+
group_inputs = []
|
| 285 |
+
group_weights = []
|
| 286 |
+
group_biases = []
|
| 287 |
+
group_nodes = []
|
| 288 |
+
for node in subset:
|
| 289 |
+
if CallFunctionVarArgs(aten.addmm.default).match(node):
|
| 290 |
+
bias, input, weight = node.args
|
| 291 |
+
else:
|
| 292 |
+
assert CallFunctionVarArgs(aten.mm.default).match(node)
|
| 293 |
+
input, weight = node.args
|
| 294 |
+
bias = None
|
| 295 |
+
|
| 296 |
+
group_nodes.append(node)
|
| 297 |
+
group_inputs.append(input)
|
| 298 |
+
group_weights.append(weight)
|
| 299 |
+
group_biases.append(bias)
|
| 300 |
+
|
| 301 |
+
if all(bias is None for bias in group_biases):
|
| 302 |
+
group_biases = None # type: ignore[assignment]
|
| 303 |
+
group_biases: Optional[List[Any]]
|
| 304 |
+
|
| 305 |
+
with graph.inserting_before(subset[0]):
|
| 306 |
+
fused_mm = graph.call_function(
|
| 307 |
+
torch.ops.fbgemm.gmm.default,
|
| 308 |
+
args=(group_inputs, group_weights, group_biases),
|
| 309 |
+
kwargs={"smart_fused": True},
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
for i, original_mm in enumerate(group_nodes):
|
| 313 |
+
with graph.inserting_after(fused_mm):
|
| 314 |
+
new_mm = graph.call_function(operator.getitem, args=(fused_mm, i))
|
| 315 |
+
original_mm.replace_all_uses_with(new_mm)
|
| 316 |
+
new_mm.meta.update(original_mm.meta)
|
| 317 |
+
graph.erase_node(original_mm)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
|
| 321 |
+
"""
|
| 322 |
+
Batch pointwise operator (e.g., add, mul) in post grad pass.
|
| 323 |
+
"""
|
| 324 |
+
|
| 325 |
+
def __init__(self, op, **kwargs):
|
| 326 |
+
super().__init__(op, **kwargs)
|
| 327 |
+
self.op = op
|
| 328 |
+
|
| 329 |
+
def _pointwise_node_can_be_fused(self, node: torch.fx.Node):
|
| 330 |
+
# note: we only consider the case where the inputs are tensors
|
| 331 |
+
# for mixed precision training, we need to make sure the inputs
|
| 332 |
+
# of the aten.cat when do the stack should be the same dtype
|
| 333 |
+
# otherwise, the output of the aten.cat may be not the same as
|
| 334 |
+
# its inputs, and cause dtype not same error in mm or addmm
|
| 335 |
+
input, other = node.args
|
| 336 |
+
return (
|
| 337 |
+
input.meta["tensor_meta"].shape == other.meta["tensor_meta"].shape # type: ignore[union-attr]
|
| 338 |
+
if hasattr(input, "meta")
|
| 339 |
+
and hasattr(other, "meta")
|
| 340 |
+
and "tensor_meta" in input.meta # type: ignore[union-attr]
|
| 341 |
+
and "tensor_meta" in other.meta # type: ignore[union-attr]
|
| 342 |
+
else False
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
def match(self, node: torch.fx.Node):
|
| 346 |
+
if CallFunctionVarArgs(self.op).match(
|
| 347 |
+
node
|
| 348 |
+
) and self._pointwise_node_can_be_fused(node):
|
| 349 |
+
alpha = node.kwargs.get("alpha", 1.0)
|
| 350 |
+
rounding_mode = node.kwargs.get("rounding_mode", None)
|
| 351 |
+
input, other = node.args
|
| 352 |
+
shape = list(input.meta["tensor_meta"].shape) # type: ignore[union-attr]
|
| 353 |
+
group_key = (
|
| 354 |
+
"batch_" + self.op.__name__.lower() + "_post_grad",
|
| 355 |
+
str(shape),
|
| 356 |
+
str(input.meta["tensor_meta"].dtype), # type: ignore[union-attr]
|
| 357 |
+
str(other.meta["tensor_meta"].dtype), # type: ignore[union-attr]
|
| 358 |
+
str(alpha),
|
| 359 |
+
str(rounding_mode),
|
| 360 |
+
)
|
| 361 |
+
else:
|
| 362 |
+
group_key = None
|
| 363 |
+
return group_key
|
| 364 |
+
|
| 365 |
+
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
| 366 |
+
batch_inputs, batch_others = [], []
|
| 367 |
+
alpha = subset[0].kwargs.get("alpha", 1.0)
|
| 368 |
+
|
| 369 |
+
for node in subset:
|
| 370 |
+
input, other = node.args
|
| 371 |
+
batch_inputs.append(input)
|
| 372 |
+
batch_others.append(other)
|
| 373 |
+
|
| 374 |
+
with graph.inserting_before(subset[0]):
|
| 375 |
+
stack_inputs = decompose_stack(graph, batch_inputs)
|
| 376 |
+
stack_others = decompose_stack(graph, batch_others)
|
| 377 |
+
|
| 378 |
+
batch_op = graph.call_function(
|
| 379 |
+
self.op,
|
| 380 |
+
args=(stack_inputs, stack_others),
|
| 381 |
+
kwargs={"alpha": alpha} if self.op == aten.add.Tensor else {},
|
| 382 |
+
)
|
| 383 |
+
for i, original_add in enumerate(subset):
|
| 384 |
+
with graph.inserting_after(batch_op):
|
| 385 |
+
new_add = graph.call_function(
|
| 386 |
+
torch.ops.aten.select, args=((batch_op, 0, i))
|
| 387 |
+
)
|
| 388 |
+
original_add.replace_all_uses_with(new_add)
|
| 389 |
+
new_add.meta.update(original_add.meta)
|
| 390 |
+
graph.erase_node(original_add)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
@register_fusion("batch_linear_lhs")
|
| 394 |
+
class BatchLinearLHSFusion(BatchFusion):
|
| 395 |
+
"""
|
| 396 |
+
Batch linear left-hand side fusion. This pass tries to fuse the following patterns:
|
| 397 |
+
|
| 398 |
+
torch.nn.functional.linear(x, w1), linear(x, w2),... * linear(x, wn)
|
| 399 |
+
-> torch.mm(x, torch.cat([w1, w2,... * wn]).transpose(0, 1))
|
| 400 |
+
|
| 401 |
+
We have a separate pass to eliminate contiguous transpose in a generic way.
|
| 402 |
+
"""
|
| 403 |
+
|
| 404 |
+
def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool, Any]]:
|
| 405 |
+
if CallFunctionVarArgs(torch.nn.functional.linear).match(
|
| 406 |
+
node
|
| 407 |
+
) and is_linear_node_can_be_fused(node):
|
| 408 |
+
input = get_arg_value(node, 0, "input")
|
| 409 |
+
bias = get_arg_value(node, 2, "bias")
|
| 410 |
+
group_key = ("batch_linear_lhs", bias is None, input)
|
| 411 |
+
else:
|
| 412 |
+
group_key = None
|
| 413 |
+
return group_key
|
| 414 |
+
|
| 415 |
+
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
| 416 |
+
batch_nodes = []
|
| 417 |
+
batch_input = None
|
| 418 |
+
batch_weights = []
|
| 419 |
+
batch_biases = []
|
| 420 |
+
split_sections = []
|
| 421 |
+
for node in subset:
|
| 422 |
+
input = get_arg_value(node, 0, "input")
|
| 423 |
+
weight = get_arg_value(node, 1, "weight")
|
| 424 |
+
bias = get_arg_value(node, 2, "bias")
|
| 425 |
+
batch_nodes.append(node)
|
| 426 |
+
if batch_input is None:
|
| 427 |
+
batch_input = input
|
| 428 |
+
else:
|
| 429 |
+
assert batch_input is input
|
| 430 |
+
batch_weights.append(weight)
|
| 431 |
+
if bias:
|
| 432 |
+
batch_biases.append(bias)
|
| 433 |
+
split_sections.append(weight.meta["example_value"].shape[0])
|
| 434 |
+
|
| 435 |
+
with graph.inserting_before(subset[0]):
|
| 436 |
+
cat_weights = graph.call_function(
|
| 437 |
+
torch.cat, args=(batch_weights,), kwargs={"dim": 0}
|
| 438 |
+
)
|
| 439 |
+
transposed_weights = graph.call_function(
|
| 440 |
+
torch.transpose, args=(cat_weights, 0, 1)
|
| 441 |
+
)
|
| 442 |
+
if len(batch_biases) > 0:
|
| 443 |
+
cat_biases = graph.call_function(
|
| 444 |
+
torch.cat, args=(batch_biases,), kwargs={"dim": 0}
|
| 445 |
+
)
|
| 446 |
+
fused_lhs = graph.call_function(
|
| 447 |
+
torch.addmm,
|
| 448 |
+
args=(cat_biases, batch_input, transposed_weights),
|
| 449 |
+
)
|
| 450 |
+
else:
|
| 451 |
+
fused_lhs = graph.call_function(
|
| 452 |
+
torch.mm,
|
| 453 |
+
args=(batch_input, transposed_weights),
|
| 454 |
+
)
|
| 455 |
+
fused_lhs_list = graph.call_function(
|
| 456 |
+
torch.split, args=(fused_lhs, split_sections), kwargs={"dim": 1}
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
for i, node in enumerate(batch_nodes):
|
| 460 |
+
with graph.inserting_after(fused_lhs_list):
|
| 461 |
+
new_node = graph.call_function(
|
| 462 |
+
operator.getitem, args=(fused_lhs_list, i)
|
| 463 |
+
)
|
| 464 |
+
node.replace_all_uses_with(new_node)
|
| 465 |
+
new_node.meta.update(node.meta)
|
| 466 |
+
graph.erase_node(node)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def is_node_meta_valid(node: Optional[torch.fx.Node]):
|
| 470 |
+
if node is None:
|
| 471 |
+
return True
|
| 472 |
+
if "example_value" not in node.meta:
|
| 473 |
+
return False
|
| 474 |
+
return True
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def is_linear_node_can_be_fused(node: torch.fx.Node):
|
| 478 |
+
input = get_arg_value(node, 0, "input")
|
| 479 |
+
weight = get_arg_value(node, 1, "weight")
|
| 480 |
+
return (
|
| 481 |
+
is_node_meta_valid(node)
|
| 482 |
+
and is_node_meta_valid(input)
|
| 483 |
+
and is_node_meta_valid(weight)
|
| 484 |
+
and len(input.meta["example_value"].shape) == 2
|
| 485 |
+
and len(weight.meta["example_value"].shape) == 2
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
@register_fusion("batch_linear")
|
| 490 |
+
class PreGradBatchLinearFusion(BatchFusion):
|
| 491 |
+
"""
|
| 492 |
+
Batch linear fusion in pre grad pass.
|
| 493 |
+
Fuse linear with same size with torch.baddmm
|
| 494 |
+
"""
|
| 495 |
+
|
| 496 |
+
def _getitem_args(self, getitem_node: torch.fx.Node):
|
| 497 |
+
if getitem_node.target != operator.__getitem__ or (
|
| 498 |
+
getitem_node.op != "call_function"
|
| 499 |
+
):
|
| 500 |
+
return None
|
| 501 |
+
return getitem_node.args[0]
|
| 502 |
+
|
| 503 |
+
def match(self, node: torch.fx.Node):
|
| 504 |
+
if CallFunctionVarArgs(torch.nn.functional.linear).match(
|
| 505 |
+
node
|
| 506 |
+
) and is_linear_node_can_be_fused(node):
|
| 507 |
+
input = get_arg_value(node, 0, "input")
|
| 508 |
+
weight = get_arg_value(node, 1, "weight")
|
| 509 |
+
bias = get_arg_value(node, 2, "bias")
|
| 510 |
+
group_key = (
|
| 511 |
+
"batch_linear_pre_grad",
|
| 512 |
+
self._getitem_args(input),
|
| 513 |
+
str(input.meta["example_value"].shape),
|
| 514 |
+
str(weight.meta["example_value"].shape),
|
| 515 |
+
bias is None,
|
| 516 |
+
)
|
| 517 |
+
else:
|
| 518 |
+
group_key = None
|
| 519 |
+
return group_key
|
| 520 |
+
|
| 521 |
+
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
| 522 |
+
batch_nodes = []
|
| 523 |
+
batch_inputs = []
|
| 524 |
+
batch_weights = []
|
| 525 |
+
batch_biases = []
|
| 526 |
+
batch_inputs_metadata = []
|
| 527 |
+
batch_weights_metadata = []
|
| 528 |
+
batch_biases_metadata = []
|
| 529 |
+
for node in subset:
|
| 530 |
+
batch_nodes.append(node)
|
| 531 |
+
input = get_arg_value(node, 0, "input")
|
| 532 |
+
batch_inputs.append(input)
|
| 533 |
+
batch_inputs_metadata.append(input.meta["example_value"])
|
| 534 |
+
weight = get_arg_value(node, 1, "weight")
|
| 535 |
+
batch_weights.append(weight)
|
| 536 |
+
batch_weights_metadata.append(weight.meta["example_value"])
|
| 537 |
+
bias = get_arg_value(node, 2, "bias")
|
| 538 |
+
batch_biases.append(bias)
|
| 539 |
+
if bias is not None and hasattr(bias, "meta"):
|
| 540 |
+
batch_biases_metadata.append(bias.meta["example_value"])
|
| 541 |
+
|
| 542 |
+
with graph.inserting_before(subset[0]):
|
| 543 |
+
stack_inputs = graph.call_function(
|
| 544 |
+
torch.stack, args=(batch_inputs,), kwargs={"dim": 0}
|
| 545 |
+
)
|
| 546 |
+
update_stack_example_value(stack_inputs, batch_inputs_metadata)
|
| 547 |
+
stack_weights = graph.call_function(
|
| 548 |
+
torch.stack, args=(batch_weights,), kwargs={"dim": 0}
|
| 549 |
+
)
|
| 550 |
+
update_stack_example_value(stack_weights, batch_weights_metadata)
|
| 551 |
+
transpose_weight = graph.call_function(
|
| 552 |
+
torch.transpose, args=(stack_weights, 1, 2)
|
| 553 |
+
)
|
| 554 |
+
if all(bias is None for bias in batch_biases):
|
| 555 |
+
bmm = graph.call_function(
|
| 556 |
+
torch.bmm,
|
| 557 |
+
args=(stack_inputs, transpose_weight),
|
| 558 |
+
)
|
| 559 |
+
else:
|
| 560 |
+
stack_biases = graph.call_function(
|
| 561 |
+
torch.stack, args=(batch_biases,), kwargs={"dim": 0}
|
| 562 |
+
)
|
| 563 |
+
update_stack_example_value(stack_biases, batch_biases_metadata)
|
| 564 |
+
unsqueeze_biases = graph.call_function(
|
| 565 |
+
torch.unsqueeze, args=(stack_biases, 1)
|
| 566 |
+
)
|
| 567 |
+
bmm = graph.call_function(
|
| 568 |
+
torch.baddbmm,
|
| 569 |
+
args=(unsqueeze_biases, stack_inputs, transpose_weight),
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
bmm = graph.call_function(torch.unbind, args=(bmm,), kwargs={"dim": 0})
|
| 573 |
+
for i, linear in enumerate(batch_nodes):
|
| 574 |
+
with graph.inserting_after(bmm):
|
| 575 |
+
getitem = graph.call_function(operator.getitem, args=(bmm, i))
|
| 576 |
+
linear.replace_all_uses_with(getitem)
|
| 577 |
+
getitem.meta.update(linear.meta)
|
| 578 |
+
graph.erase_node(linear)
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
@register_fusion("batch_layernorm")
|
| 582 |
+
class BatchLayernormFusion(BatchFusion):
|
| 583 |
+
"""
|
| 584 |
+
Batch layer norm fusion in pre grad pass
|
| 585 |
+
"""
|
| 586 |
+
|
| 587 |
+
def match(self, node: torch.fx.Node):
|
| 588 |
+
if CallFunctionVarArgs(torch.nn.functional.layer_norm).match(node):
|
| 589 |
+
input = get_arg_value(node, 0, "input")
|
| 590 |
+
weight = get_arg_value(node, 2, "weight")
|
| 591 |
+
bias = get_arg_value(node, 3, "bias")
|
| 592 |
+
group_key = (
|
| 593 |
+
(
|
| 594 |
+
"batch_layernorm",
|
| 595 |
+
str(input.meta["example_value"].shape),
|
| 596 |
+
str(weight.meta["example_value"].shape)
|
| 597 |
+
if weight is not None
|
| 598 |
+
else "",
|
| 599 |
+
str(bias.meta["example_value"].shape) if bias is not None else "",
|
| 600 |
+
str(get_arg_value(node, 1, "normalized_shape")),
|
| 601 |
+
str(get_arg_value(node, 4, "eps")),
|
| 602 |
+
)
|
| 603 |
+
if "example_value" in input.meta
|
| 604 |
+
and is_node_meta_valid(weight)
|
| 605 |
+
and is_node_meta_valid(bias)
|
| 606 |
+
else None
|
| 607 |
+
)
|
| 608 |
+
else:
|
| 609 |
+
group_key = None
|
| 610 |
+
return group_key
|
| 611 |
+
|
| 612 |
+
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
| 613 |
+
group_inputs = []
|
| 614 |
+
group_shapes = []
|
| 615 |
+
group_weights = []
|
| 616 |
+
group_biases = []
|
| 617 |
+
group_epss = []
|
| 618 |
+
group_nodes = []
|
| 619 |
+
group_inputs_metadata = []
|
| 620 |
+
group_biases_metadata = []
|
| 621 |
+
group_weights_metadata = []
|
| 622 |
+
for node in subset:
|
| 623 |
+
group_nodes.append(node)
|
| 624 |
+
input = get_arg_value(node, 0, "input")
|
| 625 |
+
group_inputs.append(input)
|
| 626 |
+
group_inputs_metadata.append(input.meta["example_value"])
|
| 627 |
+
group_shapes.append(get_arg_value(node, 1, "normalized_shape"))
|
| 628 |
+
weight = get_arg_value(node, 2, "weight")
|
| 629 |
+
group_weights.append(weight)
|
| 630 |
+
if weight is not None and hasattr(weight, "meta"):
|
| 631 |
+
group_weights_metadata.append(weight.meta["example_value"])
|
| 632 |
+
bias = get_arg_value(node, 3, "bias")
|
| 633 |
+
group_biases.append(bias)
|
| 634 |
+
if bias is not None and hasattr(bias, "meta"):
|
| 635 |
+
group_biases_metadata.append(bias.meta["example_value"])
|
| 636 |
+
eps = get_arg_value(node, 4, "eps")
|
| 637 |
+
if eps is None:
|
| 638 |
+
eps = 1e-5
|
| 639 |
+
group_epss.append(eps)
|
| 640 |
+
stack_dim = -1 - len(group_shapes[-1])
|
| 641 |
+
|
| 642 |
+
if all(bias is None for bias in group_biases):
|
| 643 |
+
group_biases = None # type: ignore[assignment]
|
| 644 |
+
group_biases: Optional[List[Any]]
|
| 645 |
+
if all(weight is None for weight in group_weights):
|
| 646 |
+
group_weights = None # type: ignore[assignment]
|
| 647 |
+
group_weights: Optional[List[Any]]
|
| 648 |
+
assert all(
|
| 649 |
+
eps == group_epss[0] for eps in group_epss
|
| 650 |
+
), "all epsilon values must be equal"
|
| 651 |
+
|
| 652 |
+
with graph.inserting_before(subset[0]):
|
| 653 |
+
stack_input = graph.call_function(
|
| 654 |
+
torch.stack, args=(group_inputs,), kwargs={"dim": stack_dim}
|
| 655 |
+
)
|
| 656 |
+
update_stack_example_value(stack_input, group_inputs_metadata, stack_dim)
|
| 657 |
+
if group_weights is not None:
|
| 658 |
+
stack_weight = graph.call_function(
|
| 659 |
+
torch.stack, args=(group_weights,), kwargs={"dim": 0}
|
| 660 |
+
)
|
| 661 |
+
update_stack_example_value(stack_weight, group_weights_metadata)
|
| 662 |
+
else:
|
| 663 |
+
stack_weight = None
|
| 664 |
+
if group_biases is not None:
|
| 665 |
+
stack_bias = graph.call_function(
|
| 666 |
+
torch.stack, args=(group_biases,), kwargs={"dim": 0}
|
| 667 |
+
)
|
| 668 |
+
update_stack_example_value(stack_bias, group_biases_metadata)
|
| 669 |
+
else:
|
| 670 |
+
stack_bias = None
|
| 671 |
+
|
| 672 |
+
batch_layer_norm = graph.call_function(
|
| 673 |
+
torch.nn.functional.layer_norm,
|
| 674 |
+
args=(stack_input, group_shapes[-1]),
|
| 675 |
+
kwargs={"eps": group_epss[-1]},
|
| 676 |
+
)
|
| 677 |
+
batch_layer_norm.meta["example_value"] = stack_input.meta["example_value"]
|
| 678 |
+
|
| 679 |
+
if group_weights is not None and group_biases is not None:
|
| 680 |
+
previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
|
| 681 |
+
batch_layer_norm = graph.call_function(
|
| 682 |
+
torch.mul, args=(stack_weight, batch_layer_norm)
|
| 683 |
+
)
|
| 684 |
+
update_pointwise_example_value(
|
| 685 |
+
batch_layer_norm,
|
| 686 |
+
stack_weight.meta["example_value"],
|
| 687 |
+
previous_batch_layer_norm_meta,
|
| 688 |
+
torch.mul,
|
| 689 |
+
)
|
| 690 |
+
previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
|
| 691 |
+
batch_layer_norm = graph.call_function(
|
| 692 |
+
torch.add, args=(stack_bias, batch_layer_norm)
|
| 693 |
+
)
|
| 694 |
+
update_pointwise_example_value(
|
| 695 |
+
batch_layer_norm,
|
| 696 |
+
stack_bias.meta["example_value"],
|
| 697 |
+
previous_batch_layer_norm_meta,
|
| 698 |
+
torch.add,
|
| 699 |
+
)
|
| 700 |
+
elif group_weights is not None and group_biases is None:
|
| 701 |
+
previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
|
| 702 |
+
batch_layer_norm = graph.call_function(
|
| 703 |
+
torch.mul, args=(stack_weight, batch_layer_norm)
|
| 704 |
+
)
|
| 705 |
+
update_pointwise_example_value(
|
| 706 |
+
batch_layer_norm,
|
| 707 |
+
stack_weight.meta["example_value"],
|
| 708 |
+
previous_batch_layer_norm_meta,
|
| 709 |
+
torch.mul,
|
| 710 |
+
)
|
| 711 |
+
elif group_weights is None and group_biases is not None:
|
| 712 |
+
previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
|
| 713 |
+
batch_layer_norm = graph.call_function(
|
| 714 |
+
torch.add, args=(stack_bias, batch_layer_norm)
|
| 715 |
+
)
|
| 716 |
+
update_pointwise_example_value(
|
| 717 |
+
batch_layer_norm,
|
| 718 |
+
stack_bias.meta["example_value"],
|
| 719 |
+
previous_batch_layer_norm_meta,
|
| 720 |
+
torch.add,
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
batch_layer_norm_unbind = graph.call_function(
|
| 724 |
+
torch.unbind,
|
| 725 |
+
args=(batch_layer_norm,),
|
| 726 |
+
kwargs={"dim": stack_dim},
|
| 727 |
+
)
|
| 728 |
+
update_stack_example_value(
|
| 729 |
+
batch_layer_norm_unbind,
|
| 730 |
+
batch_layer_norm.meta["example_value"],
|
| 731 |
+
op=torch.unbind,
|
| 732 |
+
dim=stack_dim,
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
for i, node in enumerate(group_nodes):
|
| 736 |
+
with graph.inserting_after(batch_layer_norm_unbind):
|
| 737 |
+
new_node = graph.call_function(
|
| 738 |
+
operator.getitem, args=(batch_layer_norm_unbind, i)
|
| 739 |
+
)
|
| 740 |
+
node.replace_all_uses_with(new_node)
|
| 741 |
+
new_node.meta.update(node.meta)
|
| 742 |
+
graph.erase_node(node)
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
class BatchPointwiseOpsPreGradFusion(BatchPointwiseOpsFusionFactory):
|
| 746 |
+
"""
|
| 747 |
+
Batch poinwise ops (e.g., sigmoid, relu, tanh) fusion in pre grad pass.
|
| 748 |
+
We fuse it in random place, and the introduced stack node may be merged in split cat.
|
| 749 |
+
"""
|
| 750 |
+
|
| 751 |
+
def __init__(self, op, **kwargs):
|
| 752 |
+
super().__init__(op, **kwargs)
|
| 753 |
+
self.op = op
|
| 754 |
+
|
| 755 |
+
def match(self, node: torch.fx.Node):
|
| 756 |
+
input = get_arg_value(node, 0, "input")
|
| 757 |
+
if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node):
|
| 758 |
+
# for relu op, we also use the inplace to construct the key
|
| 759 |
+
group_key = (
|
| 760 |
+
"batch_" + self.op.__name__.lower() + "_pre_grad",
|
| 761 |
+
str(input.meta["example_value"].shape),
|
| 762 |
+
str(node.kwargs.get("inplace", False)),
|
| 763 |
+
)
|
| 764 |
+
else:
|
| 765 |
+
group_key = None
|
| 766 |
+
return group_key
|
| 767 |
+
|
| 768 |
+
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
| 769 |
+
batch_nodes = []
|
| 770 |
+
batch_inputs = []
|
| 771 |
+
batch_inputs_metadata = []
|
| 772 |
+
|
| 773 |
+
for node in subset:
|
| 774 |
+
batch_nodes.append(node)
|
| 775 |
+
input = get_arg_value(node, 0, "input")
|
| 776 |
+
batch_inputs.append(input)
|
| 777 |
+
batch_inputs_metadata.append(input.meta["example_value"])
|
| 778 |
+
|
| 779 |
+
with graph.inserting_before(subset[0]):
|
| 780 |
+
stack_inputs = graph.call_function(
|
| 781 |
+
torch.stack, args=(batch_inputs,), kwargs={"dim": 0}
|
| 782 |
+
)
|
| 783 |
+
update_stack_example_value(stack_inputs, batch_inputs_metadata)
|
| 784 |
+
if self.op == torch.nn.functional.relu:
|
| 785 |
+
batch_op = graph.call_function(
|
| 786 |
+
self.op,
|
| 787 |
+
args=(stack_inputs,),
|
| 788 |
+
kwargs={"inplace": subset[0].kwargs.get("inplace", False)},
|
| 789 |
+
)
|
| 790 |
+
else:
|
| 791 |
+
batch_op = graph.call_function(
|
| 792 |
+
self.op,
|
| 793 |
+
args=(stack_inputs,),
|
| 794 |
+
)
|
| 795 |
+
unbind_op = graph.call_function(
|
| 796 |
+
torch.unbind, args=(batch_op,), kwargs={"dim": 0}
|
| 797 |
+
)
|
| 798 |
+
for i, node in enumerate(batch_nodes):
|
| 799 |
+
with graph.inserting_after(unbind_op):
|
| 800 |
+
getitem = graph.call_function(operator.getitem, args=(unbind_op, i))
|
| 801 |
+
node.replace_all_uses_with(getitem)
|
| 802 |
+
getitem.meta.update(node.meta)
|
| 803 |
+
graph.erase_node(node)
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
@register_fusion("batch_tanh")
|
| 807 |
+
class BatchTanhPreGradFusion(BatchPointwiseOpsPreGradFusion):
|
| 808 |
+
def __init__(self, **kwargs):
|
| 809 |
+
super().__init__(torch.tanh, **kwargs)
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
@register_fusion("batch_sigmoid")
|
| 813 |
+
class BatchSigmoidPreGradFusion(BatchPointwiseOpsPreGradFusion):
|
| 814 |
+
def __init__(self, **kwargs):
|
| 815 |
+
super().__init__(torch.sigmoid, **kwargs)
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
@register_fusion("batch_relu")
|
| 819 |
+
class BatchReLuPreGradFusion(BatchPointwiseOpsPreGradFusion):
|
| 820 |
+
def __init__(self, **kwargs):
|
| 821 |
+
super().__init__(torch.nn.functional.relu, **kwargs)
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
@register_fusion("batch_aten_add", pre_grad=False)
|
| 825 |
+
class BatchAddPostGradFusion(BatchPointwiseOpsPostGradFusion):
|
| 826 |
+
def __init__(self, **kwargs):
|
| 827 |
+
super().__init__(aten.add.Tensor, **kwargs)
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
@register_fusion("batch_aten_sub", pre_grad=False)
|
| 831 |
+
class BatchSubPostGradFusion(BatchPointwiseOpsPostGradFusion):
|
| 832 |
+
def __init__(self, **kwargs):
|
| 833 |
+
super().__init__(aten.sub.Tensor, **kwargs)
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
@register_fusion("batch_aten_div", pre_grad=False)
|
| 837 |
+
class BatchDivPostGradFusion(BatchPointwiseOpsPostGradFusion):
|
| 838 |
+
def __init__(self, **kwargs):
|
| 839 |
+
super().__init__(aten.div.Tensor, **kwargs)
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
@register_fusion("batch_aten_mul", pre_grad=False)
|
| 843 |
+
class BatchMulPostGradFusion(BatchPointwiseOpsPostGradFusion):
|
| 844 |
+
def __init__(self, **kwargs):
|
| 845 |
+
super().__init__(aten.mul.Tensor, **kwargs)
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
class _OrderedSet:
|
| 849 |
+
def __init__(self, param=None):
|
| 850 |
+
if param:
|
| 851 |
+
self.rep = OrderedDict({k: None for k in param})
|
| 852 |
+
else:
|
| 853 |
+
self.rep = OrderedDict()
|
| 854 |
+
|
| 855 |
+
def __contains__(self, o):
|
| 856 |
+
return o in self.rep
|
| 857 |
+
|
| 858 |
+
def __len__(self):
|
| 859 |
+
return self.rep.__len__()
|
| 860 |
+
|
| 861 |
+
def append(self, o):
|
| 862 |
+
self.rep[o] = None
|
| 863 |
+
|
| 864 |
+
def __iter__(self):
|
| 865 |
+
return self.rep.keys().__iter__()
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
def find_independent_subset_greedy(
|
| 869 |
+
node_list: Iterable[torch.fx.Node],
|
| 870 |
+
graph_search_options: Dict[str, Any],
|
| 871 |
+
) -> Iterator[Iterable[torch.fx.Node]]:
|
| 872 |
+
"""
|
| 873 |
+
Yields a list of subsets of `node_list` where no element in the subset
|
| 874 |
+
depends on any other element in the subset. This results in a set of
|
| 875 |
+
independent nodes which can be fused together.
|
| 876 |
+
|
| 877 |
+
The order of `node_list` is preserved within each subset so we can benefit
|
| 878 |
+
from split-cat elimination in later passes.
|
| 879 |
+
|
| 880 |
+
During iteration it is only safe to mutate the graph by changing the nodes
|
| 881 |
+
that have been returned.
|
| 882 |
+
|
| 883 |
+
graph_search_options:
|
| 884 |
+
- min_fuse_set_size: Minimum size of the subset to consider. Subsets below
|
| 885 |
+
this size will be ignored.
|
| 886 |
+
- max_fuse_set_size: Maximum size of the subset to consider. Subsets will
|
| 887 |
+
be broken to be at most this size.
|
| 888 |
+
"""
|
| 889 |
+
|
| 890 |
+
# Compute all the children of `node` which are members of
|
| 891 |
+
# `interesting_nodes`.
|
| 892 |
+
def find_dependent_nodes(node, interesting_nodes):
|
| 893 |
+
visited_node_set: Set[torch.fx.Node] = {node}
|
| 894 |
+
dep_set: Set[torch.fx.Node] = set()
|
| 895 |
+
|
| 896 |
+
work = [node]
|
| 897 |
+
while work:
|
| 898 |
+
node = work.pop()
|
| 899 |
+
for input_node in node.all_input_nodes:
|
| 900 |
+
if input_node in interesting_nodes:
|
| 901 |
+
dep_set.add(input_node)
|
| 902 |
+
|
| 903 |
+
if input_node not in visited_node_set:
|
| 904 |
+
visited_node_set.add(input_node)
|
| 905 |
+
work.append(input_node)
|
| 906 |
+
|
| 907 |
+
return dep_set
|
| 908 |
+
|
| 909 |
+
min_fuse_set_size = graph_search_options["min_fuse_set_size"]
|
| 910 |
+
max_fuse_set_size = graph_search_options["max_fuse_set_size"]
|
| 911 |
+
|
| 912 |
+
# node_list needs to be a set because we only track the nodes that are left
|
| 913 |
+
# in it (and we want to do the `in` on a set, not a list). But we want to
|
| 914 |
+
# keep the correct order.
|
| 915 |
+
node_list = _OrderedSet(node_list)
|
| 916 |
+
|
| 917 |
+
cache: Dict[torch.fx.Node, Set[torch.fx.Node]] = {}
|
| 918 |
+
while node_list:
|
| 919 |
+
subset: List[torch.fx.Node] = []
|
| 920 |
+
subset_deps: Set[torch.fx.Node] = set()
|
| 921 |
+
|
| 922 |
+
next_round_node_list = _OrderedSet()
|
| 923 |
+
for node in node_list:
|
| 924 |
+
if len(subset) >= max_fuse_set_size or node in subset_deps:
|
| 925 |
+
next_round_node_list.append(node)
|
| 926 |
+
continue
|
| 927 |
+
|
| 928 |
+
dep_set = cache.pop(node, None)
|
| 929 |
+
if dep_set is None:
|
| 930 |
+
dep_set = find_dependent_nodes(node, node_list)
|
| 931 |
+
|
| 932 |
+
if not dep_set.intersection(subset):
|
| 933 |
+
subset.append(node)
|
| 934 |
+
subset_deps.update(dep_set)
|
| 935 |
+
else:
|
| 936 |
+
next_round_node_list.append(node)
|
| 937 |
+
cache[node] = dep_set
|
| 938 |
+
|
| 939 |
+
if len(subset) >= min_fuse_set_size:
|
| 940 |
+
# Careful here - the caller uses the subsets to fuse nodes together
|
| 941 |
+
# so we need to clear any cache entry that contains one of the
|
| 942 |
+
# returned nodes because the dependency list could be different
|
| 943 |
+
# (larger) after the merge.
|
| 944 |
+
cache = {k: v for k, v in cache.items() if v.isdisjoint(subset)}
|
| 945 |
+
yield subset
|
| 946 |
+
|
| 947 |
+
node_list = next_round_node_list
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
def get_fusion_candidates(
|
| 951 |
+
rule: GroupBatchFusionBase, root_node: torch.fx.Node, fused_set: Set[torch.fx.Node]
|
| 952 |
+
) -> DefaultDict[Any, List[torch.fx.Node]]:
|
| 953 |
+
"""
|
| 954 |
+
Search fusion candidates for a specific rule using BFS starting from the root node.
|
| 955 |
+
We only search the subgraph within graph_search_options["max_fuse_search_depth"].
|
| 956 |
+
"""
|
| 957 |
+
q: Deque[Tuple[int, torch.fx.Node]] = collections.deque()
|
| 958 |
+
|
| 959 |
+
candidate_dict: DefaultDict[Any, List[torch.fx.Node]] = collections.defaultdict(
|
| 960 |
+
list
|
| 961 |
+
)
|
| 962 |
+
|
| 963 |
+
if root_node.target in SEARCH_EXCLUSIONS:
|
| 964 |
+
return candidate_dict
|
| 965 |
+
|
| 966 |
+
visited_set: Set[torch.fx.Node] = set()
|
| 967 |
+
|
| 968 |
+
for next_node in root_node.all_input_nodes:
|
| 969 |
+
q.append((1, next_node))
|
| 970 |
+
visited_set.add(next_node)
|
| 971 |
+
|
| 972 |
+
while len(q) > 0:
|
| 973 |
+
depth, node = q.popleft()
|
| 974 |
+
|
| 975 |
+
if node in fused_set:
|
| 976 |
+
continue
|
| 977 |
+
|
| 978 |
+
key = rule.match(node)
|
| 979 |
+
if key is not None:
|
| 980 |
+
candidate_nodes = candidate_dict[key]
|
| 981 |
+
if node not in candidate_nodes:
|
| 982 |
+
candidate_nodes.append(node)
|
| 983 |
+
else:
|
| 984 |
+
if depth < rule.graph_search_options["max_fuse_search_depth"]:
|
| 985 |
+
for next_node in node.all_input_nodes:
|
| 986 |
+
if next_node not in visited_set:
|
| 987 |
+
visited_set.add(next_node)
|
| 988 |
+
q.append((depth + 1, next_node))
|
| 989 |
+
|
| 990 |
+
return candidate_dict
|
| 991 |
+
|
| 992 |
+
|
| 993 |
+
def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusionBase):
|
| 994 |
+
stable_topological_sort(graph) # type: ignore[arg-type]
|
| 995 |
+
fused_set: Set[torch.fx.Node] = set()
|
| 996 |
+
|
| 997 |
+
for node in reversed(graph.nodes):
|
| 998 |
+
candidates = get_fusion_candidates(rule, node, fused_set)
|
| 999 |
+
|
| 1000 |
+
for key, candidate_nodes in candidates.items():
|
| 1001 |
+
if len(candidate_nodes) < rule.graph_search_options["min_fuse_set_size"]:
|
| 1002 |
+
continue
|
| 1003 |
+
|
| 1004 |
+
for subset in find_independent_subset_greedy(
|
| 1005 |
+
candidate_nodes, rule.graph_search_options
|
| 1006 |
+
):
|
| 1007 |
+
rule.fuse(graph, subset)
|
| 1008 |
+
fused_set.update(subset)
|
| 1009 |
+
if isinstance(rule, GroupFusion):
|
| 1010 |
+
counters["inductor"]["group_fusion"] += 1
|
| 1011 |
+
elif isinstance(rule, BatchFusion):
|
| 1012 |
+
counters["inductor"]["batch_fusion"] += 1
|
| 1013 |
+
else:
|
| 1014 |
+
counters["inductor"]["unknown_group_batch_fusion"] += 1
|
| 1015 |
+
|
| 1016 |
+
log.debug(
|
| 1017 |
+
f"{rule.__class__.__name__}: key = {key}; subset size = {len(list(subset))}" # noqa: G004
|
| 1018 |
+
)
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
def generate_fusion_from_config(config_options: Dict[str, Any], pre_grad=True):
|
| 1022 |
+
fusions: List[GroupBatchFusionBase] = []
|
| 1023 |
+
for name, options in config_options.items():
|
| 1024 |
+
fusion_cls = PRE_GRAD_FUSIONS[name] if pre_grad else POST_GRAD_FUSIONS[name]
|
| 1025 |
+
_options = graph_search_options.copy()
|
| 1026 |
+
_options.update(options)
|
| 1027 |
+
fusions.append(fusion_cls(graph_search_options=_options)) # type: ignore[operator]
|
| 1028 |
+
return fusions
|
| 1029 |
+
|
| 1030 |
+
|
| 1031 |
+
def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
|
| 1032 |
+
fusions: List[GroupBatchFusionBase] = []
|
| 1033 |
+
# we keep all current pre grad fusions to keep
|
| 1034 |
+
# current implementation, will remove this later
|
| 1035 |
+
if pre_grad:
|
| 1036 |
+
fusions += generate_fusion_from_config(
|
| 1037 |
+
config.pre_grad_fusion_options, pre_grad=True
|
| 1038 |
+
)
|
| 1039 |
+
else:
|
| 1040 |
+
fbgemm_fusion_keys = [
|
| 1041 |
+
x
|
| 1042 |
+
for x in config.post_grad_fusion_options
|
| 1043 |
+
if config.post_grad_fusion_options[x].get("require_fbgemm", False)
|
| 1044 |
+
]
|
| 1045 |
+
fbgemm_fusions = {
|
| 1046 |
+
fusion: config.post_grad_fusion_options[fusion]
|
| 1047 |
+
for fusion in fbgemm_fusion_keys
|
| 1048 |
+
}
|
| 1049 |
+
non_fbgemm_fusions = {
|
| 1050 |
+
fusion: config.post_grad_fusion_options[fusion]
|
| 1051 |
+
for fusion in config.post_grad_fusion_options.keys()
|
| 1052 |
+
if fusion not in fbgemm_fusion_keys
|
| 1053 |
+
}
|
| 1054 |
+
fusions += generate_fusion_from_config(non_fbgemm_fusions, pre_grad=False)
|
| 1055 |
+
if has_fbgemm:
|
| 1056 |
+
fusions += generate_fusion_from_config(fbgemm_fusions, pre_grad=False)
|
| 1057 |
+
|
| 1058 |
+
for rule in fusions:
|
| 1059 |
+
apply_group_batch_fusion(graph, rule) # type: ignore[arg-type]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/joint_graph.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import typing
|
| 3 |
+
from collections import Counter
|
| 4 |
+
from typing import Dict, List, Set
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch._guards
|
| 8 |
+
from torch._inductor.constant_folding import ConstantFolder
|
| 9 |
+
from torch.multiprocessing.reductions import StorageWeakRef
|
| 10 |
+
|
| 11 |
+
from .. import config
|
| 12 |
+
from ..pattern_matcher import (
|
| 13 |
+
CallFunction,
|
| 14 |
+
init_once_fakemode,
|
| 15 |
+
KeywordArg,
|
| 16 |
+
Match,
|
| 17 |
+
PatternMatcherPass,
|
| 18 |
+
register_graph_pattern,
|
| 19 |
+
stable_topological_sort,
|
| 20 |
+
)
|
| 21 |
+
from .replace_random import replace_random_passes
|
| 22 |
+
|
| 23 |
+
log = logging.getLogger(__name__)
|
| 24 |
+
patterns = PatternMatcherPass()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@init_once_fakemode
|
| 28 |
+
def lazy_init():
|
| 29 |
+
from .fuse_attention import _sfdp_init
|
| 30 |
+
from .misc_patterns import _misc_patterns_init
|
| 31 |
+
from .pad_mm import _pad_mm_init
|
| 32 |
+
|
| 33 |
+
_pad_mm_init()
|
| 34 |
+
_sfdp_init()
|
| 35 |
+
_misc_patterns_init()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@torch.utils._python_dispatch._disable_current_modes()
|
| 39 |
+
def remove_no_ops(
|
| 40 |
+
gm: torch.fx.GraphModule, zeros: Set[torch.fx.Node], ones: Set[torch.fx.Node]
|
| 41 |
+
):
|
| 42 |
+
"Removes no-ops: (+ 0, - 0, * 1, / 1)"
|
| 43 |
+
aten = torch.ops.aten
|
| 44 |
+
graph = gm.graph
|
| 45 |
+
|
| 46 |
+
def fake_tensors_eq(t1, t2, fields=("shape", "dtype", "device")):
|
| 47 |
+
if any(not isinstance(t, torch.Tensor) for t in (t1, t2)):
|
| 48 |
+
return False
|
| 49 |
+
for field in fields:
|
| 50 |
+
if getattr(t1, field) != getattr(t2, field):
|
| 51 |
+
return False
|
| 52 |
+
return True
|
| 53 |
+
|
| 54 |
+
def replace_no_op(node, replace_input_index):
|
| 55 |
+
replacement = node.args[replace_input_index]
|
| 56 |
+
|
| 57 |
+
# https://github.com/pytorch/pytorch/issues/86128 causes
|
| 58 |
+
# non-Tensor inputs even for ops with only Tensor inputs.
|
| 59 |
+
# TODO - decompose/type promote to avoid this
|
| 60 |
+
if not all(isinstance(arg, torch.fx.Node) for arg in node.args):
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
if not fake_tensors_eq(node.meta["val"], replacement.meta["val"]):
|
| 64 |
+
if fake_tensors_eq(
|
| 65 |
+
node.meta["val"],
|
| 66 |
+
replacement.meta["val"],
|
| 67 |
+
("shape", "device"),
|
| 68 |
+
):
|
| 69 |
+
with graph.inserting_after(node):
|
| 70 |
+
replacement = graph.call_function(
|
| 71 |
+
torch.ops.prims.convert_element_type.default,
|
| 72 |
+
args=(replacement, node.meta["val"].dtype),
|
| 73 |
+
)
|
| 74 |
+
else:
|
| 75 |
+
return
|
| 76 |
+
|
| 77 |
+
node.replace_all_uses_with(replacement)
|
| 78 |
+
replacement.meta.update(node.meta)
|
| 79 |
+
graph.erase_node(node)
|
| 80 |
+
|
| 81 |
+
for node in graph.nodes:
|
| 82 |
+
if node.op != "call_function":
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
# TODO handle Tensor-Scalar adds, it's a different schema
|
| 86 |
+
if node.target == aten.add.Tensor and len(node.args) == 2:
|
| 87 |
+
if (
|
| 88 |
+
not any(e in zeros for e in node.args)
|
| 89 |
+
or node.kwargs.get("alpha", 1) != 1
|
| 90 |
+
):
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
replace_index = 1 if node.args[0] in zeros else 0
|
| 94 |
+
replace_no_op(node, replace_index)
|
| 95 |
+
|
| 96 |
+
elif node.target == aten.sub.Tensor and len(node.args) == 2:
|
| 97 |
+
if node.args[1] not in zeros or node.kwargs.get("alpha", 1) != 1:
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
replace_no_op(node, 0)
|
| 101 |
+
|
| 102 |
+
elif node.target == aten.mul.Tensor and len(node.args) == 2:
|
| 103 |
+
if not any(e in ones for e in node.args):
|
| 104 |
+
continue
|
| 105 |
+
|
| 106 |
+
replace_input_index = 1 if node.args[0] in ones else 0
|
| 107 |
+
replace_no_op(node, replace_input_index)
|
| 108 |
+
|
| 109 |
+
elif (
|
| 110 |
+
node.target == aten.div.Tensor
|
| 111 |
+
and len(node.args) == 2
|
| 112 |
+
and node.args[1] in ones
|
| 113 |
+
):
|
| 114 |
+
replace_no_op(node, 0)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@torch.utils._python_dispatch._disable_current_modes()
|
| 118 |
+
def remove_redundant_views(gm: torch.fx.GraphModule):
|
| 119 |
+
"""
|
| 120 |
+
Removes redundant views by reusing existing ones.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
# A dictionary mapping a tensor to all aliased views.
|
| 124 |
+
views: Dict[torch.fx.Node, Dict[torch.dtype, torch.fx.Node]] = {}
|
| 125 |
+
graph = gm.graph
|
| 126 |
+
|
| 127 |
+
for node in graph.nodes:
|
| 128 |
+
if node.op != "call_function":
|
| 129 |
+
continue
|
| 130 |
+
|
| 131 |
+
if node.target != torch.ops.aten.view.dtype:
|
| 132 |
+
continue
|
| 133 |
+
|
| 134 |
+
src = node.args[0]
|
| 135 |
+
to_type = node.args[1]
|
| 136 |
+
existing_views = views.get(src)
|
| 137 |
+
is_needed = True
|
| 138 |
+
|
| 139 |
+
if existing_views:
|
| 140 |
+
# Replace the view with the an existing view if available.
|
| 141 |
+
alias = existing_views.get(to_type)
|
| 142 |
+
if alias:
|
| 143 |
+
is_needed = False
|
| 144 |
+
node.replace_all_uses_with(alias)
|
| 145 |
+
alias.meta.update(node.meta)
|
| 146 |
+
graph.erase_node(node)
|
| 147 |
+
else:
|
| 148 |
+
from_type = src.meta["val"].dtype
|
| 149 |
+
existing_views = {from_type: src}
|
| 150 |
+
views[src] = existing_views
|
| 151 |
+
|
| 152 |
+
if is_needed:
|
| 153 |
+
# Save the new alias but do not replace existing one.
|
| 154 |
+
existing_views.setdefault(to_type, node)
|
| 155 |
+
views[node] = existing_views
|
| 156 |
+
|
| 157 |
+
# Clean up unused views.
|
| 158 |
+
while True:
|
| 159 |
+
unused_views = [alias for alias in views if not alias.users]
|
| 160 |
+
if len(unused_views) == 0:
|
| 161 |
+
break
|
| 162 |
+
for unused in unused_views:
|
| 163 |
+
views.pop(unused)
|
| 164 |
+
graph.erase_node(unused)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class UniformValueConstantFolder(ConstantFolder):
|
| 168 |
+
"""
|
| 169 |
+
Runs constant folding and replaces tensors that have a unifrom value
|
| 170 |
+
with a tensor constructor call: aten.full([shape], value, ...)
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
def __init__(self, gm, skip_constructors=False):
|
| 174 |
+
super().__init__(gm, skip_constructors)
|
| 175 |
+
self.node_storages_ptrs: Dict[torch.fx.Node, int] = {}
|
| 176 |
+
self.constant_data_ptrs: Dict[torch.fx.Node, StorageWeakRef] = {}
|
| 177 |
+
# we may constant fold a tensor which in the graph has a sym size
|
| 178 |
+
# see: [constant folding refining of symints]
|
| 179 |
+
self.node_replacements_shapes: Dict[torch.fx.Node, List[int]] = {}
|
| 180 |
+
|
| 181 |
+
def insertable_tensor_check(self, t: torch.Tensor) -> bool:
|
| 182 |
+
# TODO - we could also Tensors which get replaced with arange here
|
| 183 |
+
return (
|
| 184 |
+
t.numel() != 0
|
| 185 |
+
and bool((t == t.flatten()[0]).all())
|
| 186 |
+
and torch._C._has_storage(t)
|
| 187 |
+
and t.layout == torch.strided
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
|
| 191 |
+
self.node_replacements[node] = tensor.flatten()[0].item()
|
| 192 |
+
self.constant_data_ptrs[node] = StorageWeakRef(tensor.untyped_storage())
|
| 193 |
+
shape = list(tensor.shape)
|
| 194 |
+
assert all(type(dim) is int for dim in shape)
|
| 195 |
+
self.node_replacements_shapes[node] = shape
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@torch.utils._python_dispatch._disable_current_modes()
|
| 199 |
+
def constant_fold_uniform_value(gm: torch.fx.GraphModule):
|
| 200 |
+
"Runs constant folding and replaces constants which can be constructed with a single `full` call. Calls into remove_no_ops."
|
| 201 |
+
aten = torch.ops.aten
|
| 202 |
+
|
| 203 |
+
# Constant folding can leak memory, especially with repeated compilation, so we are only going to
|
| 204 |
+
# remove constants which can be replaced with a constructor.
|
| 205 |
+
cf = UniformValueConstantFolder(gm)
|
| 206 |
+
cf.run()
|
| 207 |
+
|
| 208 |
+
node_replacements = cf.node_replacements
|
| 209 |
+
|
| 210 |
+
# note: [constant folding refining of symints]
|
| 211 |
+
# constant folding will partially evaluate a graph such that values which have dependencies which
|
| 212 |
+
# are entirely known at compile time may also become compile time constants. in some cases,
|
| 213 |
+
# this will include symints which we had not yet previously deduced are guaranteed a
|
| 214 |
+
# constant value and is then deduced in constant folding. an example is:
|
| 215 |
+
# unbacked_symint_eq_11 = torch.full((), 11).item()
|
| 216 |
+
# torch.full((unbacked_symint_eq_11,), 0)
|
| 217 |
+
node_replacements_shapes = cf.node_replacements_shapes
|
| 218 |
+
|
| 219 |
+
graph = gm.graph
|
| 220 |
+
|
| 221 |
+
zeros = set()
|
| 222 |
+
ones = set()
|
| 223 |
+
|
| 224 |
+
# Got failures in `test_is_set_to_cuda` if we change aliasing on constants,
|
| 225 |
+
# so just constant-ify if a Tensor is unaliased
|
| 226 |
+
constant_data_ptr_count: typing.Counter[StorageWeakRef] = Counter()
|
| 227 |
+
|
| 228 |
+
for node in cf.node_replacements:
|
| 229 |
+
constant_data_ptr_count[cf.constant_data_ptrs[node]] += 1
|
| 230 |
+
|
| 231 |
+
for node, value in node_replacements.items():
|
| 232 |
+
# we dont have a functional way right now of instantiating a non-contiguous tensor with full/zeros/ones right now
|
| 233 |
+
# hasn't shown up to be important yet
|
| 234 |
+
fake_tensor = node.meta["val"]
|
| 235 |
+
if not fake_tensor.is_contiguous(memory_format=torch.contiguous_format):
|
| 236 |
+
continue
|
| 237 |
+
|
| 238 |
+
if constant_data_ptr_count[cf.constant_data_ptrs[node]] > 1:
|
| 239 |
+
continue
|
| 240 |
+
|
| 241 |
+
with graph.inserting_after(node):
|
| 242 |
+
# the conversion from tensor and back to value can be lossy, just use the original full ctor value
|
| 243 |
+
if (
|
| 244 |
+
node.op == "call_function"
|
| 245 |
+
and node.target == aten.full.default
|
| 246 |
+
and len(node.args) == 2
|
| 247 |
+
):
|
| 248 |
+
value = node.args[1]
|
| 249 |
+
|
| 250 |
+
# refines symints, see [constant folding refining of symints] above
|
| 251 |
+
for runtime_size, compile_time_size in zip(
|
| 252 |
+
node_replacements_shapes[node], fake_tensor.shape
|
| 253 |
+
):
|
| 254 |
+
torch._check(runtime_size == compile_time_size)
|
| 255 |
+
|
| 256 |
+
# zeros, and ones just get traced into full, so we insert those
|
| 257 |
+
new_node = graph.call_function(
|
| 258 |
+
aten.full.default,
|
| 259 |
+
args=(node_replacements_shapes[node], value),
|
| 260 |
+
kwargs={
|
| 261 |
+
"dtype": fake_tensor.dtype,
|
| 262 |
+
"layout": torch.strided,
|
| 263 |
+
"device": fake_tensor.device,
|
| 264 |
+
"pin_memory": False,
|
| 265 |
+
},
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
new_node.meta.update(node.meta)
|
| 269 |
+
node.replace_all_uses_with(new_node)
|
| 270 |
+
graph.erase_node(node)
|
| 271 |
+
|
| 272 |
+
if value == 0:
|
| 273 |
+
zeros.add(new_node)
|
| 274 |
+
elif value == 1:
|
| 275 |
+
ones.add(new_node)
|
| 276 |
+
|
| 277 |
+
remove_no_ops(gm, zeros, ones)
|
| 278 |
+
remove_redundant_views(gm)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def joint_graph_passes(graph: torch.fx.GraphModule):
|
| 282 |
+
"""
|
| 283 |
+
Run FX transformations on the joint forwards+backwards graph.
|
| 284 |
+
"""
|
| 285 |
+
lazy_init()
|
| 286 |
+
count = 0
|
| 287 |
+
|
| 288 |
+
if config.joint_graph_constant_folding:
|
| 289 |
+
constant_fold_uniform_value(graph)
|
| 290 |
+
|
| 291 |
+
if config.pattern_matcher:
|
| 292 |
+
count += patterns.apply(graph.graph) # type: ignore[arg-type]
|
| 293 |
+
|
| 294 |
+
if not config.fallback_random:
|
| 295 |
+
count += replace_random_passes(graph)
|
| 296 |
+
|
| 297 |
+
if count:
|
| 298 |
+
stable_topological_sort(graph.graph)
|
| 299 |
+
graph.graph.lint()
|
| 300 |
+
graph.recompile()
|
| 301 |
+
return graph
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
@register_graph_pattern(
|
| 305 |
+
CallFunction(
|
| 306 |
+
torch.ops.prims.convert_element_type.default,
|
| 307 |
+
CallFunction(
|
| 308 |
+
torch.ops.prims.convert_element_type.default,
|
| 309 |
+
KeywordArg("arg"),
|
| 310 |
+
KeywordArg("dtype1"),
|
| 311 |
+
),
|
| 312 |
+
KeywordArg("dtype2"),
|
| 313 |
+
),
|
| 314 |
+
pass_dict=patterns,
|
| 315 |
+
)
|
| 316 |
+
def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtype):
|
| 317 |
+
"""Remove chain of dtype conversions often created by AMP"""
|
| 318 |
+
graph = match.graph
|
| 319 |
+
node = match.output_node()
|
| 320 |
+
allowed = {torch.float16, torch.bfloat16, torch.float32, torch.float64}
|
| 321 |
+
if dtype1 in allowed and dtype2 in allowed:
|
| 322 |
+
repl = graph.call_function(
|
| 323 |
+
torch.ops.prims.convert_element_type.default, (arg, dtype2)
|
| 324 |
+
)
|
| 325 |
+
repl.meta.update(node.meta)
|
| 326 |
+
node.replace_all_uses_with(repl)
|
| 327 |
+
match.erase_nodes(graph)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
@register_graph_pattern(
|
| 331 |
+
CallFunction(torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")),
|
| 332 |
+
pass_dict=patterns,
|
| 333 |
+
)
|
| 334 |
+
def pointless_view(match: Match, arg, size):
|
| 335 |
+
"""Remove no-op view"""
|
| 336 |
+
graph = match.graph
|
| 337 |
+
node = match.output_node()
|
| 338 |
+
arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr]
|
| 339 |
+
if size == arg_size:
|
| 340 |
+
node.replace_all_uses_with(node.args[0])
|
| 341 |
+
match.erase_nodes(graph)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pre_grad.py
ADDED
|
@@ -0,0 +1,611 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch._dynamo.utils import counters, detect_fake_mode, optimus_scuba_log
|
| 8 |
+
from torch._utils_internal import upload_graph
|
| 9 |
+
from torch.fx.experimental.optimization import (
|
| 10 |
+
matches_module_pattern,
|
| 11 |
+
replace_node_module,
|
| 12 |
+
)
|
| 13 |
+
from torch.fx.passes.shape_prop import ShapeProp
|
| 14 |
+
from torch.nn import functional as F
|
| 15 |
+
from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights
|
| 16 |
+
|
| 17 |
+
from .. import config
|
| 18 |
+
|
| 19 |
+
from ..fx_utils import matches_module_function_pattern
|
| 20 |
+
from ..pattern_matcher import (
|
| 21 |
+
init_once_fakemode,
|
| 22 |
+
PatternMatcherPass,
|
| 23 |
+
stable_topological_sort,
|
| 24 |
+
)
|
| 25 |
+
from ..utils import is_cpu_device, pass_execution_and_save
|
| 26 |
+
from .group_batch_fusion import group_batch_fusion_passes
|
| 27 |
+
from .misc_patterns import numpy_compat_normalization
|
| 28 |
+
|
| 29 |
+
log = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
normalization_pass = PatternMatcherPass(
|
| 32 |
+
prevent_match_across_mutations=True, pass_name="normalization_pass"
|
| 33 |
+
)
|
| 34 |
+
merge_splits_pass = PatternMatcherPass(
|
| 35 |
+
prevent_match_across_mutations=True, pass_name="merge_splits_pass"
|
| 36 |
+
)
|
| 37 |
+
split_cat_pass = PatternMatcherPass(
|
| 38 |
+
prevent_match_across_mutations=True, pass_name="split_cat_pass"
|
| 39 |
+
)
|
| 40 |
+
unbind_stack_pass = PatternMatcherPass(
|
| 41 |
+
prevent_match_across_mutations=True, pass_name="unbind_stack_pass"
|
| 42 |
+
)
|
| 43 |
+
efficient_conv_bn_eval_pass = PatternMatcherPass(
|
| 44 |
+
prevent_match_across_mutations=True, pass_name="efficient_conv_bn_eval_pass"
|
| 45 |
+
)
|
| 46 |
+
merge_getitem_cat_pass = PatternMatcherPass(
|
| 47 |
+
prevent_match_across_mutations=True, pass_name="merge_getitem_cat_pass"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
fuse_split_linear_add_pass = PatternMatcherPass(
|
| 51 |
+
prevent_match_across_mutations=True,
|
| 52 |
+
pass_name="fuse_split_linear_add_pass",
|
| 53 |
+
)
|
| 54 |
+
fuse_chunk_squeeze_cat_pass = PatternMatcherPass(
|
| 55 |
+
prevent_match_across_mutations=True,
|
| 56 |
+
pass_name="fuse_chunk_squeeze_cat_pass",
|
| 57 |
+
)
|
| 58 |
+
remove_reshape_pass = PatternMatcherPass(
|
| 59 |
+
prevent_match_across_mutations=True,
|
| 60 |
+
pass_name="remove_reshape_pass",
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# based on predispatch aten IR
|
| 64 |
+
normalization_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
|
| 65 |
+
merge_splits_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
|
| 66 |
+
split_cat_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
|
| 67 |
+
unbind_stack_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
|
| 68 |
+
merge_getitem_cat_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def fuse_parallel_linear_pass(graph):
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def remove_split_ops(graph, shape_prop):
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
pattern_matcher_passes: List[PatternMatcherPass] = [
|
| 80 |
+
normalization_pass,
|
| 81 |
+
merge_getitem_cat_pass,
|
| 82 |
+
merge_splits_pass,
|
| 83 |
+
split_cat_pass,
|
| 84 |
+
unbind_stack_pass,
|
| 85 |
+
efficient_conv_bn_eval_pass,
|
| 86 |
+
]
|
| 87 |
+
pattern_matcher_passes_aten: List[PatternMatcherPass] = [
|
| 88 |
+
merge_getitem_cat_pass_aten,
|
| 89 |
+
merge_splits_pass_aten,
|
| 90 |
+
split_cat_pass_aten,
|
| 91 |
+
unbind_stack_pass_aten,
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@init_once_fakemode
|
| 96 |
+
def lazy_init():
|
| 97 |
+
from . import efficient_conv_bn_eval, split_cat # noqa: F401 # noqa: F401
|
| 98 |
+
|
| 99 |
+
if config.is_fbcode():
|
| 100 |
+
from . import fb # type: ignore[attr-defined] # noqa: F401
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs=None):
|
| 104 |
+
"""
|
| 105 |
+
Apply passes on the input FX graph using Torch IR.
|
| 106 |
+
|
| 107 |
+
WARNING:
|
| 108 |
+
The IR before grad is not functional or normalized, so it is harder
|
| 109 |
+
to write passes on this IR. Passes must be safe with respect to
|
| 110 |
+
aliasing and mutation and need to handle all possible arg schemas.
|
| 111 |
+
|
| 112 |
+
Consider adding a new pass to post_grad.py or joint_graph.py which
|
| 113 |
+
are after functionalization and normalization.
|
| 114 |
+
"""
|
| 115 |
+
if config.pattern_matcher:
|
| 116 |
+
lazy_init()
|
| 117 |
+
if hasattr(
|
| 118 |
+
config, "fx_passes_numeric_check"
|
| 119 |
+
) and config.fx_passes_numeric_check.get("pre_grad", False):
|
| 120 |
+
gm_before_fx_passes = gm.__copy__()
|
| 121 |
+
# explicitly run with predispatch atenIR based passes
|
| 122 |
+
if config.is_predispatch:
|
| 123 |
+
|
| 124 |
+
def shape_prop(mod) -> None:
|
| 125 |
+
ShapeProp(
|
| 126 |
+
gm=mod,
|
| 127 |
+
fake_mode=detect_fake_mode(example_inputs),
|
| 128 |
+
).propagate(*example_inputs)
|
| 129 |
+
|
| 130 |
+
# normalization pass
|
| 131 |
+
pass_execution_and_save(
|
| 132 |
+
normalization_pass_aten.apply,
|
| 133 |
+
gm,
|
| 134 |
+
"[Pre grad(predispatch IR)]Apply normalization pass",
|
| 135 |
+
)
|
| 136 |
+
pass_execution_and_save(
|
| 137 |
+
group_batch_fusion_passes,
|
| 138 |
+
gm,
|
| 139 |
+
"[Pre grad(predispatch IR)] Apply group_batch_fusion",
|
| 140 |
+
)
|
| 141 |
+
pass_execution_and_save(
|
| 142 |
+
fuse_chunk_squeeze_cat_pass.apply,
|
| 143 |
+
gm,
|
| 144 |
+
"[Pre grad(predispatch IR)] Apply fuse_chunk_squeeze_cat_pass",
|
| 145 |
+
)
|
| 146 |
+
pass_execution_and_save(
|
| 147 |
+
fuse_split_linear_add_pass.apply,
|
| 148 |
+
gm,
|
| 149 |
+
"[Pre grad(predispatch IR)] Apply fuse_split_linear_add_pass",
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
log.debug(
|
| 153 |
+
"[Pre grad(predispatch IR)]Before split cat in pre grad pass. graph: %s",
|
| 154 |
+
gm.graph,
|
| 155 |
+
)
|
| 156 |
+
for ind, pattern_matcher_pass_aten in enumerate(
|
| 157 |
+
pattern_matcher_passes_aten
|
| 158 |
+
):
|
| 159 |
+
pass_execution_and_save(
|
| 160 |
+
pattern_matcher_pass_aten.apply,
|
| 161 |
+
gm,
|
| 162 |
+
f"[Pre grad(predispatch IR)]Apply split_cat, index: {ind}",
|
| 163 |
+
)
|
| 164 |
+
pass_execution_and_save(
|
| 165 |
+
remove_reshape_pass.apply,
|
| 166 |
+
gm,
|
| 167 |
+
"[Pre grad(predispatch IR)] Apply remove_reshape_pass",
|
| 168 |
+
)
|
| 169 |
+
pass_execution_and_save(
|
| 170 |
+
fuse_parallel_linear_pass,
|
| 171 |
+
gm,
|
| 172 |
+
"[Pre grad(predispatch IR)] Apply fuse_parallel_linear_pass",
|
| 173 |
+
)
|
| 174 |
+
pass_execution_and_save(
|
| 175 |
+
lambda graph: remove_split_ops(graph.owning_module, shape_prop),
|
| 176 |
+
gm,
|
| 177 |
+
"[Pre grad(predispatch IR)] Apply remove_split_ops",
|
| 178 |
+
)
|
| 179 |
+
shape_prop(gm)
|
| 180 |
+
|
| 181 |
+
else:
|
| 182 |
+
# We only log the graph with changes to avoid the excessive compilation time
|
| 183 |
+
# https://fb.workplace.com/groups/257735836456307/permalink/633533465543207/
|
| 184 |
+
if example_inputs is not None:
|
| 185 |
+
gm = fuse_fx(gm, example_inputs)
|
| 186 |
+
numpy_compat_normalization(gm.graph)
|
| 187 |
+
inductor_before_change = copy.deepcopy(counters["inductor"])
|
| 188 |
+
group_batch_fusion_passes(gm.graph, pre_grad=True)
|
| 189 |
+
if counters["inductor"] != inductor_before_change:
|
| 190 |
+
optimus_scuba_log["group_batch_fusion_pre_grad"] = upload_graph(
|
| 191 |
+
gm.graph
|
| 192 |
+
)
|
| 193 |
+
for pattern_matcher_pass in pattern_matcher_passes:
|
| 194 |
+
inductor_before_change = copy.deepcopy(counters["inductor"])
|
| 195 |
+
pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
|
| 196 |
+
if counters["inductor"] != inductor_before_change:
|
| 197 |
+
optimus_scuba_log[
|
| 198 |
+
f"split_cat_pattern_{pattern_matcher_pass.pass_name}_pre_grad"
|
| 199 |
+
] = upload_graph(gm.graph)
|
| 200 |
+
|
| 201 |
+
if config.pre_grad_custom_pass is not None:
|
| 202 |
+
config.pre_grad_custom_pass(gm.graph)
|
| 203 |
+
stable_topological_sort(gm.graph)
|
| 204 |
+
gm.graph.lint()
|
| 205 |
+
gm.recompile()
|
| 206 |
+
|
| 207 |
+
if (
|
| 208 |
+
config.pattern_matcher
|
| 209 |
+
and hasattr(config, "fx_passes_numeric_check")
|
| 210 |
+
and config.fx_passes_numeric_check.get("pre_grad", False)
|
| 211 |
+
and example_inputs is not None
|
| 212 |
+
):
|
| 213 |
+
from .numeric_utils import numeric_check_if_enabled
|
| 214 |
+
|
| 215 |
+
gm_after_fx_passes = gm.__copy__()
|
| 216 |
+
numeric_check_if_enabled(
|
| 217 |
+
gm_before_fx_passes, # type: ignore[possibly-undefined]
|
| 218 |
+
gm_after_fx_passes,
|
| 219 |
+
example_inputs,
|
| 220 |
+
config.fx_passes_numeric_check.get("num_iterations", 1),
|
| 221 |
+
config.fx_passes_numeric_check.get("precision", 1e-4),
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
return gm
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def fuse_fx(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule:
|
| 228 |
+
is_cpu = is_cpu_device(example_inputs)
|
| 229 |
+
|
| 230 |
+
fake_mode = detect_fake_mode(example_inputs)
|
| 231 |
+
|
| 232 |
+
gm = sink_cat_after_pointwise(gm)
|
| 233 |
+
if config.permute_fusion and not is_cpu:
|
| 234 |
+
# For linear permute fusion, we need to check input info to identify
|
| 235 |
+
# and perform proper permutation/transpose
|
| 236 |
+
ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs)
|
| 237 |
+
gm = linear_permute_fusion(gm)
|
| 238 |
+
gm = permute_linear_fusion(gm)
|
| 239 |
+
gm = permute_matmul_fusion(gm)
|
| 240 |
+
|
| 241 |
+
# make sure the autograd is disabled.
|
| 242 |
+
if torch.is_grad_enabled() or not is_cpu:
|
| 243 |
+
return gm
|
| 244 |
+
if config.freezing:
|
| 245 |
+
gm = remove_identity(gm)
|
| 246 |
+
gm = fuse_conv_bn(gm)
|
| 247 |
+
return gm
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def fetch_attr(target: str, mod):
|
| 251 |
+
target_atoms = target.split(".")
|
| 252 |
+
attr_itr = mod
|
| 253 |
+
for i, atom in enumerate(target_atoms):
|
| 254 |
+
if not hasattr(attr_itr, atom):
|
| 255 |
+
raise RuntimeError(
|
| 256 |
+
f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
|
| 257 |
+
)
|
| 258 |
+
attr_itr = getattr(attr_itr, atom)
|
| 259 |
+
return attr_itr
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def remove_identity(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
| 263 |
+
"""
|
| 264 |
+
Removes all identity layers from the module.
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
class IdentityRemover(torch.fx.Transformer):
|
| 268 |
+
def call_module(self, target, args, kwargs):
|
| 269 |
+
if isinstance(self.submodules[target], nn.Identity):
|
| 270 |
+
assert len(args) == 1
|
| 271 |
+
return args[0]
|
| 272 |
+
else:
|
| 273 |
+
return super().call_module(target, args, kwargs)
|
| 274 |
+
|
| 275 |
+
return IdentityRemover(gm).transform()
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False) -> torch.fx.GraphModule:
|
| 279 |
+
"""
|
| 280 |
+
Fuses Convolution/BN layers for inference purposes.
|
| 281 |
+
"""
|
| 282 |
+
modules_patterns = [
|
| 283 |
+
(torch.nn.Conv1d, torch.nn.BatchNorm1d),
|
| 284 |
+
(torch.nn.Conv2d, torch.nn.BatchNorm2d),
|
| 285 |
+
(torch.nn.Conv3d, torch.nn.BatchNorm3d),
|
| 286 |
+
]
|
| 287 |
+
module_function_patterns = [
|
| 288 |
+
(torch.nn.Conv1d, F.batch_norm),
|
| 289 |
+
(torch.nn.Conv2d, F.batch_norm),
|
| 290 |
+
(torch.nn.Conv3d, F.batch_norm),
|
| 291 |
+
]
|
| 292 |
+
modules = dict(gm.named_modules())
|
| 293 |
+
for pattern in modules_patterns:
|
| 294 |
+
for node in gm.graph.nodes:
|
| 295 |
+
if matches_module_pattern(pattern, node, modules):
|
| 296 |
+
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
|
| 297 |
+
continue
|
| 298 |
+
conv = modules[node.args[0].target]
|
| 299 |
+
bn = modules[node.target]
|
| 300 |
+
eval_mode = all(not n.training for n in [conv, bn])
|
| 301 |
+
if not eval_mode:
|
| 302 |
+
continue
|
| 303 |
+
if not bn.track_running_stats:
|
| 304 |
+
continue
|
| 305 |
+
fused_conv = fuse_conv_bn_eval(conv, bn)
|
| 306 |
+
replace_node_module(node.args[0], modules, fused_conv)
|
| 307 |
+
node.replace_all_uses_with(node.args[0])
|
| 308 |
+
gm.graph.erase_node(node)
|
| 309 |
+
gm.graph.lint()
|
| 310 |
+
for pattern in module_function_patterns:
|
| 311 |
+
for node in gm.graph.nodes:
|
| 312 |
+
if matches_module_function_pattern(pattern, node, modules):
|
| 313 |
+
# TODO: support kwargs.
|
| 314 |
+
if len(node.args) != 8:
|
| 315 |
+
continue
|
| 316 |
+
conv = modules[node.args[0].target]
|
| 317 |
+
bn_training = node.args[5]
|
| 318 |
+
bn_eps = node.args[7]
|
| 319 |
+
if conv.training or bn_training:
|
| 320 |
+
continue
|
| 321 |
+
if type(bn_eps) is not float:
|
| 322 |
+
continue
|
| 323 |
+
bn_args_is_constant = all(
|
| 324 |
+
n.op == "get_attr" and len(n.users) == 1 for n in node.args[1:5]
|
| 325 |
+
)
|
| 326 |
+
if not bn_args_is_constant:
|
| 327 |
+
continue
|
| 328 |
+
bn_running_mean = fetch_attr(node.args[1].target, gm)
|
| 329 |
+
bn_running_var = fetch_attr(node.args[2].target, gm)
|
| 330 |
+
bn_weight = fetch_attr(node.args[3].target, gm)
|
| 331 |
+
bn_bias = fetch_attr(node.args[4].target, gm)
|
| 332 |
+
if bn_running_mean is None or bn_running_var is None:
|
| 333 |
+
continue
|
| 334 |
+
fused_conv = copy.deepcopy(conv)
|
| 335 |
+
fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights(
|
| 336 |
+
fused_conv.weight,
|
| 337 |
+
fused_conv.bias,
|
| 338 |
+
bn_running_mean,
|
| 339 |
+
bn_running_var,
|
| 340 |
+
bn_eps,
|
| 341 |
+
bn_weight,
|
| 342 |
+
bn_bias,
|
| 343 |
+
)
|
| 344 |
+
replace_node_module(node.args[0], modules, fused_conv)
|
| 345 |
+
node.replace_all_uses_with(node.args[0])
|
| 346 |
+
gm.graph.erase_node(node)
|
| 347 |
+
gm.graph.lint()
|
| 348 |
+
gm.recompile()
|
| 349 |
+
|
| 350 |
+
return gm
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
class NormalizedLinearNode:
|
| 354 |
+
def __init__(self, node: torch.fx.Node) -> None:
|
| 355 |
+
assert node.op == "call_function"
|
| 356 |
+
assert node.target in [torch.nn.functional.linear]
|
| 357 |
+
self.node: torch.fx.Node = node
|
| 358 |
+
|
| 359 |
+
def get_input(self) -> torch.fx.Node:
|
| 360 |
+
if len(self.node.args) > 0:
|
| 361 |
+
return self.node.args[0] # type: ignore[return-value]
|
| 362 |
+
else:
|
| 363 |
+
return self.node.kwargs["input"] # type: ignore[return-value]
|
| 364 |
+
|
| 365 |
+
def get_weight(self) -> torch.fx.Node:
|
| 366 |
+
if len(self.node.args) > 1:
|
| 367 |
+
return self.node.args[1] # type: ignore[return-value]
|
| 368 |
+
else:
|
| 369 |
+
return self.node.kwargs["weight"] # type: ignore[return-value]
|
| 370 |
+
|
| 371 |
+
def get_bias(self) -> torch.fx.Node:
|
| 372 |
+
if len(self.node.args) > 2:
|
| 373 |
+
return self.node.args[2] # type: ignore[return-value]
|
| 374 |
+
else:
|
| 375 |
+
return self.node.kwargs["bias"] if "bias" in self.node.kwargs else None # type: ignore[return-value]
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class NormalizedMatmulNode:
|
| 379 |
+
def __init__(self, node: torch.fx.Node) -> None:
|
| 380 |
+
assert node.op == "call_function"
|
| 381 |
+
assert node.target in [torch.bmm, torch.matmul]
|
| 382 |
+
self.node: torch.fx.Node = node
|
| 383 |
+
|
| 384 |
+
def get_input(self) -> torch.fx.Node:
|
| 385 |
+
if len(self.node.args) > 0:
|
| 386 |
+
return self.node.args[0] # type: ignore[return-value]
|
| 387 |
+
else:
|
| 388 |
+
return self.node.kwargs["input"] # type: ignore[return-value]
|
| 389 |
+
|
| 390 |
+
def get_other(self) -> torch.fx.Node:
|
| 391 |
+
if len(self.node.args) > 1:
|
| 392 |
+
return self.node.args[1] # type: ignore[return-value]
|
| 393 |
+
else:
|
| 394 |
+
return self.node.kwargs["other"] # type: ignore[return-value]
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def check_permute(node: torch.fx.Node) -> bool:
|
| 398 |
+
ranks = len(node.meta["tensor_meta"].shape)
|
| 399 |
+
if len(node.args) > 3:
|
| 400 |
+
permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] # type: ignore[operator]
|
| 401 |
+
elif (
|
| 402 |
+
"permutation" in node.kwargs
|
| 403 |
+
and node.kwargs["permutation"] is not None
|
| 404 |
+
and len(node.kwargs["permutation"]) > 2 # type: ignore[arg-type]
|
| 405 |
+
):
|
| 406 |
+
permutation = [i % ranks for i in node.kwargs["permutation"]] # type: ignore[union-attr]
|
| 407 |
+
else:
|
| 408 |
+
return False
|
| 409 |
+
allowed_permutation = list(range(ranks))
|
| 410 |
+
allowed_permutation[-1] = ranks - 2
|
| 411 |
+
allowed_permutation[-2] = ranks - 1
|
| 412 |
+
return permutation == allowed_permutation
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def sink_cat_after_pointwise(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
| 416 |
+
def one_user(node):
|
| 417 |
+
users = list(node.users)
|
| 418 |
+
return users[0] if len(users) == 1 else None
|
| 419 |
+
|
| 420 |
+
def is_view(node):
|
| 421 |
+
view = {"view"}
|
| 422 |
+
return node.op == "call_method" and node.target in view
|
| 423 |
+
|
| 424 |
+
def is_pointwise_unary(node):
|
| 425 |
+
pointwise = {torch.relu, torch.tanh, "relu", "tanh"}
|
| 426 |
+
return node.op in {"call_function", "call_method"} and node.target in pointwise
|
| 427 |
+
|
| 428 |
+
g = module.graph
|
| 429 |
+
for node in g.nodes:
|
| 430 |
+
if node.op != "call_function" or node.target != torch.cat:
|
| 431 |
+
continue
|
| 432 |
+
|
| 433 |
+
cat_or_view = node
|
| 434 |
+
while True:
|
| 435 |
+
user = one_user(cat_or_view)
|
| 436 |
+
if not user or not is_view(user):
|
| 437 |
+
break
|
| 438 |
+
cat_or_view = user
|
| 439 |
+
|
| 440 |
+
if user and is_pointwise_unary(user):
|
| 441 |
+
with g.inserting_before(node):
|
| 442 |
+
|
| 443 |
+
def cat_args(tensors, dim=0):
|
| 444 |
+
return tensors, dim
|
| 445 |
+
|
| 446 |
+
tensors, dim = cat_args(*node.args, **node.kwargs)
|
| 447 |
+
new_tensors = [
|
| 448 |
+
g.create_node(user.op, user.target, args=(arg,), kwargs=user.kwargs)
|
| 449 |
+
for arg in tensors
|
| 450 |
+
]
|
| 451 |
+
new_cat = g.create_node(
|
| 452 |
+
"call_function", torch.cat, args=(new_tensors, dim)
|
| 453 |
+
)
|
| 454 |
+
user.replace_all_uses_with(cat_or_view)
|
| 455 |
+
node.replace_all_uses_with(new_cat)
|
| 456 |
+
g.erase_node(user)
|
| 457 |
+
g.erase_node(node)
|
| 458 |
+
g.lint()
|
| 459 |
+
module.recompile()
|
| 460 |
+
return module
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
| 464 |
+
for node in module.graph.nodes:
|
| 465 |
+
if (
|
| 466 |
+
node.op == "call_method"
|
| 467 |
+
and node.target == "permute"
|
| 468 |
+
and check_permute(node)
|
| 469 |
+
):
|
| 470 |
+
if len(node.args) > 0:
|
| 471 |
+
input_node = node.args[0]
|
| 472 |
+
else:
|
| 473 |
+
input_node = node.kwargs["input"]
|
| 474 |
+
if (
|
| 475 |
+
input_node.op == "call_function"
|
| 476 |
+
and input_node.target == torch.nn.functional.linear
|
| 477 |
+
):
|
| 478 |
+
normalized = NormalizedLinearNode(input_node)
|
| 479 |
+
input = normalized.get_input()
|
| 480 |
+
weight = normalized.get_weight()
|
| 481 |
+
bias = normalized.get_bias()
|
| 482 |
+
with module.graph.inserting_before(node):
|
| 483 |
+
fused_node = module.graph.call_function(
|
| 484 |
+
linear_transpose, args=(input, weight, bias)
|
| 485 |
+
)
|
| 486 |
+
node.replace_all_uses_with(fused_node)
|
| 487 |
+
module.graph.erase_node(node)
|
| 488 |
+
if len(input_node.users) == 0:
|
| 489 |
+
module.graph.erase_node(input_node)
|
| 490 |
+
|
| 491 |
+
module.graph.lint()
|
| 492 |
+
module.recompile()
|
| 493 |
+
return module
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
# Y1 = X * W^T + bias
|
| 497 |
+
# Y2 = Y1.permute(0, 2, 1)
|
| 498 |
+
# ---->
|
| 499 |
+
# Y2 = (W * X^T + bias.unsqueeze(-1))^T
|
| 500 |
+
def linear_transpose(
|
| 501 |
+
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
|
| 502 |
+
) -> torch.Tensor:
|
| 503 |
+
if bias is None:
|
| 504 |
+
return torch.matmul(weight, input.transpose(-1, -2))
|
| 505 |
+
return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1)
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
| 509 |
+
for node in module.graph.nodes:
|
| 510 |
+
if node.op == "call_function" and node.target == torch.nn.functional.linear:
|
| 511 |
+
if len(node.args) > 0:
|
| 512 |
+
input_node = node.args[0]
|
| 513 |
+
else:
|
| 514 |
+
input_node = node.kwargs["input"]
|
| 515 |
+
if (
|
| 516 |
+
input_node.op == "call_method"
|
| 517 |
+
and input_node.target == "permute"
|
| 518 |
+
and check_permute(input_node)
|
| 519 |
+
):
|
| 520 |
+
normalized = NormalizedLinearNode(node)
|
| 521 |
+
if len(input_node.args) > 0:
|
| 522 |
+
input = input_node.args[0]
|
| 523 |
+
else:
|
| 524 |
+
input = input_node.kwargs["input"]
|
| 525 |
+
weight = normalized.get_weight()
|
| 526 |
+
bias = normalized.get_bias()
|
| 527 |
+
with module.graph.inserting_before(node):
|
| 528 |
+
fused_node = module.graph.call_function(
|
| 529 |
+
transpose_linear, args=(input, weight, bias)
|
| 530 |
+
)
|
| 531 |
+
node.replace_all_uses_with(fused_node)
|
| 532 |
+
module.graph.erase_node(node)
|
| 533 |
+
if len(input_node.users) == 0:
|
| 534 |
+
module.graph.erase_node(input_node)
|
| 535 |
+
|
| 536 |
+
module.graph.lint()
|
| 537 |
+
module.recompile()
|
| 538 |
+
return module
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
| 542 |
+
for node in module.graph.nodes:
|
| 543 |
+
if node.op == "call_function" and (
|
| 544 |
+
node.target == torch.bmm or node.target == torch.matmul
|
| 545 |
+
):
|
| 546 |
+
normalized = NormalizedMatmulNode(node)
|
| 547 |
+
input_A_node = normalized.get_input()
|
| 548 |
+
input_B_node = normalized.get_other()
|
| 549 |
+
input_A = input_A_node
|
| 550 |
+
input_B = input_B_node
|
| 551 |
+
Atrans = Btrans = False
|
| 552 |
+
if (
|
| 553 |
+
input_A_node.op == "call_method"
|
| 554 |
+
and input_A_node.target == "permute"
|
| 555 |
+
and check_permute(input_A_node)
|
| 556 |
+
):
|
| 557 |
+
Atrans = True
|
| 558 |
+
if len(input_A_node.args) > 0:
|
| 559 |
+
input_A = input_A_node.args[0] # type: ignore[assignment]
|
| 560 |
+
else:
|
| 561 |
+
input_A = input_A_node.kwargs["input"] # type: ignore[assignment]
|
| 562 |
+
|
| 563 |
+
if (
|
| 564 |
+
input_B_node.op == "call_method"
|
| 565 |
+
and input_B_node.target == "permute"
|
| 566 |
+
and check_permute(input_B_node)
|
| 567 |
+
):
|
| 568 |
+
Btrans = True
|
| 569 |
+
if len(input_B_node.args) > 0:
|
| 570 |
+
input_B = input_B_node.args[0] # type: ignore[assignment]
|
| 571 |
+
else:
|
| 572 |
+
input_B = input_B_node.kwargs["input"] # type: ignore[assignment]
|
| 573 |
+
|
| 574 |
+
if Atrans or Btrans:
|
| 575 |
+
with module.graph.inserting_before(node):
|
| 576 |
+
fused_node = module.graph.call_function(
|
| 577 |
+
transpose_matmul,
|
| 578 |
+
args=(input_A, input_B, Atrans, Btrans),
|
| 579 |
+
)
|
| 580 |
+
node.replace_all_uses_with(fused_node)
|
| 581 |
+
module.graph.erase_node(node)
|
| 582 |
+
if Atrans and len(input_A_node.users) == 0:
|
| 583 |
+
module.graph.erase_node(input_A_node)
|
| 584 |
+
if Btrans and len(input_B_node.users) == 0:
|
| 585 |
+
module.graph.erase_node(input_B_node)
|
| 586 |
+
|
| 587 |
+
module.graph.lint()
|
| 588 |
+
module.recompile()
|
| 589 |
+
return module
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
# X1 = X.permute(0, 2, 1)
|
| 593 |
+
# Y1 = X1 * W1^T + bias1
|
| 594 |
+
# ---->
|
| 595 |
+
# Y2 = X1.transpose(-1, -2) * W1^T + bias1
|
| 596 |
+
def transpose_linear(
|
| 597 |
+
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
|
| 598 |
+
) -> torch.Tensor:
|
| 599 |
+
if bias is None:
|
| 600 |
+
return torch.matmul(input.transpose(-1, -2), weight.t())
|
| 601 |
+
return torch.matmul(input.transpose(-1, -2), weight.t()) + bias
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def transpose_matmul(
|
| 605 |
+
A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool
|
| 606 |
+
) -> torch.Tensor:
|
| 607 |
+
if Atrans:
|
| 608 |
+
A = A.transpose(-1, -2)
|
| 609 |
+
if Btrans:
|
| 610 |
+
B = B.transpose(-1, -2)
|
| 611 |
+
return torch.matmul(A, B)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/debug.cpython-311.pyc
ADDED
|
Binary file (1.72 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-311.pyc
ADDED
|
Binary file (7.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-311.pyc
ADDED
|
Binary file (54.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-311.pyc
ADDED
|
Binary file (430 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/match.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .core import unify, reify # type: ignore[attr-defined]
|
| 2 |
+
from .variable import isvar
|
| 3 |
+
from .utils import _toposort, freeze
|
| 4 |
+
from .unification_tools import groupby, first # type: ignore[import]
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Dispatcher:
|
| 8 |
+
def __init__(self, name):
|
| 9 |
+
self.name = name
|
| 10 |
+
self.funcs = {}
|
| 11 |
+
self.ordering = []
|
| 12 |
+
|
| 13 |
+
def add(self, signature, func):
|
| 14 |
+
self.funcs[freeze(signature)] = func
|
| 15 |
+
self.ordering = ordering(self.funcs)
|
| 16 |
+
|
| 17 |
+
def __call__(self, *args, **kwargs):
|
| 18 |
+
func, s = self.resolve(args)
|
| 19 |
+
return func(*args, **kwargs)
|
| 20 |
+
|
| 21 |
+
def resolve(self, args):
|
| 22 |
+
n = len(args)
|
| 23 |
+
for signature in self.ordering:
|
| 24 |
+
if len(signature) != n:
|
| 25 |
+
continue
|
| 26 |
+
s = unify(freeze(args), signature)
|
| 27 |
+
if s is not False:
|
| 28 |
+
result = self.funcs[signature]
|
| 29 |
+
return result, s
|
| 30 |
+
raise NotImplementedError("No match found. \nKnown matches: "
|
| 31 |
+
+ str(self.ordering) + "\nInput: " + str(args))
|
| 32 |
+
|
| 33 |
+
def register(self, *signature):
|
| 34 |
+
def _(func):
|
| 35 |
+
self.add(signature, func)
|
| 36 |
+
return self
|
| 37 |
+
return _
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class VarDispatcher(Dispatcher):
|
| 41 |
+
""" A dispatcher that calls functions with variable names
|
| 42 |
+
>>> # xdoctest: +SKIP
|
| 43 |
+
>>> d = VarDispatcher('d')
|
| 44 |
+
>>> x = var('x')
|
| 45 |
+
>>> @d.register('inc', x)
|
| 46 |
+
... def f(x):
|
| 47 |
+
... return x + 1
|
| 48 |
+
>>> @d.register('double', x)
|
| 49 |
+
... def f(x):
|
| 50 |
+
... return x * 2
|
| 51 |
+
>>> d('inc', 10)
|
| 52 |
+
11
|
| 53 |
+
>>> d('double', 10)
|
| 54 |
+
20
|
| 55 |
+
"""
|
| 56 |
+
def __call__(self, *args, **kwargs):
|
| 57 |
+
func, s = self.resolve(args)
|
| 58 |
+
d = {k.token: v for k, v in s.items()}
|
| 59 |
+
return func(**d)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
global_namespace = {} # type: ignore[var-annotated]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def match(*signature, **kwargs):
|
| 66 |
+
namespace = kwargs.get('namespace', global_namespace)
|
| 67 |
+
dispatcher = kwargs.get('Dispatcher', Dispatcher)
|
| 68 |
+
|
| 69 |
+
def _(func):
|
| 70 |
+
name = func.__name__
|
| 71 |
+
|
| 72 |
+
if name not in namespace:
|
| 73 |
+
namespace[name] = dispatcher(name)
|
| 74 |
+
d = namespace[name]
|
| 75 |
+
|
| 76 |
+
d.add(signature, func)
|
| 77 |
+
|
| 78 |
+
return d
|
| 79 |
+
return _
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def supercedes(a, b):
|
| 83 |
+
""" ``a`` is a more specific match than ``b`` """
|
| 84 |
+
if isvar(b) and not isvar(a):
|
| 85 |
+
return True
|
| 86 |
+
s = unify(a, b)
|
| 87 |
+
if s is False:
|
| 88 |
+
return False
|
| 89 |
+
s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)}
|
| 90 |
+
if reify(a, s) == a:
|
| 91 |
+
return True
|
| 92 |
+
if reify(b, s) == b:
|
| 93 |
+
return False
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# Taken from multipledispatch
|
| 97 |
+
def edge(a, b, tie_breaker=hash):
|
| 98 |
+
""" A should be checked before B
|
| 99 |
+
Tie broken by tie_breaker, defaults to ``hash``
|
| 100 |
+
"""
|
| 101 |
+
if supercedes(a, b):
|
| 102 |
+
if supercedes(b, a):
|
| 103 |
+
return tie_breaker(a) > tie_breaker(b)
|
| 104 |
+
else:
|
| 105 |
+
return True
|
| 106 |
+
return False
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# Taken from multipledispatch
|
| 110 |
+
def ordering(signatures):
|
| 111 |
+
""" A sane ordering of signatures to check, first to last
|
| 112 |
+
Topological sort of edges as given by ``edge`` and ``supercedes``
|
| 113 |
+
"""
|
| 114 |
+
signatures = list(map(tuple, signatures))
|
| 115 |
+
edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
|
| 116 |
+
edges = groupby(first, edges)
|
| 117 |
+
for s in signatures:
|
| 118 |
+
if s not in edges:
|
| 119 |
+
edges[s] = []
|
| 120 |
+
edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment]
|
| 121 |
+
return _toposort(edges)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-311.pyc
ADDED
|
Binary file (4.77 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .utils import typename
|
| 2 |
+
|
| 3 |
+
__all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"]
|
| 4 |
+
|
| 5 |
+
class VariadicSignatureType(type):
|
| 6 |
+
# checking if subclass is a subclass of self
|
| 7 |
+
def __subclasscheck__(cls, subclass):
|
| 8 |
+
other_type = (subclass.variadic_type if isvariadic(subclass)
|
| 9 |
+
else (subclass,))
|
| 10 |
+
return subclass is cls or all(
|
| 11 |
+
issubclass(other, cls.variadic_type) for other in other_type # type: ignore[attr-defined]
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
def __eq__(cls, other):
|
| 15 |
+
"""
|
| 16 |
+
Return True if other has the same variadic type
|
| 17 |
+
Parameters
|
| 18 |
+
----------
|
| 19 |
+
other : object (type)
|
| 20 |
+
The object (type) to check
|
| 21 |
+
Returns
|
| 22 |
+
-------
|
| 23 |
+
bool
|
| 24 |
+
Whether or not `other` is equal to `self`
|
| 25 |
+
"""
|
| 26 |
+
return (isvariadic(other) and
|
| 27 |
+
set(cls.variadic_type) == set(other.variadic_type)) # type: ignore[attr-defined]
|
| 28 |
+
|
| 29 |
+
def __hash__(cls):
|
| 30 |
+
return hash((type(cls), frozenset(cls.variadic_type))) # type: ignore[attr-defined]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def isvariadic(obj):
|
| 34 |
+
"""Check whether the type `obj` is variadic.
|
| 35 |
+
Parameters
|
| 36 |
+
----------
|
| 37 |
+
obj : type
|
| 38 |
+
The type to check
|
| 39 |
+
Returns
|
| 40 |
+
-------
|
| 41 |
+
bool
|
| 42 |
+
Whether or not `obj` is variadic
|
| 43 |
+
Examples
|
| 44 |
+
--------
|
| 45 |
+
>>> # xdoctest: +SKIP
|
| 46 |
+
>>> isvariadic(int)
|
| 47 |
+
False
|
| 48 |
+
>>> isvariadic(Variadic[int])
|
| 49 |
+
True
|
| 50 |
+
"""
|
| 51 |
+
return isinstance(obj, VariadicSignatureType)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class VariadicSignatureMeta(type):
|
| 55 |
+
"""A metaclass that overrides ``__getitem__`` on the class. This is used to
|
| 56 |
+
generate a new type for Variadic signatures. See the Variadic class for
|
| 57 |
+
examples of how this behaves.
|
| 58 |
+
"""
|
| 59 |
+
def __getitem__(cls, variadic_type):
|
| 60 |
+
if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)):
|
| 61 |
+
raise ValueError("Variadic types must be type or tuple of types"
|
| 62 |
+
" (Variadic[int] or Variadic[(int, float)]")
|
| 63 |
+
|
| 64 |
+
if not isinstance(variadic_type, tuple):
|
| 65 |
+
variadic_type = variadic_type,
|
| 66 |
+
return VariadicSignatureType(
|
| 67 |
+
f'Variadic[{typename(variadic_type)}]',
|
| 68 |
+
(),
|
| 69 |
+
dict(variadic_type=variadic_type, __slots__=())
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class Variadic(metaclass=VariadicSignatureMeta):
|
| 74 |
+
"""A class whose getitem method can be used to generate a new type
|
| 75 |
+
representing a specific variadic signature.
|
| 76 |
+
Examples
|
| 77 |
+
--------
|
| 78 |
+
>>> # xdoctest: +SKIP
|
| 79 |
+
>>> Variadic[int] # any number of int arguments
|
| 80 |
+
<class 'multipledispatch.variadic.Variadic[int]'>
|
| 81 |
+
>>> Variadic[(int, str)] # any number of one of int or str arguments
|
| 82 |
+
<class 'multipledispatch.variadic.Variadic[(int, str)]'>
|
| 83 |
+
>>> issubclass(int, Variadic[int])
|
| 84 |
+
True
|
| 85 |
+
>>> issubclass(int, Variadic[(int, str)])
|
| 86 |
+
True
|
| 87 |
+
>>> issubclass(str, Variadic[(int, str)])
|
| 88 |
+
True
|
| 89 |
+
>>> issubclass(float, Variadic[(int, str)])
|
| 90 |
+
False
|
| 91 |
+
"""
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/unification_tools.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import operator
|
| 3 |
+
from functools import reduce
|
| 4 |
+
from collections.abc import Mapping
|
| 5 |
+
|
| 6 |
+
__all__ = ('merge', 'merge_with', 'valmap', 'keymap', 'itemmap',
|
| 7 |
+
'valfilter', 'keyfilter', 'itemfilter',
|
| 8 |
+
'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in')
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _get_factory(f, kwargs):
|
| 12 |
+
factory = kwargs.pop('factory', dict)
|
| 13 |
+
if kwargs:
|
| 14 |
+
raise TypeError(f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'")
|
| 15 |
+
return factory
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def merge(*dicts, **kwargs):
|
| 19 |
+
""" Merge a collection of dictionaries
|
| 20 |
+
|
| 21 |
+
>>> merge({1: 'one'}, {2: 'two'})
|
| 22 |
+
{1: 'one', 2: 'two'}
|
| 23 |
+
|
| 24 |
+
Later dictionaries have precedence
|
| 25 |
+
|
| 26 |
+
>>> merge({1: 2, 3: 4}, {3: 3, 4: 4})
|
| 27 |
+
{1: 2, 3: 3, 4: 4}
|
| 28 |
+
|
| 29 |
+
See Also:
|
| 30 |
+
merge_with
|
| 31 |
+
"""
|
| 32 |
+
if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
|
| 33 |
+
dicts = dicts[0]
|
| 34 |
+
factory = _get_factory(merge, kwargs)
|
| 35 |
+
|
| 36 |
+
rv = factory()
|
| 37 |
+
for d in dicts:
|
| 38 |
+
rv.update(d)
|
| 39 |
+
return rv
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def merge_with(func, *dicts, **kwargs):
|
| 43 |
+
""" Merge dictionaries and apply function to combined values
|
| 44 |
+
|
| 45 |
+
A key may occur in more than one dict, and all values mapped from the key
|
| 46 |
+
will be passed to the function as a list, such as func([val1, val2, ...]).
|
| 47 |
+
|
| 48 |
+
>>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20})
|
| 49 |
+
{1: 11, 2: 22}
|
| 50 |
+
|
| 51 |
+
>>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30}) # doctest: +SKIP
|
| 52 |
+
{1: 1, 2: 2, 3: 30}
|
| 53 |
+
|
| 54 |
+
See Also:
|
| 55 |
+
merge
|
| 56 |
+
"""
|
| 57 |
+
if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
|
| 58 |
+
dicts = dicts[0]
|
| 59 |
+
factory = _get_factory(merge_with, kwargs)
|
| 60 |
+
|
| 61 |
+
result = factory()
|
| 62 |
+
for d in dicts:
|
| 63 |
+
for k, v in d.items():
|
| 64 |
+
if k not in result:
|
| 65 |
+
result[k] = [v]
|
| 66 |
+
else:
|
| 67 |
+
result[k].append(v)
|
| 68 |
+
return valmap(func, result, factory)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def valmap(func, d, factory=dict):
|
| 72 |
+
""" Apply function to values of dictionary
|
| 73 |
+
|
| 74 |
+
>>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
|
| 75 |
+
>>> valmap(sum, bills) # doctest: +SKIP
|
| 76 |
+
{'Alice': 65, 'Bob': 45}
|
| 77 |
+
|
| 78 |
+
See Also:
|
| 79 |
+
keymap
|
| 80 |
+
itemmap
|
| 81 |
+
"""
|
| 82 |
+
rv = factory()
|
| 83 |
+
rv.update(zip(d.keys(), map(func, d.values())))
|
| 84 |
+
return rv
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def keymap(func, d, factory=dict):
|
| 88 |
+
""" Apply function to keys of dictionary
|
| 89 |
+
|
| 90 |
+
>>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
|
| 91 |
+
>>> keymap(str.lower, bills) # doctest: +SKIP
|
| 92 |
+
{'alice': [20, 15, 30], 'bob': [10, 35]}
|
| 93 |
+
|
| 94 |
+
See Also:
|
| 95 |
+
valmap
|
| 96 |
+
itemmap
|
| 97 |
+
"""
|
| 98 |
+
rv = factory()
|
| 99 |
+
rv.update(zip(map(func, d.keys()), d.values()))
|
| 100 |
+
return rv
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def itemmap(func, d, factory=dict):
|
| 104 |
+
""" Apply function to items of dictionary
|
| 105 |
+
|
| 106 |
+
>>> accountids = {"Alice": 10, "Bob": 20}
|
| 107 |
+
>>> itemmap(reversed, accountids) # doctest: +SKIP
|
| 108 |
+
{10: "Alice", 20: "Bob"}
|
| 109 |
+
|
| 110 |
+
See Also:
|
| 111 |
+
keymap
|
| 112 |
+
valmap
|
| 113 |
+
"""
|
| 114 |
+
rv = factory()
|
| 115 |
+
rv.update(map(func, d.items()))
|
| 116 |
+
return rv
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def valfilter(predicate, d, factory=dict):
|
| 120 |
+
""" Filter items in dictionary by value
|
| 121 |
+
|
| 122 |
+
>>> iseven = lambda x: x % 2 == 0
|
| 123 |
+
>>> d = {1: 2, 2: 3, 3: 4, 4: 5}
|
| 124 |
+
>>> valfilter(iseven, d)
|
| 125 |
+
{1: 2, 3: 4}
|
| 126 |
+
|
| 127 |
+
See Also:
|
| 128 |
+
keyfilter
|
| 129 |
+
itemfilter
|
| 130 |
+
valmap
|
| 131 |
+
"""
|
| 132 |
+
rv = factory()
|
| 133 |
+
for k, v in d.items():
|
| 134 |
+
if predicate(v):
|
| 135 |
+
rv[k] = v
|
| 136 |
+
return rv
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def keyfilter(predicate, d, factory=dict):
|
| 140 |
+
""" Filter items in dictionary by key
|
| 141 |
+
|
| 142 |
+
>>> iseven = lambda x: x % 2 == 0
|
| 143 |
+
>>> d = {1: 2, 2: 3, 3: 4, 4: 5}
|
| 144 |
+
>>> keyfilter(iseven, d)
|
| 145 |
+
{2: 3, 4: 5}
|
| 146 |
+
|
| 147 |
+
See Also:
|
| 148 |
+
valfilter
|
| 149 |
+
itemfilter
|
| 150 |
+
keymap
|
| 151 |
+
"""
|
| 152 |
+
rv = factory()
|
| 153 |
+
for k, v in d.items():
|
| 154 |
+
if predicate(k):
|
| 155 |
+
rv[k] = v
|
| 156 |
+
return rv
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def itemfilter(predicate, d, factory=dict):
|
| 160 |
+
""" Filter items in dictionary by item
|
| 161 |
+
|
| 162 |
+
>>> def isvalid(item):
|
| 163 |
+
... k, v = item
|
| 164 |
+
... return k % 2 == 0 and v < 4
|
| 165 |
+
|
| 166 |
+
>>> d = {1: 2, 2: 3, 3: 4, 4: 5}
|
| 167 |
+
>>> itemfilter(isvalid, d)
|
| 168 |
+
{2: 3}
|
| 169 |
+
|
| 170 |
+
See Also:
|
| 171 |
+
keyfilter
|
| 172 |
+
valfilter
|
| 173 |
+
itemmap
|
| 174 |
+
"""
|
| 175 |
+
rv = factory()
|
| 176 |
+
for item in d.items():
|
| 177 |
+
if predicate(item):
|
| 178 |
+
k, v = item
|
| 179 |
+
rv[k] = v
|
| 180 |
+
return rv
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def assoc(d, key, value, factory=dict):
|
| 184 |
+
""" Return a new dict with new key value pair
|
| 185 |
+
|
| 186 |
+
New dict has d[key] set to value. Does not modify the initial dictionary.
|
| 187 |
+
|
| 188 |
+
>>> assoc({'x': 1}, 'x', 2)
|
| 189 |
+
{'x': 2}
|
| 190 |
+
>>> assoc({'x': 1}, 'y', 3) # doctest: +SKIP
|
| 191 |
+
{'x': 1, 'y': 3}
|
| 192 |
+
"""
|
| 193 |
+
d2 = factory()
|
| 194 |
+
d2.update(d)
|
| 195 |
+
d2[key] = value
|
| 196 |
+
return d2
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def dissoc(d, *keys, **kwargs):
|
| 200 |
+
""" Return a new dict with the given key(s) removed.
|
| 201 |
+
|
| 202 |
+
New dict has d[key] deleted for each supplied key.
|
| 203 |
+
Does not modify the initial dictionary.
|
| 204 |
+
|
| 205 |
+
>>> dissoc({'x': 1, 'y': 2}, 'y')
|
| 206 |
+
{'x': 1}
|
| 207 |
+
>>> dissoc({'x': 1, 'y': 2}, 'y', 'x')
|
| 208 |
+
{}
|
| 209 |
+
>>> dissoc({'x': 1}, 'y') # Ignores missing keys
|
| 210 |
+
{'x': 1}
|
| 211 |
+
"""
|
| 212 |
+
factory = _get_factory(dissoc, kwargs)
|
| 213 |
+
d2 = factory()
|
| 214 |
+
|
| 215 |
+
if len(keys) < len(d) * .6:
|
| 216 |
+
d2.update(d)
|
| 217 |
+
for key in keys:
|
| 218 |
+
if key in d2:
|
| 219 |
+
del d2[key]
|
| 220 |
+
else:
|
| 221 |
+
remaining = set(d)
|
| 222 |
+
remaining.difference_update(keys)
|
| 223 |
+
for k in remaining:
|
| 224 |
+
d2[k] = d[k]
|
| 225 |
+
return d2
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def assoc_in(d, keys, value, factory=dict):
|
| 229 |
+
""" Return a new dict with new, potentially nested, key value pair
|
| 230 |
+
|
| 231 |
+
>>> purchase = {'name': 'Alice',
|
| 232 |
+
... 'order': {'items': ['Apple', 'Orange'],
|
| 233 |
+
... 'costs': [0.50, 1.25]},
|
| 234 |
+
... 'credit card': '5555-1234-1234-1234'}
|
| 235 |
+
>>> assoc_in(purchase, ['order', 'costs'], [0.25, 1.00]) # doctest: +SKIP
|
| 236 |
+
{'credit card': '5555-1234-1234-1234',
|
| 237 |
+
'name': 'Alice',
|
| 238 |
+
'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}}
|
| 239 |
+
"""
|
| 240 |
+
return update_in(d, keys, lambda x: value, value, factory)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def update_in(d, keys, func, default=None, factory=dict):
|
| 244 |
+
""" Update value in a (potentially) nested dictionary
|
| 245 |
+
|
| 246 |
+
inputs:
|
| 247 |
+
d - dictionary on which to operate
|
| 248 |
+
keys - list or tuple giving the location of the value to be changed in d
|
| 249 |
+
func - function to operate on that value
|
| 250 |
+
|
| 251 |
+
If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the
|
| 252 |
+
original dictionary with v replaced by func(v), but does not mutate the
|
| 253 |
+
original dictionary.
|
| 254 |
+
|
| 255 |
+
If k0 is not a key in d, update_in creates nested dictionaries to the depth
|
| 256 |
+
specified by the keys, with the innermost value set to func(default).
|
| 257 |
+
|
| 258 |
+
>>> inc = lambda x: x + 1
|
| 259 |
+
>>> update_in({'a': 0}, ['a'], inc)
|
| 260 |
+
{'a': 1}
|
| 261 |
+
|
| 262 |
+
>>> transaction = {'name': 'Alice',
|
| 263 |
+
... 'purchase': {'items': ['Apple', 'Orange'],
|
| 264 |
+
... 'costs': [0.50, 1.25]},
|
| 265 |
+
... 'credit card': '5555-1234-1234-1234'}
|
| 266 |
+
>>> update_in(transaction, ['purchase', 'costs'], sum) # doctest: +SKIP
|
| 267 |
+
{'credit card': '5555-1234-1234-1234',
|
| 268 |
+
'name': 'Alice',
|
| 269 |
+
'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}}
|
| 270 |
+
|
| 271 |
+
>>> # updating a value when k0 is not in d
|
| 272 |
+
>>> update_in({}, [1, 2, 3], str, default="bar")
|
| 273 |
+
{1: {2: {3: 'bar'}}}
|
| 274 |
+
>>> update_in({1: 'foo'}, [2, 3, 4], inc, 0)
|
| 275 |
+
{1: 'foo', 2: {3: {4: 1}}}
|
| 276 |
+
"""
|
| 277 |
+
ks = iter(keys)
|
| 278 |
+
k = next(ks)
|
| 279 |
+
|
| 280 |
+
rv = inner = factory()
|
| 281 |
+
rv.update(d)
|
| 282 |
+
|
| 283 |
+
for key in ks:
|
| 284 |
+
if k in d:
|
| 285 |
+
d = d[k]
|
| 286 |
+
dtemp = factory()
|
| 287 |
+
dtemp.update(d)
|
| 288 |
+
else:
|
| 289 |
+
d = dtemp = factory()
|
| 290 |
+
|
| 291 |
+
inner[k] = inner = dtemp
|
| 292 |
+
k = key
|
| 293 |
+
|
| 294 |
+
if k in d:
|
| 295 |
+
inner[k] = func(d[k])
|
| 296 |
+
else:
|
| 297 |
+
inner[k] = func(default)
|
| 298 |
+
return rv
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def get_in(keys, coll, default=None, no_default=False):
|
| 302 |
+
""" Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys.
|
| 303 |
+
|
| 304 |
+
If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless
|
| 305 |
+
``no_default`` is specified, then it raises KeyError or IndexError.
|
| 306 |
+
|
| 307 |
+
``get_in`` is a generalization of ``operator.getitem`` for nested data
|
| 308 |
+
structures such as dictionaries and lists.
|
| 309 |
+
|
| 310 |
+
>>> transaction = {'name': 'Alice',
|
| 311 |
+
... 'purchase': {'items': ['Apple', 'Orange'],
|
| 312 |
+
... 'costs': [0.50, 1.25]},
|
| 313 |
+
... 'credit card': '5555-1234-1234-1234'}
|
| 314 |
+
>>> get_in(['purchase', 'items', 0], transaction)
|
| 315 |
+
'Apple'
|
| 316 |
+
>>> get_in(['name'], transaction)
|
| 317 |
+
'Alice'
|
| 318 |
+
>>> get_in(['purchase', 'total'], transaction)
|
| 319 |
+
>>> get_in(['purchase', 'items', 'apple'], transaction)
|
| 320 |
+
>>> get_in(['purchase', 'items', 10], transaction)
|
| 321 |
+
>>> get_in(['purchase', 'total'], transaction, 0)
|
| 322 |
+
0
|
| 323 |
+
>>> get_in(['y'], {}, no_default=True)
|
| 324 |
+
Traceback (most recent call last):
|
| 325 |
+
...
|
| 326 |
+
KeyError: 'y'
|
| 327 |
+
|
| 328 |
+
See Also:
|
| 329 |
+
itertoolz.get
|
| 330 |
+
operator.getitem
|
| 331 |
+
"""
|
| 332 |
+
try:
|
| 333 |
+
return reduce(operator.getitem, keys, coll)
|
| 334 |
+
except (KeyError, IndexError, TypeError):
|
| 335 |
+
if no_default:
|
| 336 |
+
raise
|
| 337 |
+
return default
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def getter(index):
|
| 341 |
+
if isinstance(index, list):
|
| 342 |
+
if len(index) == 1:
|
| 343 |
+
index = index[0]
|
| 344 |
+
return lambda x: (x[index],)
|
| 345 |
+
elif index:
|
| 346 |
+
return operator.itemgetter(*index)
|
| 347 |
+
else:
|
| 348 |
+
return lambda x: ()
|
| 349 |
+
else:
|
| 350 |
+
return operator.itemgetter(index)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def groupby(key, seq):
|
| 354 |
+
""" Group a collection by a key function
|
| 355 |
+
|
| 356 |
+
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
|
| 357 |
+
>>> groupby(len, names) # doctest: +SKIP
|
| 358 |
+
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
|
| 359 |
+
|
| 360 |
+
>>> iseven = lambda x: x % 2 == 0
|
| 361 |
+
>>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
|
| 362 |
+
{False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
|
| 363 |
+
|
| 364 |
+
Non-callable keys imply grouping on a member.
|
| 365 |
+
|
| 366 |
+
>>> groupby('gender', [{'name': 'Alice', 'gender': 'F'},
|
| 367 |
+
... {'name': 'Bob', 'gender': 'M'},
|
| 368 |
+
... {'name': 'Charlie', 'gender': 'M'}]) # doctest:+SKIP
|
| 369 |
+
{'F': [{'gender': 'F', 'name': 'Alice'}],
|
| 370 |
+
'M': [{'gender': 'M', 'name': 'Bob'},
|
| 371 |
+
{'gender': 'M', 'name': 'Charlie'}]}
|
| 372 |
+
|
| 373 |
+
Not to be confused with ``itertools.groupby``
|
| 374 |
+
|
| 375 |
+
See Also:
|
| 376 |
+
countby
|
| 377 |
+
"""
|
| 378 |
+
if not callable(key):
|
| 379 |
+
key = getter(key)
|
| 380 |
+
d = collections.defaultdict(lambda: [].append) # type: ignore[var-annotated]
|
| 381 |
+
for item in seq:
|
| 382 |
+
d[key(item)](item)
|
| 383 |
+
rv = {}
|
| 384 |
+
for k, v in d.items():
|
| 385 |
+
rv[k] = v.__self__ # type: ignore[var-annotated, attr-defined]
|
| 386 |
+
return rv
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def first(seq):
|
| 390 |
+
""" The first element in a sequence
|
| 391 |
+
|
| 392 |
+
>>> first('ABC')
|
| 393 |
+
'A'
|
| 394 |
+
"""
|
| 395 |
+
return next(iter(seq))
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/utils.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"]
|
| 2 |
+
def hashable(x):
|
| 3 |
+
try:
|
| 4 |
+
hash(x)
|
| 5 |
+
return True
|
| 6 |
+
except TypeError:
|
| 7 |
+
return False
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def transitive_get(key, d):
|
| 11 |
+
""" Transitive dict.get
|
| 12 |
+
>>> d = {1: 2, 2: 3, 3: 4}
|
| 13 |
+
>>> d.get(1)
|
| 14 |
+
2
|
| 15 |
+
>>> transitive_get(1, d)
|
| 16 |
+
4
|
| 17 |
+
"""
|
| 18 |
+
while hashable(key) and key in d:
|
| 19 |
+
key = d[key]
|
| 20 |
+
return key
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def raises(err, lamda):
|
| 24 |
+
try:
|
| 25 |
+
lamda()
|
| 26 |
+
return False
|
| 27 |
+
except err:
|
| 28 |
+
return True
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Taken from theano/theano/gof/sched.py
|
| 32 |
+
# Avoids licensing issues because this was written by Matthew Rocklin
|
| 33 |
+
def _toposort(edges):
|
| 34 |
+
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
|
| 35 |
+
inputs:
|
| 36 |
+
edges - a dict of the form {a: {b, c}} where b and c depend on a
|
| 37 |
+
outputs:
|
| 38 |
+
L - an ordered list of nodes that satisfy the dependencies of edges
|
| 39 |
+
>>> # xdoctest: +SKIP
|
| 40 |
+
>>> _toposort({1: (2, 3), 2: (3, )})
|
| 41 |
+
[1, 2, 3]
|
| 42 |
+
Closely follows the wikipedia page [2]
|
| 43 |
+
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
|
| 44 |
+
Communications of the ACM
|
| 45 |
+
[2] http://en.wikipedia.org/wiki/Toposort#Algorithms
|
| 46 |
+
"""
|
| 47 |
+
incoming_edges = reverse_dict(edges)
|
| 48 |
+
incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
|
| 49 |
+
S = ({v for v in edges if v not in incoming_edges})
|
| 50 |
+
L = []
|
| 51 |
+
|
| 52 |
+
while S:
|
| 53 |
+
n = S.pop()
|
| 54 |
+
L.append(n)
|
| 55 |
+
for m in edges.get(n, ()):
|
| 56 |
+
assert n in incoming_edges[m]
|
| 57 |
+
incoming_edges[m].remove(n)
|
| 58 |
+
if not incoming_edges[m]:
|
| 59 |
+
S.add(m)
|
| 60 |
+
if any(incoming_edges.get(v, None) for v in edges):
|
| 61 |
+
raise ValueError("Input has cycles")
|
| 62 |
+
return L
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def reverse_dict(d):
|
| 66 |
+
"""Reverses direction of dependence dict
|
| 67 |
+
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
|
| 68 |
+
>>> reverse_dict(d) # doctest: +SKIP
|
| 69 |
+
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
|
| 70 |
+
:note: dict order are not deterministic. As we iterate on the
|
| 71 |
+
input dict, it make the output of this function depend on the
|
| 72 |
+
dict order. So this function output order should be considered
|
| 73 |
+
as undeterministic.
|
| 74 |
+
"""
|
| 75 |
+
result = {} # type: ignore[var-annotated]
|
| 76 |
+
for key in d:
|
| 77 |
+
for val in d[key]:
|
| 78 |
+
result[val] = result.get(val, tuple()) + (key, )
|
| 79 |
+
return result
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def xfail(func):
|
| 83 |
+
try:
|
| 84 |
+
func()
|
| 85 |
+
raise Exception("XFailed test passed") # pragma:nocover
|
| 86 |
+
except Exception:
|
| 87 |
+
pass
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def freeze(d):
|
| 91 |
+
""" Freeze container to hashable form
|
| 92 |
+
>>> freeze(1)
|
| 93 |
+
1
|
| 94 |
+
>>> freeze([1, 2])
|
| 95 |
+
(1, 2)
|
| 96 |
+
>>> freeze({1: 2}) # doctest: +SKIP
|
| 97 |
+
frozenset([(1, 2)])
|
| 98 |
+
"""
|
| 99 |
+
if isinstance(d, dict):
|
| 100 |
+
return frozenset(map(freeze, d.items()))
|
| 101 |
+
if isinstance(d, set):
|
| 102 |
+
return frozenset(map(freeze, d))
|
| 103 |
+
if isinstance(d, (tuple, list)):
|
| 104 |
+
return tuple(map(freeze, d))
|
| 105 |
+
return d
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/variable.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import contextmanager
|
| 2 |
+
from .utils import hashable
|
| 3 |
+
from .dispatch import dispatch
|
| 4 |
+
|
| 5 |
+
_global_logic_variables = set() # type: ignore[var-annotated]
|
| 6 |
+
_glv = _global_logic_variables
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Var:
|
| 10 |
+
""" Logic Variable """
|
| 11 |
+
|
| 12 |
+
_id = 1
|
| 13 |
+
|
| 14 |
+
def __new__(cls, *token):
|
| 15 |
+
if len(token) == 0:
|
| 16 |
+
token = f"_{Var._id}" # type: ignore[assignment]
|
| 17 |
+
Var._id += 1
|
| 18 |
+
elif len(token) == 1:
|
| 19 |
+
token = token[0]
|
| 20 |
+
|
| 21 |
+
obj = object.__new__(cls)
|
| 22 |
+
obj.token = token # type: ignore[attr-defined]
|
| 23 |
+
return obj
|
| 24 |
+
|
| 25 |
+
def __str__(self):
|
| 26 |
+
return "~" + str(self.token) # type: ignore[attr-defined]
|
| 27 |
+
__repr__ = __str__
|
| 28 |
+
|
| 29 |
+
def __eq__(self, other):
|
| 30 |
+
return type(self) == type(other) and self.token == other.token # type: ignore[attr-defined]
|
| 31 |
+
|
| 32 |
+
def __hash__(self):
|
| 33 |
+
return hash((type(self), self.token)) # type: ignore[attr-defined]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def var():
|
| 37 |
+
return lambda *args: Var(*args)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def vars():
|
| 41 |
+
return lambda n: [var() for i in range(n)]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dispatch(Var)
|
| 45 |
+
def isvar(v):
|
| 46 |
+
return True
|
| 47 |
+
|
| 48 |
+
isvar
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dispatch(object) # type: ignore[no-redef]
|
| 52 |
+
def isvar(o):
|
| 53 |
+
return not not _glv and hashable(o) and o in _glv
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@contextmanager
|
| 57 |
+
def variables(*variables):
|
| 58 |
+
"""
|
| 59 |
+
Context manager for logic variables
|
| 60 |
+
|
| 61 |
+
Example:
|
| 62 |
+
>>> # xdoctest: +SKIP("undefined vars")
|
| 63 |
+
>>> from __future__ import with_statement
|
| 64 |
+
>>> with variables(1):
|
| 65 |
+
... print(isvar(1))
|
| 66 |
+
True
|
| 67 |
+
>>> print(isvar(1))
|
| 68 |
+
False
|
| 69 |
+
>>> # Normal approach
|
| 70 |
+
>>> from unification import unify
|
| 71 |
+
>>> x = var('x')
|
| 72 |
+
>>> unify(x, 1)
|
| 73 |
+
{~x: 1}
|
| 74 |
+
>>> # Context Manager approach
|
| 75 |
+
>>> with variables('x'):
|
| 76 |
+
... print(unify('x', 1))
|
| 77 |
+
{'x': 1}
|
| 78 |
+
"""
|
| 79 |
+
old_global_logic_variables = _global_logic_variables.copy()
|
| 80 |
+
_global_logic_variables.update(set(variables))
|
| 81 |
+
try:
|
| 82 |
+
yield
|
| 83 |
+
finally:
|
| 84 |
+
_global_logic_variables.clear()
|
| 85 |
+
_global_logic_variables.update(old_global_logic_variables)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSGeneratorImpl.h
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2022 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <ATen/core/Generator.h>
|
| 6 |
+
#include <ATen/core/PhiloxRNGEngine.h>
|
| 7 |
+
#include <c10/core/GeneratorImpl.h>
|
| 8 |
+
#include <c10/util/Optional.h>
|
| 9 |
+
|
| 10 |
+
namespace at {
|
| 11 |
+
namespace mps::detail {
|
| 12 |
+
|
| 13 |
+
static const uint32_t PHILOX_STATE_N = 7;
|
| 14 |
+
struct rng_data_pod {
|
| 15 |
+
std::array<uint32_t, PHILOX_STATE_N> state{1};
|
| 16 |
+
uint64_t seed = default_rng_seed_val;
|
| 17 |
+
};
|
| 18 |
+
|
| 19 |
+
TORCH_API const Generator& getDefaultMPSGenerator();
|
| 20 |
+
TORCH_API Generator createMPSGenerator(uint64_t seed_val = default_rng_seed_val);
|
| 21 |
+
|
| 22 |
+
} // namespace mps::detail
|
| 23 |
+
|
| 24 |
+
struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl {
|
| 25 |
+
// Constructors
|
| 26 |
+
MPSGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
|
| 27 |
+
~MPSGeneratorImpl() override = default;
|
| 28 |
+
|
| 29 |
+
// MPSGeneratorImpl methods
|
| 30 |
+
std::shared_ptr<MPSGeneratorImpl> clone() const;
|
| 31 |
+
void set_current_seed(uint64_t seed) override;
|
| 32 |
+
void set_offset(uint64_t offset) override;
|
| 33 |
+
uint64_t get_offset() const override;
|
| 34 |
+
uint64_t current_seed() const override;
|
| 35 |
+
uint64_t seed() override;
|
| 36 |
+
void set_state(const c10::TensorImpl& new_state) override;
|
| 37 |
+
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
|
| 38 |
+
void update_philox_counters();
|
| 39 |
+
|
| 40 |
+
void set_engine(at::Philox4_32 engine) { engine_ = engine; };
|
| 41 |
+
at::Philox4_32 engine() { return engine_; };
|
| 42 |
+
uint32_t* state_data() { return data_.state.data(); }
|
| 43 |
+
static DeviceType device_type() { return DeviceType::MPS; };
|
| 44 |
+
|
| 45 |
+
private:
|
| 46 |
+
mps::detail::rng_data_pod data_;
|
| 47 |
+
at::Philox4_32 engine_;
|
| 48 |
+
|
| 49 |
+
MPSGeneratorImpl* clone_impl() const override;
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSHooks.h
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2022 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <ATen/detail/MPSHooksInterface.h>
|
| 6 |
+
#include <ATen/Generator.h>
|
| 7 |
+
#include <ATen/mps/MPSEvent.h>
|
| 8 |
+
#include <c10/util/Optional.h>
|
| 9 |
+
|
| 10 |
+
namespace at::mps {
|
| 11 |
+
|
| 12 |
+
// The real implementation of MPSHooksInterface
|
| 13 |
+
struct MPSHooks : public at::MPSHooksInterface {
|
| 14 |
+
MPSHooks(at::MPSHooksArgs) {}
|
| 15 |
+
void initMPS() const override;
|
| 16 |
+
|
| 17 |
+
// MPSDevice interface
|
| 18 |
+
bool hasMPS() const override;
|
| 19 |
+
bool isOnMacOSorNewer(unsigned major, unsigned minor) const override;
|
| 20 |
+
|
| 21 |
+
// MPSGeneratorImpl interface
|
| 22 |
+
const Generator& getDefaultMPSGenerator() const override;
|
| 23 |
+
|
| 24 |
+
// MPSStream interface
|
| 25 |
+
void deviceSynchronize() const override;
|
| 26 |
+
void commitStream() const override;
|
| 27 |
+
void* getCommandBuffer() const override;
|
| 28 |
+
void* getDispatchQueue() const override;
|
| 29 |
+
|
| 30 |
+
// MPSAllocator interface
|
| 31 |
+
Allocator* getMPSDeviceAllocator() const override;
|
| 32 |
+
void emptyCache() const override;
|
| 33 |
+
size_t getCurrentAllocatedMemory() const override;
|
| 34 |
+
size_t getDriverAllocatedMemory() const override;
|
| 35 |
+
void setMemoryFraction(double ratio) const override;
|
| 36 |
+
|
| 37 |
+
// MPSProfiler interface
|
| 38 |
+
void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const override;
|
| 39 |
+
void profilerStopTrace() const override;
|
| 40 |
+
|
| 41 |
+
// MPSEvent interface
|
| 42 |
+
uint32_t acquireEvent(bool enable_timing) const override;
|
| 43 |
+
void releaseEvent(uint32_t event_id) const override;
|
| 44 |
+
void recordEvent(uint32_t event_id) const override;
|
| 45 |
+
void waitForEvent(uint32_t event_id) const override;
|
| 46 |
+
void synchronizeEvent(uint32_t event_id) const override;
|
| 47 |
+
bool queryEvent(uint32_t event_id) const override;
|
| 48 |
+
double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const override;
|
| 49 |
+
|
| 50 |
+
// Compatibility with Accelerator API
|
| 51 |
+
bool hasPrimaryContext(DeviceIndex device_index) const override {
|
| 52 |
+
// When MPS is available, it is always in use for the one device.
|
| 53 |
+
return true;
|
| 54 |
+
}
|
| 55 |
+
};
|
| 56 |
+
|
| 57 |
+
} // namespace at::mps
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSStream.h
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2022 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <cstdint>
|
| 6 |
+
#include <utility>
|
| 7 |
+
|
| 8 |
+
#include <c10/core/DeviceGuard.h>
|
| 9 |
+
#include <c10/util/Exception.h>
|
| 10 |
+
#include <c10/core/Stream.h>
|
| 11 |
+
#include <ATen/mps/MPSDevice.h>
|
| 12 |
+
|
| 13 |
+
#ifdef __OBJC__
|
| 14 |
+
#include <Foundation/Foundation.h>
|
| 15 |
+
#include <Metal/Metal.h>
|
| 16 |
+
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
| 17 |
+
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
|
| 18 |
+
typedef id<MTLCommandQueue> MTLCommandQueue_t;
|
| 19 |
+
typedef id<MTLCommandBuffer> MTLCommandBuffer_t;
|
| 20 |
+
typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t;
|
| 21 |
+
typedef id<MTLSharedEvent> MTLSharedEvent_t;
|
| 22 |
+
typedef id<MTLDevice> MTLDevice_t;
|
| 23 |
+
#else
|
| 24 |
+
typedef void* MTLCommandQueue_t;
|
| 25 |
+
typedef void* MTLCommandQueue;
|
| 26 |
+
typedef void* MTLCommandBuffer_t;
|
| 27 |
+
typedef void* MTLCommandBuffer;
|
| 28 |
+
typedef void* MTLComputeCommandEncoder_t;
|
| 29 |
+
typedef void* MTLSharedEvent_t;
|
| 30 |
+
typedef void* dispatch_queue_t;
|
| 31 |
+
typedef void* MTLDevice_t;
|
| 32 |
+
#define nil NULL;
|
| 33 |
+
#endif
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
namespace at::mps {
|
| 37 |
+
|
| 38 |
+
//-----------------------------------------------------------------
|
| 39 |
+
// MPSStream
|
| 40 |
+
//-----------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
enum class SyncType {
|
| 43 |
+
NONE, // no commit to command buffer
|
| 44 |
+
COMMIT, // commit and flush the command buffer
|
| 45 |
+
COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish
|
| 46 |
+
COMMIT_AND_CONTINUE,// commit and continue with a new underlying command buffer
|
| 47 |
+
COMMIT_ADAPTIVE, // commit adaptively based on available memory
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
class TORCH_API MPSStream
|
| 51 |
+
{
|
| 52 |
+
public:
|
| 53 |
+
enum Unchecked { UNCHECKED };
|
| 54 |
+
|
| 55 |
+
/// Construct a MPSStream from a Stream. This construction is checked,
|
| 56 |
+
/// and will raise an error if the Stream is not, in fact, a MPS stream.
|
| 57 |
+
explicit MPSStream(Stream stream);
|
| 58 |
+
|
| 59 |
+
~MPSStream();
|
| 60 |
+
MTLCommandQueue_t commandQueue() const { return _commandQueue; };
|
| 61 |
+
dispatch_queue_t queue() const { return _serialQueue; }
|
| 62 |
+
|
| 63 |
+
MPSCommandBuffer* commandBuffer();
|
| 64 |
+
MTLComputeCommandEncoder_t commandEncoder();
|
| 65 |
+
void endKernelCoalescing();
|
| 66 |
+
void synchronize(SyncType syncType);
|
| 67 |
+
void fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
|
| 68 |
+
void copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
|
| 69 |
+
size_t length, size_t srcOffset, size_t dstOffset,
|
| 70 |
+
uint64_t profileId, SyncType syncType = SyncType::NONE);
|
| 71 |
+
void copy_and_sync(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
|
| 72 |
+
size_t length, size_t srcOffset, size_t dstOffset,
|
| 73 |
+
bool non_blocking, uint64_t profileId);
|
| 74 |
+
void executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType = SyncType::NONE);
|
| 75 |
+
void addCompletedHandler(MTLCommandBufferHandler block);
|
| 76 |
+
|
| 77 |
+
/// Get the MPS device index that this stream is associated with.
|
| 78 |
+
c10::DeviceIndex device_index() const { return _stream.device_index(); }
|
| 79 |
+
|
| 80 |
+
MTLCommandQueue_t stream() const { return _commandQueue; };
|
| 81 |
+
|
| 82 |
+
MTLDevice_t device() const { return [_commandQueue device];}
|
| 83 |
+
|
| 84 |
+
/// Explicit conversion to Stream.
|
| 85 |
+
Stream unwrap() const { return _stream; }
|
| 86 |
+
|
| 87 |
+
private:
|
| 88 |
+
Stream _stream;
|
| 89 |
+
MTLCommandQueue_t _commandQueue = nil;
|
| 90 |
+
MPSCommandBuffer* _commandBuffer = nil;
|
| 91 |
+
MPSCommandBuffer* _prevCommandBuffer = nil;
|
| 92 |
+
MTLComputeCommandEncoder_t _commandEncoder = nil;
|
| 93 |
+
MPSGraphExecutionDescriptor *_executionDescriptor = nil;
|
| 94 |
+
MPSGraphCompilationDescriptor *_compilationDescriptor = nil;
|
| 95 |
+
dispatch_queue_t _serialQueue = nullptr;
|
| 96 |
+
// CommitAndContinue is enabled by default
|
| 97 |
+
bool _enableCommitAndContinue = true;
|
| 98 |
+
|
| 99 |
+
// use synchronize() to access any of these commit functions outside MPSStream
|
| 100 |
+
void commit();
|
| 101 |
+
void commitAndWait();
|
| 102 |
+
void commitAndContinue();
|
| 103 |
+
void flush();
|
| 104 |
+
};
|
| 105 |
+
|
| 106 |
+
/**
|
| 107 |
+
* Get the current MPS stream
|
| 108 |
+
*/
|
| 109 |
+
TORCH_API MPSStream* getCurrentMPSStream();
|
| 110 |
+
|
| 111 |
+
/**
|
| 112 |
+
* Get the default MPS stream
|
| 113 |
+
*/
|
| 114 |
+
TORCH_API MPSStream* getDefaultMPSStream();
|
| 115 |
+
|
| 116 |
+
//-----------------------------------------------------------------
|
| 117 |
+
// MPSStreamImpl
|
| 118 |
+
//-----------------------------------------------------------------
|
| 119 |
+
|
| 120 |
+
class TORCH_API MPSStreamImpl
|
| 121 |
+
{
|
| 122 |
+
public:
|
| 123 |
+
/**
|
| 124 |
+
* Gets single instance of the MPSStream.
|
| 125 |
+
*/
|
| 126 |
+
static MPSStream* getInstance();
|
| 127 |
+
|
| 128 |
+
private:
|
| 129 |
+
static MPSStream* _stream;
|
| 130 |
+
MPSStreamImpl();
|
| 131 |
+
};
|
| 132 |
+
|
| 133 |
+
} // namespace at::mps
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_amp_foreach_non_finite_check_and_unscale_ops.h
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API _amp_foreach_non_finite_check_and_unscale_ {
|
| 18 |
+
using schema = void (at::TensorList, at::Tensor &, const at::Tensor &);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_amp_foreach_non_finite_check_and_unscale_")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_amp_foreach_non_finite_check_and_unscale_(Tensor(a!)[] self, Tensor(b!) found_inf, Tensor inv_scale) -> ()")
|
| 24 |
+
static void call(at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale);
|
| 25 |
+
static void redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API _amp_foreach_non_finite_check_and_unscale_out {
|
| 29 |
+
using schema = void (at::TensorList, at::Tensor &, const at::Tensor &, at::TensorList);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_amp_foreach_non_finite_check_and_unscale")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_amp_foreach_non_finite_check_and_unscale.out(Tensor[] self, Tensor(b!) found_inf, Tensor inv_scale, *, Tensor(a!)[] out) -> ()")
|
| 35 |
+
static void call(at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale, at::TensorList out);
|
| 36 |
+
static void redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale, at::TensorList out);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
struct TORCH_API _amp_foreach_non_finite_check_and_unscale {
|
| 40 |
+
using schema = ::std::tuple<::std::vector<at::Tensor>,at::Tensor> (at::TensorList, const at::Tensor &, const at::Tensor &);
|
| 41 |
+
using ptr_schema = schema*;
|
| 42 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 43 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_amp_foreach_non_finite_check_and_unscale")
|
| 44 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 45 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_amp_foreach_non_finite_check_and_unscale(Tensor[] self, Tensor found_inf, Tensor inv_scale) -> (Tensor[] self_out, Tensor found_inf_out)")
|
| 46 |
+
static ::std::tuple<::std::vector<at::Tensor>,at::Tensor> call(at::TensorList self, const at::Tensor & found_inf, const at::Tensor & inv_scale);
|
| 47 |
+
static ::std::tuple<::std::vector<at::Tensor>,at::Tensor> redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & found_inf, const at::Tensor & inv_scale);
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_copy_from_native.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
TORCH_API at::Tensor & _copy_from_out(const at::Tensor & self, const at::Tensor & dst, bool non_blocking, at::Tensor & out);
|
| 20 |
+
} // namespace native
|
| 21 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_empty_per_channel_affine_quantized.h
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/_empty_per_channel_affine_quantized_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
|
| 26 |
+
inline at::Tensor _empty_per_channel_affine_quantized(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, c10::optional<at::MemoryFormat> memory_format=MemoryFormat::Contiguous) {
|
| 27 |
+
return at::_ops::_empty_per_channel_affine_quantized::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
|
| 28 |
+
}
|
| 29 |
+
namespace symint {
|
| 30 |
+
template <typename T, typename = std::enable_if_t<std::is_same<T, int64_t>::value>>
|
| 31 |
+
at::Tensor _empty_per_channel_affine_quantized(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, c10::optional<at::MemoryFormat> memory_format=MemoryFormat::Contiguous) {
|
| 32 |
+
return at::_ops::_empty_per_channel_affine_quantized::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
// aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
|
| 37 |
+
inline at::Tensor _empty_per_channel_affine_quantized(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format) {
|
| 38 |
+
return at::_ops::_empty_per_channel_affine_quantized::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format);
|
| 39 |
+
}
|
| 40 |
+
namespace symint {
|
| 41 |
+
template <typename T, typename = std::enable_if_t<std::is_same<T, int64_t>::value>>
|
| 42 |
+
at::Tensor _empty_per_channel_affine_quantized(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format) {
|
| 43 |
+
return at::_ops::_empty_per_channel_affine_quantized::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format);
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
// aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
|
| 48 |
+
inline at::Tensor _empty_per_channel_affine_quantized_symint(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, c10::optional<at::MemoryFormat> memory_format=MemoryFormat::Contiguous) {
|
| 49 |
+
return at::_ops::_empty_per_channel_affine_quantized::call(size, scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
|
| 50 |
+
}
|
| 51 |
+
namespace symint {
|
| 52 |
+
template <typename T, typename = std::enable_if_t<std::is_same<T, c10::SymInt>::value>>
|
| 53 |
+
at::Tensor _empty_per_channel_affine_quantized(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, c10::optional<at::MemoryFormat> memory_format=MemoryFormat::Contiguous) {
|
| 54 |
+
return at::_ops::_empty_per_channel_affine_quantized::call(size, scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
// aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
|
| 59 |
+
inline at::Tensor _empty_per_channel_affine_quantized_symint(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format) {
|
| 60 |
+
return at::_ops::_empty_per_channel_affine_quantized::call(size, scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format);
|
| 61 |
+
}
|
| 62 |
+
namespace symint {
|
| 63 |
+
template <typename T, typename = std::enable_if_t<std::is_same<T, c10::SymInt>::value>>
|
| 64 |
+
at::Tensor _empty_per_channel_affine_quantized(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format) {
|
| 65 |
+
return at::_ops::_empty_per_channel_affine_quantized::call(size, scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format);
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
// aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
|
| 70 |
+
inline at::Tensor & _empty_per_channel_affine_quantized_out(at::Tensor & out, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::MemoryFormat> memory_format=MemoryFormat::Contiguous) {
|
| 71 |
+
return at::_ops::_empty_per_channel_affine_quantized_out::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out);
|
| 72 |
+
}
|
| 73 |
+
namespace symint {
|
| 74 |
+
template <typename T, typename = std::enable_if_t<std::is_same<T, int64_t>::value>>
|
| 75 |
+
at::Tensor & _empty_per_channel_affine_quantized_out(at::Tensor & out, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::MemoryFormat> memory_format=MemoryFormat::Contiguous) {
|
| 76 |
+
return at::_ops::_empty_per_channel_affine_quantized_out::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out);
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
// aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
|
| 81 |
+
inline at::Tensor & _empty_per_channel_affine_quantized_outf(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::MemoryFormat> memory_format, at::Tensor & out) {
|
| 82 |
+
return at::_ops::_empty_per_channel_affine_quantized_out::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out);
|
| 83 |
+
}
|
| 84 |
+
namespace symint {
|
| 85 |
+
template <typename T, typename = std::enable_if_t<std::is_same<T, int64_t>::value>>
|
| 86 |
+
at::Tensor & _empty_per_channel_affine_quantized_outf(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::MemoryFormat> memory_format, at::Tensor & out) {
|
| 87 |
+
return at::_ops::_empty_per_channel_affine_quantized_out::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out);
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
// aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
|
| 92 |
+
inline at::Tensor & _empty_per_channel_affine_quantized_symint_out(at::Tensor & out, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::MemoryFormat> memory_format=MemoryFormat::Contiguous) {
|
| 93 |
+
return at::_ops::_empty_per_channel_affine_quantized_out::call(size, scales, zero_points, axis, memory_format, out);
|
| 94 |
+
}
|
| 95 |
+
namespace symint {
|
| 96 |
+
template <typename T, typename = std::enable_if_t<std::is_same<T, c10::SymInt>::value>>
|
| 97 |
+
at::Tensor & _empty_per_channel_affine_quantized_out(at::Tensor & out, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::MemoryFormat> memory_format=MemoryFormat::Contiguous) {
|
| 98 |
+
return at::_ops::_empty_per_channel_affine_quantized_out::call(size, scales, zero_points, axis, memory_format, out);
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
// aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
|
| 103 |
+
inline at::Tensor & _empty_per_channel_affine_quantized_symint_outf(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::MemoryFormat> memory_format, at::Tensor & out) {
|
| 104 |
+
return at::_ops::_empty_per_channel_affine_quantized_out::call(size, scales, zero_points, axis, memory_format, out);
|
| 105 |
+
}
|
| 106 |
+
namespace symint {
|
| 107 |
+
template <typename T, typename = std::enable_if_t<std::is_same<T, c10::SymInt>::value>>
|
| 108 |
+
at::Tensor & _empty_per_channel_affine_quantized_outf(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional<at::MemoryFormat> memory_format, at::Tensor & out) {
|
| 109 |
+
return at::_ops::_empty_per_channel_affine_quantized_out::call(size, scales, zero_points, axis, memory_format, out);
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_fft_c2r_cuda_dispatch.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cuda {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor _fft_c2r(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, int64_t last_dim_size);
|
| 21 |
+
TORCH_API at::Tensor _fft_c2r_symint(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size);
|
| 22 |
+
TORCH_API at::Tensor & _fft_c2r_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, int64_t last_dim_size);
|
| 23 |
+
TORCH_API at::Tensor & _fft_c2r_outf(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, int64_t last_dim_size, at::Tensor & out);
|
| 24 |
+
TORCH_API at::Tensor & _fft_c2r_symint_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size);
|
| 25 |
+
TORCH_API at::Tensor & _fft_c2r_symint_outf(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size, at::Tensor & out);
|
| 26 |
+
|
| 27 |
+
} // namespace cuda
|
| 28 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_flash_attention_forward.h
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/_flash_attention_forward_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
| 26 |
+
inline ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> _flash_attention_forward(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const c10::optional<at::Tensor> & cum_seq_q, const c10::optional<at::Tensor> & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, bool return_debug_mask, c10::optional<double> scale=c10::nullopt) {
|
| 27 |
+
return at::_ops::_flash_attention_forward::call(query, key, value, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, return_debug_mask, scale);
|
| 28 |
+
}
|
| 29 |
+
namespace symint {
|
| 30 |
+
template <typename T, typename = std::enable_if_t<std::is_same<T, int64_t>::value>>
|
| 31 |
+
::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> _flash_attention_forward(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const c10::optional<at::Tensor> & cum_seq_q, const c10::optional<at::Tensor> & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, bool return_debug_mask, c10::optional<double> scale=c10::nullopt) {
|
| 32 |
+
return at::_ops::_flash_attention_forward::call(query, key, value, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, return_debug_mask, scale);
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
// aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
| 37 |
+
inline ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> _flash_attention_forward_symint(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const c10::optional<at::Tensor> & cum_seq_q, const c10::optional<at::Tensor> & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, bool return_debug_mask, c10::optional<double> scale=c10::nullopt) {
|
| 38 |
+
return at::_ops::_flash_attention_forward::call(query, key, value, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, return_debug_mask, scale);
|
| 39 |
+
}
|
| 40 |
+
namespace symint {
|
| 41 |
+
template <typename T, typename = std::enable_if_t<std::is_same<T, c10::SymInt>::value>>
|
| 42 |
+
::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> _flash_attention_forward(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const c10::optional<at::Tensor> & cum_seq_q, const c10::optional<at::Tensor> & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, bool return_debug_mask, c10::optional<double> scale=c10::nullopt) {
|
| 43 |
+
return at::_ops::_flash_attention_forward::call(query, key, value, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, return_debug_mask, scale);
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_addcmul_native.h
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
TORCH_API void _foreach_addcmul_Scalar_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value, at::TensorList out);
|
| 20 |
+
TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_scalar_slow(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1);
|
| 21 |
+
TORCH_API void foreach_tensor_addcmul_scalar_slow_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1);
|
| 22 |
+
TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_scalar_cuda(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1);
|
| 23 |
+
TORCH_API void foreach_tensor_addcmul_scalar_cuda_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1);
|
| 24 |
+
TORCH_API void _foreach_addcmul_ScalarList_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars, at::TensorList out);
|
| 25 |
+
TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_scalarlist_slow(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars);
|
| 26 |
+
TORCH_API void foreach_tensor_addcmul_scalarlist_slow_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars);
|
| 27 |
+
TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_scalarlist_cuda(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars);
|
| 28 |
+
TORCH_API void foreach_tensor_addcmul_scalarlist_cuda_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars);
|
| 29 |
+
TORCH_API void _foreach_addcmul_Tensor_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars, at::TensorList out);
|
| 30 |
+
TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_tensor_slow(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars);
|
| 31 |
+
TORCH_API void foreach_tensor_addcmul_tensor_slow_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars);
|
| 32 |
+
TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_tensor_cuda(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars);
|
| 33 |
+
TORCH_API void foreach_tensor_addcmul_tensor_cuda_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars);
|
| 34 |
+
} // namespace native
|
| 35 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_erfc_cuda_dispatch.h
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cuda {
|
| 19 |
+
|
| 20 |
+
TORCH_API ::std::vector<at::Tensor> _foreach_erfc(at::TensorList self);
|
| 21 |
+
TORCH_API void _foreach_erfc_(at::TensorList self);
|
| 22 |
+
|
| 23 |
+
} // namespace cuda
|
| 24 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_round.h
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/_foreach_round_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::_foreach_round(Tensor[] self) -> Tensor[]
|
| 26 |
+
inline ::std::vector<at::Tensor> _foreach_round(at::TensorList self) {
|
| 27 |
+
return at::_ops::_foreach_round::call(self);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// aten::_foreach_round_(Tensor(a!)[] self) -> ()
|
| 31 |
+
inline void _foreach_round_(at::TensorList self) {
|
| 32 |
+
return at::_ops::_foreach_round_::call(self);
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
// aten::_foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
|
| 36 |
+
inline void _foreach_round_out(at::TensorList out, at::TensorList self) {
|
| 37 |
+
return at::_ops::_foreach_round_out::call(self, out);
|
| 38 |
+
}
|
| 39 |
+
// aten::_foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
|
| 40 |
+
inline void _foreach_round_outf(at::TensorList self, at::TensorList out) {
|
| 41 |
+
return at::_ops::_foreach_round_out::call(self, out);
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_sin_native.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
TORCH_API void _foreach_sin_out(at::TensorList self, at::TensorList out);
|
| 20 |
+
TORCH_API ::std::vector<at::Tensor> foreach_tensor_sin_slow(at::TensorList self);
|
| 21 |
+
TORCH_API void foreach_tensor_sin_slow_(at::TensorList self);
|
| 22 |
+
TORCH_API ::std::vector<at::Tensor> foreach_tensor_sin_cuda(at::TensorList self);
|
| 23 |
+
TORCH_API void foreach_tensor_sin_cuda_(at::TensorList self);
|
| 24 |
+
} // namespace native
|
| 25 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_tanh_cpu_dispatch.h
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cpu {
|
| 19 |
+
|
| 20 |
+
TORCH_API ::std::vector<at::Tensor> _foreach_tanh(at::TensorList self);
|
| 21 |
+
TORCH_API void _foreach_tanh_(at::TensorList self);
|
| 22 |
+
|
| 23 |
+
} // namespace cpu
|
| 24 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_functional_assert_scalar_native.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
TORCH_API at::Tensor _functional_assert_scalar(const at::Scalar & self, c10::string_view assert_msg, const at::Tensor & dep_token);
|
| 20 |
+
} // namespace native
|
| 21 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_linalg_svd_ops.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API _linalg_svd {
|
| 18 |
+
using schema = ::std::tuple<at::Tensor,at::Tensor,at::Tensor> (const at::Tensor &, bool, bool, c10::optional<c10::string_view>);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_linalg_svd")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)")
|
| 24 |
+
static ::std::tuple<at::Tensor,at::Tensor,at::Tensor> call(const at::Tensor & A, bool full_matrices, bool compute_uv, c10::optional<c10::string_view> driver);
|
| 25 |
+
static ::std::tuple<at::Tensor,at::Tensor,at::Tensor> redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool full_matrices, bool compute_uv, c10::optional<c10::string_view> driver);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API _linalg_svd_U {
|
| 29 |
+
using schema = ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &> (const at::Tensor &, bool, bool, c10::optional<c10::string_view>, at::Tensor &, at::Tensor &, at::Tensor &);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_linalg_svd")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "U")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_linalg_svd.U(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)")
|
| 35 |
+
static ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &> call(const at::Tensor & A, bool full_matrices, bool compute_uv, c10::optional<c10::string_view> driver, at::Tensor & U, at::Tensor & S, at::Tensor & Vh);
|
| 36 |
+
static ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &> redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool full_matrices, bool compute_uv, c10::optional<c10::string_view> driver, at::Tensor & U, at::Tensor & S, at::Tensor & Vh);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nested_from_padded.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/_nested_from_padded_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::_nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor
|
| 26 |
+
inline at::Tensor _nested_from_padded(const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213=false) {
|
| 27 |
+
return at::_ops::_nested_from_padded::call(padded, cpu_nested_shape_example, fuse_transform_0213);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// aten::_nested_from_padded.out(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False, *, Tensor(a!) out) -> Tensor(a!)
|
| 31 |
+
inline at::Tensor & _nested_from_padded_out(at::Tensor & out, const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213=false) {
|
| 32 |
+
return at::_ops::_nested_from_padded_out::call(padded, cpu_nested_shape_example, fuse_transform_0213, out);
|
| 33 |
+
}
|
| 34 |
+
// aten::_nested_from_padded.out(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False, *, Tensor(a!) out) -> Tensor(a!)
|
| 35 |
+
inline at::Tensor & _nested_from_padded_outf(const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213, at::Tensor & out) {
|
| 36 |
+
return at::_ops::_nested_from_padded_out::call(padded, cpu_nested_shape_example, fuse_transform_0213, out);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nnpack_spatial_convolution_compositeexplicitautograd_dispatch.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeexplicitautograd {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor _nnpack_spatial_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef padding, at::IntArrayRef stride=1);
|
| 21 |
+
TORCH_API at::Tensor _nnpack_spatial_convolution_symint(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride=c10::SymInt(1));
|
| 22 |
+
TORCH_API at::Tensor & _nnpack_spatial_convolution_out(at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef padding, at::IntArrayRef stride=1);
|
| 23 |
+
TORCH_API at::Tensor & _nnpack_spatial_convolution_outf(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::Tensor & out);
|
| 24 |
+
TORCH_API at::Tensor & _nnpack_spatial_convolution_symint_out(at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride=c10::SymInt(1));
|
| 25 |
+
TORCH_API at::Tensor & _nnpack_spatial_convolution_symint_outf(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, at::Tensor & out);
|
| 26 |
+
|
| 27 |
+
} // namespace compositeexplicitautograd
|
| 28 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_prelu_kernel_backward_cuda_dispatch.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cuda {
|
| 19 |
+
|
| 20 |
+
TORCH_API ::std::tuple<at::Tensor,at::Tensor> _prelu_kernel_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight);
|
| 21 |
+
|
| 22 |
+
} // namespace cuda
|
| 23 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_transform_bias_rescale_qkv.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/_transform_bias_rescale_qkv_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::_transform_bias_rescale_qkv(Tensor qkv, Tensor qkv_bias, int num_heads) -> (Tensor, Tensor, Tensor)
|
| 26 |
+
inline ::std::tuple<at::Tensor,at::Tensor,at::Tensor> _transform_bias_rescale_qkv(const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads) {
|
| 27 |
+
return at::_ops::_transform_bias_rescale_qkv::call(qkv, qkv_bias, num_heads);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// aten::_transform_bias_rescale_qkv.out(Tensor qkv, Tensor qkv_bias, int num_heads, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
|
| 31 |
+
inline ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &> _transform_bias_rescale_qkv_out(at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads) {
|
| 32 |
+
return at::_ops::_transform_bias_rescale_qkv_out::call(qkv, qkv_bias, num_heads, out0, out1, out2);
|
| 33 |
+
}
|
| 34 |
+
// aten::_transform_bias_rescale_qkv.out(Tensor qkv, Tensor qkv_bias, int num_heads, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
|
| 35 |
+
inline ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &> _transform_bias_rescale_qkv_outf(const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
|
| 36 |
+
return at::_ops::_transform_bias_rescale_qkv_out::call(qkv, qkv_bias, num_heads, out0, out1, out2);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/acos_meta_dispatch.h
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace meta {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor acos(const at::Tensor & self);
|
| 21 |
+
TORCH_API at::Tensor & acos_out(at::Tensor & out, const at::Tensor & self);
|
| 22 |
+
TORCH_API at::Tensor & acos_outf(const at::Tensor & self, at::Tensor & out);
|
| 23 |
+
TORCH_API at::Tensor & acos_(at::Tensor & self);
|
| 24 |
+
|
| 25 |
+
} // namespace meta
|
| 26 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/alias_ops.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API alias {
|
| 18 |
+
using schema = at::Tensor (const at::Tensor &);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::alias")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "alias(Tensor(a) self) -> Tensor(a)")
|
| 24 |
+
static at::Tensor call(const at::Tensor & self);
|
| 25 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/any_meta_dispatch.h
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace meta {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor any(const at::Tensor & self, int64_t dim, bool keepdim=false);
|
| 21 |
+
TORCH_API at::Tensor & any_out(at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false);
|
| 22 |
+
TORCH_API at::Tensor & any_outf(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out);
|
| 23 |
+
TORCH_API at::Tensor any(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false);
|
| 24 |
+
TORCH_API at::Tensor & any_out(at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false);
|
| 25 |
+
TORCH_API at::Tensor & any_outf(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out);
|
| 26 |
+
TORCH_API at::Tensor any(const at::Tensor & self);
|
| 27 |
+
TORCH_API at::Tensor & any_out(at::Tensor & out, const at::Tensor & self);
|
| 28 |
+
TORCH_API at::Tensor & any_outf(const at::Tensor & self, at::Tensor & out);
|
| 29 |
+
|
| 30 |
+
} // namespace meta
|
| 31 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/batch_norm_update_stats_cuda_dispatch.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cuda {
|
| 19 |
+
|
| 20 |
+
TORCH_API ::std::tuple<at::Tensor,at::Tensor> batch_norm_update_stats(const at::Tensor & input, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, double momentum);
|
| 21 |
+
|
| 22 |
+
} // namespace cuda
|
| 23 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/ccol_indices_compositeexplicitautograd_dispatch.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeexplicitautograd {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor ccol_indices(const at::Tensor & self);
|
| 21 |
+
|
| 22 |
+
} // namespace compositeexplicitautograd
|
| 23 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/column_stack_native.h
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
TORCH_API at::Tensor column_stack(at::TensorList tensors);
|
| 20 |
+
TORCH_API at::Tensor & column_stack_out(at::TensorList tensors, at::Tensor & out);
|
| 21 |
+
} // namespace native
|
| 22 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/concatenate.h
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/concatenate_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::concatenate(Tensor[] tensors, int dim=0) -> Tensor
|
| 26 |
+
inline at::Tensor concatenate(at::TensorList tensors, int64_t dim=0) {
|
| 27 |
+
return at::_ops::concatenate::call(tensors, dim);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// aten::concatenate.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
|
| 31 |
+
inline at::Tensor & concatenate_out(at::Tensor & out, at::TensorList tensors, int64_t dim=0) {
|
| 32 |
+
return at::_ops::concatenate_out::call(tensors, dim, out);
|
| 33 |
+
}
|
| 34 |
+
// aten::concatenate.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
|
| 35 |
+
inline at::Tensor & concatenate_outf(at::TensorList tensors, int64_t dim, at::Tensor & out) {
|
| 36 |
+
return at::_ops::concatenate_out::call(tensors, dim, out);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
// aten::concatenate.names(Tensor[] tensors, Dimname dim) -> Tensor
|
| 40 |
+
inline at::Tensor concatenate(at::TensorList tensors, at::Dimname dim) {
|
| 41 |
+
return at::_ops::concatenate_names::call(tensors, dim);
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
// aten::concatenate.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)
|
| 45 |
+
inline at::Tensor & concatenate_out(at::Tensor & out, at::TensorList tensors, at::Dimname dim) {
|
| 46 |
+
return at::_ops::concatenate_names_out::call(tensors, dim, out);
|
| 47 |
+
}
|
| 48 |
+
// aten::concatenate.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)
|
| 49 |
+
inline at::Tensor & concatenate_outf(at::TensorList tensors, at::Dimname dim, at::Tensor & out) {
|
| 50 |
+
return at::_ops::concatenate_names_out::call(tensors, dim, out);
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cudnn_convolution_add_relu_ops.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API cudnn_convolution_add_relu {
|
| 18 |
+
using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Tensor &, const c10::optional<at::Scalar> &, const c10::optional<at::Tensor> &, c10::SymIntArrayRef, c10::SymIntArrayRef, c10::SymIntArrayRef, c10::SymInt);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::cudnn_convolution_add_relu")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "cudnn_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor")
|
| 24 |
+
static at::Tensor call(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const c10::optional<at::Scalar> & alpha, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups);
|
| 25 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const c10::optional<at::Scalar> & alpha, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API cudnn_convolution_add_relu_out {
|
| 29 |
+
using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Tensor &, const c10::optional<at::Scalar> &, const c10::optional<at::Tensor> &, c10::SymIntArrayRef, c10::SymIntArrayRef, c10::SymIntArrayRef, c10::SymInt, at::Tensor &);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::cudnn_convolution_add_relu")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "cudnn_convolution_add_relu.out(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)")
|
| 35 |
+
static at::Tensor & call(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const c10::optional<at::Scalar> & alpha, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out);
|
| 36 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const c10::optional<at::Scalar> & alpha, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cudnn_grid_sampler_backward_cuda_dispatch.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cuda {
|
| 19 |
+
|
| 20 |
+
TORCH_API ::std::tuple<at::Tensor,at::Tensor> cudnn_grid_sampler_backward(const at::Tensor & self, const at::Tensor & grid, const at::Tensor & grad_output);
|
| 21 |
+
|
| 22 |
+
} // namespace cuda
|
| 23 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/diag_embed_compositeexplicitautograd_dispatch.h
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeexplicitautograd {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor & diag_embed_out(at::Tensor & out, const at::Tensor & self, int64_t offset=0, int64_t dim1=-2, int64_t dim2=-1);
|
| 21 |
+
TORCH_API at::Tensor & diag_embed_outf(const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out);
|
| 22 |
+
|
| 23 |
+
} // namespace compositeexplicitautograd
|
| 24 |
+
} // namespace at
|