Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ir.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/simd.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp +354 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__pycache__/device_op_overrides.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/device_op_overrides.py +19 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/ddp_fusion.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/b2b_gemm.py +746 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/binary_folding.py +276 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/ddp_fusion.py +599 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py +153 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py +80 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py +406 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/freezing_patterns.py +227 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/fuse_attention.py +909 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py +1317 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/joint_graph.py +694 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/micro_pipeline_tp.py +854 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/misc_patterns.py +131 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py +1266 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/numeric_utils.py +212 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pad_mm.py +881 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/post_grad.py +1318 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pre_grad.py +800 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/quantization.py +2589 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/reinplace.py +688 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/replace_random.py +145 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_18.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_19.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-311.pyc +0 -0
.gitattributes
CHANGED
|
@@ -137,3 +137,8 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 137 |
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 138 |
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/lowering.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 139 |
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 138 |
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/lowering.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 139 |
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 140 |
+
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ir.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 141 |
+
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 142 |
+
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/simd.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 143 |
+
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 144 |
+
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ir.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:092aa1e8b674926d96609f7d70e837e88bb2433dce56bfb3b265696082850bf7
|
| 3 |
+
size 361762
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (196 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:088ceca24b4ba43a80ac34a889d96ba95ca4779739aea2e41a34600e2f7fd8ae
|
| 3 |
+
size 103679
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/simd.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a9dde4c92669d913e3e0a7bf7bbc82b533f3e5492ce09ea3322607c4c9cec549
|
| 3 |
+
size 106873
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:454c0716087aee149fe6ab1aaf1a2f50703e82c70c9edf9bd21bb618627510ad
|
| 3 |
+
size 176338
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1d043ca807b2f2cb8a6fa93a9c0fe15f2091627b58bf0206323eb3d52b7c26cc
|
| 3 |
+
size 122305
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h>
|
| 2 |
+
#include <torch/csrc/inductor/aoti_runtime/interface.h>
|
| 3 |
+
#include <torch/csrc/inductor/aoti_runtime/model_container.h>
|
| 4 |
+
#include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h>
|
| 5 |
+
#include <torch/csrc/inductor/aoti_runtime/thread_local.h>
|
| 6 |
+
|
| 7 |
+
#include <iostream>
|
| 8 |
+
#include <sstream>
|
| 9 |
+
#include <stdexcept>
|
| 10 |
+
#include <vector>
|
| 11 |
+
|
| 12 |
+
#define CONVERT_EXCEPTION_TO_ERROR_CODE(...) \
|
| 13 |
+
try { \
|
| 14 |
+
__VA_ARGS__ \
|
| 15 |
+
} catch (const std::exception& e) { \
|
| 16 |
+
std::cerr << "Error: " << e.what() << std::endl; \
|
| 17 |
+
return AOTI_RUNTIME_FAILURE; \
|
| 18 |
+
} catch (...) { \
|
| 19 |
+
std::cerr << "Unknown exception occurred." << std::endl; \
|
| 20 |
+
return AOTI_RUNTIME_FAILURE; \
|
| 21 |
+
} \
|
| 22 |
+
return AOTI_RUNTIME_SUCCESS;
|
| 23 |
+
|
| 24 |
+
#define AOTI_VECTOR_SIZE_CHECK(actual_size, expected_size, name) \
|
| 25 |
+
do { \
|
| 26 |
+
AOTI_RUNTIME_CHECK( \
|
| 27 |
+
actual_size == expected_size, \
|
| 28 |
+
"expected " + std::string(name) + " vector size to be " + \
|
| 29 |
+
std::to_string(expected_size) + ", but got " + \
|
| 30 |
+
std::to_string(actual_size)); \
|
| 31 |
+
} while (0)
|
| 32 |
+
|
| 33 |
+
// AOTInductor uses at::addmm_out, which doesn't supports
|
| 34 |
+
// arguments that requires gradient. For this reason, we
|
| 35 |
+
// enforce no_grad context for run APIs.
|
| 36 |
+
//
|
| 37 |
+
// A RAII, thread local (!) guard that enables or disables grad mode upon
|
| 38 |
+
// construction, and sets it back to the original value upon destruction.
|
| 39 |
+
struct AOTINoGradGuard {
|
| 40 |
+
AOTINoGradGuard() : prev_mode(aoti_torch_grad_mode_is_enabled()) {
|
| 41 |
+
aoti_torch_grad_mode_set_enabled(false);
|
| 42 |
+
}
|
| 43 |
+
~AOTINoGradGuard() {
|
| 44 |
+
aoti_torch_grad_mode_set_enabled(prev_mode);
|
| 45 |
+
}
|
| 46 |
+
bool prev_mode;
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
extern "C" {
|
| 50 |
+
|
| 51 |
+
AOTIRuntimeError AOTInductorModelContainerCreate(
|
| 52 |
+
AOTInductorModelContainerHandle* container_handle,
|
| 53 |
+
size_t num_models,
|
| 54 |
+
bool is_cpu,
|
| 55 |
+
const char* cubin_dir) {
|
| 56 |
+
return AOTInductorModelContainerCreateWithDevice(
|
| 57 |
+
container_handle,
|
| 58 |
+
num_models,
|
| 59 |
+
is_cpu ? "cpu" : "cuda",
|
| 60 |
+
cubin_dir);
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
AOTIRuntimeError AOTInductorModelContainerCreateWithDevice(
|
| 64 |
+
AOTInductorModelContainerHandle* container_handle,
|
| 65 |
+
size_t num_models,
|
| 66 |
+
const char* device_str,
|
| 67 |
+
const char* cubin_dir) {
|
| 68 |
+
if (num_models == 0) {
|
| 69 |
+
std::cerr << "Error: num_models must be positive, but got 0" << std::endl;
|
| 70 |
+
return AOTI_RUNTIME_FAILURE;
|
| 71 |
+
}
|
| 72 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 73 |
+
std::optional<std::string> cubin_dir_opt;
|
| 74 |
+
if (cubin_dir != nullptr) {
|
| 75 |
+
cubin_dir_opt.emplace(cubin_dir);
|
| 76 |
+
}
|
| 77 |
+
auto* container = new torch::aot_inductor::AOTInductorModelContainer(
|
| 78 |
+
num_models, std::string(device_str), cubin_dir_opt);
|
| 79 |
+
*container_handle =
|
| 80 |
+
reinterpret_cast<AOTInductorModelContainerHandle>(container);
|
| 81 |
+
})
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
AOTIRuntimeError AOTInductorModelContainerDelete(
|
| 85 |
+
AOTInductorModelContainerHandle container_handle) {
|
| 86 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 87 |
+
auto* container =
|
| 88 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 89 |
+
container_handle);
|
| 90 |
+
delete container;
|
| 91 |
+
});
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
AOTIRuntimeError AOTInductorModelContainerRun(
|
| 95 |
+
AOTInductorModelContainerHandle container_handle,
|
| 96 |
+
AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles
|
| 97 |
+
// are stolen; the array itself is borrowed
|
| 98 |
+
size_t num_inputs,
|
| 99 |
+
AtenTensorHandle*
|
| 100 |
+
output_handles, // array for writing output AtenTensorHandle; handles
|
| 101 |
+
// will be stolen by the caller; the array itself is
|
| 102 |
+
// borrowed
|
| 103 |
+
size_t num_outputs,
|
| 104 |
+
AOTInductorStreamHandle stream_handle,
|
| 105 |
+
AOTIProxyExecutorHandle proxy_executor_handle) {
|
| 106 |
+
auto* container =
|
| 107 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 108 |
+
container_handle);
|
| 109 |
+
AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs");
|
| 110 |
+
AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs");
|
| 111 |
+
|
| 112 |
+
auto stream =
|
| 113 |
+
reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
|
| 114 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 115 |
+
AOTINoGradGuard guard;
|
| 116 |
+
container->run(
|
| 117 |
+
input_handles, output_handles, stream, proxy_executor_handle);
|
| 118 |
+
})
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
AOTIRuntimeError AOTInductorModelContainerGetNumConstants(
|
| 122 |
+
AOTInductorModelContainerHandle container_handle,
|
| 123 |
+
size_t* num_constants) {
|
| 124 |
+
auto* container =
|
| 125 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 126 |
+
container_handle);
|
| 127 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 128 |
+
{ *num_constants = container->num_constants(); })
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
AOTIRuntimeError AOTInductorModelContainerGetConstantName(
|
| 132 |
+
AOTInductorModelContainerHandle container_handle,
|
| 133 |
+
size_t idx,
|
| 134 |
+
const char** name) {
|
| 135 |
+
auto* container =
|
| 136 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 137 |
+
container_handle);
|
| 138 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 139 |
+
{ *name = container->constant_name(idx); })
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN(
|
| 143 |
+
AOTInductorModelContainerHandle container_handle,
|
| 144 |
+
size_t idx,
|
| 145 |
+
const char** original_fqn) {
|
| 146 |
+
auto* container =
|
| 147 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 148 |
+
container_handle);
|
| 149 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 150 |
+
{ *original_fqn = container->constant_original_fqn(idx); })
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded(
|
| 154 |
+
AOTInductorModelContainerHandle container_handle,
|
| 155 |
+
size_t idx,
|
| 156 |
+
bool* from_folded) {
|
| 157 |
+
auto* container =
|
| 158 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(container_handle);
|
| 159 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({ *from_folded = container->constant_from_folded(idx); })
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
AOTIRuntimeError AOTInductorModelContainerGetConstantDtype(
|
| 163 |
+
AOTInductorModelContainerHandle container_handle,
|
| 164 |
+
size_t idx,
|
| 165 |
+
int32_t* dtype) {
|
| 166 |
+
auto* container =
|
| 167 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 168 |
+
container_handle);
|
| 169 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 170 |
+
{ *dtype = container->constant_dtype(idx); })
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer(
|
| 174 |
+
AOTInductorModelContainerHandle container_handle,
|
| 175 |
+
AOTInductorConstantMapHandle constant_map_handle,
|
| 176 |
+
bool use_inactive,
|
| 177 |
+
bool validate_full_update) {
|
| 178 |
+
auto* container =
|
| 179 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 180 |
+
container_handle);
|
| 181 |
+
auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
|
| 182 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 183 |
+
container->update_constant_buffer(
|
| 184 |
+
*input_map, use_inactive, validate_full_update);
|
| 185 |
+
})
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer(
|
| 189 |
+
AOTInductorModelContainerHandle container_handle,
|
| 190 |
+
AOTInductorConstantMapHandle constant_map_handle) {
|
| 191 |
+
return AOTInductorModelContainerUpdateConstantBuffer(container_handle,
|
| 192 |
+
constant_map_handle,
|
| 193 |
+
/*use_inactive*/ true,
|
| 194 |
+
/*validate_full_update*/ true);
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
AOTIRuntimeError AOTInductorModelContainerRunConstantFolding(
|
| 198 |
+
AOTInductorModelContainerHandle container_handle,
|
| 199 |
+
bool use_inactive,
|
| 200 |
+
AOTInductorStreamHandle stream_handle,
|
| 201 |
+
AOTIProxyExecutorHandle proxy_executor_handle) {
|
| 202 |
+
auto* container =
|
| 203 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 204 |
+
container_handle);
|
| 205 |
+
auto stream =
|
| 206 |
+
reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
|
| 207 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 208 |
+
AOTINoGradGuard guard;
|
| 209 |
+
container->run_const_fold(use_inactive, stream, proxy_executor_handle);
|
| 210 |
+
})
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer(
|
| 214 |
+
AOTInductorModelContainerHandle container_handle) {
|
| 215 |
+
auto* container =
|
| 216 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 217 |
+
container_handle);
|
| 218 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 219 |
+
container->swap_constant_buffer();
|
| 220 |
+
})
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
AOTIRuntimeError AOTInductorModelContainerGetNumInputs(
|
| 224 |
+
AOTInductorModelContainerHandle container_handle,
|
| 225 |
+
size_t* ret_num_inputs) {
|
| 226 |
+
auto* container =
|
| 227 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 228 |
+
container_handle);
|
| 229 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 230 |
+
{ *ret_num_inputs = container->num_inputs(); })
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
AOTIRuntimeError AOTInductorModelContainerGetInputName(
|
| 234 |
+
AOTInductorModelContainerHandle container_handle,
|
| 235 |
+
size_t input_idx,
|
| 236 |
+
const char** ret_input_names) {
|
| 237 |
+
auto* container =
|
| 238 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 239 |
+
container_handle);
|
| 240 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 241 |
+
{ *ret_input_names = container->input_name(input_idx); })
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
AOTIRuntimeError AOTInductorModelContainerGetNumOutputs(
|
| 245 |
+
AOTInductorModelContainerHandle container_handle,
|
| 246 |
+
size_t* ret_num_outputs) {
|
| 247 |
+
auto* container =
|
| 248 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 249 |
+
container_handle);
|
| 250 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 251 |
+
{ *ret_num_outputs = container->num_outputs(); })
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
AOTIRuntimeError AOTInductorModelContainerGetOutputName(
|
| 255 |
+
AOTInductorModelContainerHandle container_handle,
|
| 256 |
+
size_t output_idx,
|
| 257 |
+
const char** ret_output_names) {
|
| 258 |
+
auto* container =
|
| 259 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 260 |
+
container_handle);
|
| 261 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
| 262 |
+
{ *ret_output_names = container->output_name(output_idx); })
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
AOTIRuntimeError AOTInductorModelContainerGetCallSpec(
|
| 266 |
+
AOTInductorModelContainerHandle container_handle,
|
| 267 |
+
const char** in_spec,
|
| 268 |
+
const char** out_spec) {
|
| 269 |
+
auto* container =
|
| 270 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
| 271 |
+
container_handle);
|
| 272 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 273 |
+
*in_spec = container->get_in_spec();
|
| 274 |
+
*out_spec = container->get_out_spec();
|
| 275 |
+
})
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
AOTIRuntimeError AOTInductorModelCreate(
|
| 279 |
+
AOTInductorModelHandle* model_handle,
|
| 280 |
+
AOTInductorConstantMapHandle constant_map_handle){
|
| 281 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 282 |
+
auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
|
| 283 |
+
auto constant_array = std::make_shared<std::vector<torch::aot_inductor::ConstantHandle>>();
|
| 284 |
+
auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
|
| 285 |
+
|
| 286 |
+
auto model = new torch::aot_inductor::AOTInductorModel(
|
| 287 |
+
constant_map,
|
| 288 |
+
constant_array,
|
| 289 |
+
"cpu", // device_str is hardcoded, as AOTInductorModelCreate is only use for CPU models
|
| 290 |
+
""
|
| 291 |
+
);
|
| 292 |
+
|
| 293 |
+
if (input_map) {
|
| 294 |
+
for (auto const& kv : *input_map) {
|
| 295 |
+
constant_map->emplace(kv.first, kv.second);
|
| 296 |
+
}
|
| 297 |
+
} else {
|
| 298 |
+
model->load_constants();
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
*model_handle = reinterpret_cast<AOTInductorModelHandle>(model);
|
| 302 |
+
})}
|
| 303 |
+
|
| 304 |
+
AOTIRuntimeError AOTInductorModelRun(
|
| 305 |
+
AOTInductorModelHandle model_handle,
|
| 306 |
+
AtenTensorHandle* input_handles,
|
| 307 |
+
AtenTensorHandle* output_handles) {
|
| 308 |
+
auto model =
|
| 309 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
|
| 310 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 311 |
+
AOTINoGradGuard guard;
|
| 312 |
+
model->run_impl(
|
| 313 |
+
input_handles,
|
| 314 |
+
output_handles,
|
| 315 |
+
(torch::aot_inductor::DeviceStreamType) nullptr,
|
| 316 |
+
nullptr);
|
| 317 |
+
})
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle){
|
| 321 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 322 |
+
auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(
|
| 323 |
+
model_handle);
|
| 324 |
+
delete model;
|
| 325 |
+
})}
|
| 326 |
+
|
| 327 |
+
AOTIRuntimeError AOTInductorModelGetNumOutputs(
|
| 328 |
+
AOTInductorModelHandle model_handle,
|
| 329 |
+
size_t* ret_num_outputs) {
|
| 330 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 331 |
+
auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
|
| 332 |
+
*ret_num_outputs = model->num_outputs();
|
| 333 |
+
})
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
AOTIRuntimeError AOTInductorModelUpdateConstantsMap(
|
| 337 |
+
AOTInductorModelHandle model_handle,
|
| 338 |
+
AOTInductorConstantMapHandle constant_map_handle) {
|
| 339 |
+
auto model =
|
| 340 |
+
reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
|
| 341 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 342 |
+
auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
|
| 343 |
+
auto input_map =
|
| 344 |
+
reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(
|
| 345 |
+
constant_map_handle);
|
| 346 |
+
|
| 347 |
+
for (auto const& kv : *input_map) {
|
| 348 |
+
constant_map->emplace(kv.first, kv.second);
|
| 349 |
+
}
|
| 350 |
+
model->update_constants_map(std::move(constant_map));
|
| 351 |
+
})
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
} // extern "C"
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (200 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__pycache__/device_op_overrides.cpython-311.pyc
ADDED
|
Binary file (1.48 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/device_op_overrides.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from ..common import DeviceOpOverrides, register_device_op_overrides
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class XPUDeviceOpOverrides(DeviceOpOverrides):
|
| 6 |
+
def import_get_raw_stream_as(self, name):
|
| 7 |
+
return f"from torch._C import _xpu_getCurrentRawStream as {name}"
|
| 8 |
+
|
| 9 |
+
def set_device(self, device_idx):
|
| 10 |
+
return f"torch.xpu.set_device({device_idx})"
|
| 11 |
+
|
| 12 |
+
def synchronize(self):
|
| 13 |
+
return "torch.xpu.synchronize()"
|
| 14 |
+
|
| 15 |
+
def device_guard(self, device_idx):
|
| 16 |
+
return f"torch.xpu._DeviceGuard({device_idx})"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
register_device_op_overrides("xpu", XPUDeviceOpOverrides())
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/ddp_fusion.cpython-311.pyc
ADDED
|
Binary file (30.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/b2b_gemm.py
ADDED
|
@@ -0,0 +1,746 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
from collections import deque
|
| 4 |
+
from typing import Dict, List, Set, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils._pytree import tree_map
|
| 8 |
+
|
| 9 |
+
from ..._dynamo.utils import counters
|
| 10 |
+
from ..ir import (
|
| 11 |
+
ComputedBuffer,
|
| 12 |
+
FixedLayout,
|
| 13 |
+
FlexibleLayout,
|
| 14 |
+
InputBuffer,
|
| 15 |
+
StorageBox,
|
| 16 |
+
Subgraph,
|
| 17 |
+
TensorBox,
|
| 18 |
+
)
|
| 19 |
+
from ..lowering import lowerings
|
| 20 |
+
from ..pattern_matcher import (
|
| 21 |
+
Arg,
|
| 22 |
+
CallFunction,
|
| 23 |
+
Match,
|
| 24 |
+
PatternMatcherPass,
|
| 25 |
+
register_graph_pattern,
|
| 26 |
+
)
|
| 27 |
+
from ..select_algorithm import (
|
| 28 |
+
autotune_select_algorithm,
|
| 29 |
+
ExternKernelChoice,
|
| 30 |
+
TritonTemplate,
|
| 31 |
+
TritonTemplateCaller,
|
| 32 |
+
)
|
| 33 |
+
from ..utils import ceildiv
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
B2B_GEMM_PASS = PatternMatcherPass(
|
| 37 |
+
pass_name="b2b_gemm_pass",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def b2b_gemm_grid(M, P, meta):
|
| 42 |
+
return (ceildiv(M, meta["BLOCK_SIZE_M"]) * ceildiv(P, meta["BLOCK_SIZE_P"]), 1, 1)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
b2b_gemm_left_template = TritonTemplate(
|
| 46 |
+
name="b2b_gemm_left",
|
| 47 |
+
grid=b2b_gemm_grid,
|
| 48 |
+
debug=False,
|
| 49 |
+
source=r"""
|
| 50 |
+
{{def_kernel("A", "B", "C")}}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# B2B_GEMM_LEFT_TRITON_ENTRANCE
|
| 54 |
+
|
| 55 |
+
# dynamic shapes
|
| 56 |
+
M = {{size("A", 0)}}
|
| 57 |
+
N = {{size("A", 1)}}
|
| 58 |
+
O = {{size("C", 0)}}
|
| 59 |
+
P = {{size("C", 1)}}
|
| 60 |
+
|
| 61 |
+
# dynamic strides
|
| 62 |
+
stride_am = {{stride("A", 0)}}
|
| 63 |
+
stride_an = {{stride("A", 1)}}
|
| 64 |
+
stride_bn = {{stride("B", 0)}}
|
| 65 |
+
stride_bo = {{stride("B", 1)}}
|
| 66 |
+
stride_co = {{stride("C", 0)}}
|
| 67 |
+
stride_cp = {{stride("C", 1)}}
|
| 68 |
+
|
| 69 |
+
# output block counts
|
| 70 |
+
num_m_block = tl.cdiv(M, BLOCK_SIZE_M)
|
| 71 |
+
num_p_block = tl.cdiv(P, BLOCK_SIZE_P)
|
| 72 |
+
|
| 73 |
+
# internal block counts
|
| 74 |
+
num_n_block = tl.cdiv(N, BLOCK_SIZE_N)
|
| 75 |
+
num_o_block = tl.cdiv(O, BLOCK_SIZE_O)
|
| 76 |
+
|
| 77 |
+
# output block ids
|
| 78 |
+
pid = tl.program_id(axis=0)
|
| 79 |
+
m_block_id = pid // num_p_block
|
| 80 |
+
p_block_id = pid % num_p_block
|
| 81 |
+
|
| 82 |
+
# accumulator
|
| 83 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P), dtype=tl.float32)
|
| 84 |
+
|
| 85 |
+
# main loop
|
| 86 |
+
offs_m = (m_block_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))
|
| 87 |
+
offs_p = (p_block_id * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P))
|
| 88 |
+
# (subgraph(A @ B) @ C)
|
| 89 |
+
offs_o = tl.arange(0, BLOCK_SIZE_O)
|
| 90 |
+
for _ in range(num_o_block):
|
| 91 |
+
c_mask = (offs_o[:, None] < O) & (offs_p[None, :] < P)
|
| 92 |
+
c_ptrs = C + (offs_o[:, None] * stride_co + offs_p[None, :] * stride_cp)
|
| 93 |
+
c = tl.load(c_ptrs, mask=c_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_O * BLOCK_SIZE_P
|
| 94 |
+
acc_ab = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_O), dtype=tl.float32)
|
| 95 |
+
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
| 96 |
+
for __ in range(num_n_block):
|
| 97 |
+
a_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
| 98 |
+
a_ptrs = A + (offs_m[:, None] * stride_am + offs_n[None, :] * stride_an)
|
| 99 |
+
a = tl.load(a_ptrs, mask=a_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_M * BLOCK_SIZE_N
|
| 100 |
+
b_mask = (offs_n[:, None] < N) & (offs_o[None, :] < O)
|
| 101 |
+
b_ptrs = B + (offs_n[:, None] * stride_bn + offs_o[None, :] * stride_bo)
|
| 102 |
+
b = tl.load(b_ptrs, mask=b_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_N * BLOCK_SIZE_O
|
| 103 |
+
acc_ab += tl.dot(a, b, out_dtype=tl.float32)
|
| 104 |
+
offs_n += BLOCK_SIZE_N
|
| 105 |
+
# apply the subgraph
|
| 106 |
+
{{ modification(
|
| 107 |
+
subgraph_number=0,
|
| 108 |
+
output_name="post_subgraph_acc_ab",
|
| 109 |
+
inner_mm="acc_ab"
|
| 110 |
+
) | indent_except_first(2) }}
|
| 111 |
+
acc += tl.dot(post_subgraph_acc_ab, c, out_dtype=tl.float32)
|
| 112 |
+
offs_o += BLOCK_SIZE_O
|
| 113 |
+
|
| 114 |
+
# type conversion
|
| 115 |
+
acc = acc.to(tl.float16)
|
| 116 |
+
|
| 117 |
+
# store preparation
|
| 118 |
+
idx_m = offs_m[:, None]
|
| 119 |
+
idx_p = offs_p[None, :]
|
| 120 |
+
out_mask = (idx_m < M) & (idx_p < P)
|
| 121 |
+
|
| 122 |
+
{{store_output(("idx_m", "idx_p"), "acc", "out_mask")}}
|
| 123 |
+
""",
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
b2b_gemm_right_template = TritonTemplate(
|
| 128 |
+
name="b2b_gemm_right",
|
| 129 |
+
grid=b2b_gemm_grid,
|
| 130 |
+
debug=False,
|
| 131 |
+
source=r"""
|
| 132 |
+
{{def_kernel("A", "B", "C")}}
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# B2B_GEMM_RIGHT_TRITON_ENTRANCE
|
| 136 |
+
|
| 137 |
+
# dynamic shapes
|
| 138 |
+
M = {{size("A", 0)}}
|
| 139 |
+
N = {{size("A", 1)}}
|
| 140 |
+
O = {{size("C", 0)}}
|
| 141 |
+
P = {{size("C", 1)}}
|
| 142 |
+
|
| 143 |
+
# dynamic strides
|
| 144 |
+
stride_am = {{stride("A", 0)}}
|
| 145 |
+
stride_an = {{stride("A", 1)}}
|
| 146 |
+
stride_bn = {{stride("B", 0)}}
|
| 147 |
+
stride_bo = {{stride("B", 1)}}
|
| 148 |
+
stride_co = {{stride("C", 0)}}
|
| 149 |
+
stride_cp = {{stride("C", 1)}}
|
| 150 |
+
|
| 151 |
+
# output block counts
|
| 152 |
+
num_m_block = tl.cdiv(M, BLOCK_SIZE_M)
|
| 153 |
+
num_p_block = tl.cdiv(P, BLOCK_SIZE_P)
|
| 154 |
+
|
| 155 |
+
# internal block counts
|
| 156 |
+
num_n_block = tl.cdiv(N, BLOCK_SIZE_N)
|
| 157 |
+
num_o_block = tl.cdiv(O, BLOCK_SIZE_O)
|
| 158 |
+
|
| 159 |
+
# output block ids
|
| 160 |
+
pid = tl.program_id(axis=0)
|
| 161 |
+
m_block_id = pid // num_p_block
|
| 162 |
+
p_block_id = pid % num_p_block
|
| 163 |
+
|
| 164 |
+
# accumulator
|
| 165 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P), dtype=tl.float32)
|
| 166 |
+
|
| 167 |
+
# main loop (two cases)
|
| 168 |
+
offs_m = (m_block_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))
|
| 169 |
+
offs_p = (p_block_id * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P))
|
| 170 |
+
# (A @ subgraph(B @ C))
|
| 171 |
+
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
| 172 |
+
for _ in range(num_n_block):
|
| 173 |
+
a_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
| 174 |
+
a_ptrs = A + (offs_m[:, None] * stride_am + offs_n[None, :] * stride_an)
|
| 175 |
+
a = tl.load(a_ptrs, mask=a_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_M * BLOCK_SIZE_N
|
| 176 |
+
acc_bc = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_P), dtype=tl.float32)
|
| 177 |
+
offs_o = tl.arange(0, BLOCK_SIZE_O)
|
| 178 |
+
for __ in range(num_o_block):
|
| 179 |
+
b_mask = (offs_n[:, None] < N) & (offs_o[None, :] < O)
|
| 180 |
+
b_ptrs = B + (offs_n[:, None] * stride_bn + offs_o[None, :] * stride_bo)
|
| 181 |
+
b = tl.load(b_ptrs, mask=b_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_N * BLOCK_SIZE_O
|
| 182 |
+
c_mask = (offs_o[:, None] < O) & (offs_p[None, :] < P)
|
| 183 |
+
c_ptrs = C + (offs_o[:, None] * stride_co + offs_p[None, :] * stride_cp)
|
| 184 |
+
c = tl.load(c_ptrs, mask=c_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_O * BLOCK_SIZE_P
|
| 185 |
+
acc_bc += tl.dot(b, c, out_dtype=tl.float32)
|
| 186 |
+
offs_o += BLOCK_SIZE_O
|
| 187 |
+
# apply the subgraph
|
| 188 |
+
{{ modification(
|
| 189 |
+
subgraph_number=0,
|
| 190 |
+
output_name="post_subgraph_acc_bc",
|
| 191 |
+
inner_mm="acc_bc"
|
| 192 |
+
) | indent_except_first(2) }}
|
| 193 |
+
acc += tl.dot(a, post_subgraph_acc_bc, out_dtype=tl.float32)
|
| 194 |
+
offs_n += BLOCK_SIZE_N
|
| 195 |
+
|
| 196 |
+
# type conversion
|
| 197 |
+
acc = acc.to(tl.float16)
|
| 198 |
+
|
| 199 |
+
# store preparation
|
| 200 |
+
idx_m = offs_m[:, None]
|
| 201 |
+
idx_p = offs_p[None, :]
|
| 202 |
+
out_mask = (idx_m < M) & (idx_p < P)
|
| 203 |
+
|
| 204 |
+
{{store_output(("idx_m", "idx_p"), "acc", "out_mask")}}
|
| 205 |
+
""",
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# Note: load_ratio_left and load_ratio_right are only calculating numbers
|
| 210 |
+
# in the trivial subgraph case; i.e. (A @ (B @ C)) or ((A @ B) @ C)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def load_ratio_left(
|
| 214 |
+
M: int, N: int, O: int, P: int, m: int, n: int, o: int, p: int
|
| 215 |
+
) -> float:
|
| 216 |
+
"""
|
| 217 |
+
compute the ratio of estimated numbers of loads in baseline and b2bgemm
|
| 218 |
+
M, N, O, P are matrix sizes
|
| 219 |
+
m, n, o, p are block sizes
|
| 220 |
+
| | baseline (lower bound) | b2bgemm
|
| 221 |
+
| load | M * N + N * O + M * O + O * P | M / m * P / p * O / o * (o * p + N / n * (m * n + n * o))
|
| 222 |
+
| store | M * O + M * P | M * P
|
| 223 |
+
b2bgemm is always better on stores, but for loads we need to find out beneficial cases using this function
|
| 224 |
+
"""
|
| 225 |
+
base = M * N + N * O + M * O + O * P
|
| 226 |
+
gemm = (
|
| 227 |
+
ceildiv(M, m)
|
| 228 |
+
* ceildiv(P, p)
|
| 229 |
+
* ceildiv(O, o)
|
| 230 |
+
* (o * p + ceildiv(N, n) * (m * n + n * o))
|
| 231 |
+
)
|
| 232 |
+
return base / gemm
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def load_ratio_right(
|
| 236 |
+
M: int, N: int, O: int, P: int, m: int, n: int, o: int, p: int
|
| 237 |
+
) -> float:
|
| 238 |
+
"""
|
| 239 |
+
compute the ratio of estimated numbers of loads in baseline and b2bgemm
|
| 240 |
+
M, N, O, P are matrix sizes
|
| 241 |
+
m, n, o, p are block sizes
|
| 242 |
+
| | baseline (lower bound) | b2bgemm
|
| 243 |
+
| load | N * O + O * P + M * N + N * P | M / m * P / p * N / n * (m * n + O / o * (n * o + o * p))
|
| 244 |
+
| store | N * P + M * P | M * P
|
| 245 |
+
b2bgemm is always better on stores, but for loads we need to find out beneficial cases using this function
|
| 246 |
+
"""
|
| 247 |
+
base = N * O + O * P + M * N + N * P
|
| 248 |
+
gemm = (
|
| 249 |
+
ceildiv(M, m)
|
| 250 |
+
* ceildiv(P, p)
|
| 251 |
+
* ceildiv(N, n)
|
| 252 |
+
* (m * n + ceildiv(O, o) * (n * o + o * p))
|
| 253 |
+
)
|
| 254 |
+
return base / gemm
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# the block sizes are limited by hardware (the shared memory)
|
| 258 |
+
# intuitively, the optimization works when the intermediate matrix is large
|
| 259 |
+
# and we assign large block sizes to large dimensions
|
| 260 |
+
b2b_gemm_configs = [
|
| 261 |
+
{
|
| 262 |
+
"BLOCK_SIZE_M": 128,
|
| 263 |
+
"BLOCK_SIZE_N": 16,
|
| 264 |
+
"BLOCK_SIZE_O": 16,
|
| 265 |
+
"BLOCK_SIZE_P": 16,
|
| 266 |
+
"num_stages": 4,
|
| 267 |
+
"num_warps": 8,
|
| 268 |
+
},
|
| 269 |
+
{
|
| 270 |
+
"BLOCK_SIZE_M": 128,
|
| 271 |
+
"BLOCK_SIZE_N": 32,
|
| 272 |
+
"BLOCK_SIZE_O": 32,
|
| 273 |
+
"BLOCK_SIZE_P": 32,
|
| 274 |
+
"num_stages": 2,
|
| 275 |
+
"num_warps": 4,
|
| 276 |
+
},
|
| 277 |
+
{
|
| 278 |
+
"BLOCK_SIZE_M": 128,
|
| 279 |
+
"BLOCK_SIZE_N": 64,
|
| 280 |
+
"BLOCK_SIZE_O": 64,
|
| 281 |
+
"BLOCK_SIZE_P": 64,
|
| 282 |
+
"num_stages": 2,
|
| 283 |
+
"num_warps": 4,
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
"BLOCK_SIZE_M": 128,
|
| 287 |
+
"BLOCK_SIZE_N": 16,
|
| 288 |
+
"BLOCK_SIZE_O": 128,
|
| 289 |
+
"BLOCK_SIZE_P": 16,
|
| 290 |
+
"num_stages": 4,
|
| 291 |
+
"num_warps": 8,
|
| 292 |
+
},
|
| 293 |
+
{
|
| 294 |
+
"BLOCK_SIZE_M": 128,
|
| 295 |
+
"BLOCK_SIZE_N": 32,
|
| 296 |
+
"BLOCK_SIZE_O": 128,
|
| 297 |
+
"BLOCK_SIZE_P": 32,
|
| 298 |
+
"num_stages": 2,
|
| 299 |
+
"num_warps": 4,
|
| 300 |
+
},
|
| 301 |
+
{
|
| 302 |
+
"BLOCK_SIZE_M": 128,
|
| 303 |
+
"BLOCK_SIZE_N": 64,
|
| 304 |
+
"BLOCK_SIZE_O": 128,
|
| 305 |
+
"BLOCK_SIZE_P": 64,
|
| 306 |
+
"num_stages": 2,
|
| 307 |
+
"num_warps": 4,
|
| 308 |
+
},
|
| 309 |
+
{
|
| 310 |
+
"BLOCK_SIZE_M": 16,
|
| 311 |
+
"BLOCK_SIZE_N": 16,
|
| 312 |
+
"BLOCK_SIZE_O": 16,
|
| 313 |
+
"BLOCK_SIZE_P": 128,
|
| 314 |
+
"num_stages": 4,
|
| 315 |
+
"num_warps": 8,
|
| 316 |
+
},
|
| 317 |
+
{
|
| 318 |
+
"BLOCK_SIZE_M": 32,
|
| 319 |
+
"BLOCK_SIZE_N": 32,
|
| 320 |
+
"BLOCK_SIZE_O": 32,
|
| 321 |
+
"BLOCK_SIZE_P": 128,
|
| 322 |
+
"num_stages": 2,
|
| 323 |
+
"num_warps": 4,
|
| 324 |
+
},
|
| 325 |
+
{
|
| 326 |
+
"BLOCK_SIZE_M": 64,
|
| 327 |
+
"BLOCK_SIZE_N": 64,
|
| 328 |
+
"BLOCK_SIZE_O": 64,
|
| 329 |
+
"BLOCK_SIZE_P": 128,
|
| 330 |
+
"num_stages": 2,
|
| 331 |
+
"num_warps": 4,
|
| 332 |
+
},
|
| 333 |
+
{
|
| 334 |
+
"BLOCK_SIZE_M": 16,
|
| 335 |
+
"BLOCK_SIZE_N": 128,
|
| 336 |
+
"BLOCK_SIZE_O": 16,
|
| 337 |
+
"BLOCK_SIZE_P": 128,
|
| 338 |
+
"num_stages": 4,
|
| 339 |
+
"num_warps": 8,
|
| 340 |
+
},
|
| 341 |
+
{
|
| 342 |
+
"BLOCK_SIZE_M": 32,
|
| 343 |
+
"BLOCK_SIZE_N": 128,
|
| 344 |
+
"BLOCK_SIZE_O": 32,
|
| 345 |
+
"BLOCK_SIZE_P": 128,
|
| 346 |
+
"num_stages": 2,
|
| 347 |
+
"num_warps": 4,
|
| 348 |
+
},
|
| 349 |
+
{
|
| 350 |
+
"BLOCK_SIZE_M": 64,
|
| 351 |
+
"BLOCK_SIZE_N": 128,
|
| 352 |
+
"BLOCK_SIZE_O": 64,
|
| 353 |
+
"BLOCK_SIZE_P": 128,
|
| 354 |
+
"num_stages": 2,
|
| 355 |
+
"num_warps": 4,
|
| 356 |
+
},
|
| 357 |
+
]
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def is_b2b_gemm_good_on(
|
| 361 |
+
is_left_assoc: bool,
|
| 362 |
+
A_node: torch.fx.Node,
|
| 363 |
+
B_node: torch.fx.Node,
|
| 364 |
+
C_node: torch.fx.Node,
|
| 365 |
+
) -> bool:
|
| 366 |
+
"""
|
| 367 |
+
checks whether the sizes are good for b2b_gemm
|
| 368 |
+
"""
|
| 369 |
+
# basic checks
|
| 370 |
+
if not all(["val" in A_node.meta, "val" in B_node.meta, "val" in C_node.meta]):
|
| 371 |
+
return False
|
| 372 |
+
A, B, C = (
|
| 373 |
+
A_node.meta["val"],
|
| 374 |
+
B_node.meta["val"],
|
| 375 |
+
C_node.meta["val"],
|
| 376 |
+
) # torch._subclasses.fake_tensor.FakeTensor
|
| 377 |
+
if not all([A.is_cuda, B.is_cuda, C.is_cuda]):
|
| 378 |
+
return False
|
| 379 |
+
if not all([len(A.shape) == 2, len(B.shape) == 2, len(C.shape) == 2]):
|
| 380 |
+
return False
|
| 381 |
+
if not ((A.shape[1] == B.shape[0]) and (B.shape[1] == C.shape[0])):
|
| 382 |
+
return False
|
| 383 |
+
# size checks: we only dispatch to B2B-GEMM when the average load ratio is > 1
|
| 384 |
+
M, N = A.shape
|
| 385 |
+
O, P = C.shape
|
| 386 |
+
ratios = []
|
| 387 |
+
if is_left_assoc:
|
| 388 |
+
for config in b2b_gemm_configs:
|
| 389 |
+
ratio = load_ratio_left(
|
| 390 |
+
M,
|
| 391 |
+
N,
|
| 392 |
+
O,
|
| 393 |
+
P,
|
| 394 |
+
config["BLOCK_SIZE_M"],
|
| 395 |
+
config["BLOCK_SIZE_N"],
|
| 396 |
+
config["BLOCK_SIZE_O"],
|
| 397 |
+
config["BLOCK_SIZE_P"],
|
| 398 |
+
)
|
| 399 |
+
ratios.append(ratio)
|
| 400 |
+
else:
|
| 401 |
+
for config in b2b_gemm_configs:
|
| 402 |
+
ratio = load_ratio_right(
|
| 403 |
+
M,
|
| 404 |
+
N,
|
| 405 |
+
O,
|
| 406 |
+
P,
|
| 407 |
+
config["BLOCK_SIZE_M"],
|
| 408 |
+
config["BLOCK_SIZE_N"],
|
| 409 |
+
config["BLOCK_SIZE_O"],
|
| 410 |
+
config["BLOCK_SIZE_P"],
|
| 411 |
+
)
|
| 412 |
+
ratios.append(ratio)
|
| 413 |
+
ratios.sort(reverse=True)
|
| 414 |
+
average_ratio = 1.0
|
| 415 |
+
for r in ratios[:3]: # top 3 choices
|
| 416 |
+
average_ratio *= r
|
| 417 |
+
average_ratio = average_ratio ** (1 / 3)
|
| 418 |
+
return (
|
| 419 |
+
average_ratio > 1
|
| 420 |
+
) # even if average_ratio is close to 1, the number of stores is always better
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def unoptimized_b2b_gemm(
|
| 424 |
+
is_left_assoc: bool,
|
| 425 |
+
subgraph: Subgraph,
|
| 426 |
+
A: torch.Tensor,
|
| 427 |
+
B: torch.Tensor,
|
| 428 |
+
C: torch.Tensor,
|
| 429 |
+
*,
|
| 430 |
+
out: torch.Tensor,
|
| 431 |
+
) -> torch.Tensor:
|
| 432 |
+
"""
|
| 433 |
+
The unoptimized version is used as a fallback when the b2b_gemm kernel is not beneficial.
|
| 434 |
+
"""
|
| 435 |
+
if is_left_assoc:
|
| 436 |
+
torch.mm(subgraph.graph_module(torch.mm(A, B)), C, out=out)
|
| 437 |
+
else:
|
| 438 |
+
torch.mm(A, subgraph.graph_module(torch.mm(B, C)), out=out)
|
| 439 |
+
return out
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
unoptimized_choice = ExternKernelChoice(unoptimized_b2b_gemm)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def build_subgraph_buffer(
|
| 446 |
+
args: List[TensorBox],
|
| 447 |
+
subgraph: Subgraph,
|
| 448 |
+
):
|
| 449 |
+
"""
|
| 450 |
+
This function is adapted from ../kernel/flex_attention.py.
|
| 451 |
+
The goal is to take in the required args and produce the subgraph buffer
|
| 452 |
+
The subgraph buffer is a ComputedBuffer that will be inlined into the triton template
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
args: The args that are passed into the subgraph
|
| 456 |
+
subgraph: The Subgraph ir for which to produce the output node
|
| 457 |
+
"""
|
| 458 |
+
cnt = 0
|
| 459 |
+
env = {}
|
| 460 |
+
for node in subgraph.graph_module.graph.nodes:
|
| 461 |
+
if node.op == "placeholder":
|
| 462 |
+
env[node] = args[cnt]
|
| 463 |
+
cnt += 1
|
| 464 |
+
elif node.op == "call_function":
|
| 465 |
+
# For call_function we use the default lowerings and pass in the
|
| 466 |
+
# already created TensorBoxes as args
|
| 467 |
+
args, kwargs = tree_map(
|
| 468 |
+
lambda x: env[x] if x in env else x, (node.args, node.kwargs)
|
| 469 |
+
)
|
| 470 |
+
env[node] = lowerings[node.target](*args, **kwargs)
|
| 471 |
+
elif node.op == "output":
|
| 472 |
+
|
| 473 |
+
def convert_output_node_to_buffer(output):
|
| 474 |
+
if output is None:
|
| 475 |
+
return None
|
| 476 |
+
output_node = output
|
| 477 |
+
output_buffer = env[output_node]
|
| 478 |
+
assert isinstance(output_buffer, TensorBox), (
|
| 479 |
+
"The output node for B2B-GEMM's subgraph must be a TensorBox, but got: ",
|
| 480 |
+
type(output_buffer),
|
| 481 |
+
)
|
| 482 |
+
assert isinstance(output_buffer.data, StorageBox), (
|
| 483 |
+
"The output node for B2B-GEMM's subgraph must be a StorageBox, but got: ",
|
| 484 |
+
type(output_buffer),
|
| 485 |
+
)
|
| 486 |
+
subgraph_buffer = ComputedBuffer(
|
| 487 |
+
name=None,
|
| 488 |
+
layout=FlexibleLayout(
|
| 489 |
+
device=output_buffer.data.get_device(),
|
| 490 |
+
dtype=output_buffer.data.get_dtype(),
|
| 491 |
+
size=output_buffer.data.get_size(),
|
| 492 |
+
),
|
| 493 |
+
data=output_buffer.data.data, # type: ignore[arg-type]
|
| 494 |
+
)
|
| 495 |
+
return subgraph_buffer
|
| 496 |
+
|
| 497 |
+
# node.args[0] should be a single element representing the output of the subgraph
|
| 498 |
+
return tree_map(convert_output_node_to_buffer, node.args[0])
|
| 499 |
+
|
| 500 |
+
raise ValueError("B2B-GEMM was passed a subgraph with no output node!")
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def create_placeholder(
|
| 504 |
+
name: str, dtype: torch.dtype, device: torch.device
|
| 505 |
+
) -> TensorBox:
|
| 506 |
+
"""
|
| 507 |
+
Creates a placeholder input buffers for producing subgraph_output
|
| 508 |
+
"""
|
| 509 |
+
input_buffer = InputBuffer(name, FixedLayout(device, dtype, [], []))
|
| 510 |
+
return TensorBox.create(input_buffer)
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def tuned_b2b_gemm(
|
| 514 |
+
is_left_assoc: bool,
|
| 515 |
+
subgraph: Subgraph,
|
| 516 |
+
A: torch._inductor.ir.TensorBox,
|
| 517 |
+
B: torch._inductor.ir.TensorBox,
|
| 518 |
+
C: torch._inductor.ir.TensorBox,
|
| 519 |
+
*,
|
| 520 |
+
layout=None,
|
| 521 |
+
) -> torch._inductor.ir.TensorBox:
|
| 522 |
+
# call .realize() to get rid of Pointwise
|
| 523 |
+
A.realize()
|
| 524 |
+
B.realize()
|
| 525 |
+
C.realize()
|
| 526 |
+
layout = FixedLayout(A.get_device(), A.get_dtype(), [A.shape[0], C.shape[1]])
|
| 527 |
+
subgraph_buffer = build_subgraph_buffer(
|
| 528 |
+
[create_placeholder("inner_mm", A.get_dtype(), A.get_device())],
|
| 529 |
+
subgraph,
|
| 530 |
+
)
|
| 531 |
+
choices: list[TritonTemplateCaller] = []
|
| 532 |
+
for config in b2b_gemm_configs:
|
| 533 |
+
if is_left_assoc:
|
| 534 |
+
b2b_gemm_left_template.maybe_append_choice(
|
| 535 |
+
choices,
|
| 536 |
+
input_nodes=(A, B, C),
|
| 537 |
+
layout=layout,
|
| 538 |
+
subgraphs=[subgraph_buffer],
|
| 539 |
+
**config,
|
| 540 |
+
)
|
| 541 |
+
else:
|
| 542 |
+
b2b_gemm_right_template.maybe_append_choice(
|
| 543 |
+
choices,
|
| 544 |
+
input_nodes=(A, B, C),
|
| 545 |
+
layout=layout,
|
| 546 |
+
subgraphs=[subgraph_buffer],
|
| 547 |
+
**config,
|
| 548 |
+
)
|
| 549 |
+
# add the unoptimized choice to mitigate performance degradation
|
| 550 |
+
choices.append(
|
| 551 |
+
unoptimized_choice.bind(
|
| 552 |
+
(A, B, C), layout, is_left_assoc=is_left_assoc, subgraph=subgraph
|
| 553 |
+
)
|
| 554 |
+
)
|
| 555 |
+
# autotune
|
| 556 |
+
return autotune_select_algorithm("b2b_gemm", choices, [A, B, C], layout)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
# match the inner mm of a potential b2b_gemm
|
| 560 |
+
@register_graph_pattern(
|
| 561 |
+
CallFunction(torch.ops.aten.mm, Arg(), Arg()),
|
| 562 |
+
pass_dict=B2B_GEMM_PASS,
|
| 563 |
+
)
|
| 564 |
+
def b2b_gemm_handler(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node) -> None:
|
| 565 |
+
# match.args: list[torch.fx.Node]
|
| 566 |
+
|
| 567 |
+
def is_pointwise_node(node: torch.fx.Node) -> bool:
|
| 568 |
+
return (
|
| 569 |
+
node.op == "call_function"
|
| 570 |
+
and isinstance(node.target, torch._ops.OpOverload)
|
| 571 |
+
and (torch.Tag.pointwise in node.target.tags)
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
def is_mm(node: torch.fx.Node) -> bool:
|
| 575 |
+
return node.target == torch.ops.aten.mm.default
|
| 576 |
+
|
| 577 |
+
# the inner MM
|
| 578 |
+
inner_mm = match.nodes[-1]
|
| 579 |
+
|
| 580 |
+
# find the (candidate) outer MM, which will be re-checked below to ensure every path reaches it
|
| 581 |
+
# In a real (A @ f(B @ C)), every path starting from (B @ C) must reach (A @ _).
|
| 582 |
+
outer_mm = None
|
| 583 |
+
node = inner_mm
|
| 584 |
+
while len(node.users) > 0:
|
| 585 |
+
node = next(iter(node.users))
|
| 586 |
+
if is_mm(node):
|
| 587 |
+
outer_mm = node
|
| 588 |
+
break
|
| 589 |
+
elif is_pointwise_node(node):
|
| 590 |
+
continue
|
| 591 |
+
else:
|
| 592 |
+
break
|
| 593 |
+
if not outer_mm:
|
| 594 |
+
return
|
| 595 |
+
|
| 596 |
+
# find the unique input node for outer_mm representing f(B @ C) in (A @ f(B @ C))
|
| 597 |
+
# we call it the "f_node"
|
| 598 |
+
# when the pattern is simply (A @ (B @ C)), f_node is just inner_mm
|
| 599 |
+
f_node = inner_mm
|
| 600 |
+
while next(iter(f_node.users)) is not outer_mm:
|
| 601 |
+
f_node = next(iter(f_node.users))
|
| 602 |
+
|
| 603 |
+
def all_reach_via_pointwise_with_no_other_inputs(
|
| 604 |
+
src: torch.fx.Node,
|
| 605 |
+
dst: torch.fx.Node,
|
| 606 |
+
) -> Tuple[bool, Set[torch.fx.Node]]:
|
| 607 |
+
"""
|
| 608 |
+
check whether every user path from src reaches dst via pointwise nodes,
|
| 609 |
+
with no other input nodes for the intermediates and dst;
|
| 610 |
+
return
|
| 611 |
+
(1) the Boolean value
|
| 612 |
+
(2) the subgraph node set including src and dst (which only makes sense when the Boolean value is True)
|
| 613 |
+
"""
|
| 614 |
+
visited: Set[torch.fx.Node] = set()
|
| 615 |
+
input_counter: Dict[torch.fx.Node, int] = {}
|
| 616 |
+
|
| 617 |
+
all_reachable = True
|
| 618 |
+
queue = deque([src])
|
| 619 |
+
while queue:
|
| 620 |
+
node = queue.popleft()
|
| 621 |
+
if node not in visited:
|
| 622 |
+
if node is dst:
|
| 623 |
+
visited.add(node)
|
| 624 |
+
elif (node is src) or is_pointwise_node(node):
|
| 625 |
+
for user in node.users.keys():
|
| 626 |
+
# for nodes other than dst, bookkeep their users' input counts
|
| 627 |
+
if user not in input_counter:
|
| 628 |
+
input_counter[user] = len(user.all_input_nodes)
|
| 629 |
+
input_counter[user] -= 1
|
| 630 |
+
# continue BFS
|
| 631 |
+
queue.append(user)
|
| 632 |
+
visited.add(node)
|
| 633 |
+
else:
|
| 634 |
+
all_reachable = False
|
| 635 |
+
break
|
| 636 |
+
|
| 637 |
+
return (
|
| 638 |
+
all_reachable and all(count == 0 for count in input_counter.values()),
|
| 639 |
+
visited,
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
# check inner_mm reaches f_node on every user path via pointwise nodes with no outside input_nodes
|
| 643 |
+
ok, subgraph_node_set = all_reach_via_pointwise_with_no_other_inputs(
|
| 644 |
+
inner_mm, f_node
|
| 645 |
+
)
|
| 646 |
+
if not ok:
|
| 647 |
+
return
|
| 648 |
+
|
| 649 |
+
# check inner_mm's inputs and f_node's outputs
|
| 650 |
+
if not (len(inner_mm.all_input_nodes) == 2 and len(f_node.users) == 1):
|
| 651 |
+
return
|
| 652 |
+
|
| 653 |
+
# at this point, the nodes between inner_mm and f_node (both included)
|
| 654 |
+
# are all used internally inside (A @ subgraph(B @ C))
|
| 655 |
+
# i.e. they neither have other users nor have other inputs
|
| 656 |
+
|
| 657 |
+
# original graph and module
|
| 658 |
+
graph, module = inner_mm.graph, inner_mm.graph.owning_module
|
| 659 |
+
|
| 660 |
+
# construct the new (sub)graph
|
| 661 |
+
subgraph_node_list: List[
|
| 662 |
+
torch.fx.Node
|
| 663 |
+
] = [] # ordered list of nodes used for node removal later
|
| 664 |
+
new_graph: torch.fx.Graph = torch.fx.Graph()
|
| 665 |
+
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
|
| 666 |
+
new_input_anchor: torch.fx.Node # inner_mm, to be changed to an input node
|
| 667 |
+
new_output_anchor: torch.fx.Node # f_node, to be used to construct an output node
|
| 668 |
+
new_input_node: torch.fx.Node
|
| 669 |
+
new_output_node: torch.fx.Node
|
| 670 |
+
for node in graph.nodes: # preserve the order of nodes
|
| 671 |
+
if node in subgraph_node_set:
|
| 672 |
+
subgraph_node_list.append(node)
|
| 673 |
+
new_node = new_graph.node_copy(
|
| 674 |
+
node, lambda x: node_remapping[x] if x in node_remapping else x
|
| 675 |
+
)
|
| 676 |
+
node_remapping[node] = new_node
|
| 677 |
+
if node is inner_mm:
|
| 678 |
+
new_input_anchor = new_node
|
| 679 |
+
if node is f_node:
|
| 680 |
+
new_output_anchor = new_node
|
| 681 |
+
if new_input_anchor is not new_output_anchor: # subgraph is non-trivial
|
| 682 |
+
# update the input node
|
| 683 |
+
with new_graph.inserting_before(new_input_anchor):
|
| 684 |
+
new_input_node = new_graph.placeholder(name="subgraph_input")
|
| 685 |
+
new_input_node.meta.update(new_input_anchor.meta)
|
| 686 |
+
new_input_anchor.replace_all_uses_with(new_input_node)
|
| 687 |
+
new_graph.erase_node(new_input_anchor)
|
| 688 |
+
# add the output node
|
| 689 |
+
new_output_node = new_graph.output(new_output_anchor)
|
| 690 |
+
new_output_node.meta.update(new_output_anchor.meta)
|
| 691 |
+
else: # subgraph is trivial, e.g. (A @ (B @ C))
|
| 692 |
+
# update the input node
|
| 693 |
+
with new_graph.inserting_before(new_input_anchor):
|
| 694 |
+
new_input_node = new_graph.placeholder(name="subgraph_input")
|
| 695 |
+
new_input_node.meta.update(new_input_anchor.meta)
|
| 696 |
+
new_input_anchor.replace_all_uses_with(new_input_node)
|
| 697 |
+
new_graph.erase_node(new_input_anchor)
|
| 698 |
+
# update the output node (don't use new_output_anchor since it has been erased)
|
| 699 |
+
new_output_node = new_graph.output(new_input_node)
|
| 700 |
+
new_output_node.meta.update(new_input_node.meta)
|
| 701 |
+
new_graph.lint()
|
| 702 |
+
|
| 703 |
+
# construct the subgraph
|
| 704 |
+
subgraph = Subgraph(
|
| 705 |
+
name="subgraph", graph_module=torch.fx.GraphModule(module, new_graph)
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
# two cases
|
| 709 |
+
# (1) (subgraph(A @ B) @ C), called "left_assoc"
|
| 710 |
+
# (2) (A @ subgraph(B @ C)), called "right_assoc"
|
| 711 |
+
is_left_assoc = outer_mm.args[0] is f_node
|
| 712 |
+
|
| 713 |
+
# find the nodes A, B, C and check the sizes
|
| 714 |
+
A: torch.fx.Node
|
| 715 |
+
B: torch.fx.Node
|
| 716 |
+
C: torch.fx.Node
|
| 717 |
+
if is_left_assoc:
|
| 718 |
+
A = inner_mm.args[0] # type: ignore[assignment]
|
| 719 |
+
B = inner_mm.args[1] # type: ignore[assignment]
|
| 720 |
+
C = outer_mm.args[1] # type: ignore[assignment]
|
| 721 |
+
else:
|
| 722 |
+
A = outer_mm.args[0] # type: ignore[assignment]
|
| 723 |
+
B = inner_mm.args[0] # type: ignore[assignment]
|
| 724 |
+
C = inner_mm.args[1] # type: ignore[assignment]
|
| 725 |
+
if not is_b2b_gemm_good_on(is_left_assoc, A, B, C):
|
| 726 |
+
return
|
| 727 |
+
|
| 728 |
+
# finally update the original graph
|
| 729 |
+
counters["inductor"]["b2b_gemm"] += 1
|
| 730 |
+
graph = match.graph
|
| 731 |
+
with graph.inserting_before(outer_mm):
|
| 732 |
+
function = functools.partial(tuned_b2b_gemm, is_left_assoc, subgraph)
|
| 733 |
+
function.__name__ = tuned_b2b_gemm.__name__ # type: ignore[attr-defined]
|
| 734 |
+
function._inductor_lowering_function = True # type: ignore[attr-defined]
|
| 735 |
+
replacement: torch.fx.Node = graph.call_function(
|
| 736 |
+
function,
|
| 737 |
+
(A, B, C),
|
| 738 |
+
match.kwargs,
|
| 739 |
+
)
|
| 740 |
+
replacement.meta.update(outer_mm.meta)
|
| 741 |
+
outer_mm.replace_all_uses_with(replacement)
|
| 742 |
+
# erase unnecessary nodes
|
| 743 |
+
graph.erase_node(outer_mm)
|
| 744 |
+
for node in reversed(subgraph_node_list):
|
| 745 |
+
graph.erase_node(node)
|
| 746 |
+
graph.lint()
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/binary_folding.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
import itertools
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ..._dynamo.utils import counters
|
| 8 |
+
from ..pattern_matcher import Arg, CallFunction, KeywordArg
|
| 9 |
+
from .freezing_patterns import register_binary_folding_pattern
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
aten = torch.ops.aten
|
| 13 |
+
prims = torch.ops.prims
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def mark_mixed_dtype_conv(conv):
|
| 17 |
+
conv_dtype = conv.meta["val"].dtype
|
| 18 |
+
if conv_dtype not in (torch.float16, torch.bfloat16):
|
| 19 |
+
return
|
| 20 |
+
|
| 21 |
+
if not len(conv.users) == 1:
|
| 22 |
+
return
|
| 23 |
+
|
| 24 |
+
conv_user = next(iter(conv.users.keys()))
|
| 25 |
+
if not isinstance(conv_user.meta["val"], torch.Tensor):
|
| 26 |
+
return
|
| 27 |
+
|
| 28 |
+
if not conv_user.meta["val"].dtype == torch.float32:
|
| 29 |
+
return
|
| 30 |
+
|
| 31 |
+
while conv_user.target in _binary_ops:
|
| 32 |
+
if not len(conv_user.users) == 1:
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
conv_user = next(iter(conv_user.users.keys()))
|
| 36 |
+
|
| 37 |
+
if conv_user.target != prims.convert_element_type.default:
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
conv.meta["_allow_conv_mixed_dtype_folding"] = conv_dtype
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def mark_mixed_dtype_allowed_convs(gm):
|
| 44 |
+
"""
|
| 45 |
+
Mark convolutions which we will binary fold even with mixed precision constants. We constant fold in the higher precision
|
| 46 |
+
for better accuracy and then recover the original precision after.
|
| 47 |
+
"""
|
| 48 |
+
for node in gm.graph.find_nodes(
|
| 49 |
+
op="call_function", target=aten.convolution.default
|
| 50 |
+
):
|
| 51 |
+
mark_mixed_dtype_conv(node)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def recover_original_precision_folded_convs(gm):
|
| 55 |
+
"""
|
| 56 |
+
After binary folding conv weights and biases to a higher dtype, recover the original precision they were in.
|
| 57 |
+
"""
|
| 58 |
+
graph = gm.graph
|
| 59 |
+
for node in graph.find_nodes(op="call_function", target=aten.convolution.default):
|
| 60 |
+
orig_dtype = node.meta.get("_allow_conv_mixed_dtype_folding", None)
|
| 61 |
+
if orig_dtype is None:
|
| 62 |
+
continue
|
| 63 |
+
|
| 64 |
+
with graph.inserting_before(node):
|
| 65 |
+
for idx in [1, 2]:
|
| 66 |
+
old_input = node.args[idx]
|
| 67 |
+
if old_input is None:
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
new_input = graph.create_node(
|
| 71 |
+
"call_function",
|
| 72 |
+
prims.convert_element_type.default,
|
| 73 |
+
(old_input, orig_dtype),
|
| 74 |
+
)
|
| 75 |
+
node.replace_input_with(old_input, new_input)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
_binary_ops = [aten.add.Tensor, aten.sub.Tensor, aten.mul.Tensor, aten.div.Tensor]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@functools.lru_cache(None)
|
| 82 |
+
def binary_folding_init():
|
| 83 |
+
_conv_args = [Arg() for _ in range(9)]
|
| 84 |
+
_computation_ops = [aten.convolution.default]
|
| 85 |
+
_computation_calls = [CallFunction(aten.convolution.default, *_conv_args, _users=1)]
|
| 86 |
+
|
| 87 |
+
"""
|
| 88 |
+
In order to fuse add/sub/mul/div with conv, the dimensions of its
|
| 89 |
+
constant tensor must satisfy the following:
|
| 90 |
+
- with resizing, broadcast to w/ weight/bias tensor shape
|
| 91 |
+
- broadcast to the conv output shape
|
| 92 |
+
It needs to have a shape that can resize to weight/bias
|
| 93 |
+
tensor shape because we need to run the op with the conv
|
| 94 |
+
weights/bias without changing their sizes.
|
| 95 |
+
It needs to broadcast to the conv output shape so that we do
|
| 96 |
+
accidentally change the shape of op output by pre-fusing it
|
| 97 |
+
compared to eager.
|
| 98 |
+
The only dimension value shared by weight/bias/conv output
|
| 99 |
+
is they all contain a dim with value = channels-out. In the
|
| 100 |
+
conv output tensor, this is in the second dimension,
|
| 101 |
+
so the pointwise op tensor may have a second dimension of
|
| 102 |
+
value == channels-out, but all the other dimensions have to be 1
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def _op_not_broadcasting_with_conv(weight_tensor, other_tensor):
|
| 106 |
+
# According to opDoesNotBroadCastWithConv of frozen_conv_folding.cpp
|
| 107 |
+
weight_shape = weight_tensor.shape
|
| 108 |
+
other_shape = other_tensor.shape
|
| 109 |
+
if len(weight_shape) < len(other_shape):
|
| 110 |
+
return False
|
| 111 |
+
if len(weight_shape) == len(other_shape) + 1:
|
| 112 |
+
# weight shape is [o, i, *], other_shape is [o, 1...].
|
| 113 |
+
for i in reversed(range(len(other_shape))):
|
| 114 |
+
if i == 0 and weight_shape[0] == other_shape[i]:
|
| 115 |
+
continue
|
| 116 |
+
if other_shape[i] != 1:
|
| 117 |
+
return False
|
| 118 |
+
else:
|
| 119 |
+
# weight shape is [o, i, *], other_shape is [1, i, *]
|
| 120 |
+
for i in reversed(range(len(other_shape))):
|
| 121 |
+
if i == 1 and weight_shape[0] == other_shape[i]:
|
| 122 |
+
continue
|
| 123 |
+
if other_shape[i] != 1:
|
| 124 |
+
return False
|
| 125 |
+
return True
|
| 126 |
+
|
| 127 |
+
def _check_conv_and_broadcast_op(conv_node, other):
|
| 128 |
+
# According to checkConvAndBroadcastingOpPreConditions of frozen_conv_folding.cpp.
|
| 129 |
+
# conv.weight
|
| 130 |
+
if conv_node.args[1].op != "get_attr":
|
| 131 |
+
return False
|
| 132 |
+
# conv.bias
|
| 133 |
+
if conv_node.args[1] is not None and conv_node.args[1].op != "get_attr":
|
| 134 |
+
return False
|
| 135 |
+
if (
|
| 136 |
+
not isinstance(other, int)
|
| 137 |
+
and not isinstance(other, float)
|
| 138 |
+
and other.op != "get_attr"
|
| 139 |
+
):
|
| 140 |
+
return False
|
| 141 |
+
|
| 142 |
+
if not len(conv_node.args[1].users) == 1:
|
| 143 |
+
return False
|
| 144 |
+
|
| 145 |
+
weight_meta_value = conv_node.args[1].meta.get("val")
|
| 146 |
+
if weight_meta_value is None:
|
| 147 |
+
return False
|
| 148 |
+
# Avoid fusing op that causes type promotion
|
| 149 |
+
# restricting to float avoids int/float difficulties with scalar overload
|
| 150 |
+
if not weight_meta_value.is_floating_point():
|
| 151 |
+
return False
|
| 152 |
+
if isinstance(other, torch.fx.Node) and other.op == "get_attr":
|
| 153 |
+
other_meta_value = other.meta.get("val")
|
| 154 |
+
if not other_meta_value.is_floating_point(): # type: ignore[union-attr]
|
| 155 |
+
return False
|
| 156 |
+
if (
|
| 157 |
+
torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype) # type: ignore[union-attr]
|
| 158 |
+
!= weight_meta_value.dtype
|
| 159 |
+
):
|
| 160 |
+
if not conv_node.meta.get("_allow_conv_mixed_dtype_folding", False):
|
| 161 |
+
return False
|
| 162 |
+
|
| 163 |
+
if (
|
| 164 |
+
other_meta_value.dtype != torch.float # type: ignore[union-attr]
|
| 165 |
+
and weight_meta_value.dtype not in (torch.float16, torch.bfloat16)
|
| 166 |
+
):
|
| 167 |
+
return False
|
| 168 |
+
|
| 169 |
+
if not _op_not_broadcasting_with_conv(weight_meta_value, other_meta_value):
|
| 170 |
+
return False
|
| 171 |
+
else:
|
| 172 |
+
# TODO: support scalar case
|
| 173 |
+
return False
|
| 174 |
+
|
| 175 |
+
return True
|
| 176 |
+
|
| 177 |
+
def _is_foldable_pattern(match):
|
| 178 |
+
binary_node = match.output_node()
|
| 179 |
+
computation_node = binary_node.args[0]
|
| 180 |
+
other = binary_node.args[1]
|
| 181 |
+
if binary_node.args[0].target not in _computation_ops:
|
| 182 |
+
computation_node = binary_node.args[1]
|
| 183 |
+
other = binary_node.args[0]
|
| 184 |
+
if binary_node.args[0].target == aten.convolution.default:
|
| 185 |
+
return _check_conv_and_broadcast_op(computation_node, other)
|
| 186 |
+
|
| 187 |
+
return False
|
| 188 |
+
|
| 189 |
+
def resize_scalar_or_tensor_to_shape(graph, other, shape):
|
| 190 |
+
# TODO: support scalar case
|
| 191 |
+
if other.meta.get("val").numel() == 1:
|
| 192 |
+
# expand errors if the shape input has less # dims than the tensor input
|
| 193 |
+
res = graph.create_node(
|
| 194 |
+
"call_function",
|
| 195 |
+
aten.reshape.default,
|
| 196 |
+
(other, (1,)),
|
| 197 |
+
)
|
| 198 |
+
res = graph.create_node(
|
| 199 |
+
"call_function",
|
| 200 |
+
aten.expand.default,
|
| 201 |
+
(res, shape),
|
| 202 |
+
)
|
| 203 |
+
else:
|
| 204 |
+
res = graph.create_node(
|
| 205 |
+
"call_function",
|
| 206 |
+
aten.reshape.default,
|
| 207 |
+
(other, shape),
|
| 208 |
+
)
|
| 209 |
+
return res
|
| 210 |
+
|
| 211 |
+
def _create_new_conv_node(graph, conv_node, binary_node, other):
|
| 212 |
+
assert conv_node.target == aten.convolution.default
|
| 213 |
+
conv_args = list(conv_node.args)
|
| 214 |
+
weight_meta_value = conv_node.args[1].meta.get("val")
|
| 215 |
+
bias = conv_args[2]
|
| 216 |
+
if binary_node.target in [aten.add.Tensor, aten.sub.Tensor]:
|
| 217 |
+
other_reshape = resize_scalar_or_tensor_to_shape(
|
| 218 |
+
graph, other, (weight_meta_value.size(0),)
|
| 219 |
+
)
|
| 220 |
+
new_bias = graph.create_node(
|
| 221 |
+
"call_function",
|
| 222 |
+
binary_node.target,
|
| 223 |
+
(0 if bias is None else bias, other_reshape),
|
| 224 |
+
)
|
| 225 |
+
conv_args[2] = new_bias
|
| 226 |
+
else:
|
| 227 |
+
assert binary_node.target in [aten.mul.Tensor, aten.div.Tensor]
|
| 228 |
+
weight_broadcast_shape = [1 for _ in range(len(weight_meta_value.shape))]
|
| 229 |
+
weight_broadcast_shape[0] = weight_meta_value.size(0)
|
| 230 |
+
other_reshape1 = resize_scalar_or_tensor_to_shape(
|
| 231 |
+
graph, other, tuple(weight_broadcast_shape)
|
| 232 |
+
)
|
| 233 |
+
new_weight = graph.create_node(
|
| 234 |
+
"call_function", binary_node.target, (conv_args[1], other_reshape1)
|
| 235 |
+
)
|
| 236 |
+
new_weight.meta.update(conv_args[1].meta)
|
| 237 |
+
conv_args[1] = new_weight
|
| 238 |
+
if bias is not None:
|
| 239 |
+
other_reshape = resize_scalar_or_tensor_to_shape(
|
| 240 |
+
graph, other, (weight_meta_value.size(0),)
|
| 241 |
+
)
|
| 242 |
+
new_bias = graph.create_node(
|
| 243 |
+
"call_function", binary_node.target, (bias, other_reshape)
|
| 244 |
+
)
|
| 245 |
+
new_bias.meta.update(bias.meta)
|
| 246 |
+
conv_args[2] = new_bias
|
| 247 |
+
return graph.create_node("call_function", conv_node.target, tuple(conv_args))
|
| 248 |
+
|
| 249 |
+
for _computation_call, binary_op in itertools.product(
|
| 250 |
+
_computation_calls, _binary_ops
|
| 251 |
+
):
|
| 252 |
+
|
| 253 |
+
@register_binary_folding_pattern(
|
| 254 |
+
CallFunction(binary_op, _computation_call, KeywordArg("other")),
|
| 255 |
+
extra_check=_is_foldable_pattern,
|
| 256 |
+
)
|
| 257 |
+
def folded_op(match, *args, **kwargs):
|
| 258 |
+
counters["inductor"]["binary_folding"] += 1
|
| 259 |
+
other = kwargs.get("other")
|
| 260 |
+
binary_node = match.output_node()
|
| 261 |
+
computation_node = (
|
| 262 |
+
binary_node.args[0]
|
| 263 |
+
if binary_node.args[0].target in _computation_ops
|
| 264 |
+
else binary_node.args[1]
|
| 265 |
+
)
|
| 266 |
+
graph = match.graph
|
| 267 |
+
with graph.inserting_before(binary_node):
|
| 268 |
+
# TODO: support linear?
|
| 269 |
+
assert computation_node.target == aten.convolution.default
|
| 270 |
+
new_computation_node = _create_new_conv_node(
|
| 271 |
+
graph, computation_node, binary_node, other
|
| 272 |
+
)
|
| 273 |
+
binary_node.replace_all_uses_with(new_computation_node)
|
| 274 |
+
new_computation_node.meta.update(computation_node.meta)
|
| 275 |
+
graph.erase_node(binary_node)
|
| 276 |
+
graph.erase_node(computation_node)
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/ddp_fusion.py
ADDED
|
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Owner(s): ["oncall: distributed"]
|
| 2 |
+
import collections
|
| 3 |
+
import inspect
|
| 4 |
+
import logging
|
| 5 |
+
import math
|
| 6 |
+
import operator
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import (
|
| 10 |
+
Any,
|
| 11 |
+
Callable,
|
| 12 |
+
cast,
|
| 13 |
+
Dict,
|
| 14 |
+
Generator,
|
| 15 |
+
List,
|
| 16 |
+
Optional,
|
| 17 |
+
Set,
|
| 18 |
+
Tuple,
|
| 19 |
+
Union,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.fx as fx
|
| 24 |
+
from torch._dynamo.utils import counters
|
| 25 |
+
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
| 26 |
+
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
|
| 27 |
+
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
| 28 |
+
|
| 29 |
+
from .. import config
|
| 30 |
+
from ..fx_utils import get_fake_args_kwargs
|
| 31 |
+
from ..virtualized import V
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
aten = torch.ops.aten
|
| 35 |
+
logger: logging.Logger = logging.getLogger("comm_fusion")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def move_block_after(block: List[fx.Node], target_node: fx.Node) -> None:
|
| 39 |
+
for node in block:
|
| 40 |
+
target_node.append(node)
|
| 41 |
+
target_node = node
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def move_block_before(block: List[fx.Node], target_node: fx.Node) -> None:
|
| 45 |
+
for node in block:
|
| 46 |
+
target_node.prepend(node)
|
| 47 |
+
target_node = node
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def call_function(
|
| 51 |
+
graph: fx.Graph,
|
| 52 |
+
target: Union[str, Callable[..., Any]],
|
| 53 |
+
args: Optional[Tuple[fx.node.Argument, ...]] = None,
|
| 54 |
+
kwargs: Optional[Dict[str, fx.node.Argument]] = None,
|
| 55 |
+
) -> fx.Node:
|
| 56 |
+
# We accept target as a str to avoid typing error as the type of
|
| 57 |
+
# a node.target is Union[str, Callable[..., Any]].
|
| 58 |
+
# This also allows us to avoid writing check for every call.
|
| 59 |
+
if isinstance(target, str):
|
| 60 |
+
raise RuntimeError(f"Call function should not get a str target {target=}")
|
| 61 |
+
node = graph.call_function(target, args, kwargs)
|
| 62 |
+
_, args, kwargs = get_fake_args_kwargs(node)
|
| 63 |
+
with V.fake_mode:
|
| 64 |
+
node.meta["val"] = target(*args, **kwargs)
|
| 65 |
+
# node.meta["val"] may be a container. So we use tree_map here
|
| 66 |
+
# to recursively extract the tensor metadata.
|
| 67 |
+
node.meta["tensor_meta"] = tree_map(
|
| 68 |
+
_extract_tensor_metadata, (node.meta["val"],)
|
| 69 |
+
)[0]
|
| 70 |
+
return node
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@dataclass(unsafe_hash=True)
|
| 74 |
+
class CommBlock:
|
| 75 |
+
shape: Union[torch.Size, List[torch.Size]]
|
| 76 |
+
node_list: List[fx.Node]
|
| 77 |
+
inputs: List[fx.Node]
|
| 78 |
+
wait_nodes: List[fx.Node]
|
| 79 |
+
comm_node: fx.Node
|
| 80 |
+
outputs: Set[fx.Node]
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def get_comm_block(comm_node: fx.Node) -> Optional[CommBlock]:
|
| 84 |
+
"""
|
| 85 |
+
Given a collective node (e.g., allreduce), find out all the nodes belong to
|
| 86 |
+
this communcation.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
comm_node(fx.Node): The target communication/collective node.
|
| 90 |
+
Returns:
|
| 91 |
+
The CommBlock that encapsulates the related nodes (e.g., wait_node) of
|
| 92 |
+
the given comm_node.
|
| 93 |
+
"""
|
| 94 |
+
node_list = []
|
| 95 |
+
wait_nodes = []
|
| 96 |
+
inputs, _ = tree_flatten((comm_node.args, comm_node.kwargs))
|
| 97 |
+
input_nodes = [inp for inp in inputs if isinstance(inp, fx.Node)]
|
| 98 |
+
wait_prefixes = "wait_tensor"
|
| 99 |
+
# If the users of the wait node are following items, we consinder them
|
| 100 |
+
# to be a part of the output.
|
| 101 |
+
intermediate_outputs = ("split", "reshape", "getitem", "detach", "alias")
|
| 102 |
+
|
| 103 |
+
first_user = next(iter(comm_node.users))
|
| 104 |
+
if (
|
| 105 |
+
len(comm_node.users) == 1
|
| 106 |
+
and first_user.target == torch.ops._c10d_functional.wait_tensor.default
|
| 107 |
+
):
|
| 108 |
+
# Collective with only one output
|
| 109 |
+
node_list = [comm_node, first_user]
|
| 110 |
+
wait_nodes.append(first_user)
|
| 111 |
+
elif len(comm_node.users) > 1 and first_user.target == operator.getitem:
|
| 112 |
+
# Collective with only more than one output
|
| 113 |
+
node_list.append(comm_node)
|
| 114 |
+
for user in comm_node.users:
|
| 115 |
+
if user.target != operator.getitem:
|
| 116 |
+
return None
|
| 117 |
+
if len(user.users) != 1:
|
| 118 |
+
return None
|
| 119 |
+
wait_node = next(iter(user.users))
|
| 120 |
+
if wait_node.target != torch.ops._c10d_functional.wait_tensor.default:
|
| 121 |
+
return None
|
| 122 |
+
wait_nodes.append(wait_node)
|
| 123 |
+
node_list.append(user)
|
| 124 |
+
node_list.extend(wait_nodes)
|
| 125 |
+
else:
|
| 126 |
+
return None
|
| 127 |
+
|
| 128 |
+
# Identify all the outputs of this collective block.
|
| 129 |
+
outputs: Set[fx.Node] = set()
|
| 130 |
+
nodes = collections.deque(wait_nodes)
|
| 131 |
+
while nodes:
|
| 132 |
+
node = nodes.popleft()
|
| 133 |
+
for user in node.users:
|
| 134 |
+
if isinstance(user, fx.Node) and user.name.startswith(intermediate_outputs):
|
| 135 |
+
nodes.append(user)
|
| 136 |
+
node_list.append(user)
|
| 137 |
+
else:
|
| 138 |
+
outputs.add(node)
|
| 139 |
+
break
|
| 140 |
+
|
| 141 |
+
tensor_meta = input_nodes[0].meta["tensor_meta"]
|
| 142 |
+
shape: Union[torch.Size, List[torch.Size]]
|
| 143 |
+
if isinstance(tensor_meta, TensorMetadata):
|
| 144 |
+
shape = tensor_meta.shape
|
| 145 |
+
elif isinstance(tensor_meta, (list, tuple)):
|
| 146 |
+
shape = [tm.shape for tm in tensor_meta]
|
| 147 |
+
else:
|
| 148 |
+
logger.warning("Unexpected type of tensor_meta %s", type(tensor_meta))
|
| 149 |
+
return None
|
| 150 |
+
|
| 151 |
+
return CommBlock(
|
| 152 |
+
shape=shape,
|
| 153 |
+
node_list=node_list,
|
| 154 |
+
wait_nodes=wait_nodes,
|
| 155 |
+
comm_node=comm_node,
|
| 156 |
+
inputs=input_nodes,
|
| 157 |
+
outputs=outputs,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def get_all_comm_blocks(
|
| 162 |
+
graph: fx.Graph,
|
| 163 |
+
comm_ops: Tuple[torch._ops.OpOverload, ...],
|
| 164 |
+
comm_filter: Optional[Callable[..., bool]] = None,
|
| 165 |
+
) -> List[CommBlock]:
|
| 166 |
+
if comm_filter is None:
|
| 167 |
+
|
| 168 |
+
def always_true(comm_block: CommBlock) -> bool:
|
| 169 |
+
return True
|
| 170 |
+
|
| 171 |
+
comm_filter = always_true
|
| 172 |
+
|
| 173 |
+
blocks = []
|
| 174 |
+
for node in graph.nodes:
|
| 175 |
+
if node.target not in comm_ops:
|
| 176 |
+
continue
|
| 177 |
+
comm_block = get_comm_block(node)
|
| 178 |
+
if comm_block is not None and comm_filter(comm_block):
|
| 179 |
+
blocks.append(comm_block)
|
| 180 |
+
return blocks
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _fuse_allreduce_by_concat(
|
| 184 |
+
graph: fx.Graph,
|
| 185 |
+
last_input_node: fx.Node,
|
| 186 |
+
all_input_nodes: List[fx.Node],
|
| 187 |
+
last_comm_block: CommBlock,
|
| 188 |
+
) -> CommBlock:
|
| 189 |
+
"""Given a list of inputs in order, create a fused allreduce using concat."""
|
| 190 |
+
# Flatten all the inputs to the all_reduce nodes.
|
| 191 |
+
with graph.inserting_after(last_input_node):
|
| 192 |
+
cat_inputs = []
|
| 193 |
+
for input_node in all_input_nodes:
|
| 194 |
+
assert isinstance(input_node.args[0], fx.Node)
|
| 195 |
+
input_node = input_node.args[0]
|
| 196 |
+
cat_inputs.append(
|
| 197 |
+
call_function(graph, aten.flatten.using_ints, (input_node,))
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Concat all the flattened nodes.
|
| 201 |
+
with graph.inserting_after(cat_inputs[0]):
|
| 202 |
+
cat_node = call_function(graph, aten.cat, (cat_inputs,))
|
| 203 |
+
|
| 204 |
+
# Insert the fused div node and remove the input div nodes.
|
| 205 |
+
# This is an optimization and is not mandatory for fusion.
|
| 206 |
+
divisors = [div.args[1] for div in all_input_nodes]
|
| 207 |
+
assert all(divisor == divisors[0] for divisor in divisors)
|
| 208 |
+
with graph.inserting_after(cat_node):
|
| 209 |
+
div_node = call_function(graph, last_input_node.target, (cat_node, divisors[0]))
|
| 210 |
+
|
| 211 |
+
# Create a new Comm/all_reduce node.
|
| 212 |
+
last_comm_node = last_comm_block.comm_node
|
| 213 |
+
last_wait_node = last_comm_block.wait_nodes[0]
|
| 214 |
+
with graph.inserting_after(div_node):
|
| 215 |
+
flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs))
|
| 216 |
+
flatten_args[0] = div_node
|
| 217 |
+
args, kwargs = tree_unflatten(flatten_args, spec)
|
| 218 |
+
fused_comm_node = call_function(graph, last_comm_node.target, args, kwargs)
|
| 219 |
+
|
| 220 |
+
# Create a new Wait node.
|
| 221 |
+
with graph.inserting_after(fused_comm_node):
|
| 222 |
+
flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs))
|
| 223 |
+
flatten_args[0] = fused_comm_node
|
| 224 |
+
args, kwargs = tree_unflatten(flatten_args, spec)
|
| 225 |
+
fused_wait_node = call_function(graph, last_wait_node.target, args, kwargs)
|
| 226 |
+
|
| 227 |
+
# Move the fused all_reduce and its args to right after the input node
|
| 228 |
+
nodes_to_move = cat_inputs + [cat_node, div_node, fused_comm_node, fused_wait_node]
|
| 229 |
+
move_block_after(nodes_to_move, last_input_node)
|
| 230 |
+
|
| 231 |
+
return CommBlock(
|
| 232 |
+
shape=cast(TensorMetadata, cat_node.meta.get("tensor_meta")).shape,
|
| 233 |
+
node_list=[fused_comm_node, fused_wait_node],
|
| 234 |
+
wait_nodes=[fused_wait_node],
|
| 235 |
+
comm_node=fused_comm_node,
|
| 236 |
+
inputs=[div_node],
|
| 237 |
+
outputs={fused_wait_node},
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def _fuse_with_coalesced_op(
|
| 242 |
+
graph: fx.Graph,
|
| 243 |
+
last_input_node: fx.Node,
|
| 244 |
+
all_input_nodes: List[fx.Node],
|
| 245 |
+
last_comm_block: CommBlock,
|
| 246 |
+
) -> CommBlock:
|
| 247 |
+
"""Given a list of inputs in order, create a fused allreduce by coalesced."""
|
| 248 |
+
last_comm_node = last_comm_block.comm_node
|
| 249 |
+
last_wait_node = last_comm_block.wait_nodes[0]
|
| 250 |
+
|
| 251 |
+
# Insert the fused div node and remove the input div nodes.
|
| 252 |
+
# This is an optimization and is not mandatory for fusion.
|
| 253 |
+
dividends = [div.args[0] for div in all_input_nodes]
|
| 254 |
+
divisors = [div.args[1] for div in all_input_nodes]
|
| 255 |
+
assert all(divisor == divisors[0] for divisor in divisors)
|
| 256 |
+
with graph.inserting_before(last_input_node):
|
| 257 |
+
last_input_node = call_function(
|
| 258 |
+
graph, aten._foreach_div.Scalar, (dividends, divisors[0])
|
| 259 |
+
)
|
| 260 |
+
input_node = last_input_node
|
| 261 |
+
|
| 262 |
+
# Create a new Comm/all_reduce_coalesced node.
|
| 263 |
+
with graph.inserting_after(last_comm_node):
|
| 264 |
+
flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs))
|
| 265 |
+
flatten_args[0] = input_node
|
| 266 |
+
args, kwargs = tree_unflatten(flatten_args, spec)
|
| 267 |
+
fused_comm_node = call_function(
|
| 268 |
+
graph, torch.ops._c10d_functional.all_reduce_coalesced.default, args, kwargs
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# Create a new wait node.
|
| 272 |
+
getitem_nodes = []
|
| 273 |
+
wait_nodes = []
|
| 274 |
+
flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs))
|
| 275 |
+
for idx in range(len(all_input_nodes)):
|
| 276 |
+
with graph.inserting_after(fused_comm_node):
|
| 277 |
+
gi_node = call_function(graph, operator.getitem, (fused_comm_node, idx))
|
| 278 |
+
getitem_nodes.append(gi_node)
|
| 279 |
+
flatten_args[0] = gi_node
|
| 280 |
+
args, kwargs = tree_unflatten(flatten_args, spec)
|
| 281 |
+
with graph.inserting_after(gi_node):
|
| 282 |
+
wait_nodes.append(call_function(graph, last_wait_node.target, args, kwargs))
|
| 283 |
+
|
| 284 |
+
# Move the new all_reduce_coalesced and its args to right after the input node
|
| 285 |
+
nodes_to_move = [fused_comm_node] + getitem_nodes + wait_nodes
|
| 286 |
+
move_block_after(nodes_to_move, last_input_node)
|
| 287 |
+
|
| 288 |
+
return CommBlock(
|
| 289 |
+
shape=[
|
| 290 |
+
tm.shape
|
| 291 |
+
for tm in cast(
|
| 292 |
+
List[TensorMetadata], fused_comm_node.meta.get("tensor_meta")
|
| 293 |
+
)
|
| 294 |
+
],
|
| 295 |
+
node_list=[fused_comm_node] + getitem_nodes + wait_nodes,
|
| 296 |
+
wait_nodes=wait_nodes,
|
| 297 |
+
comm_node=fused_comm_node,
|
| 298 |
+
inputs=[input_node],
|
| 299 |
+
outputs=set(wait_nodes),
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def _scatter_fused_allreduce_waits(
|
| 304 |
+
graph: fx.Graph,
|
| 305 |
+
fused_comm_block: CommBlock,
|
| 306 |
+
orig_comm_blocks: List[CommBlock],
|
| 307 |
+
node_indices: Dict[fx.Node, int],
|
| 308 |
+
split_and_reshape: bool = True,
|
| 309 |
+
) -> None:
|
| 310 |
+
"""
|
| 311 |
+
Scatters the result of the fused communication node to the original users.
|
| 312 |
+
If the fused method is concat splitting the output and reshape will be inserted,
|
| 313 |
+
before inserting getitem. Otherwise getitem will be used as the users of the
|
| 314 |
+
wait node.
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
# Before we mass up the order, we need to get the index of the last wait node
|
| 318 |
+
# in orig_comm_blocks. This index will be later used to determinee what users
|
| 319 |
+
# nodes need to be move to maintain a correct topological sort order.
|
| 320 |
+
last_wait_node_idx = 0
|
| 321 |
+
for node in graph.nodes:
|
| 322 |
+
last_wait_node_idx = max(
|
| 323 |
+
node_indices.get(node, last_wait_node_idx), last_wait_node_idx
|
| 324 |
+
)
|
| 325 |
+
if node == orig_comm_blocks[-1].wait_nodes[0]:
|
| 326 |
+
break
|
| 327 |
+
|
| 328 |
+
if split_and_reshape:
|
| 329 |
+
fused_wait_node = fused_comm_block.wait_nodes[0]
|
| 330 |
+
with graph.inserting_after(fused_wait_node):
|
| 331 |
+
split_node = call_function(
|
| 332 |
+
graph,
|
| 333 |
+
aten.split,
|
| 334 |
+
(
|
| 335 |
+
fused_wait_node,
|
| 336 |
+
[math.prod(cast(List[int], cb.shape)) for cb in orig_comm_blocks],
|
| 337 |
+
),
|
| 338 |
+
)
|
| 339 |
+
with graph.inserting_after(split_node):
|
| 340 |
+
fused_outputs = []
|
| 341 |
+
for idx, comm_block in enumerate(orig_comm_blocks):
|
| 342 |
+
split_idx_node = call_function(
|
| 343 |
+
graph, operator.getitem, (split_node, idx)
|
| 344 |
+
)
|
| 345 |
+
with graph.inserting_after(split_idx_node):
|
| 346 |
+
fused_outputs.append(
|
| 347 |
+
call_function(
|
| 348 |
+
graph, aten.reshape, (split_idx_node, comm_block.shape)
|
| 349 |
+
)
|
| 350 |
+
)
|
| 351 |
+
else:
|
| 352 |
+
fused_outputs = fused_comm_block.wait_nodes
|
| 353 |
+
|
| 354 |
+
# Scatter the fused outputs.
|
| 355 |
+
incorrect_order_nodes = []
|
| 356 |
+
for comm_block, fused_output in zip(orig_comm_blocks, fused_outputs):
|
| 357 |
+
# Some descendant users of the orig_comm_blocks may be scheduled before
|
| 358 |
+
# the fused all_reduce. For example, the user nodes of the very first
|
| 359 |
+
# all_reduce may be scheduled before the second all_reduce. Since the
|
| 360 |
+
# fused all_reduce is inserted right after the last all_reudce, the
|
| 361 |
+
# order can be wrong.
|
| 362 |
+
# `incorrect_order_nodes` records these nodes.
|
| 363 |
+
|
| 364 |
+
orig_wait = comm_block.wait_nodes[0]
|
| 365 |
+
nodes = collections.deque(list(orig_wait.users))
|
| 366 |
+
while nodes:
|
| 367 |
+
user_node = nodes.popleft()
|
| 368 |
+
if not isinstance(user_node, fx.Node):
|
| 369 |
+
continue
|
| 370 |
+
if node_indices[user_node] < last_wait_node_idx:
|
| 371 |
+
incorrect_order_nodes.append(user_node)
|
| 372 |
+
nodes.extend(list(user_node.users))
|
| 373 |
+
|
| 374 |
+
orig_wait.replace_all_uses_with(fused_output)
|
| 375 |
+
|
| 376 |
+
last_fused_result = fused_outputs[0]
|
| 377 |
+
fused_outputs_set = set(fused_outputs)
|
| 378 |
+
for node in graph.nodes:
|
| 379 |
+
if node in fused_outputs_set:
|
| 380 |
+
last_fused_result = node
|
| 381 |
+
|
| 382 |
+
# Move the incorrect_order_nodes to right after the last fused_result.
|
| 383 |
+
incorrect_order_nodes = sorted(
|
| 384 |
+
incorrect_order_nodes, key=lambda node: node_indices[node]
|
| 385 |
+
)
|
| 386 |
+
move_block_after(incorrect_order_nodes, last_fused_result)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def _fuse_allreduce(
|
| 390 |
+
graph: fx.Graph,
|
| 391 |
+
comm_blocks: List[CommBlock],
|
| 392 |
+
node_indices: Dict[fx.Node, int],
|
| 393 |
+
use_concat: bool,
|
| 394 |
+
) -> CommBlock:
|
| 395 |
+
"""Given a list of allreduce CommBlock, fuse the CommBlocks into one CommBlock."""
|
| 396 |
+
|
| 397 |
+
if len(comm_blocks) == 1:
|
| 398 |
+
return comm_blocks[0]
|
| 399 |
+
|
| 400 |
+
# Find the last input node of all the CommBlocks. This node will be served
|
| 401 |
+
# as the inserting point of the new collective op.
|
| 402 |
+
last_input_node = comm_blocks[0].inputs[0]
|
| 403 |
+
last_input_index = -1
|
| 404 |
+
all_input_nodes = []
|
| 405 |
+
for comm_block in comm_blocks:
|
| 406 |
+
input_node = comm_block.inputs[0]
|
| 407 |
+
all_input_nodes.append(input_node)
|
| 408 |
+
index = node_indices[input_node]
|
| 409 |
+
if index >= last_input_index:
|
| 410 |
+
assert index != last_input_index
|
| 411 |
+
last_input_node = input_node
|
| 412 |
+
last_input_index = index
|
| 413 |
+
|
| 414 |
+
if use_concat:
|
| 415 |
+
fused_comm_block = _fuse_allreduce_by_concat(
|
| 416 |
+
graph, last_input_node, all_input_nodes, comm_blocks[-1]
|
| 417 |
+
)
|
| 418 |
+
else:
|
| 419 |
+
fused_comm_block = _fuse_with_coalesced_op(
|
| 420 |
+
graph, last_input_node, all_input_nodes, comm_blocks[-1]
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
_scatter_fused_allreduce_waits(
|
| 424 |
+
graph, fused_comm_block, comm_blocks, node_indices, split_and_reshape=use_concat
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
for comm_block in comm_blocks:
|
| 428 |
+
for wait in comm_block.wait_nodes:
|
| 429 |
+
graph.erase_node(wait)
|
| 430 |
+
graph.erase_node(comm_block.comm_node)
|
| 431 |
+
graph.eliminate_dead_code()
|
| 432 |
+
|
| 433 |
+
return fused_comm_block
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def _bucket_size_fusion(
|
| 437 |
+
graph: fx.Graph, comm_blocks: List[CommBlock], bucket_size_mb: int
|
| 438 |
+
) -> Generator[List[CommBlock], None, None]:
|
| 439 |
+
MB = 1024**2
|
| 440 |
+
bucket_size = 1 * MB
|
| 441 |
+
bucket_cap_size = bucket_size_mb * MB
|
| 442 |
+
curr_size = 0
|
| 443 |
+
curr_blocks = []
|
| 444 |
+
|
| 445 |
+
count = 0
|
| 446 |
+
fuse_count = 0
|
| 447 |
+
for i, block in enumerate(comm_blocks):
|
| 448 |
+
curr_blocks.append(block)
|
| 449 |
+
itemsize = block.comm_node.meta["tensor_meta"].dtype.itemsize
|
| 450 |
+
curr_size += cast(torch.Size, block.shape).numel() * itemsize
|
| 451 |
+
count += 1
|
| 452 |
+
if curr_size < bucket_size and i != len(comm_blocks) - 1:
|
| 453 |
+
continue
|
| 454 |
+
|
| 455 |
+
fuse_count += 1
|
| 456 |
+
if torch.distributed.get_rank() == 0:
|
| 457 |
+
logger.info(
|
| 458 |
+
"DDP bucketing: block%d, count=%d, curr_size=%d, bucket_size=%d",
|
| 459 |
+
fuse_count,
|
| 460 |
+
count,
|
| 461 |
+
curr_size,
|
| 462 |
+
bucket_size,
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
# Set the debug counters
|
| 466 |
+
counters["inductor"]["ddp_buckets"] = fuse_count
|
| 467 |
+
yield curr_blocks
|
| 468 |
+
|
| 469 |
+
bucket_size = bucket_cap_size
|
| 470 |
+
curr_blocks = []
|
| 471 |
+
curr_size = 0
|
| 472 |
+
count = 0
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def _fuse_ddp_communication(
|
| 476 |
+
graph: fx.Graph, algorithm_fn: Callable[..., Any], fusion_fn: Callable[..., Any]
|
| 477 |
+
) -> None:
|
| 478 |
+
for output in reversed(graph.nodes):
|
| 479 |
+
if output.op == "output":
|
| 480 |
+
break
|
| 481 |
+
|
| 482 |
+
def ddp_reducer_filter(block: CommBlock) -> bool:
|
| 483 |
+
if (
|
| 484 |
+
not isinstance(block.comm_node.args[0], fx.Node)
|
| 485 |
+
or block.comm_node.args[0].target != aten.div.Tensor
|
| 486 |
+
):
|
| 487 |
+
return False
|
| 488 |
+
|
| 489 |
+
if len(block.wait_nodes[0].users) != 1:
|
| 490 |
+
# gradient/wait node should only be used by one user
|
| 491 |
+
return False
|
| 492 |
+
|
| 493 |
+
# Two cases:
|
| 494 |
+
# 1. gradient/wait node should be directly used by the output
|
| 495 |
+
# if gradient is None before bwd.
|
| 496 |
+
# 2. gradient/wait node should be directly used by copy_.
|
| 497 |
+
if (
|
| 498 |
+
output not in block.wait_nodes[0].users
|
| 499 |
+
and next(iter(block.wait_nodes[0].users)).target != aten.copy_.default
|
| 500 |
+
):
|
| 501 |
+
return False
|
| 502 |
+
|
| 503 |
+
return True
|
| 504 |
+
|
| 505 |
+
ops = (
|
| 506 |
+
torch.ops._c10d_functional.all_reduce_.default,
|
| 507 |
+
torch.ops._c10d_functional.all_reduce.default,
|
| 508 |
+
)
|
| 509 |
+
comm_blocks = get_all_comm_blocks(graph, ops, comm_filter=ddp_reducer_filter)
|
| 510 |
+
node_indices = {node: i for i, node in enumerate(graph.nodes)}
|
| 511 |
+
|
| 512 |
+
for block in algorithm_fn(graph, comm_blocks):
|
| 513 |
+
fusion_fn(graph, block, node_indices)
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def fuse_ddp_with_coalesced_op(graph: fx.Graph, bucket_size_mb: int) -> None:
|
| 517 |
+
_fuse_ddp_communication(
|
| 518 |
+
graph,
|
| 519 |
+
partial(_bucket_size_fusion, bucket_size_mb=bucket_size_mb),
|
| 520 |
+
partial(_fuse_allreduce, use_concat=False),
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def fuse_ddp_with_concat_op(graph: fx.Graph, bucket_size_mb: int) -> None:
|
| 525 |
+
_fuse_ddp_communication(
|
| 526 |
+
graph,
|
| 527 |
+
partial(_bucket_size_fusion, bucket_size_mb=bucket_size_mb),
|
| 528 |
+
partial(_fuse_allreduce, use_concat=True),
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
def schedule_comm_wait(graph: fx.Graph) -> None:
|
| 533 |
+
"""
|
| 534 |
+
Delay the execution of wait tensors of allreduce until its first user.
|
| 535 |
+
|
| 536 |
+
This algorithm considers the intermediate users, like split, getitem,
|
| 537 |
+
of the wait node and schedule those intermediate users as well.
|
| 538 |
+
This will result in a better overlapping result.
|
| 539 |
+
"""
|
| 540 |
+
ops = (
|
| 541 |
+
torch.ops._c10d_functional.all_reduce_.default,
|
| 542 |
+
torch.ops._c10d_functional.all_reduce.default,
|
| 543 |
+
torch.ops._c10d_functional.all_reduce_coalesced.default,
|
| 544 |
+
torch.ops._c10d_functional.all_reduce_coalesced_.default,
|
| 545 |
+
)
|
| 546 |
+
comm_blocks = get_all_comm_blocks(graph, ops)
|
| 547 |
+
if not comm_blocks:
|
| 548 |
+
return
|
| 549 |
+
|
| 550 |
+
# Find all the end users.
|
| 551 |
+
allreduce_users: Set[fx.Node] = set()
|
| 552 |
+
for allreduce in comm_blocks:
|
| 553 |
+
for output in allreduce.outputs:
|
| 554 |
+
allreduce_users.update(output.users)
|
| 555 |
+
|
| 556 |
+
node_indices = {node: i for i, node in enumerate(graph.nodes)}
|
| 557 |
+
for allreduce in comm_blocks:
|
| 558 |
+
# Find the earliest/first user -- target_node.
|
| 559 |
+
assert (
|
| 560 |
+
len(allreduce.outputs) >= 1
|
| 561 |
+
), f"Found a allreduce that has zero outputs/users -- {allreduce}."
|
| 562 |
+
# Initialize the target node to avoid typing issues.
|
| 563 |
+
target_node = next(iter(next(iter(allreduce.outputs)).users))
|
| 564 |
+
target_node_index = 2**31
|
| 565 |
+
for user in (user for output in allreduce.outputs for user in output.users):
|
| 566 |
+
index = node_indices[user]
|
| 567 |
+
if index < target_node_index:
|
| 568 |
+
target_node = user
|
| 569 |
+
target_node_index = index
|
| 570 |
+
|
| 571 |
+
# Move wait nodes and all the subsequent nodes in the comm_block to
|
| 572 |
+
# before the first user -- target_node.
|
| 573 |
+
wait_idx = -1
|
| 574 |
+
for wait_idx, node in enumerate(allreduce.node_list):
|
| 575 |
+
if node == allreduce.wait_nodes[0]:
|
| 576 |
+
break
|
| 577 |
+
assert wait_idx >= 0
|
| 578 |
+
move_block_before(allreduce.node_list[wait_idx:], target_node)
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def fuse_ddp_communication(
|
| 582 |
+
graph: fx.Graph, passes: List[Union[Callable[..., None], str]], bucket_size_mb: int
|
| 583 |
+
) -> None:
|
| 584 |
+
for i, pa in enumerate(passes):
|
| 585 |
+
with GraphTransformObserver(
|
| 586 |
+
graph.owning_module,
|
| 587 |
+
f"fuse_ddp_communication_pass_{i}",
|
| 588 |
+
config.trace.log_url_for_graph_xform,
|
| 589 |
+
):
|
| 590 |
+
if isinstance(pa, str):
|
| 591 |
+
func = globals()[pa]
|
| 592 |
+
else:
|
| 593 |
+
func = pa
|
| 594 |
+
if "bucket_size_mb" in {
|
| 595 |
+
v.name for v in inspect.signature(func).parameters.values()
|
| 596 |
+
}:
|
| 597 |
+
func(graph, bucket_size_mb=bucket_size_mb)
|
| 598 |
+
else:
|
| 599 |
+
func(graph)
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
from torch._dynamo.utils import counters
|
| 8 |
+
|
| 9 |
+
from .. import config
|
| 10 |
+
from ..pattern_matcher import Arg, CallFunction, Match, register_graph_pattern
|
| 11 |
+
from .split_cat import construct_pattern_matcher_pass
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
aten = torch.ops.aten
|
| 15 |
+
log = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
# TODO: need a better strategy for decomposing mm
|
| 18 |
+
MIN_FIRST_DIMENSION_DECOMPOSITION = 10240
|
| 19 |
+
MAX_OTHER_DIMENSION_DECOMPOSITION = 32
|
| 20 |
+
|
| 21 |
+
min_first_dimension_decomposition = MIN_FIRST_DIMENSION_DECOMPOSITION
|
| 22 |
+
max_other_dimention_decomposition = MAX_OTHER_DIMENSION_DECOMPOSITION
|
| 23 |
+
if "decompose_mm_pass" in config.post_grad_fusion_options:
|
| 24 |
+
min_first_dimension_decomposition = config.post_grad_fusion_options[
|
| 25 |
+
"decompose_mm_pass"
|
| 26 |
+
].get("min_first_dimension_decomposition", MIN_FIRST_DIMENSION_DECOMPOSITION)
|
| 27 |
+
max_other_dimention_decomposition = config.post_grad_fusion_options[
|
| 28 |
+
"decompose_mm_pass"
|
| 29 |
+
].get("max_other_dimention_decomposition", MAX_OTHER_DIMENSION_DECOMPOSITION)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def check_device(a: Tensor, b: Tensor) -> bool:
|
| 33 |
+
return a.is_cuda and b.is_cuda
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def realize_inputs(inputs: List[torch.fx.Node]):
|
| 37 |
+
for inp in inputs:
|
| 38 |
+
if isinstance(inp, torch.fx.node.Node):
|
| 39 |
+
inp.meta["inductor_realize_to_strides"] = True
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def should_decompose_bmm(mat1, mat2) -> bool:
|
| 43 |
+
if is_node_meta_valid(mat1) and is_node_meta_valid(mat2):
|
| 44 |
+
mat1 = mat1.meta["val"]
|
| 45 |
+
mat2 = mat2.meta["val"]
|
| 46 |
+
else:
|
| 47 |
+
return False
|
| 48 |
+
if not check_device(mat1, mat2):
|
| 49 |
+
return False
|
| 50 |
+
else:
|
| 51 |
+
if len(mat1.shape) != 3 or len(mat2.shape) != 3:
|
| 52 |
+
return False
|
| 53 |
+
if mat1.shape[0] < min_first_dimension_decomposition:
|
| 54 |
+
return False
|
| 55 |
+
# 2 of m, n, k must be <= MAX_OTHER_DIMENSION_DECOMPOSITION
|
| 56 |
+
if (mat1.shape[1] < max_other_dimention_decomposition) + (
|
| 57 |
+
mat1.shape[2] < max_other_dimention_decomposition
|
| 58 |
+
) + (mat2.shape[2] < max_other_dimention_decomposition) < 2:
|
| 59 |
+
return False
|
| 60 |
+
return True
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def should_decompose_mm(mat1, mat2) -> bool:
|
| 64 |
+
if is_node_meta_valid(mat1) and is_node_meta_valid(mat2):
|
| 65 |
+
mat1 = mat1.meta["val"]
|
| 66 |
+
mat2 = mat2.meta["val"]
|
| 67 |
+
else:
|
| 68 |
+
return False
|
| 69 |
+
return (
|
| 70 |
+
check_device(mat1, mat2)
|
| 71 |
+
and len(mat1.shape) == 2
|
| 72 |
+
and len(mat2.shape) == 2
|
| 73 |
+
and mat1.shape[0] >= min_first_dimension_decomposition
|
| 74 |
+
and mat2.shape[0] < max_other_dimention_decomposition
|
| 75 |
+
and mat2.shape[1] < max_other_dimention_decomposition
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def is_node_meta_valid(node: torch.fx.Node):
|
| 80 |
+
return "val" in node.meta
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def print_decompose_pattern(match: Match, inputs: List[torch.fx.Node]):
|
| 84 |
+
node = match.nodes[-1]
|
| 85 |
+
log.debug(
|
| 86 |
+
"Decompose %s with input shape: %s",
|
| 87 |
+
node.target,
|
| 88 |
+
", ".join(
|
| 89 |
+
str(input.meta["val"].shape) if "val" in input.meta else "None"
|
| 90 |
+
for input in inputs
|
| 91 |
+
),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@register_graph_pattern(
|
| 96 |
+
CallFunction(aten.bmm, Arg(), Arg()),
|
| 97 |
+
pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"),
|
| 98 |
+
)
|
| 99 |
+
def decompose_bmm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node):
|
| 100 |
+
def repl(mat1, mat2):
|
| 101 |
+
return torch.sum(mat1[:, :, :, None] * mat2[:, None, :, :], dim=-2).to(
|
| 102 |
+
mat1.dtype
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
if should_decompose_bmm(mat1, mat2):
|
| 106 |
+
counters["inductor"]["decompose_bmm"] += 1
|
| 107 |
+
match.replace_by_example(repl, [mat1, mat2])
|
| 108 |
+
print_decompose_pattern(match, [mat1, mat2])
|
| 109 |
+
realize_inputs([mat1, mat2])
|
| 110 |
+
return
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@register_graph_pattern(
|
| 114 |
+
CallFunction(aten.addmm, Arg(), Arg(), Arg()),
|
| 115 |
+
pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"),
|
| 116 |
+
)
|
| 117 |
+
def decompose_addmm(
|
| 118 |
+
match: Match,
|
| 119 |
+
mat1: torch.fx.Node,
|
| 120 |
+
mat2: torch.fx.Node,
|
| 121 |
+
mat3: torch.fx.Node,
|
| 122 |
+
):
|
| 123 |
+
def repl(mat1, mat2, mat3):
|
| 124 |
+
return (
|
| 125 |
+
torch.sum(mat2[:, :, None] * mat3[None, :, :], dim=-2).to(mat2.dtype) + mat1
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
if should_decompose_mm(mat2, mat3):
|
| 129 |
+
counters["inductor"]["decompose_addmm"] += 1
|
| 130 |
+
match.replace_by_example(repl, [mat1, mat2, mat3])
|
| 131 |
+
print_decompose_pattern(match, [mat1, mat2, mat3])
|
| 132 |
+
realize_inputs([mat1, mat2, mat3])
|
| 133 |
+
return
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@register_graph_pattern(
|
| 137 |
+
CallFunction(aten.mm, Arg(), Arg()),
|
| 138 |
+
pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"),
|
| 139 |
+
)
|
| 140 |
+
def decompose_mm(
|
| 141 |
+
match: Match,
|
| 142 |
+
mat1: torch.fx.Node,
|
| 143 |
+
mat2: torch.fx.Node,
|
| 144 |
+
):
|
| 145 |
+
def repl(mat1, mat2):
|
| 146 |
+
return torch.sum(mat1[:, :, None] * mat2[None, :, :], dim=-2).to(mat1.dtype)
|
| 147 |
+
|
| 148 |
+
if should_decompose_mm(mat1, mat2):
|
| 149 |
+
counters["inductor"]["decompose_mm"] += 1
|
| 150 |
+
match.replace_by_example(repl, [mat1, mat2])
|
| 151 |
+
print_decompose_pattern(match, [mat1, mat2])
|
| 152 |
+
realize_inputs([mat1, mat2])
|
| 153 |
+
return
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import SymBool, SymFloat, SymInt
|
| 7 |
+
from torch.types import py_sym_types
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class _SymExprHash:
|
| 12 |
+
"""
|
| 13 |
+
Hash for a py_sym_types that will use the underlying sympy expression
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
sym_obj: Union[SymInt, SymFloat, SymBool]
|
| 17 |
+
|
| 18 |
+
def __hash__(self) -> int:
|
| 19 |
+
return hash((type(self.sym_obj), self.sym_obj.node.expr))
|
| 20 |
+
|
| 21 |
+
def __eq__(self, value) -> bool:
|
| 22 |
+
if not isinstance(value, _SymExprHash):
|
| 23 |
+
return False
|
| 24 |
+
return self.sym_obj.node.expr == value.sym_obj.node.expr
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class _SymHashingDict:
|
| 28 |
+
"""
|
| 29 |
+
Wrapper around a dictionary that will convert sym types to hash with _SymExprHash and reuse
|
| 30 |
+
existing sym proxies.
|
| 31 |
+
|
| 32 |
+
SymPy hash is not always reliable so optimistically hash sympy expression, and if those fail,
|
| 33 |
+
fallback to symnodes.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self):
|
| 37 |
+
self.sym_hash_dict = {}
|
| 38 |
+
|
| 39 |
+
def __setitem__(self, key, value):
|
| 40 |
+
self.sym_hash_dict.__setitem__(self._wrap_to_sym_expr_hash(key), value)
|
| 41 |
+
|
| 42 |
+
def __getitem__(self, key):
|
| 43 |
+
return self.sym_hash_dict[self._wrap_to_sym_expr_hash(key)]
|
| 44 |
+
|
| 45 |
+
def __contains__(self, key):
|
| 46 |
+
return self._wrap_to_sym_expr_hash(key) in self.sym_hash_dict
|
| 47 |
+
|
| 48 |
+
def get(self, key, default=None):
|
| 49 |
+
return self.sym_hash_dict.get(self._wrap_to_sym_expr_hash(key), default)
|
| 50 |
+
|
| 51 |
+
def _wrap_to_sym_expr_hash(self, key):
|
| 52 |
+
return _SymExprHash(key) if isinstance(key, py_sym_types) else key
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def dedupe_symints(graph: torch.fx.Graph):
|
| 56 |
+
"""
|
| 57 |
+
Dedupes sym ints in the graph to nodes are resolvable to symint graph inputs.
|
| 58 |
+
|
| 59 |
+
We only dedupe from graph inputs to avoid adding a potential dependency in the forward
|
| 60 |
+
from the backward.
|
| 61 |
+
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
sym_dict = _SymHashingDict()
|
| 65 |
+
resolvable_from_input_symints = set()
|
| 66 |
+
|
| 67 |
+
for node in graph.nodes:
|
| 68 |
+
val = node.meta.get("val", None)
|
| 69 |
+
if val is None or not isinstance(val, py_sym_types):
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
if node.op == "placeholder":
|
| 73 |
+
resolvable_from_input_symints.add(node)
|
| 74 |
+
sym_dict[val] = node
|
| 75 |
+
elif existing_node := sym_dict.get(val):
|
| 76 |
+
node.replace_all_uses_with(existing_node)
|
| 77 |
+
graph.erase_node(node)
|
| 78 |
+
elif all(n in resolvable_from_input_symints for n in node.all_input_nodes):
|
| 79 |
+
sym_dict[val] = node
|
| 80 |
+
resolvable_from_input_symints.add(node)
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch._dynamo.utils import counters
|
| 5 |
+
from torch._inductor import config as inductor_config
|
| 6 |
+
from torch.func import functional_call
|
| 7 |
+
|
| 8 |
+
from ..pattern_matcher import (
|
| 9 |
+
CallFunctionVarArgs,
|
| 10 |
+
CallModuleVarArgs,
|
| 11 |
+
Match,
|
| 12 |
+
register_graph_pattern,
|
| 13 |
+
)
|
| 14 |
+
from .pre_grad import efficient_conv_bn_eval_pass
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def efficient_conv_bn_eval(
|
| 18 |
+
bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor
|
| 19 |
+
):
|
| 20 |
+
"""
|
| 21 |
+
Implementation based on https://arxiv.org/abs/2305.11624
|
| 22 |
+
"Efficient ConvBN Blocks for Transfer Learning and Beyond"
|
| 23 |
+
It leverages the associative law between convolution and affine transform,
|
| 24 |
+
i.e., normalize (weight conv feature) = (normalize weight) conv feature.
|
| 25 |
+
It works for Eval mode of ConvBN blocks during validation, and can be used
|
| 26 |
+
for **training** as well, but only if one sets `bn.training=False`. It
|
| 27 |
+
reduces memory footprint and computation cost, at the cost of slightly
|
| 28 |
+
reduced numerical stability.
|
| 29 |
+
Args:
|
| 30 |
+
bn (nn.modules.batchnorm._BatchNorm): a BatchNorm module.
|
| 31 |
+
conv (nn.modules.conv._ConvNd): a conv module
|
| 32 |
+
x (torch.Tensor): Input feature map.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
assert bn.running_var is not None
|
| 36 |
+
|
| 37 |
+
# These lines of code are designed to deal with various cases
|
| 38 |
+
# like bn without affine transform, and conv without bias
|
| 39 |
+
weight_on_the_fly = conv.weight
|
| 40 |
+
if conv.bias is not None:
|
| 41 |
+
bias_on_the_fly = conv.bias
|
| 42 |
+
else:
|
| 43 |
+
bias_on_the_fly = torch.zeros_like(bn.running_var)
|
| 44 |
+
|
| 45 |
+
if bn.weight is not None:
|
| 46 |
+
bn_weight = bn.weight
|
| 47 |
+
else:
|
| 48 |
+
bn_weight = torch.ones_like(bn.running_var)
|
| 49 |
+
|
| 50 |
+
if bn.bias is not None:
|
| 51 |
+
bn_bias = bn.bias
|
| 52 |
+
else:
|
| 53 |
+
bn_bias = torch.zeros_like(bn.running_var)
|
| 54 |
+
|
| 55 |
+
# shape of [C_out, 1, 1, 1] in Conv2d
|
| 56 |
+
target_shape = [-1] + [1] * (conv.weight.ndim - 1)
|
| 57 |
+
if isinstance(conv, nn.modules.conv._ConvTransposeNd):
|
| 58 |
+
# for transposed conv, the C_out dimension should at index 1.
|
| 59 |
+
target_shape[:2] = [target_shape[1], target_shape[0]]
|
| 60 |
+
weight_coeff = torch.rsqrt(bn.running_var + bn.eps).reshape(target_shape)
|
| 61 |
+
# shape of [C_out, 1, 1, 1] in Conv2d
|
| 62 |
+
coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff
|
| 63 |
+
|
| 64 |
+
# shape of [C_out, C_in, k, k] in Conv2d
|
| 65 |
+
weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
|
| 66 |
+
# shape of [C_out] in Conv2d
|
| 67 |
+
bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * (
|
| 68 |
+
bias_on_the_fly - bn.running_mean
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
input = x
|
| 72 |
+
params = {"weight": weight_on_the_fly, "bias": bias_on_the_fly}
|
| 73 |
+
output = functional_call(conv, params, input)
|
| 74 |
+
return output
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def efficient_conv_bn_eval_decomposed(
|
| 78 |
+
bn_weight,
|
| 79 |
+
bn_bias,
|
| 80 |
+
bn_running_mean,
|
| 81 |
+
bn_running_var,
|
| 82 |
+
bn_eps,
|
| 83 |
+
conv: torch._ops.OpOverload,
|
| 84 |
+
conv_weight,
|
| 85 |
+
conv_bias,
|
| 86 |
+
x,
|
| 87 |
+
conv_remainging_args,
|
| 88 |
+
):
|
| 89 |
+
"""
|
| 90 |
+
Implementation based on https://arxiv.org/abs/2305.11624
|
| 91 |
+
"Efficient ConvBN Blocks for Transfer Learning and Beyond"
|
| 92 |
+
It leverages the associative law between convolution and affine transform,
|
| 93 |
+
i.e., normalize (weight conv feature) = (normalize weight) conv feature.
|
| 94 |
+
It works for Eval mode of ConvBN blocks during validation, and can be used
|
| 95 |
+
for **training** as well, but only if one sets `bn.training=False`. It
|
| 96 |
+
reduces memory footprint and computation cost, at the cost of slightly
|
| 97 |
+
reduced numerical stability.
|
| 98 |
+
Args:
|
| 99 |
+
"""
|
| 100 |
+
assert bn_running_var is not None
|
| 101 |
+
|
| 102 |
+
# These lines of code are designed to deal with various cases
|
| 103 |
+
# like bn without affine transform, and conv without bias
|
| 104 |
+
weight_on_the_fly = conv_weight
|
| 105 |
+
if conv_bias is not None:
|
| 106 |
+
bias_on_the_fly = conv_bias
|
| 107 |
+
else:
|
| 108 |
+
bias_on_the_fly = torch.zeros_like(bn_running_var)
|
| 109 |
+
|
| 110 |
+
if bn_weight is not None:
|
| 111 |
+
bn_weight = bn_weight
|
| 112 |
+
else:
|
| 113 |
+
bn_weight = torch.ones_like(bn_running_var)
|
| 114 |
+
|
| 115 |
+
if bn_bias is not None:
|
| 116 |
+
bn_bias = bn_bias
|
| 117 |
+
else:
|
| 118 |
+
bn_bias = torch.zeros_like(bn_running_var)
|
| 119 |
+
|
| 120 |
+
# shape of [C_out, 1, 1, 1] in Conv2d
|
| 121 |
+
target_shape = [-1] + [1] * (conv_weight.ndim - 1)
|
| 122 |
+
if "conv_transpose" in conv.__str__():
|
| 123 |
+
# for transposed conv, the C_out dimension should at index 1.
|
| 124 |
+
target_shape[:2] = [target_shape[1], target_shape[0]]
|
| 125 |
+
weight_coeff = torch.rsqrt(bn_running_var + bn_eps).reshape(target_shape)
|
| 126 |
+
# shape of [C_out, 1, 1, 1] in Conv2d
|
| 127 |
+
coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff
|
| 128 |
+
|
| 129 |
+
# shape of [C_out, C_in, k, k] in Conv2d
|
| 130 |
+
weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
|
| 131 |
+
# shape of [C_out] in Conv2d
|
| 132 |
+
bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * (
|
| 133 |
+
bias_on_the_fly - bn_running_mean
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
input = x
|
| 137 |
+
return conv(*((input, weight_on_the_fly, bias_on_the_fly) + conv_remainging_args))
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@register_graph_pattern(
|
| 141 |
+
CallFunctionVarArgs(
|
| 142 |
+
[
|
| 143 |
+
torch.nn.functional.batch_norm,
|
| 144 |
+
]
|
| 145 |
+
),
|
| 146 |
+
pass_dict=efficient_conv_bn_eval_pass,
|
| 147 |
+
extra_check=lambda match: not inductor_config.freezing
|
| 148 |
+
and inductor_config.efficient_conv_bn_eval_fx_passes,
|
| 149 |
+
)
|
| 150 |
+
def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs):
|
| 151 |
+
bn_node = match.nodes[0]
|
| 152 |
+
graph = match.graph
|
| 153 |
+
assert len(bn_node.args) == 8
|
| 154 |
+
|
| 155 |
+
# We can only use efficient conv-bn for eval mode with track_running_stats
|
| 156 |
+
# bn_node.args is `training`
|
| 157 |
+
if bn_node.args[-3]:
|
| 158 |
+
return
|
| 159 |
+
|
| 160 |
+
# Check if the input is Conv
|
| 161 |
+
input_node = bn_node.args[0]
|
| 162 |
+
|
| 163 |
+
if input_node.op != "call_function": # type: ignore[union-attr]
|
| 164 |
+
return
|
| 165 |
+
|
| 166 |
+
input_fn = input_node.target # type: ignore[arg-type, union-attr]
|
| 167 |
+
supported_convs = [
|
| 168 |
+
torch._C._nn.linear,
|
| 169 |
+
torch.conv1d,
|
| 170 |
+
torch.conv2d,
|
| 171 |
+
torch.conv3d,
|
| 172 |
+
torch.conv_transpose1d,
|
| 173 |
+
torch.conv_transpose2d,
|
| 174 |
+
torch.conv_transpose3d,
|
| 175 |
+
]
|
| 176 |
+
|
| 177 |
+
if not any(input_fn is cls for cls in supported_convs):
|
| 178 |
+
return
|
| 179 |
+
|
| 180 |
+
conv_node = input_node
|
| 181 |
+
# Output of conv is used by other nodes, cannot optimize
|
| 182 |
+
if len(conv_node.users) > 1: # type: ignore[union-attr]
|
| 183 |
+
return
|
| 184 |
+
|
| 185 |
+
counters["inductor"]["efficient_conv_bn_eval"] += 1
|
| 186 |
+
|
| 187 |
+
with graph.inserting_before(bn_node):
|
| 188 |
+
# prepare args for the fused function
|
| 189 |
+
bn_running_mean = bn_node.args[1]
|
| 190 |
+
bn_running_var = bn_node.args[2]
|
| 191 |
+
bn_weight = bn_node.args[3]
|
| 192 |
+
bn_bias = bn_node.args[4]
|
| 193 |
+
bn_eps = bn_node.args[7]
|
| 194 |
+
assert len(conv_node.args) >= 2 # type: ignore[union-attr]
|
| 195 |
+
conv_input = conv_node.args[0] # type: ignore[union-attr]
|
| 196 |
+
conv_weight = conv_node.args[1] # type: ignore[union-attr]
|
| 197 |
+
conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None # type: ignore[union-attr]
|
| 198 |
+
conv_remainging_args = conv_node.args[3:] # type: ignore[union-attr]
|
| 199 |
+
args = (
|
| 200 |
+
bn_weight,
|
| 201 |
+
bn_bias,
|
| 202 |
+
bn_running_mean,
|
| 203 |
+
bn_running_var,
|
| 204 |
+
bn_eps,
|
| 205 |
+
conv_node.target, # type: ignore[union-attr]
|
| 206 |
+
conv_weight,
|
| 207 |
+
conv_bias,
|
| 208 |
+
conv_input,
|
| 209 |
+
conv_remainging_args,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# create a new node
|
| 213 |
+
new_node = graph.create_node(
|
| 214 |
+
op="call_function",
|
| 215 |
+
target=efficient_conv_bn_eval_decomposed,
|
| 216 |
+
args=args, # type: ignore[arg-type]
|
| 217 |
+
name="efficient_conv_bn_eval",
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# this node replaces the original conv + bn, and therefore
|
| 221 |
+
# should replace the uses of bn_node
|
| 222 |
+
bn_node.replace_all_uses_with(new_node)
|
| 223 |
+
# take care of the deletion order:
|
| 224 |
+
# delete bn_node first, and then conv_node
|
| 225 |
+
graph.erase_node(bn_node)
|
| 226 |
+
graph.erase_node(conv_node) # type: ignore[arg-type]
|
| 227 |
+
|
| 228 |
+
return
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
@register_graph_pattern(
|
| 232 |
+
CallFunctionVarArgs(
|
| 233 |
+
[
|
| 234 |
+
torch.ops.aten.batch_norm.default,
|
| 235 |
+
]
|
| 236 |
+
),
|
| 237 |
+
pass_dict=efficient_conv_bn_eval_pass,
|
| 238 |
+
extra_check=lambda match: not inductor_config.freezing
|
| 239 |
+
and inductor_config.efficient_conv_bn_eval_fx_passes,
|
| 240 |
+
)
|
| 241 |
+
def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwargs):
|
| 242 |
+
bn_node = match.nodes[0]
|
| 243 |
+
graph = match.graph
|
| 244 |
+
assert len(bn_node.args) == 9
|
| 245 |
+
|
| 246 |
+
# We can only use efficient conv-bn for eval mode with track_running_stats
|
| 247 |
+
# bn_node.args is `training`
|
| 248 |
+
if bn_node.args[-4]:
|
| 249 |
+
return
|
| 250 |
+
|
| 251 |
+
# Check if the input is Conv
|
| 252 |
+
input_node = bn_node.args[0]
|
| 253 |
+
|
| 254 |
+
if input_node.op != "call_function": # type: ignore[union-attr]
|
| 255 |
+
return
|
| 256 |
+
|
| 257 |
+
input_fn = input_node.target # type: ignore[arg-type, union-attr]
|
| 258 |
+
supported_convs = [
|
| 259 |
+
torch.ops.aten.linear.default,
|
| 260 |
+
torch.ops.aten.conv1d.default,
|
| 261 |
+
torch.ops.aten.conv2d.default,
|
| 262 |
+
torch.ops.aten.conv3d.default,
|
| 263 |
+
torch.ops.aten.conv_transpose1d.default,
|
| 264 |
+
torch.ops.aten.conv_transpose2d.input,
|
| 265 |
+
torch.ops.aten.conv_transpose3d.input,
|
| 266 |
+
]
|
| 267 |
+
|
| 268 |
+
if not any(input_fn is cls for cls in supported_convs):
|
| 269 |
+
return
|
| 270 |
+
|
| 271 |
+
conv_node = input_node
|
| 272 |
+
# Output of conv is used by other nodes, cannot optimize
|
| 273 |
+
if len(conv_node.users) > 1: # type: ignore[union-attr]
|
| 274 |
+
return
|
| 275 |
+
|
| 276 |
+
counters["inductor"]["efficient_conv_bn_eval"] += 1
|
| 277 |
+
|
| 278 |
+
with graph.inserting_before(bn_node):
|
| 279 |
+
# prepare args for the fused function
|
| 280 |
+
bn_weight = bn_node.args[1]
|
| 281 |
+
bn_bias = bn_node.args[2]
|
| 282 |
+
bn_running_mean = bn_node.args[3]
|
| 283 |
+
bn_running_var = bn_node.args[4]
|
| 284 |
+
bn_eps = bn_node.args[7]
|
| 285 |
+
assert len(conv_node.args) >= 2 # type: ignore[union-attr]
|
| 286 |
+
conv_input = conv_node.args[0] # type: ignore[union-attr]
|
| 287 |
+
conv_weight = conv_node.args[1] # type: ignore[union-attr]
|
| 288 |
+
conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None # type: ignore[union-attr]
|
| 289 |
+
conv_remainging_args = conv_node.args[3:] # type: ignore[union-attr]
|
| 290 |
+
args = (
|
| 291 |
+
bn_weight,
|
| 292 |
+
bn_bias,
|
| 293 |
+
bn_running_mean,
|
| 294 |
+
bn_running_var,
|
| 295 |
+
bn_eps,
|
| 296 |
+
conv_node.target, # type: ignore[union-attr]
|
| 297 |
+
conv_weight,
|
| 298 |
+
conv_bias,
|
| 299 |
+
conv_input,
|
| 300 |
+
conv_remainging_args,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# create a new node
|
| 304 |
+
new_node = graph.create_node(
|
| 305 |
+
op="call_function",
|
| 306 |
+
target=efficient_conv_bn_eval_decomposed,
|
| 307 |
+
args=args, # type: ignore[arg-type]
|
| 308 |
+
name="efficient_conv_bn_eval",
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# this node replaces the original conv + bn, and therefore
|
| 312 |
+
# should replace the uses of bn_node
|
| 313 |
+
bn_node.replace_all_uses_with(new_node)
|
| 314 |
+
# take care of the deletion order:
|
| 315 |
+
# delete bn_node first, and then conv_node
|
| 316 |
+
graph.erase_node(bn_node)
|
| 317 |
+
graph.erase_node(conv_node) # type: ignore[arg-type]
|
| 318 |
+
|
| 319 |
+
return
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
@register_graph_pattern(
|
| 323 |
+
CallModuleVarArgs(
|
| 324 |
+
[
|
| 325 |
+
nn.modules.batchnorm._BatchNorm,
|
| 326 |
+
nn.BatchNorm1d,
|
| 327 |
+
nn.BatchNorm2d,
|
| 328 |
+
nn.BatchNorm3d,
|
| 329 |
+
nn.SyncBatchNorm,
|
| 330 |
+
],
|
| 331 |
+
),
|
| 332 |
+
pass_dict=efficient_conv_bn_eval_pass,
|
| 333 |
+
extra_check=lambda match: not inductor_config.freezing
|
| 334 |
+
and inductor_config.efficient_conv_bn_eval_fx_passes,
|
| 335 |
+
)
|
| 336 |
+
def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
|
| 337 |
+
# We matched a BN node
|
| 338 |
+
bn_node = match.nodes[0]
|
| 339 |
+
graph = match.graph
|
| 340 |
+
gm = graph.owning_module
|
| 341 |
+
bn_mod = getattr(gm, bn_node.target) # type: ignore[arg-type]
|
| 342 |
+
|
| 343 |
+
# We can only use efficient conv-bn for eval mode with track_running_stats
|
| 344 |
+
if not bn_mod.track_running_stats or bn_mod.training:
|
| 345 |
+
return
|
| 346 |
+
|
| 347 |
+
# Check if the input is Conv
|
| 348 |
+
if bn_node.args:
|
| 349 |
+
input_node = bn_node.args[0]
|
| 350 |
+
else:
|
| 351 |
+
input_node = bn_node.kwargs["input"]
|
| 352 |
+
if input_node.op != "call_module": # type: ignore[union-attr]
|
| 353 |
+
return
|
| 354 |
+
if not hasattr(gm, input_node.target): # type: ignore[arg-type, union-attr]
|
| 355 |
+
return
|
| 356 |
+
input_mod = getattr(gm, input_node.target) # type: ignore[arg-type, union-attr]
|
| 357 |
+
supported_convs = [
|
| 358 |
+
nn.Linear,
|
| 359 |
+
nn.Conv1d,
|
| 360 |
+
nn.Conv2d,
|
| 361 |
+
nn.Conv3d,
|
| 362 |
+
nn.ConvTranspose1d,
|
| 363 |
+
nn.ConvTranspose2d,
|
| 364 |
+
nn.ConvTranspose3d,
|
| 365 |
+
]
|
| 366 |
+
if not any(isinstance(input_mod, cls) for cls in supported_convs):
|
| 367 |
+
return
|
| 368 |
+
conv_node = input_node
|
| 369 |
+
# Output of conv is used by other nodes, cannot optimize
|
| 370 |
+
if len(conv_node.users) > 1: # type: ignore[union-attr]
|
| 371 |
+
return
|
| 372 |
+
|
| 373 |
+
# Find a pair of conv and bn computation nodes to optimize.
|
| 374 |
+
counters["inductor"]["efficient_conv_bn_eval"] += 1
|
| 375 |
+
|
| 376 |
+
with graph.inserting_before(conv_node): # type: ignore[arg-type]
|
| 377 |
+
# create `get_attr` node to access modules
|
| 378 |
+
# note that we directly call `create_node` to fill the `name`
|
| 379 |
+
# argument. `graph.get_attr` and
|
| 380 |
+
# `graph.call_function` does not allow the `name` argument.
|
| 381 |
+
conv_get_node = graph.create_node(
|
| 382 |
+
op="get_attr", target=conv_node.target, name="get_conv" # type: ignore[union-attr]
|
| 383 |
+
)
|
| 384 |
+
bn_get_node = graph.create_node(
|
| 385 |
+
op="get_attr", target=bn_node.target, name="get_bn"
|
| 386 |
+
)
|
| 387 |
+
if conv_node.args: # type: ignore[union-attr]
|
| 388 |
+
conv_input = conv_node.args[0] # type: ignore[union-attr]
|
| 389 |
+
else:
|
| 390 |
+
conv_input = conv_node.kwargs["input"] # type: ignore[union-attr]
|
| 391 |
+
# prepare args for the fused function
|
| 392 |
+
args = (bn_get_node, conv_get_node, conv_input)
|
| 393 |
+
# create a new node
|
| 394 |
+
new_node = graph.create_node(
|
| 395 |
+
op="call_function",
|
| 396 |
+
target=efficient_conv_bn_eval,
|
| 397 |
+
args=args,
|
| 398 |
+
name="efficient_conv_bn_eval",
|
| 399 |
+
)
|
| 400 |
+
# this node replaces the original conv + bn, and therefore
|
| 401 |
+
# should replace the uses of bn_node
|
| 402 |
+
bn_node.replace_all_uses_with(new_node)
|
| 403 |
+
# take care of the deletion order:
|
| 404 |
+
# delete bn_node first, and then conv_node
|
| 405 |
+
graph.erase_node(bn_node)
|
| 406 |
+
graph.erase_node(conv_node) # type: ignore[arg-type]
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/freezing_patterns.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._inductor.compile_fx import fake_tensor_prop
|
| 6 |
+
|
| 7 |
+
from ..._dynamo.utils import counters
|
| 8 |
+
from .. import config
|
| 9 |
+
from ..pattern_matcher import (
|
| 10 |
+
_return_true,
|
| 11 |
+
CallFunction,
|
| 12 |
+
fwd_only,
|
| 13 |
+
Ignored,
|
| 14 |
+
init_once_fakemode,
|
| 15 |
+
KeywordArg,
|
| 16 |
+
Match,
|
| 17 |
+
PatternMatcherPass,
|
| 18 |
+
register_graph_pattern,
|
| 19 |
+
register_replacement,
|
| 20 |
+
stable_topological_sort,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
aten = torch.ops.aten
|
| 25 |
+
|
| 26 |
+
# First pass_patterns[0] are applied, then [1], then [2]
|
| 27 |
+
pass_patterns = [
|
| 28 |
+
PatternMatcherPass(),
|
| 29 |
+
PatternMatcherPass(),
|
| 30 |
+
PatternMatcherPass(),
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
binary_folding_pass = PatternMatcherPass()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def freezing_passes(gm: torch.fx.GraphModule, aot_example_inputs):
|
| 37 |
+
"""
|
| 38 |
+
Passes that are applied to the graph to freeze pass.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
from ..freezing import constant_fold
|
| 42 |
+
|
| 43 |
+
lazy_init()
|
| 44 |
+
# We need a few rounds of binary folding to get rid of all the
|
| 45 |
+
# unnecessary nodes, but may need a good method to chose the rounds number.
|
| 46 |
+
# works like: conv+binary+binary.
|
| 47 |
+
binary_folding = counters["inductor"]["binary_folding"]
|
| 48 |
+
fake_tensor_prop(gm, aot_example_inputs, True)
|
| 49 |
+
|
| 50 |
+
torch._inductor.fx_passes.binary_folding.mark_mixed_dtype_allowed_convs(gm)
|
| 51 |
+
for _ in range(4):
|
| 52 |
+
constant_fold(gm)
|
| 53 |
+
# Make sure meta['val'] is properly set for all nodes
|
| 54 |
+
fake_tensor_prop(gm, aot_example_inputs, True)
|
| 55 |
+
binary_folding_pass.apply(gm.graph) # type: ignore[arg-type]
|
| 56 |
+
# If we don't have binary folding, we don't need to run the pass again.
|
| 57 |
+
# TODO: remove the need to run fake_tensor_prop on the whole model.
|
| 58 |
+
if counters["inductor"]["binary_folding"] == binary_folding:
|
| 59 |
+
break
|
| 60 |
+
binary_folding = counters["inductor"]["binary_folding"]
|
| 61 |
+
|
| 62 |
+
torch._inductor.fx_passes.binary_folding.recover_original_precision_folded_convs(gm)
|
| 63 |
+
|
| 64 |
+
constant_fold(gm)
|
| 65 |
+
fake_tensor_prop(gm, aot_example_inputs, True)
|
| 66 |
+
|
| 67 |
+
for pattern in pass_patterns:
|
| 68 |
+
pattern.apply(gm.graph) # type: ignore[arg-type]
|
| 69 |
+
|
| 70 |
+
# The CPU weight packing always assume the conv's weight is channels last,
|
| 71 |
+
# So make sure the layout_optimization is on when doing it.
|
| 72 |
+
if (
|
| 73 |
+
torch._C._has_mkldnn
|
| 74 |
+
and config.cpp.weight_prepack
|
| 75 |
+
and config.layout_optimization
|
| 76 |
+
):
|
| 77 |
+
from .mkldnn_fusion import _eliminate_duplicate_packed_nodes
|
| 78 |
+
|
| 79 |
+
_eliminate_duplicate_packed_nodes(gm)
|
| 80 |
+
|
| 81 |
+
stable_topological_sort(gm.graph)
|
| 82 |
+
gm.recompile()
|
| 83 |
+
gm.graph.lint()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@init_once_fakemode
|
| 87 |
+
def lazy_init():
|
| 88 |
+
if torch._C._has_mkldnn and config.cpp.weight_prepack:
|
| 89 |
+
from .mkldnn_fusion import _mkldnn_weight_pack_init
|
| 90 |
+
|
| 91 |
+
_mkldnn_weight_pack_init()
|
| 92 |
+
|
| 93 |
+
from .binary_folding import binary_folding_init
|
| 94 |
+
|
| 95 |
+
addmm_patterns_init()
|
| 96 |
+
binary_folding_init()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def register_freezing_graph_pattern(pattern, extra_check=_return_true, pass_number=0):
|
| 100 |
+
return register_graph_pattern(
|
| 101 |
+
pattern,
|
| 102 |
+
extra_check=extra_check,
|
| 103 |
+
pass_dict=pass_patterns[pass_number],
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def register_binary_folding_pattern(pattern, extra_check=_return_true):
|
| 108 |
+
return register_graph_pattern(
|
| 109 |
+
pattern,
|
| 110 |
+
extra_check=extra_check,
|
| 111 |
+
pass_dict=binary_folding_pass,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@functools.lru_cache(None)
|
| 116 |
+
def addmm_patterns_init():
|
| 117 |
+
if torch.cuda.is_available():
|
| 118 |
+
# workaround https://github.com/pytorch/pytorch/issues/97894
|
| 119 |
+
device = "cuda"
|
| 120 |
+
else:
|
| 121 |
+
device = "cpu"
|
| 122 |
+
val = functools.partial(torch.empty, (10, 10), device=device, requires_grad=False)
|
| 123 |
+
|
| 124 |
+
def check_concat_weights(match):
|
| 125 |
+
weight_inputs = ["w1", "w2"]
|
| 126 |
+
if "w3" in match.kwargs:
|
| 127 |
+
weight_inputs.append("w3")
|
| 128 |
+
|
| 129 |
+
equal_shape_inputs = [weight_inputs]
|
| 130 |
+
|
| 131 |
+
if "b1" in match.kwargs:
|
| 132 |
+
bias_inputs = ["b1", "b2"]
|
| 133 |
+
if "b3" in match.kwargs:
|
| 134 |
+
bias_inputs.append("b3")
|
| 135 |
+
|
| 136 |
+
equal_shape_inputs.append(bias_inputs)
|
| 137 |
+
|
| 138 |
+
for equal_shape_group in equal_shape_inputs:
|
| 139 |
+
inps = [match.kwargs[name] for name in equal_shape_group]
|
| 140 |
+
|
| 141 |
+
if not all(
|
| 142 |
+
inp.op == "get_attr"
|
| 143 |
+
and inp.meta["val"].shape == inps[0].meta["val"].shape
|
| 144 |
+
for inp in inps
|
| 145 |
+
):
|
| 146 |
+
return False
|
| 147 |
+
|
| 148 |
+
return True
|
| 149 |
+
|
| 150 |
+
def matmul_fuse_pattern(inp, w1, w2, w3):
|
| 151 |
+
return (inp @ w1, inp @ w2, inp @ w3)
|
| 152 |
+
|
| 153 |
+
def matmul_replacement(inp, w1, w2, w3):
|
| 154 |
+
cat_t = torch.cat((w1, w2, w3), dim=1)
|
| 155 |
+
mm = inp @ cat_t
|
| 156 |
+
return mm.chunk(3, dim=1)
|
| 157 |
+
|
| 158 |
+
register_replacement(
|
| 159 |
+
matmul_fuse_pattern,
|
| 160 |
+
matmul_replacement,
|
| 161 |
+
[val(), val(), val(), val()],
|
| 162 |
+
fwd_only,
|
| 163 |
+
pass_patterns[0],
|
| 164 |
+
extra_check=check_concat_weights,
|
| 165 |
+
exclusive_arg_names=("w1", "w2", "w3"),
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def matmul_fuse_pattern_two(inp, w1, w2):
|
| 169 |
+
return (inp @ w1, inp @ w2)
|
| 170 |
+
|
| 171 |
+
def matmul_replacement_two(inp, w1, w2):
|
| 172 |
+
cat_t = torch.cat((w1, w2), dim=1)
|
| 173 |
+
mm = inp @ cat_t
|
| 174 |
+
return mm.chunk(2, dim=1)
|
| 175 |
+
|
| 176 |
+
register_replacement(
|
| 177 |
+
matmul_fuse_pattern_two,
|
| 178 |
+
matmul_replacement_two,
|
| 179 |
+
[val(), val(), val()],
|
| 180 |
+
fwd_only,
|
| 181 |
+
pass_patterns[0],
|
| 182 |
+
extra_check=check_concat_weights,
|
| 183 |
+
exclusive_arg_names=("w1", "w2"),
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def addmm_fuse_pattern_second(inp, w1, w2, w3, b1, b2, b3):
|
| 187 |
+
return (
|
| 188 |
+
aten.addmm(b1, inp, w1),
|
| 189 |
+
aten.addmm(b2, inp, w2),
|
| 190 |
+
aten.addmm(b3, inp, w3),
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
def addmm_fuse_replacement_second(inp, w1, w2, w3, b1, b2, b3):
|
| 194 |
+
cat_w = torch.cat((w1, w2, w3), dim=1)
|
| 195 |
+
cat_b = torch.cat((b1, b2, b3))
|
| 196 |
+
return aten.addmm(cat_b, inp, cat_w).chunk(3, dim=1)
|
| 197 |
+
|
| 198 |
+
register_replacement(
|
| 199 |
+
addmm_fuse_pattern_second,
|
| 200 |
+
addmm_fuse_replacement_second,
|
| 201 |
+
[val() for _ in range(7)],
|
| 202 |
+
fwd_only,
|
| 203 |
+
pass_patterns[0],
|
| 204 |
+
extra_check=check_concat_weights,
|
| 205 |
+
exclusive_arg_names=("w1", "w2", "w3", "b1", "b2", "b3"),
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def same_dtype(match):
|
| 210 |
+
return match.output_node().args[0].meta["val"].dtype == match.kwargs["dtype"]
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
@register_graph_pattern(
|
| 214 |
+
CallFunction(
|
| 215 |
+
torch.ops.prims.convert_element_type.default,
|
| 216 |
+
Ignored(),
|
| 217 |
+
KeywordArg("dtype"),
|
| 218 |
+
),
|
| 219 |
+
pass_dict=pass_patterns[0],
|
| 220 |
+
extra_check=same_dtype,
|
| 221 |
+
)
|
| 222 |
+
def unnecessary_dtype_convert(match: Match, **kwargs):
|
| 223 |
+
"""Remove unnecessary dtype conversion op, probably left as a result of Conv-Bn folding"""
|
| 224 |
+
graph = match.graph
|
| 225 |
+
node = match.output_node()
|
| 226 |
+
node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type]
|
| 227 |
+
graph.erase_node(node)
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/fuse_attention.py
ADDED
|
@@ -0,0 +1,909 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
import inspect
|
| 4 |
+
import logging
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.nn.attention import sdpa_kernel, SDPBackend
|
| 9 |
+
|
| 10 |
+
from ..._dynamo.utils import counters
|
| 11 |
+
from ..pattern_matcher import (
|
| 12 |
+
filter_nodes,
|
| 13 |
+
fwd_only,
|
| 14 |
+
gen_register_replacement,
|
| 15 |
+
joint_fwd_bwd,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
log = logging.getLogger(__name__)
|
| 20 |
+
aten = torch.ops.aten
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
if torch.version.hip:
|
| 24 |
+
|
| 25 |
+
def _scaled_dot_product_attention(*args, **kwargs):
|
| 26 |
+
with sdpa_kernel(backends=[SDPBackend.MATH, SDPBackend.FLASH_ATTENTION]):
|
| 27 |
+
return aten.scaled_dot_product_attention(*args, **kwargs)
|
| 28 |
+
|
| 29 |
+
else:
|
| 30 |
+
_scaled_dot_product_attention = aten.scaled_dot_product_attention
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _sfdp_pattern_1(query, key, value, inv_scale):
|
| 34 |
+
return (
|
| 35 |
+
torch.matmul(query, key.transpose(-2, -1))
|
| 36 |
+
.div(inv_scale)
|
| 37 |
+
.softmax(dim=-1)
|
| 38 |
+
.matmul(value)
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _sfdp_replacement_1(query, key, value, inv_scale):
|
| 43 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 44 |
+
return _scaled_dot_product_attention(
|
| 45 |
+
query.contiguous(),
|
| 46 |
+
key.contiguous(),
|
| 47 |
+
value.contiguous(),
|
| 48 |
+
attn_mask=None,
|
| 49 |
+
dropout_p=0.0,
|
| 50 |
+
is_causal=False,
|
| 51 |
+
scale=1.0 / inv_scale,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _sfdp_pattern_2(query, key, value, scale_factor):
|
| 56 |
+
return (
|
| 57 |
+
torch.matmul(query, key.transpose(-2, -1))
|
| 58 |
+
.mul(scale_factor)
|
| 59 |
+
.softmax(dim=-1)
|
| 60 |
+
.matmul(value)
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _sfdp_replacement_2(query, key, value, scale_factor):
|
| 65 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 66 |
+
return _scaled_dot_product_attention(
|
| 67 |
+
query.contiguous(),
|
| 68 |
+
key.contiguous(),
|
| 69 |
+
value.contiguous(),
|
| 70 |
+
attn_mask=None,
|
| 71 |
+
dropout_p=0.0,
|
| 72 |
+
is_causal=False,
|
| 73 |
+
scale=scale_factor,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _sfdp_pattern_3(query, key, value, inv_scale_factor, dropout_p):
|
| 78 |
+
return torch.nn.functional.dropout(
|
| 79 |
+
torch.matmul(query, key.transpose(-2, -1))
|
| 80 |
+
.div(inv_scale_factor)
|
| 81 |
+
.softmax(dim=-1),
|
| 82 |
+
p=dropout_p,
|
| 83 |
+
).matmul(value)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _sfdp_replacement_3(query, key, value, inv_scale_factor, dropout_p):
|
| 87 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 88 |
+
return _scaled_dot_product_attention(
|
| 89 |
+
query.contiguous(),
|
| 90 |
+
key.contiguous(),
|
| 91 |
+
value.contiguous(),
|
| 92 |
+
attn_mask=None,
|
| 93 |
+
dropout_p=dropout_p,
|
| 94 |
+
is_causal=False,
|
| 95 |
+
scale=1.0 / inv_scale_factor,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _sfdp_pattern_4(query, key, value, scale_factor, dropout_p):
|
| 100 |
+
return torch.nn.functional.dropout(
|
| 101 |
+
torch.matmul(query, key.transpose(-2, -1)).mul(scale_factor).softmax(dim=-1),
|
| 102 |
+
p=dropout_p,
|
| 103 |
+
).matmul(value)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _sfdp_replacement_4(query, key, value, scale_factor, dropout_p):
|
| 107 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 108 |
+
return _scaled_dot_product_attention(
|
| 109 |
+
query.contiguous(),
|
| 110 |
+
key.contiguous(),
|
| 111 |
+
value.contiguous(),
|
| 112 |
+
attn_mask=None,
|
| 113 |
+
dropout_p=dropout_p,
|
| 114 |
+
is_causal=False,
|
| 115 |
+
scale=scale_factor,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _sfdp_pattern_5(query, key, value, attn_mask):
|
| 120 |
+
attn_weight = torch.softmax(
|
| 121 |
+
(query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1
|
| 122 |
+
)
|
| 123 |
+
# attn_weight = torch.dropout(attn_weight, dropout_p)
|
| 124 |
+
return attn_weight @ value
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _sfdp_replacement_5(query, key, value, attn_mask):
|
| 128 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 129 |
+
return _scaled_dot_product_attention(
|
| 130 |
+
query.contiguous(),
|
| 131 |
+
key.contiguous(),
|
| 132 |
+
value.contiguous(),
|
| 133 |
+
attn_mask=attn_mask.to(dtype=query.dtype),
|
| 134 |
+
dropout_p=0.0,
|
| 135 |
+
is_causal=False,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _sfdp_pattern_6(query, key, value, attn_mask, dropout_p):
|
| 140 |
+
attn_weight = torch.softmax(
|
| 141 |
+
(query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1
|
| 142 |
+
)
|
| 143 |
+
attn_weight = torch.dropout(attn_weight, dropout_p, True)
|
| 144 |
+
return attn_weight @ value
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p):
|
| 148 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 149 |
+
return _scaled_dot_product_attention(
|
| 150 |
+
query.contiguous(),
|
| 151 |
+
key.contiguous(),
|
| 152 |
+
value.contiguous(),
|
| 153 |
+
attn_mask=attn_mask.to(dtype=query.dtype),
|
| 154 |
+
dropout_p=dropout_p,
|
| 155 |
+
is_causal=False,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _sfdp_pattern_7(query, key, value, dropout_p):
|
| 160 |
+
# in real workloads inputs to matmul are permuted
|
| 161 |
+
# causing matmul to expand to a series of expand and clone calls
|
| 162 |
+
# we want the same to happen during pattern tracing
|
| 163 |
+
q = query.permute(0, 2, 1, 3)
|
| 164 |
+
k = key.permute(0, 2, 1, 3)
|
| 165 |
+
v = value.permute(0, 2, 1, 3)
|
| 166 |
+
div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
|
| 167 |
+
div = div.to(torch.float32)
|
| 168 |
+
attn_weight = torch.softmax(div, dim=-1)
|
| 169 |
+
attn_weight = torch.dropout(attn_weight, dropout_p, True)
|
| 170 |
+
attn_weight = attn_weight.to(torch.float16)
|
| 171 |
+
return attn_weight @ v
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _sfdp_replacement_7(query, key, value, dropout_p):
|
| 175 |
+
# sdpa prefers inputs in permuted format
|
| 176 |
+
# it makes a copy to put them in this format
|
| 177 |
+
# if they aren't already
|
| 178 |
+
# to make replacement efficient ensure that inputs to sdpa
|
| 179 |
+
# are in required order
|
| 180 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 181 |
+
q = query.permute(0, 2, 1, 3)
|
| 182 |
+
k = key.permute(0, 2, 1, 3)
|
| 183 |
+
v = value.permute(0, 2, 1, 3)
|
| 184 |
+
return _scaled_dot_product_attention(
|
| 185 |
+
q,
|
| 186 |
+
k,
|
| 187 |
+
v,
|
| 188 |
+
attn_mask=None, # attn_mask,
|
| 189 |
+
dropout_p=dropout_p,
|
| 190 |
+
is_causal=False,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _sfdp_pattern_8(query, key, value):
|
| 195 |
+
# no dropout version of pattern 7
|
| 196 |
+
q = query.permute(0, 2, 1, 3)
|
| 197 |
+
k = key.permute(0, 2, 1, 3)
|
| 198 |
+
v = value.permute(0, 2, 1, 3)
|
| 199 |
+
div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
|
| 200 |
+
div = div.to(torch.float32)
|
| 201 |
+
attn_weight = torch.softmax(div, dim=-1)
|
| 202 |
+
attn_weight = attn_weight.to(torch.float16)
|
| 203 |
+
return attn_weight @ v
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def _sfdp_replacement_8(query, key, value):
|
| 207 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 208 |
+
q = query.permute(0, 2, 1, 3)
|
| 209 |
+
k = key.permute(0, 2, 1, 3)
|
| 210 |
+
v = value.permute(0, 2, 1, 3)
|
| 211 |
+
return _scaled_dot_product_attention(
|
| 212 |
+
q,
|
| 213 |
+
k,
|
| 214 |
+
v,
|
| 215 |
+
attn_mask=None, # attn_mask,
|
| 216 |
+
dropout_p=0.0,
|
| 217 |
+
is_causal=False,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def _sfdp_pattern_9(query, key, value, dropout_p):
|
| 222 |
+
q = query.permute(0, 2, 1, 3)
|
| 223 |
+
k = key.permute(0, 2, 1, 3)
|
| 224 |
+
v = value.permute(0, 2, 1, 3)
|
| 225 |
+
q = q / math.sqrt(q.size(-1))
|
| 226 |
+
div = q @ k.transpose(-2, -1)
|
| 227 |
+
div = div.to(torch.float32)
|
| 228 |
+
attn_weight = torch.softmax(div, dim=-1)
|
| 229 |
+
attn_weight = torch.dropout(attn_weight, dropout_p, True)
|
| 230 |
+
attn_weight = attn_weight.to(torch.float16)
|
| 231 |
+
return attn_weight @ v
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def _sfdp_replacement_9(query, key, value, dropout_p):
|
| 235 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 236 |
+
q = query.permute(0, 2, 1, 3)
|
| 237 |
+
k = key.permute(0, 2, 1, 3)
|
| 238 |
+
v = value.permute(0, 2, 1, 3)
|
| 239 |
+
return _scaled_dot_product_attention(
|
| 240 |
+
q,
|
| 241 |
+
k,
|
| 242 |
+
v,
|
| 243 |
+
attn_mask=None, # attn_mask,
|
| 244 |
+
dropout_p=dropout_p,
|
| 245 |
+
is_causal=False,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _sfdp_pattern_10(query, key, value):
|
| 250 |
+
# no dropout version of 9
|
| 251 |
+
q = query.permute(0, 2, 1, 3)
|
| 252 |
+
k = key.permute(0, 2, 1, 3)
|
| 253 |
+
v = value.permute(0, 2, 1, 3)
|
| 254 |
+
q = q / math.sqrt(q.size(-1))
|
| 255 |
+
div = q @ k.transpose(-2, -1)
|
| 256 |
+
div = div.to(torch.float32)
|
| 257 |
+
attn_weight = torch.softmax(div, dim=-1)
|
| 258 |
+
attn_weight = attn_weight.to(torch.float16)
|
| 259 |
+
return attn_weight @ v
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def _sfdp_replacement_10(query, key, value):
|
| 263 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 264 |
+
q = query.permute(0, 2, 1, 3)
|
| 265 |
+
k = key.permute(0, 2, 1, 3)
|
| 266 |
+
v = value.permute(0, 2, 1, 3)
|
| 267 |
+
return _scaled_dot_product_attention(
|
| 268 |
+
q,
|
| 269 |
+
k,
|
| 270 |
+
v,
|
| 271 |
+
attn_mask=None, # attn_mask,
|
| 272 |
+
dropout_p=0.0,
|
| 273 |
+
is_causal=False,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def _sfdp_pattern_11(query, key, value, inv_scale):
|
| 278 |
+
# Mainly for huggingface models
|
| 279 |
+
q = query.permute(0, 2, 1, 3)
|
| 280 |
+
k = key.permute(0, 2, 1, 3)
|
| 281 |
+
v = value.permute(0, 2, 1, 3)
|
| 282 |
+
return torch.matmul(q, k.transpose(-2, -1)).div(inv_scale).softmax(dim=-1).matmul(v)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def _sfdp_replacement_11(query, key, value, inv_scale):
|
| 286 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 287 |
+
return _scaled_dot_product_attention(
|
| 288 |
+
query.transpose(1, 2),
|
| 289 |
+
key.transpose(1, 2),
|
| 290 |
+
value.transpose(1, 2),
|
| 291 |
+
attn_mask=None,
|
| 292 |
+
dropout_p=0.0,
|
| 293 |
+
is_causal=False,
|
| 294 |
+
scale=1.0 / inv_scale,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def _sfdp_pattern_12(query, key, value, inv_scale_factor, dropout_p):
|
| 299 |
+
q = query.permute(0, 2, 1, 3)
|
| 300 |
+
k = key.permute(0, 2, 1, 3)
|
| 301 |
+
v = value.permute(0, 2, 1, 3)
|
| 302 |
+
return torch.nn.functional.dropout(
|
| 303 |
+
torch.matmul(q, k.transpose(-2, -1)).div(inv_scale_factor).softmax(dim=-1),
|
| 304 |
+
p=dropout_p,
|
| 305 |
+
).matmul(v)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def _sfdp_replacement_12(query, key, value, inv_scale_factor, dropout_p):
|
| 309 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 310 |
+
return _scaled_dot_product_attention(
|
| 311 |
+
query.transpose(1, 2),
|
| 312 |
+
key.transpose(1, 2),
|
| 313 |
+
value.transpose(1, 2),
|
| 314 |
+
attn_mask=None,
|
| 315 |
+
dropout_p=dropout_p,
|
| 316 |
+
is_causal=False,
|
| 317 |
+
scale=1.0 / inv_scale_factor,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def _sfdp_pattern_13(query, key, value, dropout_p):
|
| 322 |
+
attn_weight = torch.bmm(query, key.transpose(1, 2)).softmax(dim=-1)
|
| 323 |
+
attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p)
|
| 324 |
+
return torch.bmm(attn_weight, value)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def _sfdp_replacement_13(query, key, value, dropout_p):
|
| 328 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 329 |
+
return _scaled_dot_product_attention(
|
| 330 |
+
query.unsqueeze(0),
|
| 331 |
+
key.unsqueeze(0),
|
| 332 |
+
value.unsqueeze(0),
|
| 333 |
+
dropout_p=dropout_p,
|
| 334 |
+
scale=1.0,
|
| 335 |
+
).squeeze(0)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def _sfdp_pattern_14(query, key, value, attn_mask, inv_scale):
|
| 339 |
+
# for BertLarge
|
| 340 |
+
# Permutations are needed to create clones in graph.
|
| 341 |
+
q = query.permute([0, 2, 1, 3])
|
| 342 |
+
k = key.permute([0, 2, 1, 3])
|
| 343 |
+
v = value.permute([0, 2, 1, 3])
|
| 344 |
+
return (
|
| 345 |
+
(torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask)
|
| 346 |
+
.softmax(dim=-1)
|
| 347 |
+
.matmul(v)
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def _sfdp_replacement_14(query, key, value, attn_mask, inv_scale):
|
| 352 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 353 |
+
return _scaled_dot_product_attention(
|
| 354 |
+
query.transpose(1, 2),
|
| 355 |
+
key.transpose(1, 2),
|
| 356 |
+
value.transpose(1, 2),
|
| 357 |
+
attn_mask=attn_mask.to(dtype=query.dtype),
|
| 358 |
+
dropout_p=0.0,
|
| 359 |
+
is_causal=False,
|
| 360 |
+
scale=1.0 / inv_scale,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def _sfdp_pattern_15(query, key, value, attn_mask, inv_scale):
|
| 365 |
+
# for DistilBert
|
| 366 |
+
# Permutations are needed to create clones in graph.
|
| 367 |
+
# Ref: https://github.com/pytorch/pytorch/issues/119911
|
| 368 |
+
q = query.permute([0, 2, 1, 3])
|
| 369 |
+
k = key.permute([0, 2, 1, 3])
|
| 370 |
+
v = value.permute([0, 2, 1, 3])
|
| 371 |
+
bs = q.size(0)
|
| 372 |
+
k_len = k.size(-2)
|
| 373 |
+
scores = q @ k.transpose(-2, -1)
|
| 374 |
+
scores = scores.div(inv_scale)
|
| 375 |
+
fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
|
| 376 |
+
attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores)
|
| 377 |
+
return torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1) @ v
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def _sfdp_replacement_15(query, key, value, attn_mask, inv_scale):
|
| 381 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 382 |
+
bs = query.size(0)
|
| 383 |
+
n_head = query.size(2)
|
| 384 |
+
q_len = query.size(1)
|
| 385 |
+
k_len = key.size(1)
|
| 386 |
+
# do attn_mask->logical_not() in _scaled_dot_product_attention
|
| 387 |
+
attn_mask = (
|
| 388 |
+
(attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len))
|
| 389 |
+
)
|
| 390 |
+
return _scaled_dot_product_attention(
|
| 391 |
+
query.transpose(1, 2),
|
| 392 |
+
key.transpose(1, 2),
|
| 393 |
+
value.transpose(1, 2),
|
| 394 |
+
attn_mask=attn_mask.to(dtype=torch.bool),
|
| 395 |
+
dropout_p=0.0,
|
| 396 |
+
is_causal=False,
|
| 397 |
+
scale=1.0 / inv_scale,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def _sfdp_pattern_16(query, key, value, attn_mask, inv_scale, dropout_p):
|
| 402 |
+
# for BertLarge with dropout
|
| 403 |
+
q = query.permute([0, 2, 1, 3])
|
| 404 |
+
k = key.permute([0, 2, 1, 3])
|
| 405 |
+
v = value.permute([0, 2, 1, 3])
|
| 406 |
+
return (
|
| 407 |
+
torch.nn.functional.dropout(
|
| 408 |
+
(torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask).softmax(
|
| 409 |
+
dim=-1
|
| 410 |
+
),
|
| 411 |
+
dropout_p,
|
| 412 |
+
)
|
| 413 |
+
.to(dtype=query.dtype)
|
| 414 |
+
.matmul(v)
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def _sfdp_replacement_16(query, key, value, attn_mask, inv_scale, dropout_p):
|
| 419 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 420 |
+
return _scaled_dot_product_attention(
|
| 421 |
+
query.transpose(1, 2),
|
| 422 |
+
key.transpose(1, 2),
|
| 423 |
+
value.transpose(1, 2),
|
| 424 |
+
attn_mask=attn_mask.to(dtype=query.dtype),
|
| 425 |
+
dropout_p=dropout_p,
|
| 426 |
+
is_causal=False,
|
| 427 |
+
scale=1.0 / inv_scale,
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def _sfdp_pattern_17(query, key, value, attn_mask, inv_scale, dropout_p):
|
| 432 |
+
# for DistilBert with dropout
|
| 433 |
+
q = query.permute([0, 2, 1, 3])
|
| 434 |
+
k = key.permute([0, 2, 1, 3])
|
| 435 |
+
v = value.permute([0, 2, 1, 3])
|
| 436 |
+
bs = q.size(0)
|
| 437 |
+
k_len = k.size(-2)
|
| 438 |
+
scores = q @ k.transpose(-2, -1)
|
| 439 |
+
scores = scores.div(inv_scale)
|
| 440 |
+
fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
|
| 441 |
+
attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores)
|
| 442 |
+
return (
|
| 443 |
+
torch.nn.functional.dropout(
|
| 444 |
+
torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1), dropout_p
|
| 445 |
+
)
|
| 446 |
+
@ v
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def _sfdp_replacement_17(query, key, value, attn_mask, inv_scale, dropout_p):
|
| 451 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 452 |
+
bs = query.size(0)
|
| 453 |
+
n_head = query.size(2)
|
| 454 |
+
q_len = query.size(1)
|
| 455 |
+
k_len = key.size(1)
|
| 456 |
+
# do attn_mask->logical_not() in _scaled_dot_product_attention
|
| 457 |
+
attn_mask = (
|
| 458 |
+
(attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len))
|
| 459 |
+
)
|
| 460 |
+
return _scaled_dot_product_attention(
|
| 461 |
+
query.transpose(1, 2),
|
| 462 |
+
key.transpose(1, 2),
|
| 463 |
+
value.transpose(1, 2),
|
| 464 |
+
attn_mask=attn_mask.to(dtype=torch.bool),
|
| 465 |
+
dropout_p=dropout_p,
|
| 466 |
+
is_causal=False,
|
| 467 |
+
scale=1.0 / inv_scale,
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def _sfdp_pattern_18(query, key, value, causal_mask, dropout_p):
|
| 472 |
+
# for hf_GPT2 with dropout (introduces clone node) for inference
|
| 473 |
+
# it also returns permuted key & value
|
| 474 |
+
query = query.permute([0, 2, 1, 3])
|
| 475 |
+
key = key.permute([0, 2, 1, 3])
|
| 476 |
+
value = value.permute([0, 2, 1, 3])
|
| 477 |
+
attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2))
|
| 478 |
+
inv_scale = torch.full(
|
| 479 |
+
[],
|
| 480 |
+
value.size(-1) ** 0.5,
|
| 481 |
+
dtype=attn_weights.dtype,
|
| 482 |
+
device=attn_weights.device,
|
| 483 |
+
)
|
| 484 |
+
attn_weights = attn_weights.div(inv_scale)
|
| 485 |
+
causal_mask_value = torch.full(
|
| 486 |
+
(), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device
|
| 487 |
+
)
|
| 488 |
+
attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value)
|
| 489 |
+
return (
|
| 490 |
+
(
|
| 491 |
+
torch.nn.functional.dropout(attn_weights.softmax(dim=-1), dropout_p).matmul(
|
| 492 |
+
value
|
| 493 |
+
)
|
| 494 |
+
),
|
| 495 |
+
key,
|
| 496 |
+
value,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def _sfdp_replacement_18(query, key, value, causal_mask, dropout_p):
|
| 501 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 502 |
+
permuted_key = key.transpose(1, 2)
|
| 503 |
+
permuted_value = value.transpose(1, 2)
|
| 504 |
+
return (
|
| 505 |
+
_scaled_dot_product_attention(
|
| 506 |
+
query.transpose(1, 2),
|
| 507 |
+
permuted_key,
|
| 508 |
+
permuted_value,
|
| 509 |
+
attn_mask=causal_mask,
|
| 510 |
+
dropout_p=dropout_p,
|
| 511 |
+
is_causal=False,
|
| 512 |
+
scale=1.0 / math.sqrt(value.size(-1)),
|
| 513 |
+
),
|
| 514 |
+
permuted_key,
|
| 515 |
+
permuted_value,
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def _sfdp_pattern_19(query, key, value, causal_mask, attn_mask, dropout_p):
|
| 520 |
+
# for token-classification+gpt2 / text-generation+gpt2
|
| 521 |
+
attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2))
|
| 522 |
+
inv_scale = torch.full(
|
| 523 |
+
[],
|
| 524 |
+
value.size(-1) ** 0.5,
|
| 525 |
+
dtype=attn_weights.dtype,
|
| 526 |
+
device=attn_weights.device,
|
| 527 |
+
)
|
| 528 |
+
attn_weights = attn_weights.div(inv_scale)
|
| 529 |
+
causal_mask_value = torch.full(
|
| 530 |
+
(), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device
|
| 531 |
+
)
|
| 532 |
+
attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value)
|
| 533 |
+
attn_weights = attn_weights + attn_mask
|
| 534 |
+
attn_weights = attn_weights.softmax(dim=-1).type(value.dtype)
|
| 535 |
+
return torch.nn.functional.dropout(attn_weights, dropout_p).matmul(value)
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def _sfdp_replacement_19(query, key, value, causal_mask, attn_mask, dropout_p):
|
| 539 |
+
counters["inductor"]["fuse_attention"] += 1
|
| 540 |
+
fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
|
| 541 |
+
attn_mask = torch.where(causal_mask, attn_mask, fill_value)
|
| 542 |
+
return _scaled_dot_product_attention(
|
| 543 |
+
query,
|
| 544 |
+
key,
|
| 545 |
+
value,
|
| 546 |
+
attn_mask=attn_mask,
|
| 547 |
+
dropout_p=dropout_p,
|
| 548 |
+
is_causal=False,
|
| 549 |
+
scale=1.0 / math.sqrt(value.size(-1)),
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def _sfdp_params_check(match):
|
| 554 |
+
assert all(k in match.kwargs for k in ("query", "key", "value"))
|
| 555 |
+
query = match.kwargs["query"].meta["val"]
|
| 556 |
+
key = match.kwargs["key"].meta["val"]
|
| 557 |
+
value = match.kwargs["value"].meta["val"]
|
| 558 |
+
if not (query.dtype == key.dtype == value.dtype) or not (
|
| 559 |
+
query.device == key.device == value.device
|
| 560 |
+
):
|
| 561 |
+
return False
|
| 562 |
+
add_mask_node = filter_nodes(match.nodes, aten.add.Tensor)
|
| 563 |
+
# Has attn_mask add.
|
| 564 |
+
if len(add_mask_node) > 0:
|
| 565 |
+
attn_mask_node = add_mask_node[0].args[1]
|
| 566 |
+
# attn_mask_node may be a float/int number.
|
| 567 |
+
if not hasattr(attn_mask_node, "meta"):
|
| 568 |
+
return False
|
| 569 |
+
attn_mask = attn_mask_node.meta["val"] # type: ignore[union-attr]
|
| 570 |
+
# Make sure attn_mask.dtype == query.dtype or attn_mask.dtype == torch.bool
|
| 571 |
+
# attn_mask.dtype == torch.float for models like albert.
|
| 572 |
+
if (
|
| 573 |
+
not isinstance(attn_mask, torch.Tensor)
|
| 574 |
+
or not (
|
| 575 |
+
attn_mask.dtype == query.dtype
|
| 576 |
+
or attn_mask.dtype == torch.bool
|
| 577 |
+
or attn_mask.dtype == torch.float
|
| 578 |
+
)
|
| 579 |
+
or query.device != attn_mask.device
|
| 580 |
+
):
|
| 581 |
+
return False
|
| 582 |
+
return True
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def _sfdp_extra_check(scale_factor_op=None, disable_cuda=False):
|
| 586 |
+
def fn(match):
|
| 587 |
+
if (
|
| 588 |
+
disable_cuda
|
| 589 |
+
and "query" in match.kwargs
|
| 590 |
+
and "cuda" in str(match.kwargs["query"].meta["val"].device)
|
| 591 |
+
):
|
| 592 |
+
return False
|
| 593 |
+
if scale_factor_op is not None:
|
| 594 |
+
scale_factor_node = filter_nodes(match.nodes, scale_factor_op)[0]
|
| 595 |
+
# Note: args[1] of the scale_factor_node is always the scale_factor for the current patterns.
|
| 596 |
+
scale_factor = scale_factor_node.args[1]
|
| 597 |
+
# make sure the scale_factor a float/int. SymInt?
|
| 598 |
+
if not isinstance(scale_factor, (float, int)):
|
| 599 |
+
return False
|
| 600 |
+
return _sfdp_params_check(match)
|
| 601 |
+
|
| 602 |
+
return fn
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
def partialize_and_update_signature(func, **kwargs):
|
| 606 |
+
"""
|
| 607 |
+
Equivalent to functools.partial but also updates the signature on returned function
|
| 608 |
+
"""
|
| 609 |
+
original_sig = inspect.signature(func)
|
| 610 |
+
parameters = original_sig.parameters
|
| 611 |
+
|
| 612 |
+
new_parameters = {
|
| 613 |
+
key: value for key, value in parameters.items() if key not in kwargs
|
| 614 |
+
}
|
| 615 |
+
new_sig = inspect.Signature(parameters=list(new_parameters.values()))
|
| 616 |
+
|
| 617 |
+
partial_func = functools.partial(func, **kwargs)
|
| 618 |
+
|
| 619 |
+
def wrapper(*args, **kwargs):
|
| 620 |
+
return partial_func(*args, **kwargs)
|
| 621 |
+
|
| 622 |
+
wrapper.__signature__ = new_sig # type: ignore[attr-defined]
|
| 623 |
+
wrapper.__name__ = func.__name__
|
| 624 |
+
|
| 625 |
+
return wrapper
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
def _get_sfdp_patterns():
|
| 629 |
+
from .joint_graph import patterns
|
| 630 |
+
|
| 631 |
+
if torch.cuda.is_available():
|
| 632 |
+
# workaround https://github.com/pytorch/pytorch/issues/97894
|
| 633 |
+
device = "cuda"
|
| 634 |
+
else:
|
| 635 |
+
device = "cpu"
|
| 636 |
+
|
| 637 |
+
# sizes/values don't actually matter for initial trace
|
| 638 |
+
# once we get a possible match we re-trace with the actual values and verify the match still holds
|
| 639 |
+
g_inp = functools.partial(
|
| 640 |
+
torch.empty, (2, 4, 8, 16), device=device, requires_grad=True
|
| 641 |
+
)
|
| 642 |
+
# attn_mask
|
| 643 |
+
b_inp = functools.partial(torch.empty, (1, 1, 8, 8), device=device)
|
| 644 |
+
m_inp = functools.partial(torch.empty, (2, 1, 1, 4), device=device)
|
| 645 |
+
# inv_scale
|
| 646 |
+
c_inp = functools.partial(torch.tensor, 2.0, device=device)
|
| 647 |
+
# workaround https://github.com/pytorch/pytorch/issues/97894
|
| 648 |
+
# 0.113377 is a "magic" value that lets us recover the lost input arg relationship
|
| 649 |
+
d = {"dropout_p": 0.113377}
|
| 650 |
+
|
| 651 |
+
# we could also generate all these patterns in 3d.. TODO
|
| 652 |
+
g_3d_inp = functools.partial(
|
| 653 |
+
torch.empty, (1024, 128, 128), device=device, requires_grad=True
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
# reshape in matmul decomposition generates a clone when batch_size>1 due to the memory layout change.
|
| 657 |
+
# however when batch_size=1, reshape does not change the memory layout, so clone would not be generated.
|
| 658 |
+
# here we need to trace with input of batch_size=1 to generate a pattern graph without clone.
|
| 659 |
+
g_bs1_inp = functools.partial(
|
| 660 |
+
torch.empty, (1, 4, 8, 16), device=device, requires_grad=True
|
| 661 |
+
)
|
| 662 |
+
m_bs1_inp = functools.partial(torch.empty, (1, 1, 1, 4), device=device)
|
| 663 |
+
|
| 664 |
+
# softmax will generate a dtype conversion on inputs if they are in half,
|
| 665 |
+
# but will not in float, so we generate a pattern for both
|
| 666 |
+
for dtype in [torch.float, torch.half]:
|
| 667 |
+
g = functools.partial(g_inp, dtype=dtype)
|
| 668 |
+
b = functools.partial(b_inp, dtype=dtype)
|
| 669 |
+
b_float = functools.partial(b_inp, dtype=torch.float)
|
| 670 |
+
b_bool = functools.partial(b_inp, dtype=torch.bool)
|
| 671 |
+
m = functools.partial(m_inp, dtype=dtype)
|
| 672 |
+
m_float = functools.partial(m_inp, dtype=torch.float)
|
| 673 |
+
m_bool = functools.partial(m_inp, dtype=torch.bool)
|
| 674 |
+
c = functools.partial(c_inp, dtype=dtype)
|
| 675 |
+
g_3d = functools.partial(g_3d_inp, dtype=dtype)
|
| 676 |
+
g_bs1 = functools.partial(g_bs1_inp, dtype=dtype)
|
| 677 |
+
m_bs1 = functools.partial(m_bs1_inp, dtype=dtype)
|
| 678 |
+
m_bs1_float = functools.partial(m_bs1_inp, dtype=torch.float)
|
| 679 |
+
m_bs1_bool = functools.partial(m_bs1_inp, dtype=torch.bool)
|
| 680 |
+
|
| 681 |
+
candidates = [
|
| 682 |
+
(
|
| 683 |
+
_sfdp_pattern_1,
|
| 684 |
+
_sfdp_replacement_1,
|
| 685 |
+
[g(), g(), g(), c()],
|
| 686 |
+
{},
|
| 687 |
+
_sfdp_extra_check(aten.div.Tensor),
|
| 688 |
+
),
|
| 689 |
+
(
|
| 690 |
+
_sfdp_pattern_2,
|
| 691 |
+
_sfdp_replacement_2,
|
| 692 |
+
[g(), g(), g(), c()],
|
| 693 |
+
{},
|
| 694 |
+
_sfdp_extra_check(aten.mul.Tensor),
|
| 695 |
+
),
|
| 696 |
+
(
|
| 697 |
+
_sfdp_pattern_3,
|
| 698 |
+
_sfdp_replacement_3,
|
| 699 |
+
[g(), g(), g(), c()],
|
| 700 |
+
d,
|
| 701 |
+
_sfdp_extra_check(aten.div.Tensor),
|
| 702 |
+
),
|
| 703 |
+
(
|
| 704 |
+
_sfdp_pattern_4,
|
| 705 |
+
_sfdp_replacement_4,
|
| 706 |
+
[g(), g(), g(), c()],
|
| 707 |
+
d,
|
| 708 |
+
_sfdp_extra_check(aten.mul.Tensor),
|
| 709 |
+
),
|
| 710 |
+
(
|
| 711 |
+
_sfdp_pattern_5,
|
| 712 |
+
_sfdp_replacement_5,
|
| 713 |
+
[g(), g(), g(), b()],
|
| 714 |
+
{},
|
| 715 |
+
_sfdp_params_check,
|
| 716 |
+
),
|
| 717 |
+
(
|
| 718 |
+
_sfdp_pattern_6,
|
| 719 |
+
_sfdp_replacement_6,
|
| 720 |
+
[g(), g(), g(), b()],
|
| 721 |
+
d,
|
| 722 |
+
_sfdp_params_check,
|
| 723 |
+
),
|
| 724 |
+
(
|
| 725 |
+
_sfdp_pattern_7,
|
| 726 |
+
_sfdp_replacement_7,
|
| 727 |
+
[g(), g(), g()],
|
| 728 |
+
d,
|
| 729 |
+
_sfdp_params_check,
|
| 730 |
+
),
|
| 731 |
+
(
|
| 732 |
+
_sfdp_pattern_8,
|
| 733 |
+
_sfdp_replacement_8,
|
| 734 |
+
[g(), g(), g()],
|
| 735 |
+
{},
|
| 736 |
+
_sfdp_params_check,
|
| 737 |
+
),
|
| 738 |
+
(
|
| 739 |
+
_sfdp_pattern_9,
|
| 740 |
+
_sfdp_replacement_9,
|
| 741 |
+
[g(), g(), g()],
|
| 742 |
+
d,
|
| 743 |
+
_sfdp_params_check,
|
| 744 |
+
),
|
| 745 |
+
(
|
| 746 |
+
_sfdp_pattern_10,
|
| 747 |
+
_sfdp_replacement_10,
|
| 748 |
+
[g(), g(), g()],
|
| 749 |
+
{},
|
| 750 |
+
_sfdp_params_check,
|
| 751 |
+
),
|
| 752 |
+
(
|
| 753 |
+
_sfdp_pattern_11,
|
| 754 |
+
_sfdp_replacement_11,
|
| 755 |
+
[g(), g(), g(), c()],
|
| 756 |
+
{},
|
| 757 |
+
_sfdp_extra_check(aten.div.Tensor),
|
| 758 |
+
),
|
| 759 |
+
(
|
| 760 |
+
_sfdp_pattern_12,
|
| 761 |
+
_sfdp_replacement_12,
|
| 762 |
+
[g(), g(), g(), c()],
|
| 763 |
+
d,
|
| 764 |
+
_sfdp_extra_check(aten.div.Tensor),
|
| 765 |
+
),
|
| 766 |
+
(
|
| 767 |
+
_sfdp_pattern_13,
|
| 768 |
+
_sfdp_replacement_13,
|
| 769 |
+
[g_3d(), g_3d(), g_3d()],
|
| 770 |
+
d,
|
| 771 |
+
_sfdp_params_check,
|
| 772 |
+
),
|
| 773 |
+
(
|
| 774 |
+
_sfdp_pattern_14,
|
| 775 |
+
_sfdp_replacement_14,
|
| 776 |
+
[g(), g(), g(), m(), c()],
|
| 777 |
+
{},
|
| 778 |
+
_sfdp_extra_check(aten.div.Tensor),
|
| 779 |
+
),
|
| 780 |
+
(
|
| 781 |
+
_sfdp_pattern_15,
|
| 782 |
+
_sfdp_replacement_15,
|
| 783 |
+
[g(), g(), g(), m(), c()],
|
| 784 |
+
{},
|
| 785 |
+
_sfdp_extra_check(aten.div.Tensor),
|
| 786 |
+
),
|
| 787 |
+
# TODO: Enable CUDA after solving Bert accuracy issue of calling efficient attention
|
| 788 |
+
(
|
| 789 |
+
_sfdp_pattern_16,
|
| 790 |
+
_sfdp_replacement_16,
|
| 791 |
+
[g(), g(), g(), m(), c()],
|
| 792 |
+
d,
|
| 793 |
+
_sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
|
| 794 |
+
),
|
| 795 |
+
(
|
| 796 |
+
_sfdp_pattern_16,
|
| 797 |
+
_sfdp_replacement_16,
|
| 798 |
+
[g_bs1(), g_bs1(), g_bs1(), m_bs1(), c()],
|
| 799 |
+
d,
|
| 800 |
+
_sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
|
| 801 |
+
),
|
| 802 |
+
(
|
| 803 |
+
_sfdp_pattern_17,
|
| 804 |
+
_sfdp_replacement_17,
|
| 805 |
+
[g(), g(), g(), m(), c()],
|
| 806 |
+
d,
|
| 807 |
+
_sfdp_extra_check(aten.div.Tensor),
|
| 808 |
+
),
|
| 809 |
+
(
|
| 810 |
+
_sfdp_pattern_18,
|
| 811 |
+
_sfdp_replacement_18,
|
| 812 |
+
[g(), g(), g(), m_bool()],
|
| 813 |
+
d,
|
| 814 |
+
# CUDA AOT Inductor CI job's GPT2ForSequenceClassification accuracy test failed
|
| 815 |
+
_sfdp_extra_check(disable_cuda=True),
|
| 816 |
+
),
|
| 817 |
+
(
|
| 818 |
+
_sfdp_pattern_18,
|
| 819 |
+
_sfdp_replacement_18,
|
| 820 |
+
[g_bs1(), g_bs1(), g_bs1(), m_bs1_bool()],
|
| 821 |
+
d,
|
| 822 |
+
# CUDA AOT Inductor CI job's GPT2ForSequenceClassification accuracy test failed
|
| 823 |
+
_sfdp_extra_check(disable_cuda=True),
|
| 824 |
+
),
|
| 825 |
+
(
|
| 826 |
+
_sfdp_pattern_19,
|
| 827 |
+
_sfdp_replacement_19,
|
| 828 |
+
[g(), g(), g(), b_bool(), b_float()],
|
| 829 |
+
d,
|
| 830 |
+
_sfdp_params_check,
|
| 831 |
+
),
|
| 832 |
+
]
|
| 833 |
+
mask_fp32_patterns = ["pattern_16"]
|
| 834 |
+
if dtype == torch.half:
|
| 835 |
+
# Add inputs of bf16 q/k/v and fp32 mask, for models like albert.
|
| 836 |
+
candidates.append(
|
| 837 |
+
(
|
| 838 |
+
_sfdp_pattern_16,
|
| 839 |
+
_sfdp_replacement_16,
|
| 840 |
+
[g(), g(), g(), m_float(), c()],
|
| 841 |
+
d,
|
| 842 |
+
_sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
|
| 843 |
+
)
|
| 844 |
+
)
|
| 845 |
+
candidates.append(
|
| 846 |
+
(
|
| 847 |
+
_sfdp_pattern_16,
|
| 848 |
+
_sfdp_replacement_16,
|
| 849 |
+
[g_bs1(), g_bs1(), g_bs1(), m_bs1_float(), c()],
|
| 850 |
+
d,
|
| 851 |
+
_sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
|
| 852 |
+
)
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
for pattern, replacement, args, workaround, extra_check in candidates:
|
| 856 |
+
# XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern
|
| 857 |
+
# gets serialized to a python file and does not require tracing at runtime.
|
| 858 |
+
assert isinstance(workaround, dict)
|
| 859 |
+
name = pattern.__name__
|
| 860 |
+
|
| 861 |
+
if dtype != torch.float:
|
| 862 |
+
name += "_half"
|
| 863 |
+
if (
|
| 864 |
+
any(p in name for p in mask_fp32_patterns)
|
| 865 |
+
and args[3].dtype == torch.float32
|
| 866 |
+
):
|
| 867 |
+
name += "_mask_fp32"
|
| 868 |
+
if args[0].size(0) == 1:
|
| 869 |
+
name += "_bs1"
|
| 870 |
+
|
| 871 |
+
training_name = name + "_training"
|
| 872 |
+
yield training_name, {
|
| 873 |
+
"search_fn": pattern,
|
| 874 |
+
"replace_fn": replacement,
|
| 875 |
+
"example_inputs": args,
|
| 876 |
+
"trace_fn": joint_fwd_bwd,
|
| 877 |
+
"pass_dicts": patterns,
|
| 878 |
+
"extra_check": extra_check,
|
| 879 |
+
"scalar_workaround": workaround,
|
| 880 |
+
}
|
| 881 |
+
|
| 882 |
+
if workaround:
|
| 883 |
+
assert len(workaround) == 1 and "dropout_p" in workaround
|
| 884 |
+
# functools.partial insufficient because we look at signature downstream
|
| 885 |
+
pattern = partialize_and_update_signature(pattern, dropout_p=0.0)
|
| 886 |
+
replacement = partialize_and_update_signature(
|
| 887 |
+
replacement, dropout_p=0.0
|
| 888 |
+
)
|
| 889 |
+
workaround = {}
|
| 890 |
+
|
| 891 |
+
inference_name = name + "_inference"
|
| 892 |
+
yield inference_name, {
|
| 893 |
+
"search_fn": pattern,
|
| 894 |
+
"replace_fn": replacement,
|
| 895 |
+
"example_inputs": args,
|
| 896 |
+
"trace_fn": fwd_only,
|
| 897 |
+
"pass_dicts": patterns,
|
| 898 |
+
"extra_check": extra_check,
|
| 899 |
+
"scalar_workaround": workaround,
|
| 900 |
+
# with dropout turned into clone, we end up with a number of
|
| 901 |
+
# semantically identical graphs
|
| 902 |
+
"skip_duplicates": True,
|
| 903 |
+
}
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
@functools.lru_cache(None)
|
| 907 |
+
def _sfdp_init():
|
| 908 |
+
for key, register_replacement_kwargs in _get_sfdp_patterns():
|
| 909 |
+
gen_register_replacement(key, **register_replacement_kwargs)
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py
ADDED
|
@@ -0,0 +1,1317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import collections
|
| 3 |
+
import logging
|
| 4 |
+
import operator
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
from typing import (
|
| 7 |
+
Any,
|
| 8 |
+
DefaultDict,
|
| 9 |
+
Deque,
|
| 10 |
+
Dict,
|
| 11 |
+
Iterable,
|
| 12 |
+
Iterator,
|
| 13 |
+
List,
|
| 14 |
+
Optional,
|
| 15 |
+
Set,
|
| 16 |
+
Tuple,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from torch._dynamo.utils import counters, optimus_scuba_log
|
| 21 |
+
from torch._utils_internal import upload_graph
|
| 22 |
+
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
| 23 |
+
|
| 24 |
+
from .. import config
|
| 25 |
+
from ..pattern_matcher import (
|
| 26 |
+
CallFunctionVarArgs,
|
| 27 |
+
get_arg_value,
|
| 28 |
+
stable_topological_sort,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
# importing this will register fbgemm lowerings for inductor
|
| 34 |
+
import deeplearning.fbgemm.fbgemm_gpu.fb.inductor_lowerings # noqa: F401
|
| 35 |
+
|
| 36 |
+
has_fbgemm = True
|
| 37 |
+
except Exception:
|
| 38 |
+
has_fbgemm = False
|
| 39 |
+
|
| 40 |
+
aten = torch.ops.aten
|
| 41 |
+
|
| 42 |
+
log = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
MIN_FUSE_SET_SIZE = 5
|
| 45 |
+
MAX_FUSE_SET_SIZE = 300
|
| 46 |
+
MAX_FUSE_SEARCH_DEPTH = 5
|
| 47 |
+
# The maximum tensor size that can go into the fusion group
|
| 48 |
+
MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR = 4096
|
| 49 |
+
# Whether we only fuse nodes with same parent node
|
| 50 |
+
FUSE_NODES_WITH_SAME_PARENT = False
|
| 51 |
+
# Whether we enable the add broadcast in batch linear
|
| 52 |
+
SHAPE_BROADCAST_BATCH_LINEAR = False
|
| 53 |
+
# Whether we enable the fuse nodes with same users
|
| 54 |
+
Fuse_NODES_WITH_SAME_USERS = False
|
| 55 |
+
|
| 56 |
+
# exclude these nodes from BFS
|
| 57 |
+
# excluding get item improves optimizer compilation time by 60s
|
| 58 |
+
SEARCH_EXCLUSIONS = {operator.getitem}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
default_graph_search_options = {
|
| 62 |
+
"min_fuse_set_size": MIN_FUSE_SET_SIZE,
|
| 63 |
+
"max_fuse_set_size": MAX_FUSE_SET_SIZE,
|
| 64 |
+
"max_fuse_search_depth": MAX_FUSE_SEARCH_DEPTH,
|
| 65 |
+
"max_fuse_tensor_size_group_linear": MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR,
|
| 66 |
+
"fuse_nodes_with_same_parent": FUSE_NODES_WITH_SAME_PARENT,
|
| 67 |
+
"shape_broadcast_batch_linear": SHAPE_BROADCAST_BATCH_LINEAR,
|
| 68 |
+
"fuse_nodes_with_same_users": Fuse_NODES_WITH_SAME_USERS,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
graph_search_options = default_graph_search_options
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def update_stack_example_value(node, metadata, dim=0, op=torch.stack):
|
| 75 |
+
"""
|
| 76 |
+
Update the example value of the node in the graph to enable followup split cat opt.
|
| 77 |
+
"""
|
| 78 |
+
if node is not None and hasattr(node, "meta"):
|
| 79 |
+
if op == torch.stack:
|
| 80 |
+
example_value = torch.stack(metadata, dim=dim)
|
| 81 |
+
elif op == torch.unbind:
|
| 82 |
+
example_value = torch.unbind(metadata, dim=dim) # type: ignore[assignment]
|
| 83 |
+
else:
|
| 84 |
+
return
|
| 85 |
+
node.meta["example_value"] = example_value
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def update_pointwise_example_value(pointwise_node, input, other, op):
|
| 89 |
+
"""
|
| 90 |
+
Update the example value of the add node in the graph to enable followup split cat opt.
|
| 91 |
+
"""
|
| 92 |
+
if pointwise_node is not None and hasattr(pointwise_node, "meta"):
|
| 93 |
+
if op == torch.add:
|
| 94 |
+
example_value = torch.add(input, other)
|
| 95 |
+
elif op == torch.mul:
|
| 96 |
+
example_value = torch.mul(input, other)
|
| 97 |
+
else:
|
| 98 |
+
return
|
| 99 |
+
pointwise_node.meta["example_value"] = example_value
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class GroupBatchFusionBase:
|
| 103 |
+
def __init__(self, **kwargs) -> None:
|
| 104 |
+
self.graph_search_options = kwargs.pop(
|
| 105 |
+
"graph_search_options", default_graph_search_options
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def match(self, node):
|
| 109 |
+
raise NotImplementedError("match called on base")
|
| 110 |
+
|
| 111 |
+
def fuse(self, graph, subset):
|
| 112 |
+
raise NotImplementedError("fuse called on base")
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
PRE_GRAD_FUSIONS: Dict[str, GroupBatchFusionBase] = {}
|
| 116 |
+
POST_GRAD_FUSIONS: Dict[str, GroupBatchFusionBase] = {}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def register_fusion(name: str, pre_grad=True):
|
| 120 |
+
def decorator(fusion_cls: GroupBatchFusionBase):
|
| 121 |
+
if pre_grad:
|
| 122 |
+
PRE_GRAD_FUSIONS[name] = fusion_cls
|
| 123 |
+
else:
|
| 124 |
+
POST_GRAD_FUSIONS[name] = fusion_cls
|
| 125 |
+
return fusion_cls
|
| 126 |
+
|
| 127 |
+
return decorator
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def list_group_batch_fusions(pre_grad=True) -> List[str]:
|
| 131 |
+
if pre_grad:
|
| 132 |
+
return list(PRE_GRAD_FUSIONS.keys())
|
| 133 |
+
else:
|
| 134 |
+
return list(POST_GRAD_FUSIONS.keys())
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def decompose_stack(graph: torch.fx.GraphModule, input_tensors: List[Any]) -> Any:
|
| 138 |
+
unsqueezed_inputs = []
|
| 139 |
+
unsqueezed_inputs_meta = []
|
| 140 |
+
for input_tensor in input_tensors:
|
| 141 |
+
unsqueezed_input = graph.call_function(
|
| 142 |
+
aten.unsqueeze, args=(input_tensor,), kwargs={"dim": 0}
|
| 143 |
+
)
|
| 144 |
+
unsqueezed_inputs.append(unsqueezed_input)
|
| 145 |
+
unsqueezed_input.meta["val"] = aten.unsqueeze(input_tensor.meta["val"], dim=0) # type: ignore[assignment]
|
| 146 |
+
unsqueezed_inputs_meta.append(unsqueezed_input.meta["val"])
|
| 147 |
+
stacked_inputs = graph.call_function(
|
| 148 |
+
aten.cat, args=(unsqueezed_inputs,), kwargs={"dim": 0}
|
| 149 |
+
)
|
| 150 |
+
stacked_inputs.meta["val"] = aten.cat(unsqueezed_inputs_meta, dim=0) # type: ignore[assignment]
|
| 151 |
+
return stacked_inputs
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class GroupFusion(GroupBatchFusionBase):
|
| 155 |
+
"""
|
| 156 |
+
Fuse ops in a group way, e.g, fuse mm/addmm of arbitrary input shapes with fbgemm.gmm.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class BatchFusion(GroupBatchFusionBase):
|
| 161 |
+
"""
|
| 162 |
+
Fuse ops in a batch way, e.g, fuse mm/addmm of same input shapes with bmm.
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class BatchPointwiseOpsFusionFactory(BatchFusion):
|
| 167 |
+
def __init__(self, op, **kwargs) -> None:
|
| 168 |
+
super().__init__(**kwargs)
|
| 169 |
+
self.op = op
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
@register_fusion("batch_linear_post_grad", pre_grad=False)
|
| 173 |
+
class PostGradBatchLinearFusion(BatchFusion):
|
| 174 |
+
"""
|
| 175 |
+
Fuse ops in a batch way in post grad (aten level).
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
def _addmm_node_can_be_fused(self, node: torch.fx.Node) -> bool:
|
| 179 |
+
# pyre-fixme[7]: Incompatible return type
|
| 180 |
+
return (
|
| 181 |
+
node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0 # type: ignore[return-value]
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def _is_input_2d(self, input: torch.fx.Node) -> bool:
|
| 185 |
+
input_shapes = input.meta["val"].shape
|
| 186 |
+
return (
|
| 187 |
+
len(input_shapes) == 2
|
| 188 |
+
and isinstance(input_shapes[0], int)
|
| 189 |
+
and isinstance(input_shapes[1], int)
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
def match(
|
| 193 |
+
self, node: torch.fx.Node
|
| 194 |
+
) -> Optional[Tuple[str, int, int, int, bool, str]]:
|
| 195 |
+
if CallFunctionVarArgs(aten.mm).match(node):
|
| 196 |
+
input_m, weight_m = node.args
|
| 197 |
+
bias_m = None
|
| 198 |
+
|
| 199 |
+
elif CallFunctionVarArgs(aten.addmm.default).match(
|
| 200 |
+
node
|
| 201 |
+
) and self._addmm_node_can_be_fused(node):
|
| 202 |
+
bias_m, input_m, weight_m = node.args
|
| 203 |
+
else:
|
| 204 |
+
return None
|
| 205 |
+
# get the user of the node
|
| 206 |
+
if self.graph_search_options.get("fuse_nodes_with_same_users", False):
|
| 207 |
+
users = [user.target for user in node.users.keys()]
|
| 208 |
+
else:
|
| 209 |
+
users = "" # type: ignore[assignment]
|
| 210 |
+
# only handle the cases where inputs are 2D tensors
|
| 211 |
+
if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m): # type: ignore[arg-type]
|
| 212 |
+
return None
|
| 213 |
+
m, k = input_m.meta["val"].shape # type: ignore[union-attr]
|
| 214 |
+
n = weight_m.meta["val"].shape[1] # type: ignore[union-attr]
|
| 215 |
+
batch_key = ("batch_linear_post_grad", m, k, n, bias_m is not None, str(users))
|
| 216 |
+
return batch_key
|
| 217 |
+
|
| 218 |
+
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
| 219 |
+
batch_inputs = []
|
| 220 |
+
batch_weights = []
|
| 221 |
+
batch_biases = []
|
| 222 |
+
batch_nodes = []
|
| 223 |
+
batch_inputs_meta = []
|
| 224 |
+
batch_weights_meta = []
|
| 225 |
+
batch_biases_meta = []
|
| 226 |
+
|
| 227 |
+
for node in subset:
|
| 228 |
+
if CallFunctionVarArgs(aten.addmm.default).match(node):
|
| 229 |
+
bias, input, weight = node.args
|
| 230 |
+
elif CallFunctionVarArgs(aten.mm.default).match(node):
|
| 231 |
+
input, weight = node.args
|
| 232 |
+
bias = None
|
| 233 |
+
batch_nodes.append(node)
|
| 234 |
+
batch_inputs.append(input) # type: ignore[possibly-undefined]
|
| 235 |
+
batch_weights.append(weight) # type: ignore[possibly-undefined]
|
| 236 |
+
batch_biases.append(bias) # type: ignore[possibly-undefined]
|
| 237 |
+
batch_inputs_meta.append(input.meta) # type: ignore[possibly-undefined, union-attr]
|
| 238 |
+
batch_weights_meta.append(weight.meta) # type: ignore[possibly-undefined, union-attr]
|
| 239 |
+
if bias is not None: # type: ignore[possibly-undefined]
|
| 240 |
+
batch_biases_meta.append(bias.meta) # type: ignore[possibly-undefined, union-attr]
|
| 241 |
+
else:
|
| 242 |
+
batch_biases_meta.append(None)
|
| 243 |
+
|
| 244 |
+
with graph.inserting_before(subset[-1]):
|
| 245 |
+
fused_inputs = decompose_stack(graph, batch_inputs)
|
| 246 |
+
fused_weights = decompose_stack(graph, batch_weights)
|
| 247 |
+
fused_inputs_meta_val = torch.stack(
|
| 248 |
+
[input["val"] for input in batch_inputs_meta]
|
| 249 |
+
)
|
| 250 |
+
fused_weights_meta_val = torch.stack(
|
| 251 |
+
[weight["val"] for weight in batch_weights_meta]
|
| 252 |
+
)
|
| 253 |
+
fused_bmm = graph.call_function(
|
| 254 |
+
aten.bmm,
|
| 255 |
+
args=(fused_inputs, fused_weights),
|
| 256 |
+
)
|
| 257 |
+
fused_bmm.meta["val"] = aten.bmm(
|
| 258 |
+
fused_inputs_meta_val, fused_weights_meta_val
|
| 259 |
+
)
|
| 260 |
+
for i, original_mm in enumerate(batch_nodes):
|
| 261 |
+
has_bias = False
|
| 262 |
+
with graph.inserting_after(fused_bmm):
|
| 263 |
+
new_mm = graph.call_function(aten.select, args=((fused_bmm, 0, i)))
|
| 264 |
+
new_mm.meta["val"] = aten.select(fused_bmm.meta["val"], 0, i)
|
| 265 |
+
if batch_biases[i]:
|
| 266 |
+
has_bias = True
|
| 267 |
+
# broadcast the bias to the same shape as the mm output
|
| 268 |
+
if self.graph_search_options.get(
|
| 269 |
+
"shape_broadcast_batch_linear", False
|
| 270 |
+
):
|
| 271 |
+
broadcast_shape = torch.broadcast_shapes(
|
| 272 |
+
batch_biases_meta[i]["val"].shape, new_mm.meta["val"].shape
|
| 273 |
+
)
|
| 274 |
+
broadcast_bias = graph.call_function(
|
| 275 |
+
aten.broadcast_to.default,
|
| 276 |
+
args=(batch_biases[i],),
|
| 277 |
+
kwargs={"size": broadcast_shape},
|
| 278 |
+
)
|
| 279 |
+
broadcast_bias.meta["val"] = aten.broadcast_to(batch_biases_meta[i]["val"], broadcast_shape) # type: ignore[assignment]
|
| 280 |
+
new_bias_add = graph.call_function(
|
| 281 |
+
aten.add.Tensor, args=((broadcast_bias, new_mm))
|
| 282 |
+
)
|
| 283 |
+
new_bias_add.meta["val"] = aten.add.Tensor(
|
| 284 |
+
broadcast_bias.meta["val"], new_mm.meta["val"]
|
| 285 |
+
)
|
| 286 |
+
else:
|
| 287 |
+
new_bias_add = graph.call_function(
|
| 288 |
+
aten.add, args=((batch_biases[i], new_mm))
|
| 289 |
+
)
|
| 290 |
+
new_bias_add.meta["val"] = aten.add.Tensor(
|
| 291 |
+
batch_biases_meta[i]["val"], new_mm.meta["val"]
|
| 292 |
+
)
|
| 293 |
+
new_mm_cont = new_bias_add if has_bias else new_mm # type: ignore[possibly-undefined]
|
| 294 |
+
original_mm.replace_all_uses_with(new_mm_cont)
|
| 295 |
+
new_mm_cont.meta.update(original_mm.meta)
|
| 296 |
+
graph.erase_node(original_mm)
|
| 297 |
+
counters["inductor"]["batch_linear_post_grad"] += 1
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@register_fusion("group_linear", pre_grad=False)
|
| 301 |
+
class GroupLinearFusion(GroupFusion):
|
| 302 |
+
def _addmm_node_can_be_fused(self, node: torch.fx.Node):
|
| 303 |
+
input_shape = node.args[1].meta["val"].shape # type: ignore[union-attr]
|
| 304 |
+
weight_shape = node.args[2].meta["val"].shape # type: ignore[union-attr]
|
| 305 |
+
return (
|
| 306 |
+
node.kwargs.get("beta", 1.0) == 1.0
|
| 307 |
+
and node.kwargs.get("alpha", 1.0) == 1.0
|
| 308 |
+
and len(input_shape) == 2
|
| 309 |
+
and len(weight_shape) == 2
|
| 310 |
+
and all(x % 2 == 0 for x in input_shape + weight_shape)
|
| 311 |
+
and all(
|
| 312 |
+
shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"]
|
| 313 |
+
for shape in input_shape + weight_shape
|
| 314 |
+
)
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
def _mm_node_can_be_fused(self, node: torch.fx.Node):
|
| 318 |
+
input_shape = node.args[0].meta["val"].shape # type: ignore[union-attr]
|
| 319 |
+
weight_shape = node.args[1].meta["val"].shape # type: ignore[union-attr]
|
| 320 |
+
return (
|
| 321 |
+
len(input_shape) == 2
|
| 322 |
+
and len(weight_shape) == 2
|
| 323 |
+
and all(x % 2 == 0 for x in input_shape + weight_shape)
|
| 324 |
+
and all(
|
| 325 |
+
shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"]
|
| 326 |
+
for shape in input_shape + weight_shape
|
| 327 |
+
)
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool]]:
|
| 331 |
+
if CallFunctionVarArgs(aten.mm.default).match(
|
| 332 |
+
node
|
| 333 |
+
) and self._mm_node_can_be_fused(node):
|
| 334 |
+
group_key = ("group_linear", True)
|
| 335 |
+
elif CallFunctionVarArgs(aten.addmm.default).match(
|
| 336 |
+
node
|
| 337 |
+
) and self._addmm_node_can_be_fused(node):
|
| 338 |
+
bias = node.args[0]
|
| 339 |
+
group_key = ("group_linear", bias is None)
|
| 340 |
+
else:
|
| 341 |
+
group_key = None
|
| 342 |
+
return group_key
|
| 343 |
+
|
| 344 |
+
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
| 345 |
+
group_inputs = []
|
| 346 |
+
group_weights = []
|
| 347 |
+
group_biases = []
|
| 348 |
+
group_nodes = []
|
| 349 |
+
for node in subset:
|
| 350 |
+
if CallFunctionVarArgs(aten.addmm.default).match(node):
|
| 351 |
+
bias, input, weight = node.args
|
| 352 |
+
else:
|
| 353 |
+
assert CallFunctionVarArgs(aten.mm.default).match(node)
|
| 354 |
+
input, weight = node.args
|
| 355 |
+
bias = None
|
| 356 |
+
|
| 357 |
+
group_nodes.append(node)
|
| 358 |
+
group_inputs.append(input)
|
| 359 |
+
group_weights.append(weight)
|
| 360 |
+
group_biases.append(bias)
|
| 361 |
+
|
| 362 |
+
if all(bias is None for bias in group_biases):
|
| 363 |
+
group_biases = None # type: ignore[assignment]
|
| 364 |
+
|
| 365 |
+
with graph.inserting_before(subset[0]):
|
| 366 |
+
fused_mm = graph.call_function(
|
| 367 |
+
torch.ops.fbgemm.gmm.default,
|
| 368 |
+
args=(group_inputs, group_weights, group_biases),
|
| 369 |
+
kwargs={"smart_fused": True},
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
for i, original_mm in enumerate(group_nodes):
|
| 373 |
+
with graph.inserting_after(fused_mm):
|
| 374 |
+
new_mm = graph.call_function(operator.getitem, args=(fused_mm, i))
|
| 375 |
+
original_mm.replace_all_uses_with(new_mm)
|
| 376 |
+
new_mm.meta.update(original_mm.meta)
|
| 377 |
+
graph.erase_node(original_mm)
|
| 378 |
+
counters["inductor"]["group_linear"] += 1
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class BatchPointwiseMathOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
|
| 382 |
+
"""
|
| 383 |
+
Batch pointwise math operator (e.g., add, mul) in post grad pass.
|
| 384 |
+
"""
|
| 385 |
+
|
| 386 |
+
def __init__(self, op, **kwargs) -> None:
|
| 387 |
+
super().__init__(op, **kwargs)
|
| 388 |
+
self.op = op
|
| 389 |
+
|
| 390 |
+
def _pointwise_node_can_be_fused(self, node: torch.fx.Node):
|
| 391 |
+
# note: we only consider the case where the inputs are tensors
|
| 392 |
+
# for mixed precision training, we need to make sure the inputs
|
| 393 |
+
# of the aten.cat when do the stack should be the same dtype
|
| 394 |
+
# otherwise, the output of the aten.cat may be not the same as
|
| 395 |
+
# its inputs, and cause dtype not same error in mm or addmm
|
| 396 |
+
input, other = node.args
|
| 397 |
+
return (
|
| 398 |
+
input.meta["val"].shape == other.meta["val"].shape # type: ignore[union-attr]
|
| 399 |
+
if hasattr(input, "meta")
|
| 400 |
+
and hasattr(other, "meta")
|
| 401 |
+
and "val" in input.meta # type: ignore[union-attr]
|
| 402 |
+
and "val" in other.meta # type: ignore[union-attr]
|
| 403 |
+
else False
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
def match(self, node: torch.fx.Node):
|
| 407 |
+
if CallFunctionVarArgs(self.op).match(
|
| 408 |
+
node
|
| 409 |
+
) and self._pointwise_node_can_be_fused(node):
|
| 410 |
+
alpha = node.kwargs.get("alpha", 1.0)
|
| 411 |
+
rounding_mode = node.kwargs.get("rounding_mode", None)
|
| 412 |
+
input, other = node.args
|
| 413 |
+
shape = list(input.meta["val"].shape) # type: ignore[union-attr]
|
| 414 |
+
if self.graph_search_options.get("fuse_nodes_with_same_parent", False):
|
| 415 |
+
# only consider the linear case so far
|
| 416 |
+
# pyre-fixme[16]
|
| 417 |
+
if input.target == aten.select or other.target == aten.select: # type: ignore[union-attr]
|
| 418 |
+
parent = (
|
| 419 |
+
# pyre-fixme[16]
|
| 420 |
+
input.args[0] # type: ignore[union-attr]
|
| 421 |
+
# pyre-fixme[16]
|
| 422 |
+
if input.target == aten.select # type: ignore[union-attr]
|
| 423 |
+
else other.args[0] # type: ignore[union-attr]
|
| 424 |
+
)
|
| 425 |
+
else:
|
| 426 |
+
parent = ""
|
| 427 |
+
else:
|
| 428 |
+
parent = ""
|
| 429 |
+
group_key = (
|
| 430 |
+
"batch_aten_" + self.op.__name__.lower().split(".")[0],
|
| 431 |
+
str(shape),
|
| 432 |
+
str(input.meta["val"].dtype), # type: ignore[union-attr]
|
| 433 |
+
str(other.meta["val"].dtype), # type: ignore[union-attr]
|
| 434 |
+
str(alpha),
|
| 435 |
+
str(rounding_mode),
|
| 436 |
+
str(parent),
|
| 437 |
+
)
|
| 438 |
+
else:
|
| 439 |
+
group_key = None
|
| 440 |
+
return group_key
|
| 441 |
+
|
| 442 |
+
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
| 443 |
+
batch_inputs, batch_others = [], []
|
| 444 |
+
alpha = subset[0].kwargs.get("alpha", 1.0)
|
| 445 |
+
batch_inputs_meta, batch_others_meta = [], []
|
| 446 |
+
|
| 447 |
+
for node in subset:
|
| 448 |
+
input, other = node.args
|
| 449 |
+
batch_inputs.append(input)
|
| 450 |
+
batch_others.append(other)
|
| 451 |
+
batch_inputs_meta.append(input.meta) # type: ignore[possibly-undefined, union-attr]
|
| 452 |
+
batch_others_meta.append(other.meta) # type: ignore[possibly-undefined, union-attr]
|
| 453 |
+
|
| 454 |
+
with graph.inserting_before(subset[0]):
|
| 455 |
+
stack_inputs = decompose_stack(graph, batch_inputs)
|
| 456 |
+
stack_others = decompose_stack(graph, batch_others)
|
| 457 |
+
stack_inputs_meta = torch.stack(
|
| 458 |
+
[input["val"] for input in batch_inputs_meta]
|
| 459 |
+
)
|
| 460 |
+
stack_others_meta = torch.stack(
|
| 461 |
+
[other["val"] for other in batch_others_meta]
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
batch_op = graph.call_function(
|
| 465 |
+
self.op,
|
| 466 |
+
args=(stack_inputs, stack_others),
|
| 467 |
+
kwargs={"alpha": alpha} if self.op == aten.add.Tensor else {},
|
| 468 |
+
)
|
| 469 |
+
batch_op.meta["val"] = self.op(stack_inputs_meta, stack_others_meta)
|
| 470 |
+
for i, original_add in enumerate(subset):
|
| 471 |
+
with graph.inserting_after(batch_op):
|
| 472 |
+
new_add = graph.call_function(
|
| 473 |
+
torch.ops.aten.select, args=((batch_op, 0, i))
|
| 474 |
+
)
|
| 475 |
+
original_add.replace_all_uses_with(new_add)
|
| 476 |
+
new_add.meta.update(original_add.meta)
|
| 477 |
+
graph.erase_node(original_add)
|
| 478 |
+
counters["inductor"][
|
| 479 |
+
"batch_aten_" + self.op.__name__.lower().split(".")[0]
|
| 480 |
+
] += 1
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
@register_fusion("batch_linear_lhs")
|
| 484 |
+
class BatchLinearLHSFusion(BatchFusion):
|
| 485 |
+
"""
|
| 486 |
+
Batch linear left-hand side fusion. This pass tries to fuse the following patterns:
|
| 487 |
+
|
| 488 |
+
torch.nn.functional.linear(x, w1), linear(x, w2),... * linear(x, wn)
|
| 489 |
+
-> torch.mm(x, torch.cat([w1, w2,... * wn]).transpose(0, 1))
|
| 490 |
+
|
| 491 |
+
We have a separate pass to eliminate contiguous transpose in a generic way.
|
| 492 |
+
"""
|
| 493 |
+
|
| 494 |
+
def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool, Any]]:
|
| 495 |
+
if CallFunctionVarArgs(torch.nn.functional.linear).match(
|
| 496 |
+
node
|
| 497 |
+
) and is_linear_node_can_be_fused(node):
|
| 498 |
+
input = get_arg_value(node, 0, "input")
|
| 499 |
+
bias = get_arg_value(node, 2, "bias")
|
| 500 |
+
group_key = ("batch_linear_lhs", bias is None, input)
|
| 501 |
+
else:
|
| 502 |
+
group_key = None
|
| 503 |
+
return group_key
|
| 504 |
+
|
| 505 |
+
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
| 506 |
+
batch_nodes = []
|
| 507 |
+
batch_input = None
|
| 508 |
+
batch_weights, batch_weights_meta = [], []
|
| 509 |
+
batch_biases, batch_biases_meta = [], []
|
| 510 |
+
split_sections = []
|
| 511 |
+
for node in subset:
|
| 512 |
+
input = get_arg_value(node, 0, "input")
|
| 513 |
+
weight = get_arg_value(node, 1, "weight")
|
| 514 |
+
bias = get_arg_value(node, 2, "bias")
|
| 515 |
+
batch_nodes.append(node)
|
| 516 |
+
if batch_input is None:
|
| 517 |
+
batch_input = input
|
| 518 |
+
else:
|
| 519 |
+
assert batch_input is input
|
| 520 |
+
batch_weights.append(weight)
|
| 521 |
+
batch_weights_meta.append(weight.meta["example_value"])
|
| 522 |
+
if bias:
|
| 523 |
+
batch_biases.append(bias)
|
| 524 |
+
batch_biases_meta.append(bias.meta["example_value"])
|
| 525 |
+
split_sections.append(weight.meta["example_value"].shape[0])
|
| 526 |
+
|
| 527 |
+
with graph.inserting_before(subset[0]):
|
| 528 |
+
cat_weights = graph.call_function(
|
| 529 |
+
torch.cat, args=(batch_weights,), kwargs={"dim": 0}
|
| 530 |
+
)
|
| 531 |
+
cat_weights.meta["example_value"] = torch.cat(batch_weights_meta, dim=0)
|
| 532 |
+
transposed_weights = graph.call_function(
|
| 533 |
+
torch.transpose, args=(cat_weights, 0, 1)
|
| 534 |
+
)
|
| 535 |
+
transposed_weights.meta["example_value"] = torch.transpose(
|
| 536 |
+
cat_weights.meta["example_value"], 0, 1
|
| 537 |
+
)
|
| 538 |
+
if len(batch_biases) > 0:
|
| 539 |
+
cat_biases = graph.call_function(
|
| 540 |
+
torch.cat, args=(batch_biases,), kwargs={"dim": 0}
|
| 541 |
+
)
|
| 542 |
+
cat_biases.meta["example_value"] = torch.cat(batch_biases_meta, dim=0)
|
| 543 |
+
fused_lhs = graph.call_function(
|
| 544 |
+
torch.addmm,
|
| 545 |
+
args=(cat_biases, batch_input, transposed_weights),
|
| 546 |
+
)
|
| 547 |
+
fused_lhs.meta["example_value"] = torch.addmm(
|
| 548 |
+
cat_biases.meta["example_value"],
|
| 549 |
+
batch_input.meta["example_value"], # type: ignore[union-attr]
|
| 550 |
+
transposed_weights.meta["example_value"],
|
| 551 |
+
)
|
| 552 |
+
else:
|
| 553 |
+
fused_lhs = graph.call_function(
|
| 554 |
+
torch.mm,
|
| 555 |
+
args=(batch_input, transposed_weights),
|
| 556 |
+
)
|
| 557 |
+
fused_lhs.meta["example_value"] = torch.mm(
|
| 558 |
+
batch_input.meta["example_value"], # type: ignore[union-attr]
|
| 559 |
+
transposed_weights.meta["example_value"],
|
| 560 |
+
)
|
| 561 |
+
fused_lhs_list = graph.call_function(
|
| 562 |
+
torch.split, args=(fused_lhs, split_sections), kwargs={"dim": 1}
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
for i, node in enumerate(batch_nodes):
|
| 566 |
+
with graph.inserting_after(fused_lhs_list):
|
| 567 |
+
new_node = graph.call_function(
|
| 568 |
+
operator.getitem, args=(fused_lhs_list, i)
|
| 569 |
+
)
|
| 570 |
+
node.replace_all_uses_with(new_node)
|
| 571 |
+
new_node.meta.update(node.meta)
|
| 572 |
+
graph.erase_node(node)
|
| 573 |
+
counters["inductor"]["batch_linear_lhs"] += 1
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
def is_node_meta_valid(node: Optional[torch.fx.Node]):
|
| 577 |
+
return node is None or "example_value" in node.meta or "val" in node.meta
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
# Poor person's check for if a node in the graph mutates its input.
|
| 581 |
+
# (the graph is torch IR, so we will see torch fns and python operators)
|
| 582 |
+
def _is_mutable_node(tgt):
|
| 583 |
+
if str(tgt).endswith("_"):
|
| 584 |
+
# e.g. torch.mul_, torch.Tensor.mul_
|
| 585 |
+
return True
|
| 586 |
+
if (
|
| 587 |
+
hasattr(tgt, "__module__")
|
| 588 |
+
and tgt.__module__ == "_operator"
|
| 589 |
+
and tgt.__name__.startswith("i")
|
| 590 |
+
):
|
| 591 |
+
# e.g. operator.iand, operator.imul
|
| 592 |
+
return True
|
| 593 |
+
return False
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def is_linear_node_can_be_fused(node: torch.fx.Node):
|
| 597 |
+
input = get_arg_value(node, 0, "input")
|
| 598 |
+
weight = get_arg_value(node, 1, "weight")
|
| 599 |
+
return (
|
| 600 |
+
is_node_meta_valid(node)
|
| 601 |
+
and is_node_meta_valid(input)
|
| 602 |
+
and is_node_meta_valid(weight)
|
| 603 |
+
and len(input.meta["example_value"].shape) == 2
|
| 604 |
+
and len(weight.meta["example_value"].shape) == 2
|
| 605 |
+
# the mm -> bmm transform adds an unbind() op,
|
| 606 |
+
# which is not safe for autograd when the output of the mm is mutated.
|
| 607 |
+
# don't pattern match if any users of the mm mutate the input.
|
| 608 |
+
and not any(_is_mutable_node(user.target) for user in node.users)
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
@register_fusion("batch_linear")
|
| 613 |
+
class PreGradBatchLinearFusion(BatchFusion):
|
| 614 |
+
"""
|
| 615 |
+
Batch linear fusion in pre grad pass.
|
| 616 |
+
Fuse linear with same size with torch.baddmm
|
| 617 |
+
"""
|
| 618 |
+
|
| 619 |
+
def _getitem_args(self, getitem_node: torch.fx.Node):
|
| 620 |
+
if getitem_node.target != operator.__getitem__ or (
|
| 621 |
+
getitem_node.op != "call_function"
|
| 622 |
+
):
|
| 623 |
+
return None
|
| 624 |
+
return getitem_node.args[0]
|
| 625 |
+
|
| 626 |
+
def match(self, node: torch.fx.Node):
|
| 627 |
+
if CallFunctionVarArgs(torch.nn.functional.linear).match(
|
| 628 |
+
node
|
| 629 |
+
) and is_linear_node_can_be_fused(node):
|
| 630 |
+
input = get_arg_value(node, 0, "input")
|
| 631 |
+
weight = get_arg_value(node, 1, "weight")
|
| 632 |
+
bias = get_arg_value(node, 2, "bias")
|
| 633 |
+
if self.graph_search_options.get("fuse_nodes_with_same_users", False):
|
| 634 |
+
users = [user.target for user in node.users.keys()]
|
| 635 |
+
else:
|
| 636 |
+
users = "" # type: ignore[assignment]
|
| 637 |
+
group_key = (
|
| 638 |
+
"batch_linear",
|
| 639 |
+
self._getitem_args(input),
|
| 640 |
+
str(input.meta["example_value"].shape),
|
| 641 |
+
str(weight.meta["example_value"].shape),
|
| 642 |
+
bias is None,
|
| 643 |
+
str(users),
|
| 644 |
+
)
|
| 645 |
+
else:
|
| 646 |
+
group_key = None
|
| 647 |
+
return group_key
|
| 648 |
+
|
| 649 |
+
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
| 650 |
+
batch_nodes = []
|
| 651 |
+
batch_inputs = []
|
| 652 |
+
batch_weights = []
|
| 653 |
+
batch_biases = []
|
| 654 |
+
batch_inputs_metadata = []
|
| 655 |
+
batch_weights_metadata = []
|
| 656 |
+
batch_biases_metadata = []
|
| 657 |
+
for node in subset:
|
| 658 |
+
batch_nodes.append(node)
|
| 659 |
+
input = get_arg_value(node, 0, "input")
|
| 660 |
+
batch_inputs.append(input)
|
| 661 |
+
batch_inputs_metadata.append(input.meta["example_value"])
|
| 662 |
+
weight = get_arg_value(node, 1, "weight")
|
| 663 |
+
batch_weights.append(weight)
|
| 664 |
+
batch_weights_metadata.append(weight.meta["example_value"])
|
| 665 |
+
bias = get_arg_value(node, 2, "bias")
|
| 666 |
+
batch_biases.append(bias)
|
| 667 |
+
if bias is not None and hasattr(bias, "meta"):
|
| 668 |
+
batch_biases_metadata.append(bias.meta["example_value"])
|
| 669 |
+
|
| 670 |
+
with graph.inserting_before(subset[0]):
|
| 671 |
+
stack_inputs = graph.call_function(
|
| 672 |
+
torch.stack, args=(batch_inputs,), kwargs={"dim": 0}
|
| 673 |
+
)
|
| 674 |
+
update_stack_example_value(stack_inputs, batch_inputs_metadata)
|
| 675 |
+
stack_weights = graph.call_function(
|
| 676 |
+
torch.stack, args=(batch_weights,), kwargs={"dim": 0}
|
| 677 |
+
)
|
| 678 |
+
update_stack_example_value(stack_weights, batch_weights_metadata)
|
| 679 |
+
transpose_weight = graph.call_function(
|
| 680 |
+
torch.transpose, args=(stack_weights, 1, 2)
|
| 681 |
+
)
|
| 682 |
+
transpose_weight.meta["example_value"] = torch.transpose(
|
| 683 |
+
stack_weights.meta["example_value"], 1, 2
|
| 684 |
+
)
|
| 685 |
+
if all(bias is None for bias in batch_biases):
|
| 686 |
+
bmm = graph.call_function(
|
| 687 |
+
torch.bmm,
|
| 688 |
+
args=(stack_inputs, transpose_weight),
|
| 689 |
+
)
|
| 690 |
+
bmm.meta["example_value"] = torch.bmm(
|
| 691 |
+
stack_inputs.meta["example_value"],
|
| 692 |
+
transpose_weight.meta["example_value"],
|
| 693 |
+
)
|
| 694 |
+
bmm_meta = bmm.meta["example_value"]
|
| 695 |
+
else:
|
| 696 |
+
stack_biases = graph.call_function(
|
| 697 |
+
torch.stack, args=(batch_biases,), kwargs={"dim": 0}
|
| 698 |
+
)
|
| 699 |
+
update_stack_example_value(stack_biases, batch_biases_metadata)
|
| 700 |
+
unsqueeze_biases = graph.call_function(
|
| 701 |
+
torch.unsqueeze, args=(stack_biases, 1)
|
| 702 |
+
)
|
| 703 |
+
unsqueeze_biases.meta["example_value"] = torch.unsqueeze(
|
| 704 |
+
stack_biases.meta["example_value"], 1
|
| 705 |
+
)
|
| 706 |
+
bmm = graph.call_function(
|
| 707 |
+
torch.baddbmm,
|
| 708 |
+
args=(unsqueeze_biases, stack_inputs, transpose_weight),
|
| 709 |
+
)
|
| 710 |
+
try:
|
| 711 |
+
# it will have runtime error to broadcast when it has dynamic shape included
|
| 712 |
+
# in the meta data, so we need to skip the update meta data
|
| 713 |
+
bmm.meta["example_value"] = torch.baddbmm(
|
| 714 |
+
unsqueeze_biases.meta["example_value"],
|
| 715 |
+
stack_inputs.meta["example_value"],
|
| 716 |
+
transpose_weight.meta["example_value"],
|
| 717 |
+
)
|
| 718 |
+
bmm_meta = bmm.meta["example_value"]
|
| 719 |
+
except Exception as e:
|
| 720 |
+
log.debug(
|
| 721 |
+
f" exception when update bmm meta data with stack error tracekey {e}" # noqa: G004
|
| 722 |
+
)
|
| 723 |
+
bmm_meta = None
|
| 724 |
+
|
| 725 |
+
bmm = graph.call_function(torch.unbind, args=(bmm,), kwargs={"dim": 0})
|
| 726 |
+
if bmm_meta is not None:
|
| 727 |
+
bmm.meta["example_value"] = torch.unbind(bmm_meta, dim=0)
|
| 728 |
+
for i, linear in enumerate(batch_nodes):
|
| 729 |
+
with graph.inserting_after(bmm):
|
| 730 |
+
getitem = graph.call_function(operator.getitem, args=(bmm, i))
|
| 731 |
+
linear.replace_all_uses_with(getitem)
|
| 732 |
+
getitem.meta.update(linear.meta)
|
| 733 |
+
graph.erase_node(linear)
|
| 734 |
+
counters["inductor"]["batch_linear"] += 1
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
@register_fusion("batch_layernorm")
|
| 738 |
+
class BatchLayernormFusion(BatchFusion):
|
| 739 |
+
"""
|
| 740 |
+
Batch layer norm fusion in pre grad pass
|
| 741 |
+
"""
|
| 742 |
+
|
| 743 |
+
def match(self, node: torch.fx.Node):
|
| 744 |
+
if CallFunctionVarArgs(torch.nn.functional.layer_norm).match(node):
|
| 745 |
+
input = get_arg_value(node, 0, "input")
|
| 746 |
+
weight = get_arg_value(node, 2, "weight")
|
| 747 |
+
bias = get_arg_value(node, 3, "bias")
|
| 748 |
+
if self.graph_search_options.get("fuse_nodes_with_same_users", False):
|
| 749 |
+
users = [user.target for user in node.users.keys()]
|
| 750 |
+
else:
|
| 751 |
+
users = "" # type: ignore[assignment]
|
| 752 |
+
group_key = (
|
| 753 |
+
(
|
| 754 |
+
"batch_layernorm",
|
| 755 |
+
str(input.meta["example_value"].shape),
|
| 756 |
+
str(weight.meta["example_value"].shape)
|
| 757 |
+
if weight is not None
|
| 758 |
+
else "",
|
| 759 |
+
str(bias.meta["example_value"].shape) if bias is not None else "",
|
| 760 |
+
str(get_arg_value(node, 1, "normalized_shape")),
|
| 761 |
+
str(get_arg_value(node, 4, "eps")),
|
| 762 |
+
str(users),
|
| 763 |
+
)
|
| 764 |
+
if "example_value" in input.meta
|
| 765 |
+
and is_node_meta_valid(weight)
|
| 766 |
+
and is_node_meta_valid(bias)
|
| 767 |
+
else None
|
| 768 |
+
)
|
| 769 |
+
else:
|
| 770 |
+
group_key = None
|
| 771 |
+
return group_key
|
| 772 |
+
|
| 773 |
+
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
| 774 |
+
group_inputs = []
|
| 775 |
+
group_shapes = []
|
| 776 |
+
group_weights = []
|
| 777 |
+
group_biases = []
|
| 778 |
+
group_epss = []
|
| 779 |
+
group_nodes = []
|
| 780 |
+
group_inputs_metadata = []
|
| 781 |
+
group_biases_metadata = []
|
| 782 |
+
group_weights_metadata = []
|
| 783 |
+
for node in subset:
|
| 784 |
+
group_nodes.append(node)
|
| 785 |
+
input = get_arg_value(node, 0, "input")
|
| 786 |
+
group_inputs.append(input)
|
| 787 |
+
group_inputs_metadata.append(input.meta["example_value"])
|
| 788 |
+
group_shapes.append(get_arg_value(node, 1, "normalized_shape"))
|
| 789 |
+
weight = get_arg_value(node, 2, "weight")
|
| 790 |
+
group_weights.append(weight)
|
| 791 |
+
if weight is not None and hasattr(weight, "meta"):
|
| 792 |
+
group_weights_metadata.append(weight.meta["example_value"])
|
| 793 |
+
bias = get_arg_value(node, 3, "bias")
|
| 794 |
+
group_biases.append(bias)
|
| 795 |
+
if bias is not None and hasattr(bias, "meta"):
|
| 796 |
+
group_biases_metadata.append(bias.meta["example_value"])
|
| 797 |
+
eps = get_arg_value(node, 4, "eps")
|
| 798 |
+
if eps is None:
|
| 799 |
+
eps = 1e-5
|
| 800 |
+
group_epss.append(eps)
|
| 801 |
+
stack_dim = -1 - len(group_shapes[-1])
|
| 802 |
+
|
| 803 |
+
if all(bias is None for bias in group_biases):
|
| 804 |
+
group_biases = None # type: ignore[assignment]
|
| 805 |
+
if all(weight is None for weight in group_weights):
|
| 806 |
+
group_weights = None # type: ignore[assignment]
|
| 807 |
+
assert all(
|
| 808 |
+
eps == group_epss[0] for eps in group_epss
|
| 809 |
+
), "all epsilon values must be equal"
|
| 810 |
+
|
| 811 |
+
with graph.inserting_before(subset[0]):
|
| 812 |
+
stack_input = graph.call_function(
|
| 813 |
+
torch.stack, args=(group_inputs,), kwargs={"dim": stack_dim}
|
| 814 |
+
)
|
| 815 |
+
update_stack_example_value(stack_input, group_inputs_metadata, stack_dim)
|
| 816 |
+
if group_weights is not None:
|
| 817 |
+
stack_weight = graph.call_function(
|
| 818 |
+
torch.stack, args=(group_weights,), kwargs={"dim": 0}
|
| 819 |
+
)
|
| 820 |
+
update_stack_example_value(stack_weight, group_weights_metadata)
|
| 821 |
+
else:
|
| 822 |
+
stack_weight = None
|
| 823 |
+
if group_biases is not None:
|
| 824 |
+
stack_bias = graph.call_function(
|
| 825 |
+
torch.stack, args=(group_biases,), kwargs={"dim": 0}
|
| 826 |
+
)
|
| 827 |
+
update_stack_example_value(stack_bias, group_biases_metadata)
|
| 828 |
+
else:
|
| 829 |
+
stack_bias = None
|
| 830 |
+
|
| 831 |
+
batch_layer_norm = graph.call_function(
|
| 832 |
+
torch.nn.functional.layer_norm,
|
| 833 |
+
args=(stack_input, group_shapes[-1]),
|
| 834 |
+
kwargs={"eps": group_epss[-1]},
|
| 835 |
+
)
|
| 836 |
+
batch_layer_norm.meta["example_value"] = stack_input.meta["example_value"]
|
| 837 |
+
|
| 838 |
+
if group_weights is not None and group_biases is not None:
|
| 839 |
+
previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
|
| 840 |
+
batch_layer_norm = graph.call_function(
|
| 841 |
+
torch.mul, args=(stack_weight, batch_layer_norm)
|
| 842 |
+
)
|
| 843 |
+
update_pointwise_example_value(
|
| 844 |
+
batch_layer_norm,
|
| 845 |
+
stack_weight.meta["example_value"],
|
| 846 |
+
previous_batch_layer_norm_meta,
|
| 847 |
+
torch.mul,
|
| 848 |
+
)
|
| 849 |
+
previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
|
| 850 |
+
batch_layer_norm = graph.call_function(
|
| 851 |
+
torch.add, args=(stack_bias, batch_layer_norm)
|
| 852 |
+
)
|
| 853 |
+
update_pointwise_example_value(
|
| 854 |
+
batch_layer_norm,
|
| 855 |
+
stack_bias.meta["example_value"],
|
| 856 |
+
previous_batch_layer_norm_meta,
|
| 857 |
+
torch.add,
|
| 858 |
+
)
|
| 859 |
+
elif group_weights is not None and group_biases is None:
|
| 860 |
+
previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
|
| 861 |
+
batch_layer_norm = graph.call_function(
|
| 862 |
+
torch.mul, args=(stack_weight, batch_layer_norm)
|
| 863 |
+
)
|
| 864 |
+
update_pointwise_example_value(
|
| 865 |
+
batch_layer_norm,
|
| 866 |
+
stack_weight.meta["example_value"],
|
| 867 |
+
previous_batch_layer_norm_meta,
|
| 868 |
+
torch.mul,
|
| 869 |
+
)
|
| 870 |
+
elif group_weights is None and group_biases is not None:
|
| 871 |
+
previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
|
| 872 |
+
batch_layer_norm = graph.call_function(
|
| 873 |
+
torch.add, args=(stack_bias, batch_layer_norm)
|
| 874 |
+
)
|
| 875 |
+
update_pointwise_example_value(
|
| 876 |
+
batch_layer_norm,
|
| 877 |
+
stack_bias.meta["example_value"],
|
| 878 |
+
previous_batch_layer_norm_meta,
|
| 879 |
+
torch.add,
|
| 880 |
+
)
|
| 881 |
+
|
| 882 |
+
batch_layer_norm_unbind = graph.call_function(
|
| 883 |
+
torch.unbind,
|
| 884 |
+
args=(batch_layer_norm,),
|
| 885 |
+
kwargs={"dim": stack_dim},
|
| 886 |
+
)
|
| 887 |
+
update_stack_example_value(
|
| 888 |
+
batch_layer_norm_unbind,
|
| 889 |
+
batch_layer_norm.meta["example_value"],
|
| 890 |
+
op=torch.unbind,
|
| 891 |
+
dim=stack_dim,
|
| 892 |
+
)
|
| 893 |
+
|
| 894 |
+
for i, node in enumerate(group_nodes):
|
| 895 |
+
with graph.inserting_after(batch_layer_norm_unbind):
|
| 896 |
+
new_node = graph.call_function(
|
| 897 |
+
operator.getitem, args=(batch_layer_norm_unbind, i)
|
| 898 |
+
)
|
| 899 |
+
node.replace_all_uses_with(new_node)
|
| 900 |
+
new_node.meta.update(node.meta)
|
| 901 |
+
graph.erase_node(node)
|
| 902 |
+
counters["inductor"]["batch_layernorm"] += 1
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
class BatchPointwiseOpsPreGradFusion(BatchPointwiseOpsFusionFactory):
|
| 906 |
+
"""
|
| 907 |
+
Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in pre grad pass.
|
| 908 |
+
We fuse it in random place, and the introduced stack node may be merged in split cat.
|
| 909 |
+
"""
|
| 910 |
+
|
| 911 |
+
def __init__(self, op, **kwargs) -> None:
|
| 912 |
+
super().__init__(op, **kwargs)
|
| 913 |
+
self.op = op
|
| 914 |
+
|
| 915 |
+
def match(self, node: torch.fx.Node):
|
| 916 |
+
input = get_arg_value(node, 0, "input")
|
| 917 |
+
if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node):
|
| 918 |
+
if self.graph_search_options.get("fuse_nodes_with_same_parent", False):
|
| 919 |
+
# pyre-fixme[16]
|
| 920 |
+
parent = node.args[0]
|
| 921 |
+
parent = parent.target if parent is not None else "" # type: ignore[union-attr]
|
| 922 |
+
else:
|
| 923 |
+
parent = ""
|
| 924 |
+
# for relu op, we also use the inplace to construct the key
|
| 925 |
+
group_key = (
|
| 926 |
+
"batch_" + self.op.__name__.lower().split(".")[0],
|
| 927 |
+
str(input.meta["example_value"].shape),
|
| 928 |
+
str(node.kwargs.get("inplace", False)),
|
| 929 |
+
str(parent),
|
| 930 |
+
)
|
| 931 |
+
else:
|
| 932 |
+
group_key = None
|
| 933 |
+
return group_key
|
| 934 |
+
|
| 935 |
+
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
| 936 |
+
batch_nodes = []
|
| 937 |
+
batch_inputs = []
|
| 938 |
+
batch_inputs_metadata = []
|
| 939 |
+
|
| 940 |
+
for node in subset:
|
| 941 |
+
batch_nodes.append(node)
|
| 942 |
+
input = get_arg_value(node, 0, "input")
|
| 943 |
+
batch_inputs.append(input)
|
| 944 |
+
batch_inputs_metadata.append(input.meta["example_value"])
|
| 945 |
+
|
| 946 |
+
with graph.inserting_before(subset[0]):
|
| 947 |
+
stack_inputs = graph.call_function(
|
| 948 |
+
torch.stack, args=(batch_inputs,), kwargs={"dim": 0}
|
| 949 |
+
)
|
| 950 |
+
update_stack_example_value(stack_inputs, batch_inputs_metadata)
|
| 951 |
+
if self.op == torch.nn.functional.relu:
|
| 952 |
+
batch_op = graph.call_function(
|
| 953 |
+
self.op,
|
| 954 |
+
args=(stack_inputs,),
|
| 955 |
+
kwargs={"inplace": subset[0].kwargs.get("inplace", False)},
|
| 956 |
+
)
|
| 957 |
+
batch_op.meta["example_value"] = self.op(
|
| 958 |
+
stack_inputs.meta["example_value"],
|
| 959 |
+
inplace=subset[0].kwargs.get("inplace", False),
|
| 960 |
+
)
|
| 961 |
+
else:
|
| 962 |
+
batch_op = graph.call_function(
|
| 963 |
+
self.op,
|
| 964 |
+
args=(stack_inputs,),
|
| 965 |
+
)
|
| 966 |
+
batch_op.meta["example_value"] = self.op(
|
| 967 |
+
stack_inputs.meta["example_value"]
|
| 968 |
+
)
|
| 969 |
+
unbind_op = graph.call_function(
|
| 970 |
+
torch.unbind, args=(batch_op,), kwargs={"dim": 0}
|
| 971 |
+
)
|
| 972 |
+
unbind_op.meta["example_value"] = torch.unbind(
|
| 973 |
+
batch_op.meta["example_value"], dim=0
|
| 974 |
+
)
|
| 975 |
+
for i, node in enumerate(batch_nodes):
|
| 976 |
+
with graph.inserting_after(unbind_op):
|
| 977 |
+
getitem = graph.call_function(operator.getitem, args=(unbind_op, i))
|
| 978 |
+
node.replace_all_uses_with(getitem)
|
| 979 |
+
getitem.meta.update(node.meta)
|
| 980 |
+
graph.erase_node(node)
|
| 981 |
+
counters["inductor"]["batch_" + self.op.__name__.lower().split(".")[0]] += 1
|
| 982 |
+
|
| 983 |
+
|
| 984 |
+
class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
|
| 985 |
+
"""
|
| 986 |
+
Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in post grad pass.
|
| 987 |
+
The introduced stack node may be merged in split cat.
|
| 988 |
+
"""
|
| 989 |
+
|
| 990 |
+
def __init__(self, op, **kwargs) -> None:
|
| 991 |
+
super().__init__(op, **kwargs)
|
| 992 |
+
self.op = op
|
| 993 |
+
|
| 994 |
+
def match(self, node: torch.fx.Node):
|
| 995 |
+
input = get_arg_value(node, 0, "input")
|
| 996 |
+
if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node):
|
| 997 |
+
# for relu op, we also use the inplace to construct the key
|
| 998 |
+
# we batch the ops with same parent to enable followup split cat
|
| 999 |
+
parent = node.args[0]
|
| 1000 |
+
parent = parent.target if self.graph_search_options.get("fuse_nodes_with_same_parent", False) else "" # type: ignore[union-attr]
|
| 1001 |
+
group_key = (
|
| 1002 |
+
"batch_aten_" + self.op.__name__.lower().split(".")[0],
|
| 1003 |
+
str(input.meta["val"].shape),
|
| 1004 |
+
str(node.kwargs.get("inplace", False)),
|
| 1005 |
+
# pyre-fixme[16]
|
| 1006 |
+
str(parent),
|
| 1007 |
+
)
|
| 1008 |
+
else:
|
| 1009 |
+
group_key = None
|
| 1010 |
+
return group_key
|
| 1011 |
+
|
| 1012 |
+
def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
|
| 1013 |
+
batch_nodes = []
|
| 1014 |
+
batch_inputs = []
|
| 1015 |
+
batch_inputs_metadata = []
|
| 1016 |
+
|
| 1017 |
+
for node in subset:
|
| 1018 |
+
batch_nodes.append(node)
|
| 1019 |
+
input = get_arg_value(node, 0, "input")
|
| 1020 |
+
batch_inputs.append(input)
|
| 1021 |
+
batch_inputs_metadata.append(input.meta["val"])
|
| 1022 |
+
|
| 1023 |
+
with graph.inserting_before(subset[0]):
|
| 1024 |
+
stack_inputs = decompose_stack(graph, batch_inputs)
|
| 1025 |
+
update_stack_example_value(stack_inputs, batch_inputs_metadata)
|
| 1026 |
+
batch_op = graph.call_function(
|
| 1027 |
+
self.op,
|
| 1028 |
+
args=(stack_inputs,),
|
| 1029 |
+
)
|
| 1030 |
+
for i, node in enumerate(batch_nodes):
|
| 1031 |
+
with graph.inserting_after(batch_op):
|
| 1032 |
+
getitem = graph.call_function(aten.select, args=(batch_op, 0, i))
|
| 1033 |
+
node.replace_all_uses_with(getitem)
|
| 1034 |
+
getitem.meta.update(node.meta)
|
| 1035 |
+
graph.erase_node(node)
|
| 1036 |
+
counters["inductor"][
|
| 1037 |
+
"batch_aten_" + self.op.__name__.lower().split(".")[0]
|
| 1038 |
+
] += 1
|
| 1039 |
+
|
| 1040 |
+
|
| 1041 |
+
@register_fusion("batch_tanh")
|
| 1042 |
+
class BatchTanhPreGradFusion(BatchPointwiseOpsPreGradFusion):
|
| 1043 |
+
def __init__(self, **kwargs) -> None:
|
| 1044 |
+
super().__init__(torch.tanh, **kwargs)
|
| 1045 |
+
|
| 1046 |
+
|
| 1047 |
+
@register_fusion("batch_sigmoid")
|
| 1048 |
+
class BatchSigmoidPreGradFusion(BatchPointwiseOpsPreGradFusion):
|
| 1049 |
+
def __init__(self, **kwargs) -> None:
|
| 1050 |
+
super().__init__(torch.sigmoid, **kwargs)
|
| 1051 |
+
|
| 1052 |
+
|
| 1053 |
+
@register_fusion("batch_relu")
|
| 1054 |
+
class BatchReLuPreGradFusion(BatchPointwiseOpsPreGradFusion):
|
| 1055 |
+
def __init__(self, **kwargs) -> None:
|
| 1056 |
+
super().__init__(torch.nn.functional.relu, **kwargs)
|
| 1057 |
+
|
| 1058 |
+
|
| 1059 |
+
@register_fusion("batch_aten_tanh", pre_grad=False)
|
| 1060 |
+
class BatchTanhPostGradFusion(BatchPointwiseOpsPostGradFusion):
|
| 1061 |
+
def __init__(self, **kwargs) -> None:
|
| 1062 |
+
super().__init__(aten.tanh.default, **kwargs)
|
| 1063 |
+
|
| 1064 |
+
|
| 1065 |
+
@register_fusion("batch_aten_sigmoid", pre_grad=False)
|
| 1066 |
+
class BatchSigmoidPostGradFusion(BatchPointwiseOpsPostGradFusion):
|
| 1067 |
+
def __init__(self, **kwargs) -> None:
|
| 1068 |
+
super().__init__(aten.sigmoid.default, **kwargs)
|
| 1069 |
+
|
| 1070 |
+
|
| 1071 |
+
@register_fusion("batch_aten_relu", pre_grad=False)
|
| 1072 |
+
class BatchReLuPostGradFusion(BatchPointwiseOpsPostGradFusion):
|
| 1073 |
+
def __init__(self, **kwargs) -> None:
|
| 1074 |
+
super().__init__(aten.relu.default, **kwargs)
|
| 1075 |
+
|
| 1076 |
+
|
| 1077 |
+
@register_fusion("batch_aten_add", pre_grad=False)
|
| 1078 |
+
class BatchAddPostGradFusion(BatchPointwiseMathOpsPostGradFusion):
|
| 1079 |
+
def __init__(self, **kwargs) -> None:
|
| 1080 |
+
super().__init__(aten.add.Tensor, **kwargs)
|
| 1081 |
+
|
| 1082 |
+
|
| 1083 |
+
@register_fusion("batch_aten_sub", pre_grad=False)
|
| 1084 |
+
class BatchSubPostGradFusion(BatchPointwiseMathOpsPostGradFusion):
|
| 1085 |
+
def __init__(self, **kwargs) -> None:
|
| 1086 |
+
super().__init__(aten.sub.Tensor, **kwargs)
|
| 1087 |
+
|
| 1088 |
+
|
| 1089 |
+
@register_fusion("batch_aten_div", pre_grad=False)
|
| 1090 |
+
class BatchDivPostGradFusion(BatchPointwiseMathOpsPostGradFusion):
|
| 1091 |
+
def __init__(self, **kwargs) -> None:
|
| 1092 |
+
super().__init__(aten.div.Tensor, **kwargs)
|
| 1093 |
+
|
| 1094 |
+
|
| 1095 |
+
@register_fusion("batch_aten_mul", pre_grad=False)
|
| 1096 |
+
class BatchMulPostGradFusion(BatchPointwiseMathOpsPostGradFusion):
|
| 1097 |
+
def __init__(self, **kwargs) -> None:
|
| 1098 |
+
super().__init__(aten.mul.Tensor, **kwargs)
|
| 1099 |
+
|
| 1100 |
+
|
| 1101 |
+
class _OrderedSet:
|
| 1102 |
+
def __init__(self, param=None) -> None:
|
| 1103 |
+
if param:
|
| 1104 |
+
self.rep = OrderedDict(dict.fromkeys(param))
|
| 1105 |
+
else:
|
| 1106 |
+
self.rep = OrderedDict()
|
| 1107 |
+
|
| 1108 |
+
def __contains__(self, o) -> bool:
|
| 1109 |
+
return o in self.rep
|
| 1110 |
+
|
| 1111 |
+
def __len__(self) -> int:
|
| 1112 |
+
return self.rep.__len__()
|
| 1113 |
+
|
| 1114 |
+
def append(self, o):
|
| 1115 |
+
self.rep[o] = None
|
| 1116 |
+
|
| 1117 |
+
def __iter__(self):
|
| 1118 |
+
return self.rep.keys().__iter__()
|
| 1119 |
+
|
| 1120 |
+
|
| 1121 |
+
def find_independent_subset_greedy(
|
| 1122 |
+
node_list: Iterable[torch.fx.Node],
|
| 1123 |
+
graph_search_options: Dict[str, Any],
|
| 1124 |
+
) -> Iterator[Iterable[torch.fx.Node]]:
|
| 1125 |
+
"""
|
| 1126 |
+
Yields a list of subsets of `node_list` where no element in the subset
|
| 1127 |
+
depends on any other element in the subset. This results in a set of
|
| 1128 |
+
independent nodes which can be fused together.
|
| 1129 |
+
|
| 1130 |
+
The order of `node_list` is preserved within each subset so we can benefit
|
| 1131 |
+
from split-cat elimination in later passes.
|
| 1132 |
+
|
| 1133 |
+
During iteration it is only safe to mutate the graph by changing the nodes
|
| 1134 |
+
that have been returned.
|
| 1135 |
+
|
| 1136 |
+
graph_search_options:
|
| 1137 |
+
- min_fuse_set_size: Minimum size of the subset to consider. Subsets below
|
| 1138 |
+
this size will be ignored.
|
| 1139 |
+
- max_fuse_set_size: Maximum size of the subset to consider. Subsets will
|
| 1140 |
+
be broken to be at most this size.
|
| 1141 |
+
"""
|
| 1142 |
+
|
| 1143 |
+
# Compute all the children of `node` which are members of
|
| 1144 |
+
# `interesting_nodes`.
|
| 1145 |
+
def find_dependent_nodes(node, interesting_nodes):
|
| 1146 |
+
visited_node_set: Set[torch.fx.Node] = {node}
|
| 1147 |
+
dep_set: Set[torch.fx.Node] = set()
|
| 1148 |
+
|
| 1149 |
+
work = [node]
|
| 1150 |
+
while work:
|
| 1151 |
+
node = work.pop()
|
| 1152 |
+
for input_node in node.all_input_nodes:
|
| 1153 |
+
if input_node in interesting_nodes:
|
| 1154 |
+
dep_set.add(input_node)
|
| 1155 |
+
|
| 1156 |
+
if input_node not in visited_node_set:
|
| 1157 |
+
visited_node_set.add(input_node)
|
| 1158 |
+
work.append(input_node)
|
| 1159 |
+
|
| 1160 |
+
return dep_set
|
| 1161 |
+
|
| 1162 |
+
min_fuse_set_size = graph_search_options["min_fuse_set_size"]
|
| 1163 |
+
max_fuse_set_size = graph_search_options["max_fuse_set_size"]
|
| 1164 |
+
|
| 1165 |
+
# node_list needs to be a set because we only track the nodes that are left
|
| 1166 |
+
# in it (and we want to do the `in` on a set, not a list). But we want to
|
| 1167 |
+
# keep the correct order.
|
| 1168 |
+
node_list = _OrderedSet(node_list)
|
| 1169 |
+
|
| 1170 |
+
cache: Dict[torch.fx.Node, Set[torch.fx.Node]] = {}
|
| 1171 |
+
while node_list:
|
| 1172 |
+
subset: List[torch.fx.Node] = []
|
| 1173 |
+
subset_deps: Set[torch.fx.Node] = set()
|
| 1174 |
+
|
| 1175 |
+
next_round_node_list = _OrderedSet()
|
| 1176 |
+
for node in node_list:
|
| 1177 |
+
if len(subset) >= max_fuse_set_size or node in subset_deps:
|
| 1178 |
+
next_round_node_list.append(node)
|
| 1179 |
+
continue
|
| 1180 |
+
|
| 1181 |
+
dep_set = cache.pop(node, None)
|
| 1182 |
+
if dep_set is None:
|
| 1183 |
+
dep_set = find_dependent_nodes(node, node_list)
|
| 1184 |
+
|
| 1185 |
+
if not dep_set.intersection(subset):
|
| 1186 |
+
subset.append(node)
|
| 1187 |
+
subset_deps.update(dep_set)
|
| 1188 |
+
else:
|
| 1189 |
+
next_round_node_list.append(node)
|
| 1190 |
+
cache[node] = dep_set
|
| 1191 |
+
|
| 1192 |
+
if len(subset) >= min_fuse_set_size:
|
| 1193 |
+
# Careful here - the caller uses the subsets to fuse nodes together
|
| 1194 |
+
# so we need to clear any cache entry that contains one of the
|
| 1195 |
+
# returned nodes because the dependency list could be different
|
| 1196 |
+
# (larger) after the merge.
|
| 1197 |
+
cache = {k: v for k, v in cache.items() if v.isdisjoint(subset)}
|
| 1198 |
+
yield subset
|
| 1199 |
+
|
| 1200 |
+
node_list = next_round_node_list
|
| 1201 |
+
|
| 1202 |
+
|
| 1203 |
+
def get_fusion_candidates(
|
| 1204 |
+
rule: GroupBatchFusionBase, root_node: torch.fx.Node, fused_set: Set[torch.fx.Node]
|
| 1205 |
+
) -> DefaultDict[Any, List[torch.fx.Node]]:
|
| 1206 |
+
"""
|
| 1207 |
+
Search fusion candidates for a specific rule using BFS starting from the root node.
|
| 1208 |
+
We only search the subgraph within graph_search_options["max_fuse_search_depth"].
|
| 1209 |
+
"""
|
| 1210 |
+
q: Deque[Tuple[int, torch.fx.Node]] = collections.deque()
|
| 1211 |
+
|
| 1212 |
+
candidate_dict: DefaultDict[Any, List[torch.fx.Node]] = collections.defaultdict(
|
| 1213 |
+
list
|
| 1214 |
+
)
|
| 1215 |
+
|
| 1216 |
+
if root_node.target in SEARCH_EXCLUSIONS:
|
| 1217 |
+
return candidate_dict
|
| 1218 |
+
|
| 1219 |
+
visited_set: Set[torch.fx.Node] = set()
|
| 1220 |
+
|
| 1221 |
+
for next_node in root_node.all_input_nodes:
|
| 1222 |
+
q.append((1, next_node))
|
| 1223 |
+
visited_set.add(next_node)
|
| 1224 |
+
|
| 1225 |
+
while len(q) > 0:
|
| 1226 |
+
depth, node = q.popleft()
|
| 1227 |
+
|
| 1228 |
+
if node in fused_set:
|
| 1229 |
+
continue
|
| 1230 |
+
|
| 1231 |
+
key = rule.match(node)
|
| 1232 |
+
if key is not None:
|
| 1233 |
+
candidate_nodes = candidate_dict[key]
|
| 1234 |
+
if node not in candidate_nodes:
|
| 1235 |
+
candidate_nodes.append(node)
|
| 1236 |
+
else:
|
| 1237 |
+
if depth < rule.graph_search_options["max_fuse_search_depth"]:
|
| 1238 |
+
for next_node in node.all_input_nodes:
|
| 1239 |
+
if next_node not in visited_set:
|
| 1240 |
+
visited_set.add(next_node)
|
| 1241 |
+
q.append((depth + 1, next_node))
|
| 1242 |
+
|
| 1243 |
+
return candidate_dict
|
| 1244 |
+
|
| 1245 |
+
|
| 1246 |
+
def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusionBase):
|
| 1247 |
+
stable_topological_sort(graph) # type: ignore[arg-type]
|
| 1248 |
+
fused_set: Set[torch.fx.Node] = set()
|
| 1249 |
+
log_to_scuba = False
|
| 1250 |
+
|
| 1251 |
+
for node in reversed(graph.nodes):
|
| 1252 |
+
candidates = get_fusion_candidates(rule, node, fused_set)
|
| 1253 |
+
|
| 1254 |
+
for key, candidate_nodes in candidates.items():
|
| 1255 |
+
if len(candidate_nodes) < rule.graph_search_options["min_fuse_set_size"]:
|
| 1256 |
+
continue
|
| 1257 |
+
|
| 1258 |
+
for subset in find_independent_subset_greedy(
|
| 1259 |
+
candidate_nodes, rule.graph_search_options
|
| 1260 |
+
):
|
| 1261 |
+
rule.fuse(graph, subset)
|
| 1262 |
+
fused_set.update(subset)
|
| 1263 |
+
log.debug(
|
| 1264 |
+
f"{rule.__class__.__name__}: key = {key}; subset size = {len(list(subset))}" # noqa: G004
|
| 1265 |
+
)
|
| 1266 |
+
log_to_scuba = True
|
| 1267 |
+
if log_to_scuba:
|
| 1268 |
+
optimus_scuba_log[rule.__class__.__name__] = upload_graph(graph)
|
| 1269 |
+
|
| 1270 |
+
|
| 1271 |
+
def generate_fusion_from_config(config_options: Dict[str, Any], pre_grad=True):
|
| 1272 |
+
fusions: List[GroupBatchFusionBase] = []
|
| 1273 |
+
for name, options in config_options.items():
|
| 1274 |
+
# we skip all patterns from pattern_matcher passes (e.g., split_cat)
|
| 1275 |
+
if name not in PRE_GRAD_FUSIONS and name not in POST_GRAD_FUSIONS:
|
| 1276 |
+
continue
|
| 1277 |
+
fusion_cls = PRE_GRAD_FUSIONS[name] if pre_grad else POST_GRAD_FUSIONS[name]
|
| 1278 |
+
_options = graph_search_options.copy()
|
| 1279 |
+
_options.update(options)
|
| 1280 |
+
fusions.append(fusion_cls(graph_search_options=_options)) # type: ignore[operator]
|
| 1281 |
+
return fusions
|
| 1282 |
+
|
| 1283 |
+
|
| 1284 |
+
def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
|
| 1285 |
+
fusions: List[GroupBatchFusionBase] = []
|
| 1286 |
+
# we keep all current pre grad fusions to keep
|
| 1287 |
+
# current implementation, will remove this later
|
| 1288 |
+
if pre_grad:
|
| 1289 |
+
fusions += generate_fusion_from_config(
|
| 1290 |
+
config.pre_grad_fusion_options, pre_grad=True
|
| 1291 |
+
)
|
| 1292 |
+
else:
|
| 1293 |
+
fbgemm_fusion_keys = [
|
| 1294 |
+
x
|
| 1295 |
+
for x in config.post_grad_fusion_options
|
| 1296 |
+
if config.post_grad_fusion_options[x].get("require_fbgemm", False)
|
| 1297 |
+
]
|
| 1298 |
+
fbgemm_fusions = {
|
| 1299 |
+
fusion: config.post_grad_fusion_options[fusion]
|
| 1300 |
+
for fusion in fbgemm_fusion_keys
|
| 1301 |
+
}
|
| 1302 |
+
non_fbgemm_fusions = {
|
| 1303 |
+
fusion: config.post_grad_fusion_options[fusion]
|
| 1304 |
+
for fusion in config.post_grad_fusion_options.keys()
|
| 1305 |
+
if fusion not in fbgemm_fusion_keys
|
| 1306 |
+
}
|
| 1307 |
+
fusions += generate_fusion_from_config(non_fbgemm_fusions, pre_grad=False)
|
| 1308 |
+
if has_fbgemm:
|
| 1309 |
+
fusions += generate_fusion_from_config(fbgemm_fusions, pre_grad=False)
|
| 1310 |
+
|
| 1311 |
+
for i, rule in enumerate(fusions):
|
| 1312 |
+
with GraphTransformObserver(
|
| 1313 |
+
graph.owning_module,
|
| 1314 |
+
f"group_batch_fusion_{i}",
|
| 1315 |
+
config.trace.log_url_for_graph_xform,
|
| 1316 |
+
):
|
| 1317 |
+
apply_group_batch_fusion(graph, rule) # type: ignore[arg-type]
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/joint_graph.py
ADDED
|
@@ -0,0 +1,694 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import itertools
|
| 3 |
+
import logging
|
| 4 |
+
import typing
|
| 5 |
+
from collections import Counter
|
| 6 |
+
from typing import Any, Dict, List, Set, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._guards
|
| 10 |
+
import torch.utils._pytree as pytree
|
| 11 |
+
from torch._inductor.constant_folding import ConstantFolder
|
| 12 |
+
from torch._inductor.fx_passes.dedupe_symint_uses import _SymHashingDict
|
| 13 |
+
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
| 14 |
+
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
| 15 |
+
from torch.multiprocessing.reductions import StorageWeakRef
|
| 16 |
+
|
| 17 |
+
from ...utils._ordered_set import OrderedSet
|
| 18 |
+
from .. import config
|
| 19 |
+
from ..pattern_matcher import (
|
| 20 |
+
CallFunction,
|
| 21 |
+
init_once_fakemode,
|
| 22 |
+
KeywordArg,
|
| 23 |
+
Match,
|
| 24 |
+
MULTIPLE,
|
| 25 |
+
PatternMatcherPass,
|
| 26 |
+
register_graph_pattern,
|
| 27 |
+
stable_topological_sort,
|
| 28 |
+
)
|
| 29 |
+
from .replace_random import replace_random_passes
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
log = logging.getLogger(__name__)
|
| 33 |
+
patterns = PatternMatcherPass()
|
| 34 |
+
aten = torch.ops.aten
|
| 35 |
+
prims = torch.ops.prims
|
| 36 |
+
|
| 37 |
+
pass_patterns = [
|
| 38 |
+
patterns,
|
| 39 |
+
PatternMatcherPass(),
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@init_once_fakemode
|
| 44 |
+
def lazy_init():
|
| 45 |
+
from .fuse_attention import _sfdp_init
|
| 46 |
+
from .misc_patterns import _misc_patterns_init
|
| 47 |
+
from .pad_mm import _pad_mm_init
|
| 48 |
+
|
| 49 |
+
_pad_mm_init()
|
| 50 |
+
_sfdp_init()
|
| 51 |
+
_misc_patterns_init()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def remove_no_ops(
|
| 55 |
+
gm: torch.fx.GraphModule, zeros: Set[torch.fx.Node], ones: Set[torch.fx.Node]
|
| 56 |
+
):
|
| 57 |
+
with torch.utils._python_dispatch._disable_current_modes():
|
| 58 |
+
"Removes no-ops: (+ 0, - 0, * 1, / 1)"
|
| 59 |
+
graph = gm.graph
|
| 60 |
+
|
| 61 |
+
def fake_tensors_eq(t1, t2, fields=("shape", "dtype", "device")):
|
| 62 |
+
if any(not isinstance(t, torch.Tensor) for t in (t1, t2)):
|
| 63 |
+
return False
|
| 64 |
+
for field in fields:
|
| 65 |
+
if getattr(t1, field) != getattr(t2, field):
|
| 66 |
+
return False
|
| 67 |
+
return True
|
| 68 |
+
|
| 69 |
+
def replace_no_op(node, replace_input_index):
|
| 70 |
+
replacement = node.args[replace_input_index]
|
| 71 |
+
|
| 72 |
+
# https://github.com/pytorch/pytorch/issues/86128 causes
|
| 73 |
+
# non-Tensor inputs even for ops with only Tensor inputs.
|
| 74 |
+
# TODO - decompose/type promote to avoid this
|
| 75 |
+
if not all(isinstance(arg, torch.fx.Node) for arg in node.args):
|
| 76 |
+
return
|
| 77 |
+
|
| 78 |
+
if not fake_tensors_eq(node.meta["val"], replacement.meta["val"]):
|
| 79 |
+
if fake_tensors_eq(
|
| 80 |
+
node.meta["val"],
|
| 81 |
+
replacement.meta["val"],
|
| 82 |
+
("shape", "device"),
|
| 83 |
+
):
|
| 84 |
+
with graph.inserting_after(node):
|
| 85 |
+
replacement = graph.call_function(
|
| 86 |
+
torch.ops.prims.convert_element_type.default,
|
| 87 |
+
args=(replacement, node.meta["val"].dtype),
|
| 88 |
+
)
|
| 89 |
+
else:
|
| 90 |
+
return
|
| 91 |
+
|
| 92 |
+
node.replace_all_uses_with(replacement)
|
| 93 |
+
replacement.meta.update(node.meta)
|
| 94 |
+
graph.erase_node(node)
|
| 95 |
+
|
| 96 |
+
for node in graph.find_nodes(op="call_function", target=aten.add.Tensor):
|
| 97 |
+
# TODO handle Tensor-Scalar adds, it's a different schema
|
| 98 |
+
if len(node.args) == 2:
|
| 99 |
+
if (
|
| 100 |
+
not any(e in zeros for e in node.args)
|
| 101 |
+
or node.kwargs.get("alpha", 1) != 1
|
| 102 |
+
):
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
replace_index = 1 if node.args[0] in zeros else 0
|
| 106 |
+
replace_no_op(node, replace_index)
|
| 107 |
+
|
| 108 |
+
for node in graph.find_nodes(op="call_function", target=aten.sub.Tensor):
|
| 109 |
+
if len(node.args) == 2:
|
| 110 |
+
if node.args[1] not in zeros or node.kwargs.get("alpha", 1) != 1:
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
replace_no_op(node, 0)
|
| 114 |
+
|
| 115 |
+
for node in graph.find_nodes(op="call_function", target=aten.mul.Tensor):
|
| 116 |
+
if len(node.args) == 2:
|
| 117 |
+
if not any(e in ones for e in node.args):
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
replace_input_index = 1 if node.args[0] in ones else 0
|
| 121 |
+
replace_no_op(node, replace_input_index)
|
| 122 |
+
|
| 123 |
+
for node in graph.find_nodes(op="call_function", target=aten.div.Tensor):
|
| 124 |
+
if len(node.args) == 2 and node.args[1] in ones:
|
| 125 |
+
replace_no_op(node, 0)
|
| 126 |
+
|
| 127 |
+
# meta tensors returned from the graph have no data and can be replaced with empty_strided
|
| 128 |
+
for output_node in graph.find_nodes(op="output"):
|
| 129 |
+
had_meta_return = False
|
| 130 |
+
|
| 131 |
+
def visit(n):
|
| 132 |
+
nonlocal had_meta_return
|
| 133 |
+
val = n.meta.get("val")
|
| 134 |
+
if isinstance(val, torch.Tensor) and val.device.type == "meta":
|
| 135 |
+
with graph.inserting_before(output_node):
|
| 136 |
+
n.replace_all_uses_with(
|
| 137 |
+
graph.call_function(
|
| 138 |
+
torch.ops.aten.empty_strided.default,
|
| 139 |
+
args=(val.size(), val.stride()),
|
| 140 |
+
kwargs={"dtype": val.dtype, "device": val.device},
|
| 141 |
+
)
|
| 142 |
+
)
|
| 143 |
+
had_meta_return = True
|
| 144 |
+
|
| 145 |
+
torch.fx.map_arg(output_node.args, visit)
|
| 146 |
+
if had_meta_return:
|
| 147 |
+
graph.eliminate_dead_code()
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def remove_redundant_views(gm: torch.fx.GraphModule):
|
| 151 |
+
"""
|
| 152 |
+
Removes redundant views by reusing existing ones.
|
| 153 |
+
"""
|
| 154 |
+
with torch.utils._python_dispatch._disable_current_modes():
|
| 155 |
+
# A dictionary mapping a tensor to all aliased views.
|
| 156 |
+
views: Dict[torch.fx.Node, Dict[torch.dtype, torch.fx.Node]] = {}
|
| 157 |
+
graph = gm.graph
|
| 158 |
+
|
| 159 |
+
for node in graph.find_nodes(
|
| 160 |
+
op="call_function", target=torch.ops.aten.view.dtype
|
| 161 |
+
):
|
| 162 |
+
src = node.args[0]
|
| 163 |
+
to_type = node.args[1]
|
| 164 |
+
existing_views = views.get(src)
|
| 165 |
+
is_needed = True
|
| 166 |
+
|
| 167 |
+
if existing_views:
|
| 168 |
+
# Replace the view with the an existing view if available.
|
| 169 |
+
alias = existing_views.get(to_type)
|
| 170 |
+
if alias:
|
| 171 |
+
is_needed = False
|
| 172 |
+
node.replace_all_uses_with(alias)
|
| 173 |
+
alias.meta.update(node.meta)
|
| 174 |
+
graph.erase_node(node)
|
| 175 |
+
else:
|
| 176 |
+
from_type = src.meta["val"].dtype
|
| 177 |
+
existing_views = {from_type: src}
|
| 178 |
+
views[src] = existing_views
|
| 179 |
+
|
| 180 |
+
if is_needed:
|
| 181 |
+
# Save the new alias but do not replace existing one.
|
| 182 |
+
existing_views.setdefault(to_type, node)
|
| 183 |
+
views[node] = existing_views
|
| 184 |
+
|
| 185 |
+
# Clean up unused views.
|
| 186 |
+
while True:
|
| 187 |
+
unused_views = [alias for alias in views if not alias.users]
|
| 188 |
+
if len(unused_views) == 0:
|
| 189 |
+
break
|
| 190 |
+
for unused in unused_views:
|
| 191 |
+
views.pop(unused)
|
| 192 |
+
graph.erase_node(unused)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class UniformValueConstantFolder(ConstantFolder):
|
| 196 |
+
"""
|
| 197 |
+
Runs constant folding and replaces tensors that have a unifrom value
|
| 198 |
+
with a tensor constructor call: aten.full([shape], value, ...)
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
def __init__(self, gm, skip_constructors=False) -> None:
|
| 202 |
+
super().__init__(gm, skip_constructors)
|
| 203 |
+
self.node_storages_ptrs: Dict[torch.fx.Node, int] = {}
|
| 204 |
+
self.constant_data_ptrs: Dict[torch.fx.Node, StorageWeakRef] = {}
|
| 205 |
+
# we may constant fold a tensor which in the graph has a sym size
|
| 206 |
+
# see: [constant folding refining of symints]
|
| 207 |
+
self.node_replacements_shapes: Dict[torch.fx.Node, List[int]] = {}
|
| 208 |
+
|
| 209 |
+
# initialize symint -> node mapping so that we can
|
| 210 |
+
# use symint nodes in full constructors
|
| 211 |
+
self.symint_nodes = _SymHashingDict()
|
| 212 |
+
for n in self.module.graph.nodes:
|
| 213 |
+
if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt):
|
| 214 |
+
self.symint_nodes[n.meta["val"]] = n
|
| 215 |
+
|
| 216 |
+
# reference from torch/_funtorch/partitioners.py:get_default_op_list
|
| 217 |
+
self.view_op_packets = [
|
| 218 |
+
aten.squeeze,
|
| 219 |
+
aten.unsqueeze,
|
| 220 |
+
aten.alias,
|
| 221 |
+
aten.view,
|
| 222 |
+
aten.slice,
|
| 223 |
+
aten.t,
|
| 224 |
+
prims.broadcast_in_dim,
|
| 225 |
+
aten.expand,
|
| 226 |
+
aten.as_strided,
|
| 227 |
+
aten.permute,
|
| 228 |
+
]
|
| 229 |
+
|
| 230 |
+
self.indexing_op_packets = {
|
| 231 |
+
aten.slice,
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
def _support_dynamic_shape(self):
|
| 235 |
+
return True
|
| 236 |
+
|
| 237 |
+
def insertable_tensor_check(self, t: torch.Tensor) -> bool:
|
| 238 |
+
return True
|
| 239 |
+
|
| 240 |
+
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
|
| 241 |
+
self.node_replacements[node] = tensor.flatten()[0].item()
|
| 242 |
+
self.node_replacements_shapes[node] = node.meta["val"].shape
|
| 243 |
+
self.constant_data_ptrs[node] = StorageWeakRef(tensor.untyped_storage())
|
| 244 |
+
|
| 245 |
+
def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
|
| 246 |
+
for n in self.module.graph.find_nodes(op="placeholder"):
|
| 247 |
+
if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt):
|
| 248 |
+
env[n] = n.meta["val"]
|
| 249 |
+
else:
|
| 250 |
+
env[n] = self.unknown_value
|
| 251 |
+
|
| 252 |
+
def _deduce_value(self, node: torch.fx.Node):
|
| 253 |
+
# deduce value for full-like nodes
|
| 254 |
+
# 1. for constructors, substitute value is a tensor of size [1]
|
| 255 |
+
# 2. for view ops/indexing, substitute value is the same as the input
|
| 256 |
+
# 3. for pointwise ops, run node to get the substitute value
|
| 257 |
+
# 4. deal with some special ops
|
| 258 |
+
# otherwise, stop deduce value and return unknown value
|
| 259 |
+
|
| 260 |
+
# TODO: cat, more indexing
|
| 261 |
+
# TODO - do on cpu to avoid syncs
|
| 262 |
+
|
| 263 |
+
# single-elem attrs
|
| 264 |
+
if node.op == "get_attr" or (
|
| 265 |
+
node.op == "call_function"
|
| 266 |
+
and node.target == torch.ops.aten.lift_fresh_copy.default
|
| 267 |
+
):
|
| 268 |
+
out = super(ConstantFolder, self).run_node(node)
|
| 269 |
+
if isinstance(out, torch.Tensor) and out.numel() == 1:
|
| 270 |
+
return out
|
| 271 |
+
|
| 272 |
+
# handle device_put op
|
| 273 |
+
if node.target == prims.device_put.default:
|
| 274 |
+
return super(ConstantFolder, self).run_node(node)
|
| 275 |
+
|
| 276 |
+
# constructors ops
|
| 277 |
+
if (
|
| 278 |
+
node.op == "call_function"
|
| 279 |
+
and node.target == aten.full.default
|
| 280 |
+
and len(node.args) == 2
|
| 281 |
+
):
|
| 282 |
+
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
| 283 |
+
new_args = [[1], args[1]]
|
| 284 |
+
return aten.full.default(*new_args, **node.kwargs)
|
| 285 |
+
|
| 286 |
+
# handle before view ops because this changes value
|
| 287 |
+
if node.target == aten.view.dtype:
|
| 288 |
+
return super(ConstantFolder, self).run_node(node)
|
| 289 |
+
|
| 290 |
+
# view ops, return input tensor, the first argument
|
| 291 |
+
if hasattr(node.target, "overloadpacket") and (
|
| 292 |
+
node.target.overloadpacket in self.view_op_packets
|
| 293 |
+
or node.target.overloadpacket in self.indexing_op_packets
|
| 294 |
+
):
|
| 295 |
+
assert isinstance(node.args[0], torch.fx.Node)
|
| 296 |
+
return self.env[node.args[0]]
|
| 297 |
+
|
| 298 |
+
# we don't want to return unknown value for symints so that we can
|
| 299 |
+
# still constant fold through their use in constructors or views
|
| 300 |
+
# if we see them in a pointwise node (e.g., tensor * symint)
|
| 301 |
+
# we will bail
|
| 302 |
+
if "val" in node.meta and isinstance(node.meta["val"], torch.SymInt):
|
| 303 |
+
return node.meta["val"]
|
| 304 |
+
|
| 305 |
+
# pointwise ops
|
| 306 |
+
if isinstance(node.target, torch._ops.OpOverload) and (
|
| 307 |
+
torch.Tag.pointwise in node.target.tags
|
| 308 |
+
or node.target is torch.ops.aten.scalar_tensor.default
|
| 309 |
+
):
|
| 310 |
+
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
| 311 |
+
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
|
| 312 |
+
|
| 313 |
+
if any(isinstance(inp, torch.SymInt) for inp in flattened_inputs):
|
| 314 |
+
return self.unknown_value
|
| 315 |
+
|
| 316 |
+
# we run the ops with dim 1, so remove memory_format to avoid error
|
| 317 |
+
kwargs = dict(kwargs)
|
| 318 |
+
kwargs.pop("memory_format", None)
|
| 319 |
+
|
| 320 |
+
return node.target(*args, **kwargs)
|
| 321 |
+
|
| 322 |
+
return self.unknown_value
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def constant_fold_uniform_value(gm: torch.fx.GraphModule):
|
| 326 |
+
with torch.utils._python_dispatch._disable_current_modes():
|
| 327 |
+
"Runs constant folding and replaces constants which can be constructed with a single `full` call. Calls into remove_no_ops."
|
| 328 |
+
aten = torch.ops.aten
|
| 329 |
+
|
| 330 |
+
# Constant folding can leak memory, especially with repeated compilation, so we are only going to
|
| 331 |
+
# remove constants which can be replaced with a constructor.
|
| 332 |
+
cf = UniformValueConstantFolder(gm)
|
| 333 |
+
cf.run()
|
| 334 |
+
|
| 335 |
+
node_replacements = cf.node_replacements
|
| 336 |
+
|
| 337 |
+
# note: [constant folding refining of symints]
|
| 338 |
+
# constant folding will partially evaluate a graph such that values which have dependencies which
|
| 339 |
+
# are entirely known at compile time may also become compile time constants. in some cases,
|
| 340 |
+
# this will include symints which we had not yet previously deduced are guaranteed a
|
| 341 |
+
# constant value and is then deduced in constant folding. an example is:
|
| 342 |
+
# unbacked_symint_eq_11 = torch.full((), 11).item()
|
| 343 |
+
# torch.full((unbacked_symint_eq_11,), 0)
|
| 344 |
+
node_replacements_shapes = cf.node_replacements_shapes
|
| 345 |
+
|
| 346 |
+
graph = gm.graph
|
| 347 |
+
|
| 348 |
+
zeros = set()
|
| 349 |
+
ones = set()
|
| 350 |
+
|
| 351 |
+
# Got failures in `test_is_set_to_cuda` if we change aliasing on constants,
|
| 352 |
+
# so just constant-ify if a Tensor is unaliased
|
| 353 |
+
constant_data_ptr_count: typing.Counter[StorageWeakRef] = Counter()
|
| 354 |
+
|
| 355 |
+
for node in cf.node_replacements:
|
| 356 |
+
constant_data_ptr_count[cf.constant_data_ptrs[node]] += 1
|
| 357 |
+
|
| 358 |
+
for node, value in node_replacements.items():
|
| 359 |
+
# we dont have a functional way right now of instantiating a non-contiguous tensor with full/zeros/ones right now
|
| 360 |
+
# hasn't shown up to be important yet
|
| 361 |
+
if "val" not in node.meta:
|
| 362 |
+
# This can only happen in AOTI
|
| 363 |
+
continue
|
| 364 |
+
|
| 365 |
+
fake_tensor = node.meta["val"]
|
| 366 |
+
if not fake_tensor.is_contiguous(memory_format=torch.contiguous_format):
|
| 367 |
+
continue
|
| 368 |
+
|
| 369 |
+
# TODO - not sure about lossy uint->python value->uint conversions
|
| 370 |
+
if fake_tensor.dtype in (
|
| 371 |
+
torch.uint8,
|
| 372 |
+
torch.uint16,
|
| 373 |
+
torch.uint32,
|
| 374 |
+
torch.uint64,
|
| 375 |
+
):
|
| 376 |
+
continue
|
| 377 |
+
|
| 378 |
+
if constant_data_ptr_count[cf.constant_data_ptrs[node]] > 1:
|
| 379 |
+
continue
|
| 380 |
+
|
| 381 |
+
with graph.inserting_after(node):
|
| 382 |
+
# the conversion from tensor and back to value can be lossy, just use the original full ctor value
|
| 383 |
+
if (
|
| 384 |
+
node.op == "call_function"
|
| 385 |
+
and node.target == aten.full.default
|
| 386 |
+
and len(node.args) == 2
|
| 387 |
+
):
|
| 388 |
+
value = node.args[1]
|
| 389 |
+
|
| 390 |
+
# refines symints, see [constant folding refining of symints] above
|
| 391 |
+
for runtime_size, compile_time_size in zip(
|
| 392 |
+
node_replacements_shapes[node], fake_tensor.shape
|
| 393 |
+
):
|
| 394 |
+
torch._check(runtime_size == compile_time_size)
|
| 395 |
+
|
| 396 |
+
# replace SymInt as Node before creating a new full node
|
| 397 |
+
# e.g. (1, s0) -> (1, arg0_1)
|
| 398 |
+
node_shape = node_replacements_shapes[node]
|
| 399 |
+
if not all(
|
| 400 |
+
not isinstance(s, torch.SymInt) or s in cf.symint_nodes
|
| 401 |
+
for s in node_shape
|
| 402 |
+
):
|
| 403 |
+
continue
|
| 404 |
+
|
| 405 |
+
shapes = [
|
| 406 |
+
cf.symint_nodes[s] if isinstance(s, torch.SymInt) else s
|
| 407 |
+
for s in node_replacements_shapes[node]
|
| 408 |
+
]
|
| 409 |
+
|
| 410 |
+
# zeros and ones just get traced into full, so we insert those
|
| 411 |
+
new_node = graph.call_function(
|
| 412 |
+
aten.full.default,
|
| 413 |
+
args=(shapes, value),
|
| 414 |
+
kwargs={
|
| 415 |
+
"dtype": fake_tensor.dtype,
|
| 416 |
+
"layout": torch.strided,
|
| 417 |
+
"device": fake_tensor.device,
|
| 418 |
+
"pin_memory": False,
|
| 419 |
+
},
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
new_node.meta.update(node.meta)
|
| 423 |
+
node.replace_all_uses_with(new_node)
|
| 424 |
+
graph.erase_node(node)
|
| 425 |
+
|
| 426 |
+
if value == 0:
|
| 427 |
+
zeros.add(new_node)
|
| 428 |
+
elif value == 1:
|
| 429 |
+
ones.add(new_node)
|
| 430 |
+
|
| 431 |
+
remove_no_ops(gm, zeros, ones)
|
| 432 |
+
remove_redundant_views(gm)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def joint_graph_passes(graph: torch.fx.GraphModule):
|
| 436 |
+
"""
|
| 437 |
+
Run FX transformations on the joint forwards+backwards graph.
|
| 438 |
+
"""
|
| 439 |
+
lazy_init()
|
| 440 |
+
count = 0
|
| 441 |
+
if config.joint_custom_pre_pass is not None:
|
| 442 |
+
with GraphTransformObserver(
|
| 443 |
+
graph, "joint_custom_pre_pass", config.trace.log_url_for_graph_xform
|
| 444 |
+
):
|
| 445 |
+
config.joint_custom_pre_pass(graph.graph)
|
| 446 |
+
count += 1
|
| 447 |
+
|
| 448 |
+
from .post_grad import remove_noop_ops
|
| 449 |
+
|
| 450 |
+
remove_noop_ops(graph.graph)
|
| 451 |
+
|
| 452 |
+
if config.joint_graph_constant_folding:
|
| 453 |
+
with GraphTransformObserver(
|
| 454 |
+
graph, "constant_fold_uniform_value", config.trace.log_url_for_graph_xform
|
| 455 |
+
):
|
| 456 |
+
constant_fold_uniform_value(graph)
|
| 457 |
+
|
| 458 |
+
if config.pattern_matcher:
|
| 459 |
+
for patterns in pass_patterns:
|
| 460 |
+
count += patterns.apply(graph.graph) # type: ignore[arg-type]
|
| 461 |
+
|
| 462 |
+
if not config.fallback_random:
|
| 463 |
+
count += replace_random_passes(graph)
|
| 464 |
+
|
| 465 |
+
if config.joint_custom_post_pass is not None:
|
| 466 |
+
with GraphTransformObserver(
|
| 467 |
+
graph, "joint_custom_post_pass", config.trace.log_url_for_graph_xform
|
| 468 |
+
):
|
| 469 |
+
config.joint_custom_post_pass(graph.graph)
|
| 470 |
+
count += 1
|
| 471 |
+
|
| 472 |
+
if count:
|
| 473 |
+
stable_topological_sort(graph.graph)
|
| 474 |
+
graph.graph.lint()
|
| 475 |
+
graph.recompile()
|
| 476 |
+
return graph
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
@register_graph_pattern(
|
| 480 |
+
CallFunction(
|
| 481 |
+
torch.ops.prims.iota.default,
|
| 482 |
+
KeywordArg("length"),
|
| 483 |
+
start=KeywordArg("start"),
|
| 484 |
+
step=KeywordArg("step"),
|
| 485 |
+
dtype=KeywordArg("dtype"),
|
| 486 |
+
device=KeywordArg("device"),
|
| 487 |
+
requires_grad=KeywordArg("requires_grad"),
|
| 488 |
+
),
|
| 489 |
+
pass_dict=patterns,
|
| 490 |
+
)
|
| 491 |
+
def fix_iota_device(match: Match, length, start, step, dtype, device, requires_grad):
|
| 492 |
+
"""
|
| 493 |
+
Eager supports:
|
| 494 |
+
|
| 495 |
+
aten.index(cuda_tensor, torch.arange(..., device="cpu"))
|
| 496 |
+
|
| 497 |
+
But this results in an implicit host-device-copy and breaks cudagraphs.
|
| 498 |
+
Rewrite the arange to use CUDA.
|
| 499 |
+
"""
|
| 500 |
+
(node,) = match.nodes
|
| 501 |
+
user_devices: OrderedSet[torch.device] = OrderedSet()
|
| 502 |
+
for user in node.users:
|
| 503 |
+
if (
|
| 504 |
+
user.op == "call_function"
|
| 505 |
+
and user.target in (aten.index.Tensor, aten.index_put.default)
|
| 506 |
+
and hasattr(user.meta.get("val"), "device")
|
| 507 |
+
):
|
| 508 |
+
user_devices.add(user.meta["val"].device) # type: ignore[union-attr]
|
| 509 |
+
else:
|
| 510 |
+
return # bail out
|
| 511 |
+
|
| 512 |
+
if len(user_devices) == 1 and "val" in node.meta:
|
| 513 |
+
(user_device,) = user_devices
|
| 514 |
+
if device.type != user_device.type:
|
| 515 |
+
repl = match.graph.call_function(
|
| 516 |
+
torch.ops.prims.iota.default,
|
| 517 |
+
(length,),
|
| 518 |
+
{
|
| 519 |
+
"start": start,
|
| 520 |
+
"step": step,
|
| 521 |
+
"dtype": dtype,
|
| 522 |
+
"device": user_device,
|
| 523 |
+
"requires_grad": requires_grad,
|
| 524 |
+
},
|
| 525 |
+
)
|
| 526 |
+
repl.meta.update(node.meta)
|
| 527 |
+
repl.meta["val"] = repl.meta["val"].to(user_device)
|
| 528 |
+
node.replace_all_uses_with(repl)
|
| 529 |
+
match.erase_nodes()
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
@register_graph_pattern(
|
| 533 |
+
CallFunction(
|
| 534 |
+
torch.ops.prims.convert_element_type.default,
|
| 535 |
+
CallFunction(
|
| 536 |
+
torch.ops.prims.convert_element_type.default,
|
| 537 |
+
KeywordArg("arg"),
|
| 538 |
+
KeywordArg("dtype1"),
|
| 539 |
+
),
|
| 540 |
+
KeywordArg("dtype2"),
|
| 541 |
+
),
|
| 542 |
+
pass_dict=patterns,
|
| 543 |
+
)
|
| 544 |
+
def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtype):
|
| 545 |
+
"""Remove chain of dtype conversions often created by AMP"""
|
| 546 |
+
graph = match.graph
|
| 547 |
+
node = match.output_node()
|
| 548 |
+
allowed = {torch.float16, torch.bfloat16, torch.float32, torch.float64}
|
| 549 |
+
if dtype1 in allowed and dtype2 in allowed:
|
| 550 |
+
repl = graph.call_function(
|
| 551 |
+
torch.ops.prims.convert_element_type.default, (arg, dtype2)
|
| 552 |
+
)
|
| 553 |
+
repl.meta.update(node.meta)
|
| 554 |
+
node.replace_all_uses_with(repl)
|
| 555 |
+
match.erase_nodes()
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
@register_graph_pattern(
|
| 559 |
+
CallFunction(torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")),
|
| 560 |
+
pass_dict=patterns,
|
| 561 |
+
)
|
| 562 |
+
def pointless_view(match: Match, arg, size):
|
| 563 |
+
"""Remove no-op view"""
|
| 564 |
+
node = match.output_node()
|
| 565 |
+
arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr]
|
| 566 |
+
if size == arg_size:
|
| 567 |
+
node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type]
|
| 568 |
+
match.erase_nodes()
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
# When softmax is used with temperature or other scaling, we get the pattern
|
| 572 |
+
#
|
| 573 |
+
# scale(x) - scale(x).amax(dim, keepdim=True)
|
| 574 |
+
#
|
| 575 |
+
# which is expected to be at most zero, but we may end up with numerical
|
| 576 |
+
# discrepancies # between the recomputed values of scale(x) inside and out
|
| 577 |
+
# of the reduction, # depending on compiler optimizations, e.g. use of fma
|
| 578 |
+
# instructions.
|
| 579 |
+
#
|
| 580 |
+
# Here we replace it with the mathematically equivalent,
|
| 581 |
+
#
|
| 582 |
+
# scale(x - x.amax(dim, keepdim=True))
|
| 583 |
+
#
|
| 584 |
+
# which is more stable as we only compute the scaling once.
|
| 585 |
+
#
|
| 586 |
+
# NOTE: This pattern must come after fused attention matching!
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
def _partial_softmax_pattern(linear_func, reverse=False, to_dtype=False):
|
| 590 |
+
# Allow matching inp * other and other * input
|
| 591 |
+
if reverse:
|
| 592 |
+
scaled = CallFunction(
|
| 593 |
+
linear_func, KeywordArg("other"), KeywordArg("inp"), _users=MULTIPLE
|
| 594 |
+
)
|
| 595 |
+
else:
|
| 596 |
+
scaled = CallFunction(
|
| 597 |
+
linear_func, KeywordArg("inp"), KeywordArg("other"), _users=MULTIPLE
|
| 598 |
+
)
|
| 599 |
+
if to_dtype:
|
| 600 |
+
scaled = CallFunction(
|
| 601 |
+
prims.convert_element_type, scaled, KeywordArg("dtype"), _users=MULTIPLE
|
| 602 |
+
)
|
| 603 |
+
amax = CallFunction(
|
| 604 |
+
aten.amax.default, scaled, KeywordArg("dim"), KeywordArg("keepdim")
|
| 605 |
+
)
|
| 606 |
+
return CallFunction(aten.sub.Tensor, scaled, amax)
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def _other_is_broadcasted_in_dim(match):
|
| 610 |
+
# Check that the scaling factor is constant across the reduction dim,
|
| 611 |
+
# so scaling doesn't change which index corresponds to the maximum value
|
| 612 |
+
other = match.kwargs["other"]
|
| 613 |
+
if isinstance(other, (int, float)):
|
| 614 |
+
return True
|
| 615 |
+
|
| 616 |
+
inp = match.kwargs["inp"]
|
| 617 |
+
if not all(isinstance(x, torch.fx.Node) for x in (inp, other)):
|
| 618 |
+
return False
|
| 619 |
+
|
| 620 |
+
inp_example = inp.meta["val"]
|
| 621 |
+
other_example = other.meta["val"]
|
| 622 |
+
if isinstance(other_example, (torch.SymInt, torch.SymFloat)):
|
| 623 |
+
return True
|
| 624 |
+
|
| 625 |
+
if not all(isinstance(x, torch.Tensor) for x in (inp_example, other_example)):
|
| 626 |
+
return False
|
| 627 |
+
|
| 628 |
+
inp_ndim = inp_example.ndim
|
| 629 |
+
other_shape = other_example.shape
|
| 630 |
+
if inp_ndim < len(other_shape):
|
| 631 |
+
return False
|
| 632 |
+
|
| 633 |
+
# Pad other_shape to the same ndim as inp
|
| 634 |
+
other_shape = [1] * (inp_ndim - len(other_shape)) + list(other_shape)
|
| 635 |
+
|
| 636 |
+
dim = match.kwargs["dim"]
|
| 637 |
+
if isinstance(dim, int):
|
| 638 |
+
dim = (dim,)
|
| 639 |
+
|
| 640 |
+
return all(statically_known_true(other_shape[d] == 1) for d in dim)
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
def mul_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None):
|
| 644 |
+
def repl(inp, other):
|
| 645 |
+
if dtype is not None:
|
| 646 |
+
inp = inp.to(dtype)
|
| 647 |
+
|
| 648 |
+
sign: Union[int, float, torch.Tensor]
|
| 649 |
+
if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)):
|
| 650 |
+
sign = 1 if other >= 0 else -1
|
| 651 |
+
else:
|
| 652 |
+
one = torch.scalar_tensor(1, dtype=inp.dtype, device=inp.device)
|
| 653 |
+
sign = torch.where(other >= 0, one, -one)
|
| 654 |
+
|
| 655 |
+
inp = inp * sign
|
| 656 |
+
max_ = torch.amax(inp, dim=dim, keepdim=keepdim)
|
| 657 |
+
return (inp - max_) * (sign * other)
|
| 658 |
+
|
| 659 |
+
match.replace_by_example(repl, [inp, other])
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
for reverse, to_dtype in itertools.product((False, True), repeat=2):
|
| 663 |
+
register_graph_pattern(
|
| 664 |
+
_partial_softmax_pattern(aten.mul.Tensor, reverse=reverse, to_dtype=to_dtype),
|
| 665 |
+
pass_dict=pass_patterns[1],
|
| 666 |
+
extra_check=_other_is_broadcasted_in_dim,
|
| 667 |
+
)(mul_softmax_pattern)
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
def div_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None):
|
| 671 |
+
def repl(inp, other):
|
| 672 |
+
if dtype is not None:
|
| 673 |
+
inp = inp.to(dtype)
|
| 674 |
+
|
| 675 |
+
sign: Union[int, float, torch.Tensor]
|
| 676 |
+
if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)):
|
| 677 |
+
sign = 1 if other >= 0 else -1
|
| 678 |
+
else:
|
| 679 |
+
one = torch.scalar_tensor(1, dtype=inp.dtype, device=inp.device)
|
| 680 |
+
sign = torch.where(other >= 0, one, -one)
|
| 681 |
+
|
| 682 |
+
inp = inp * sign
|
| 683 |
+
max_ = torch.amax(inp, dim=dim, keepdim=keepdim)
|
| 684 |
+
return (inp - max_) / (sign * other)
|
| 685 |
+
|
| 686 |
+
match.replace_by_example(repl, [inp, other])
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
for to_dtype in (False, True):
|
| 690 |
+
register_graph_pattern(
|
| 691 |
+
_partial_softmax_pattern(aten.div.Tensor, to_dtype=to_dtype),
|
| 692 |
+
pass_dict=pass_patterns[1],
|
| 693 |
+
extra_check=_other_is_broadcasted_in_dim,
|
| 694 |
+
)(div_softmax_pattern)
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/micro_pipeline_tp.py
ADDED
|
@@ -0,0 +1,854 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import operator
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from typing import Any, cast, Dict, List, Optional, Set
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from .. import config, inductor_prims
|
| 10 |
+
from ..pattern_matcher import (
|
| 11 |
+
CallFunction,
|
| 12 |
+
Ignored,
|
| 13 |
+
KeywordArg,
|
| 14 |
+
ListOf,
|
| 15 |
+
Match,
|
| 16 |
+
MULTIPLE,
|
| 17 |
+
PatternExpr,
|
| 18 |
+
PatternMatcherPass,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
aten = torch.ops.aten
|
| 23 |
+
patterns = PatternMatcherPass()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _is_backward(graph: torch.fx.Graph) -> bool:
|
| 27 |
+
placeholders = []
|
| 28 |
+
for node in graph.nodes:
|
| 29 |
+
if node.op != "placeholder":
|
| 30 |
+
break
|
| 31 |
+
placeholders.append(node)
|
| 32 |
+
return not all(node.name.startswith("primal") for node in placeholders)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _compute_mm_arithmetic_intensity(M: int, N: int, K: int) -> float:
|
| 36 |
+
return M * N * K / (M * K + N * K + M * N)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _filter_nodes_by_target(nodes: List[torch.fx.Node], target) -> List[torch.fx.Node]:
|
| 40 |
+
return [x for x in nodes if x.target == target]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _find_ancestors(node: torch.fx.Node) -> Set[torch.fx.Node]:
|
| 44 |
+
ancestors = set()
|
| 45 |
+
ancestors.add(node)
|
| 46 |
+
cur_nodes = [node]
|
| 47 |
+
while len(cur_nodes) > 0:
|
| 48 |
+
new_nodes = []
|
| 49 |
+
for node in cur_nodes:
|
| 50 |
+
for inp in node.all_input_nodes:
|
| 51 |
+
if inp not in ancestors:
|
| 52 |
+
ancestors.add(inp)
|
| 53 |
+
new_nodes.append(inp)
|
| 54 |
+
cur_nodes = new_nodes
|
| 55 |
+
return {node for node in ancestors if node.op != "placeholder"}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _get_tensor(node: torch.fx.Node) -> torch.Tensor:
|
| 59 |
+
val = node.meta["val"]
|
| 60 |
+
assert isinstance(val, torch.Tensor)
|
| 61 |
+
return val
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@dataclass
|
| 65 |
+
class _AllGatherMatch:
|
| 66 |
+
match: Match
|
| 67 |
+
shard_node: torch.fx.Node
|
| 68 |
+
ag_node: torch.fx.Node
|
| 69 |
+
res_node: torch.fx.Node
|
| 70 |
+
gather_dim: int
|
| 71 |
+
group_name: str
|
| 72 |
+
|
| 73 |
+
def replace_with(self, new_node: torch.fx.Node) -> None:
|
| 74 |
+
self.res_node.replace_all_uses_with(new_node)
|
| 75 |
+
|
| 76 |
+
def erase(self) -> None:
|
| 77 |
+
for node in reversed(self.match.nodes):
|
| 78 |
+
if len(node.users) == 0:
|
| 79 |
+
node.graph.erase_node(node)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def find_all_gather_patterns(graph: torch.fx.Graph):
|
| 83 |
+
c10d = torch.ops._c10d_functional
|
| 84 |
+
|
| 85 |
+
def make_zero_dim_all_gather_pattern(shard):
|
| 86 |
+
return CallFunction(
|
| 87 |
+
c10d.wait_tensor.default,
|
| 88 |
+
CallFunction(
|
| 89 |
+
c10d.all_gather_into_tensor.default,
|
| 90 |
+
shard,
|
| 91 |
+
Ignored(),
|
| 92 |
+
KeywordArg("group_name"),
|
| 93 |
+
),
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Matches funcol.all_gather_tensor with gather_dim == 0
|
| 97 |
+
zero_dim_all_gather_pattern = make_zero_dim_all_gather_pattern(KeywordArg("shard"))
|
| 98 |
+
|
| 99 |
+
def make_all_gather_split_pattern(shard):
|
| 100 |
+
return CallFunction(
|
| 101 |
+
operator.getitem,
|
| 102 |
+
CallFunction(
|
| 103 |
+
aten.split.Tensor,
|
| 104 |
+
make_zero_dim_all_gather_pattern(shard),
|
| 105 |
+
Ignored(),
|
| 106 |
+
_users=MULTIPLE,
|
| 107 |
+
),
|
| 108 |
+
Ignored(),
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def make_cat_pattern(splits):
|
| 112 |
+
return CallFunction(
|
| 113 |
+
aten.cat.default,
|
| 114 |
+
ListOf(splits),
|
| 115 |
+
KeywordArg("gather_dim"),
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Matches funcol.all_gather_tensor with gather_dim > 0
|
| 119 |
+
non_zero_dim_all_gather_pattern = make_cat_pattern(
|
| 120 |
+
make_all_gather_split_pattern(KeywordArg("shard")),
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Match a zero-dim all-gather in which the data is transferred as uint8 and
|
| 124 |
+
# viewed back as the original dtype.
|
| 125 |
+
zero_dim_type_erased_all_gather_pattern = CallFunction(
|
| 126 |
+
aten.view.dtype,
|
| 127 |
+
make_zero_dim_all_gather_pattern(
|
| 128 |
+
KeywordArg("shard"),
|
| 129 |
+
),
|
| 130 |
+
Ignored(),
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Match a non-zero dim all-gather in which the data is transferred as uint8
|
| 134 |
+
# and viewed back as the original dtype.
|
| 135 |
+
non_zero_dim_type_erased_all_gather_pattern = CallFunction(
|
| 136 |
+
aten.view.dtype,
|
| 137 |
+
make_cat_pattern(
|
| 138 |
+
CallFunction(
|
| 139 |
+
aten.view.dtype,
|
| 140 |
+
make_all_gather_split_pattern(
|
| 141 |
+
KeywordArg("shard"),
|
| 142 |
+
),
|
| 143 |
+
Ignored(),
|
| 144 |
+
),
|
| 145 |
+
),
|
| 146 |
+
Ignored(),
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# If two patterns with the same res_node_target have the same suffix, the
|
| 150 |
+
# longer pattern should appear first in the list.
|
| 151 |
+
# e.g. supposed we have (1) A -> B -> C -> D and (2) B -> C -> D, (1)
|
| 152 |
+
# should appear before (2) in the list.
|
| 153 |
+
res_node_target_to_patterns = {
|
| 154 |
+
aten.cat.default: [
|
| 155 |
+
(non_zero_dim_all_gather_pattern, 0),
|
| 156 |
+
],
|
| 157 |
+
aten.view.dtype: [
|
| 158 |
+
(non_zero_dim_type_erased_all_gather_pattern, 0),
|
| 159 |
+
(zero_dim_type_erased_all_gather_pattern, 0),
|
| 160 |
+
],
|
| 161 |
+
c10d.wait_tensor.default: [
|
| 162 |
+
(zero_dim_all_gather_pattern, 0),
|
| 163 |
+
],
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
# Match in reverse to ensure longer patterns is prioritized
|
| 167 |
+
all_gathers = []
|
| 168 |
+
visited_ag_nodes = set()
|
| 169 |
+
for node in reversed(graph.nodes):
|
| 170 |
+
for target, patterns in res_node_target_to_patterns.items():
|
| 171 |
+
if node.target != target:
|
| 172 |
+
continue
|
| 173 |
+
for pattern, ag_node_idx in patterns:
|
| 174 |
+
match = pattern.match(node)
|
| 175 |
+
if not match:
|
| 176 |
+
continue
|
| 177 |
+
|
| 178 |
+
assert isinstance(match, Match)
|
| 179 |
+
ag_node = match.nodes[ag_node_idx]
|
| 180 |
+
assert ag_node.target == c10d.all_gather_into_tensor.default
|
| 181 |
+
|
| 182 |
+
if ag_node in visited_ag_nodes:
|
| 183 |
+
continue
|
| 184 |
+
visited_ag_nodes.add(ag_node)
|
| 185 |
+
|
| 186 |
+
ag_match = _AllGatherMatch(
|
| 187 |
+
match=match,
|
| 188 |
+
shard_node=match.kwargs["shard"],
|
| 189 |
+
ag_node=ag_node,
|
| 190 |
+
res_node=node,
|
| 191 |
+
gather_dim=match.kwargs.get("gather_dim", 0),
|
| 192 |
+
group_name=match.kwargs["group_name"],
|
| 193 |
+
)
|
| 194 |
+
all_gathers.append(ag_match)
|
| 195 |
+
|
| 196 |
+
return list(reversed(all_gathers))
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@dataclass
|
| 200 |
+
class _ReduceScatterMatch:
|
| 201 |
+
match: Match
|
| 202 |
+
input_node: torch.fx.Node
|
| 203 |
+
rs_node: torch.fx.Node
|
| 204 |
+
res_node: torch.fx.Node
|
| 205 |
+
reduce_op: str
|
| 206 |
+
scatter_dim: int
|
| 207 |
+
group_name: str
|
| 208 |
+
|
| 209 |
+
def replace_with(self, new_node: torch.fx.Node) -> None:
|
| 210 |
+
self.res_node.replace_all_uses_with(new_node)
|
| 211 |
+
|
| 212 |
+
def erase(self) -> None:
|
| 213 |
+
for node in reversed(self.match.nodes):
|
| 214 |
+
if len(node.users) == 0:
|
| 215 |
+
node.graph.erase_node(node)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def find_reduce_scatter_patterns(graph: torch.fx.Graph):
|
| 219 |
+
c10d = torch.ops._c10d_functional
|
| 220 |
+
|
| 221 |
+
def reduce_scatter_template(inp: PatternExpr):
|
| 222 |
+
return CallFunction(
|
| 223 |
+
c10d.wait_tensor.default,
|
| 224 |
+
CallFunction(
|
| 225 |
+
c10d.reduce_scatter_tensor.default,
|
| 226 |
+
inp,
|
| 227 |
+
KeywordArg("reduce_op"),
|
| 228 |
+
Ignored(),
|
| 229 |
+
KeywordArg("group_name"),
|
| 230 |
+
),
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# Matches funcol.reduce_scatter_tensor with scatter_dim == 0
|
| 234 |
+
zero_dim_reduce_scatter_pattern = reduce_scatter_template(KeywordArg("input"))
|
| 235 |
+
|
| 236 |
+
# Matches funcol.reduce_scatter_tensor with scatter_dim > 0
|
| 237 |
+
non_zero_dim_reduce_scatter_pattern = reduce_scatter_template(
|
| 238 |
+
CallFunction(
|
| 239 |
+
aten.cat.default,
|
| 240 |
+
ListOf(
|
| 241 |
+
CallFunction(
|
| 242 |
+
operator.getitem,
|
| 243 |
+
CallFunction(
|
| 244 |
+
aten.split.Tensor,
|
| 245 |
+
KeywordArg("input"),
|
| 246 |
+
Ignored(),
|
| 247 |
+
KeywordArg("scatter_dim"),
|
| 248 |
+
_users=MULTIPLE,
|
| 249 |
+
),
|
| 250 |
+
Ignored(),
|
| 251 |
+
)
|
| 252 |
+
),
|
| 253 |
+
),
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
reduce_scatters = []
|
| 257 |
+
for node in reversed(graph.nodes):
|
| 258 |
+
if node.target == c10d.wait_tensor.default:
|
| 259 |
+
if match := non_zero_dim_reduce_scatter_pattern.match(node):
|
| 260 |
+
assert isinstance(match, Match)
|
| 261 |
+
reduce_scatters.append(
|
| 262 |
+
_ReduceScatterMatch(
|
| 263 |
+
match=match,
|
| 264 |
+
input_node=match.kwargs["input"],
|
| 265 |
+
rs_node=match.nodes[-2],
|
| 266 |
+
res_node=node,
|
| 267 |
+
reduce_op=match.kwargs["reduce_op"],
|
| 268 |
+
scatter_dim=match.kwargs["scatter_dim"],
|
| 269 |
+
group_name=match.kwargs["group_name"],
|
| 270 |
+
)
|
| 271 |
+
)
|
| 272 |
+
elif match := zero_dim_reduce_scatter_pattern.match(node):
|
| 273 |
+
assert isinstance(match, Match)
|
| 274 |
+
reduce_scatters.append(
|
| 275 |
+
_ReduceScatterMatch(
|
| 276 |
+
match=match,
|
| 277 |
+
input_node=match.kwargs["input"],
|
| 278 |
+
rs_node=match.nodes[0],
|
| 279 |
+
res_node=node,
|
| 280 |
+
reduce_op=match.kwargs["reduce_op"],
|
| 281 |
+
scatter_dim=0,
|
| 282 |
+
group_name=match.kwargs["group_name"],
|
| 283 |
+
)
|
| 284 |
+
)
|
| 285 |
+
return list(reversed(reduce_scatters))
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
@dataclass
|
| 289 |
+
class _Matmul:
|
| 290 |
+
nodes: List[torch.fx.Node]
|
| 291 |
+
arg_ancestor_nodes: Set[torch.fx.Node] = field(init=False)
|
| 292 |
+
A_node: torch.fx.Node
|
| 293 |
+
B_node: torch.fx.Node
|
| 294 |
+
|
| 295 |
+
def __post_init__(self):
|
| 296 |
+
assert len(self.nodes) in (1, 3)
|
| 297 |
+
if len(self.nodes) == 1:
|
| 298 |
+
assert self.nodes[0].target in (aten.mm.default, aten._scaled_mm.default)
|
| 299 |
+
else:
|
| 300 |
+
assert self.nodes[0].target == aten.reshape.default
|
| 301 |
+
assert self.nodes[1].target in (aten.mm.default, aten._scaled_mm.default)
|
| 302 |
+
assert self.nodes[2].target == aten.reshape.default
|
| 303 |
+
self.arg_ancestor_nodes = _find_ancestors(self.B_node)
|
| 304 |
+
|
| 305 |
+
def replace_with(self, new_node: torch.fx.Node) -> None:
|
| 306 |
+
"""
|
| 307 |
+
Replace the matmul with the new node.
|
| 308 |
+
"""
|
| 309 |
+
graph = new_node.graph
|
| 310 |
+
|
| 311 |
+
# For 2D-matmuls, we simply replace the mm node with `new_node`.
|
| 312 |
+
if len(self.nodes) == 1:
|
| 313 |
+
mm_node = self.nodes[0]
|
| 314 |
+
assert mm_node.target in (aten.mm.default, aten._scaled_mm.default)
|
| 315 |
+
mm_node.replace_all_uses_with(new_node)
|
| 316 |
+
graph.erase_node(mm_node)
|
| 317 |
+
return
|
| 318 |
+
|
| 319 |
+
# An ND-matmul is reshape -> mm -> reshape sequence. We first replace
|
| 320 |
+
# the second reshape node with `new_node`. Then, we ensure that the
|
| 321 |
+
# original mm node in the sequence ends up with zero users by replacing
|
| 322 |
+
# it with a reverse reshape of `new_node`.
|
| 323 |
+
graph = new_node.graph
|
| 324 |
+
assert len(self.nodes) == 3
|
| 325 |
+
mm_node = self.nodes[1]
|
| 326 |
+
output_reshape_node = self.nodes[2]
|
| 327 |
+
|
| 328 |
+
assert mm_node.target in (aten.mm.default, aten._scaled_mm.default)
|
| 329 |
+
assert output_reshape_node.target == aten.reshape.default
|
| 330 |
+
|
| 331 |
+
output_reshape_node.replace_all_uses_with(new_node)
|
| 332 |
+
if len(mm_node.users) > 1:
|
| 333 |
+
with graph.inserting_after(new_node):
|
| 334 |
+
new_mm_node = graph.call_function(
|
| 335 |
+
aten.reshape.default,
|
| 336 |
+
args=(new_node, list(_get_tensor(mm_node).shape)),
|
| 337 |
+
)
|
| 338 |
+
mm_node.replace_all_uses_with(new_mm_node)
|
| 339 |
+
|
| 340 |
+
def erase(self) -> None:
|
| 341 |
+
for node in reversed(self.nodes):
|
| 342 |
+
if len(node.users) == 0:
|
| 343 |
+
node.graph.erase_node(node)
|
| 344 |
+
|
| 345 |
+
@classmethod
|
| 346 |
+
def from_match(cls, match: List[torch.fx.Node]) -> "_Matmul":
|
| 347 |
+
assert len(match) in (1, 3)
|
| 348 |
+
assert match[0].target in (
|
| 349 |
+
aten.mm.default,
|
| 350 |
+
aten.reshape.default,
|
| 351 |
+
)
|
| 352 |
+
mm_node = match[0] if len(match) == 1 else match[1]
|
| 353 |
+
return _Matmul(
|
| 354 |
+
nodes=match,
|
| 355 |
+
A_node=cast(torch.fx.Node, match[0].args[0]),
|
| 356 |
+
B_node=cast(torch.fx.Node, mm_node.args[1]),
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
@dataclass
|
| 361 |
+
class _ScaledMatmul(_Matmul):
|
| 362 |
+
A_scale_node: torch.fx.Node
|
| 363 |
+
B_scale_node: torch.fx.Node
|
| 364 |
+
bias_node: Optional[torch.fx.Node]
|
| 365 |
+
result_scale_node: Optional[torch.fx.Node]
|
| 366 |
+
out_dtype: Optional[torch.dtype]
|
| 367 |
+
use_fast_accum: bool
|
| 368 |
+
|
| 369 |
+
def __post_init__(self):
|
| 370 |
+
super().__post_init__()
|
| 371 |
+
self.arg_ancestor_nodes |= _find_ancestors(self.A_scale_node)
|
| 372 |
+
self.arg_ancestor_nodes |= _find_ancestors(self.B_scale_node)
|
| 373 |
+
|
| 374 |
+
@classmethod
|
| 375 |
+
def from_match(cls, match: List[torch.fx.Node]) -> "_ScaledMatmul":
|
| 376 |
+
assert len(match) in (1, 3)
|
| 377 |
+
assert match[0].target in (
|
| 378 |
+
aten._scaled_mm.default,
|
| 379 |
+
aten.reshape.default,
|
| 380 |
+
)
|
| 381 |
+
mm_node = match[0] if len(match) == 1 else match[1]
|
| 382 |
+
|
| 383 |
+
def get_arg(node: torch.fx.Node, idx: int, default: Any) -> Any:
|
| 384 |
+
if idx >= len(node.args):
|
| 385 |
+
return default
|
| 386 |
+
return node.args[idx]
|
| 387 |
+
|
| 388 |
+
return _ScaledMatmul(
|
| 389 |
+
nodes=match,
|
| 390 |
+
A_node=cast(torch.fx.Node, match[0].args[0]),
|
| 391 |
+
B_node=cast(torch.fx.Node, mm_node.args[1]),
|
| 392 |
+
A_scale_node=cast(torch.fx.Node, mm_node.args[2]),
|
| 393 |
+
B_scale_node=cast(torch.fx.Node, mm_node.args[3]),
|
| 394 |
+
bias_node=get_arg(mm_node, 4, None),
|
| 395 |
+
result_scale_node=get_arg(mm_node, 5, None),
|
| 396 |
+
out_dtype=get_arg(mm_node, 6, None),
|
| 397 |
+
use_fast_accum=get_arg(mm_node, 7, False),
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def _find_reshape_mm_reshape(node: torch.fx.Node) -> List[_Matmul]:
|
| 402 |
+
if node.target != aten.reshape.default:
|
| 403 |
+
return []
|
| 404 |
+
|
| 405 |
+
matches = []
|
| 406 |
+
for mm_node in node.users:
|
| 407 |
+
if mm_node.target not in (aten.mm.default, aten._scaled_mm.default):
|
| 408 |
+
continue
|
| 409 |
+
for reshape_node in mm_node.users:
|
| 410 |
+
if reshape_node.target != aten.reshape.default:
|
| 411 |
+
continue
|
| 412 |
+
|
| 413 |
+
# Since the reshape -> mm -> reshape pattern would be subsumed into
|
| 414 |
+
# the fused op, we only match the patterns where the shape of the
|
| 415 |
+
# second reshape is matches the mm result produced by the fused op.
|
| 416 |
+
matmul_input_node = cast(torch.fx.Node, node.args[0])
|
| 417 |
+
B_node = cast(torch.fx.Node, mm_node.args[1])
|
| 418 |
+
matmul_out_shape = torch.Size(
|
| 419 |
+
[
|
| 420 |
+
*_get_tensor(matmul_input_node).shape[:-1],
|
| 421 |
+
_get_tensor(B_node).shape[-1],
|
| 422 |
+
]
|
| 423 |
+
)
|
| 424 |
+
if _get_tensor(reshape_node).shape != matmul_out_shape:
|
| 425 |
+
continue
|
| 426 |
+
matches.append([node, mm_node, reshape_node])
|
| 427 |
+
# If for some rare reason mm_node is being reshaped by two
|
| 428 |
+
# different reshape nodes, we only include mm_node once in the
|
| 429 |
+
# parsing result.
|
| 430 |
+
break
|
| 431 |
+
|
| 432 |
+
matmuls = []
|
| 433 |
+
for match in matches:
|
| 434 |
+
mm_node = match[1]
|
| 435 |
+
if mm_node.target == aten.mm.default:
|
| 436 |
+
matmul = _Matmul.from_match(match)
|
| 437 |
+
matmuls.append(matmul)
|
| 438 |
+
elif mm_node.target == aten._scaled_mm.default:
|
| 439 |
+
matmul = _ScaledMatmul.from_match(match)
|
| 440 |
+
matmuls.append(matmul)
|
| 441 |
+
else:
|
| 442 |
+
raise AssertionError(
|
| 443 |
+
"Expect the node's target to be either aten.mm.default or "
|
| 444 |
+
f"aten._scaled_mm.default. Got {mm_node.target}."
|
| 445 |
+
)
|
| 446 |
+
return matmuls
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def _find_consumer_matmuls(node: torch.fx.Node) -> List[_Matmul]:
|
| 450 |
+
"""
|
| 451 |
+
Find the matmuls that use `node` as the lhs argument.
|
| 452 |
+
"""
|
| 453 |
+
matmuls = []
|
| 454 |
+
for user in node.users:
|
| 455 |
+
# ND matmuls
|
| 456 |
+
if user.target == aten.reshape.default:
|
| 457 |
+
matmuls.extend(_find_reshape_mm_reshape(user))
|
| 458 |
+
# 2D matmuls
|
| 459 |
+
elif user.target == aten.mm.default:
|
| 460 |
+
matmul = _Matmul.from_match(match=[user])
|
| 461 |
+
matmuls.append(matmul)
|
| 462 |
+
elif user.target == aten._scaled_mm.default:
|
| 463 |
+
matmul = _ScaledMatmul.from_match([user])
|
| 464 |
+
matmuls.append(matmul)
|
| 465 |
+
return matmuls
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def _insert_fused_all_gather_matmul(
|
| 469 |
+
graph: torch.fx.Graph,
|
| 470 |
+
matmuls: List[_Matmul],
|
| 471 |
+
shard_node: torch.fx.Node,
|
| 472 |
+
gather_dim: int,
|
| 473 |
+
group_name: str,
|
| 474 |
+
) -> torch.fx.Node:
|
| 475 |
+
mm_types = set(map(type, matmuls))
|
| 476 |
+
assert len(mm_types) == 1
|
| 477 |
+
mm_type = next(iter(mm_types))
|
| 478 |
+
if mm_type == _Matmul:
|
| 479 |
+
B_nodes = [matmul.B_node for matmul in matmuls]
|
| 480 |
+
return graph.call_function(
|
| 481 |
+
torch.ops.symm_mem.fused_all_gather_matmul.default,
|
| 482 |
+
args=(shard_node, B_nodes, gather_dim, group_name),
|
| 483 |
+
)
|
| 484 |
+
elif mm_type == _ScaledMatmul:
|
| 485 |
+
scaled_matmuls = cast(List[_ScaledMatmul], matmuls)
|
| 486 |
+
return graph.call_function(
|
| 487 |
+
torch.ops.symm_mem.fused_all_gather_scaled_matmul.default,
|
| 488 |
+
args=(
|
| 489 |
+
shard_node,
|
| 490 |
+
[matmul.B_node for matmul in scaled_matmuls],
|
| 491 |
+
scaled_matmuls[0].A_scale_node,
|
| 492 |
+
[matmul.B_scale_node for matmul in scaled_matmuls],
|
| 493 |
+
gather_dim,
|
| 494 |
+
group_name,
|
| 495 |
+
[matmul.bias_node for matmul in scaled_matmuls],
|
| 496 |
+
[matmul.result_scale_node for matmul in scaled_matmuls],
|
| 497 |
+
[matmul.out_dtype for matmul in scaled_matmuls],
|
| 498 |
+
[matmul.use_fast_accum for matmul in scaled_matmuls],
|
| 499 |
+
),
|
| 500 |
+
)
|
| 501 |
+
else:
|
| 502 |
+
raise AssertionError(f"Unexpected matmul match type: {mm_type}")
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
|
| 506 |
+
"""
|
| 507 |
+
Fused the pattern
|
| 508 |
+
|
| 509 |
+
A = all_gather_tensor(A_shard, gather_dim, group_name)
|
| 510 |
+
C_0 = torch.matmul(A, B_0)
|
| 511 |
+
C_1 = torch.matmul(A, B_1)
|
| 512 |
+
C_2 = torch.matmul(A, B_2)
|
| 513 |
+
...
|
| 514 |
+
|
| 515 |
+
into
|
| 516 |
+
|
| 517 |
+
A, Cs = torch.ops.symm_mem.fused_all_gather_matmul(
|
| 518 |
+
A_shard, [B_0, B_1, B_2, ...], gather_dim, group_name,
|
| 519 |
+
)
|
| 520 |
+
"""
|
| 521 |
+
if (
|
| 522 |
+
not torch.distributed.is_available()
|
| 523 |
+
or not torch.distributed.is_nccl_available()
|
| 524 |
+
):
|
| 525 |
+
return
|
| 526 |
+
|
| 527 |
+
c10d = torch.ops._c10d_functional
|
| 528 |
+
from torch.distributed._symmetric_memory import (
|
| 529 |
+
is_symm_mem_enabled_for_group,
|
| 530 |
+
restride_A_shard_for_fused_all_gather_matmul,
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
shard_node, ag_node, ag_res_node, gather_dim, group_name = (
|
| 534 |
+
all_gather.shard_node,
|
| 535 |
+
all_gather.ag_node,
|
| 536 |
+
all_gather.res_node,
|
| 537 |
+
all_gather.gather_dim,
|
| 538 |
+
all_gather.group_name,
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
if not is_symm_mem_enabled_for_group(group_name):
|
| 542 |
+
return
|
| 543 |
+
|
| 544 |
+
if gather_dim >= len(_get_tensor(shard_node).shape) - 1:
|
| 545 |
+
# Decomposing the matmul on the K dimension is not supported
|
| 546 |
+
return
|
| 547 |
+
|
| 548 |
+
# Find consumer matmuls
|
| 549 |
+
matmuls = _find_consumer_matmuls(ag_res_node)
|
| 550 |
+
|
| 551 |
+
# The matmuls are only fusible if non-A args don't depend on the all-gather
|
| 552 |
+
# result node
|
| 553 |
+
matmuls = [
|
| 554 |
+
matmul
|
| 555 |
+
for matmul in matmuls
|
| 556 |
+
if all_gather.res_node not in matmul.arg_ancestor_nodes
|
| 557 |
+
]
|
| 558 |
+
|
| 559 |
+
if len(matmuls) == 0 or len(set(map(type, matmuls))) != 1:
|
| 560 |
+
return
|
| 561 |
+
|
| 562 |
+
# Fuse the all_gather_tensor with the eligible matmuls
|
| 563 |
+
graph = ag_node.graph
|
| 564 |
+
with graph.inserting_before(ag_node):
|
| 565 |
+
if "val" in shard_node.meta:
|
| 566 |
+
restrided = restride_A_shard_for_fused_all_gather_matmul(
|
| 567 |
+
_get_tensor(shard_node),
|
| 568 |
+
gather_dim,
|
| 569 |
+
)
|
| 570 |
+
shard_node = graph.call_function(
|
| 571 |
+
inductor_prims.force_stride_order,
|
| 572 |
+
args=(shard_node, restrided.stride()),
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
fused_node = _insert_fused_all_gather_matmul(
|
| 576 |
+
graph, matmuls, shard_node, gather_dim, group_name
|
| 577 |
+
)
|
| 578 |
+
new_ag_node = graph.call_function(
|
| 579 |
+
operator.getitem,
|
| 580 |
+
args=(fused_node, 0),
|
| 581 |
+
)
|
| 582 |
+
new_out_nodes = graph.call_function(
|
| 583 |
+
operator.getitem,
|
| 584 |
+
args=(fused_node, 1),
|
| 585 |
+
)
|
| 586 |
+
for idx, matmul in enumerate(matmuls):
|
| 587 |
+
new_out_node = graph.call_function(
|
| 588 |
+
operator.getitem,
|
| 589 |
+
args=(new_out_nodes, idx),
|
| 590 |
+
)
|
| 591 |
+
matmul.replace_with(new_out_node)
|
| 592 |
+
matmul.erase()
|
| 593 |
+
all_gather.replace_with(new_ag_node)
|
| 594 |
+
all_gather.erase()
|
| 595 |
+
|
| 596 |
+
# Raise ancestors of non-A args that are topologically ordered between
|
| 597 |
+
# ag_res_node and the matmul above fused_node.
|
| 598 |
+
order = {node: idx for idx, node in enumerate(graph.nodes)}
|
| 599 |
+
nodes_to_raise = sorted(
|
| 600 |
+
{x for matmul in matmuls for x in matmul.arg_ancestor_nodes},
|
| 601 |
+
key=lambda x: order[x],
|
| 602 |
+
)
|
| 603 |
+
for node in nodes_to_raise:
|
| 604 |
+
if order[node] > order[fused_node]:
|
| 605 |
+
fused_node.prepend(node)
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def _find_producer_matmul(node: torch.fx.Node) -> Optional[_Matmul]:
|
| 609 |
+
if node.target == aten.mm.default:
|
| 610 |
+
return _Matmul.from_match(match=[node])
|
| 611 |
+
elif node.target == aten._scaled_mm.default:
|
| 612 |
+
return _ScaledMatmul.from_match(match=[node])
|
| 613 |
+
elif node.target == aten.reshape.default:
|
| 614 |
+
reshape_node_1 = node
|
| 615 |
+
|
| 616 |
+
mm_node = reshape_node_1.args[0]
|
| 617 |
+
assert isinstance(mm_node, torch.fx.Node)
|
| 618 |
+
if mm_node.target not in (aten.mm.default, aten._scaled_mm.default):
|
| 619 |
+
return None
|
| 620 |
+
|
| 621 |
+
reshape_node_0 = mm_node.args[0]
|
| 622 |
+
assert isinstance(reshape_node_0, torch.fx.Node)
|
| 623 |
+
if reshape_node_0.target != aten.reshape.default:
|
| 624 |
+
return None
|
| 625 |
+
|
| 626 |
+
if mm_node.target == aten.mm.default:
|
| 627 |
+
return _Matmul.from_match(match=[reshape_node_0, mm_node, reshape_node_1])
|
| 628 |
+
elif mm_node.target == aten._scaled_mm.default:
|
| 629 |
+
return _ScaledMatmul.from_match(
|
| 630 |
+
match=[reshape_node_0, mm_node, reshape_node_1]
|
| 631 |
+
)
|
| 632 |
+
return None
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
def _insert_fused_matmul_reduce_scatter(
|
| 636 |
+
graph: torch.fx.Graph,
|
| 637 |
+
matmul: _Matmul,
|
| 638 |
+
reduce_op: str,
|
| 639 |
+
scatter_dim: int,
|
| 640 |
+
group_name: str,
|
| 641 |
+
) -> torch.fx.Node:
|
| 642 |
+
if type(matmul) == _Matmul:
|
| 643 |
+
return graph.call_function(
|
| 644 |
+
torch.ops.symm_mem.fused_matmul_reduce_scatter.default,
|
| 645 |
+
args=(matmul.A_node, matmul.B_node, reduce_op, scatter_dim, group_name),
|
| 646 |
+
)
|
| 647 |
+
elif type(matmul) == _ScaledMatmul:
|
| 648 |
+
return graph.call_function(
|
| 649 |
+
torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default,
|
| 650 |
+
args=(
|
| 651 |
+
matmul.A_node,
|
| 652 |
+
matmul.B_node,
|
| 653 |
+
matmul.A_scale_node,
|
| 654 |
+
matmul.B_scale_node,
|
| 655 |
+
reduce_op,
|
| 656 |
+
scatter_dim,
|
| 657 |
+
group_name,
|
| 658 |
+
matmul.bias_node,
|
| 659 |
+
matmul.result_scale_node,
|
| 660 |
+
matmul.out_dtype,
|
| 661 |
+
matmul.use_fast_accum,
|
| 662 |
+
),
|
| 663 |
+
)
|
| 664 |
+
else:
|
| 665 |
+
raise AssertionError(f"Unexpected matmul match type: {type(matmul)}")
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
|
| 669 |
+
"""
|
| 670 |
+
Fused the pattern
|
| 671 |
+
|
| 672 |
+
reduce_scatter_tensor(A @ B, scatter_dim, group_name)
|
| 673 |
+
|
| 674 |
+
into
|
| 675 |
+
|
| 676 |
+
torch.ops.symm_mem.fused_matmul_reduce_scatter(
|
| 677 |
+
A, B, scatter_dim, group_name,
|
| 678 |
+
)
|
| 679 |
+
"""
|
| 680 |
+
if (
|
| 681 |
+
not torch.distributed.is_available()
|
| 682 |
+
or not torch.distributed.is_nccl_available()
|
| 683 |
+
):
|
| 684 |
+
return
|
| 685 |
+
|
| 686 |
+
c10d = torch.ops._c10d_functional
|
| 687 |
+
from torch.distributed._symmetric_memory import (
|
| 688 |
+
is_symm_mem_enabled_for_group,
|
| 689 |
+
restride_A_for_fused_matmul_reduce_scatter,
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
input_node, rs_node, rs_res_node, reduce_op, scatter_dim, group_name = (
|
| 693 |
+
reduce_scatter.input_node,
|
| 694 |
+
reduce_scatter.rs_node,
|
| 695 |
+
reduce_scatter.res_node,
|
| 696 |
+
reduce_scatter.reduce_op,
|
| 697 |
+
reduce_scatter.scatter_dim,
|
| 698 |
+
reduce_scatter.group_name,
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
if not is_symm_mem_enabled_for_group(group_name):
|
| 702 |
+
return
|
| 703 |
+
|
| 704 |
+
# Currently fused_matmul_reduce_scatter doesn't return the matmul result,
|
| 705 |
+
# so we can't apply the fusion if the matmul result is used by multiple
|
| 706 |
+
# users. This is not a fundamental limitation of the fused op and can be
|
| 707 |
+
# addressed if needed.
|
| 708 |
+
if len(input_node.users) != 1:
|
| 709 |
+
return
|
| 710 |
+
|
| 711 |
+
matmul = _find_producer_matmul(input_node)
|
| 712 |
+
if matmul is None:
|
| 713 |
+
return
|
| 714 |
+
|
| 715 |
+
if rs_res_node in matmul.arg_ancestor_nodes:
|
| 716 |
+
return
|
| 717 |
+
|
| 718 |
+
graph = rs_res_node.graph
|
| 719 |
+
with graph.inserting_before(rs_res_node):
|
| 720 |
+
if "val" in matmul.A_node.meta:
|
| 721 |
+
restrided = restride_A_for_fused_matmul_reduce_scatter(
|
| 722 |
+
_get_tensor(matmul.A_node),
|
| 723 |
+
scatter_dim,
|
| 724 |
+
)
|
| 725 |
+
matmul.A_node = graph.call_function(
|
| 726 |
+
inductor_prims.force_stride_order,
|
| 727 |
+
args=(matmul.A_node, restrided.stride()),
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
fused_node = _insert_fused_matmul_reduce_scatter(
|
| 731 |
+
graph,
|
| 732 |
+
matmul,
|
| 733 |
+
reduce_op,
|
| 734 |
+
scatter_dim,
|
| 735 |
+
group_name,
|
| 736 |
+
)
|
| 737 |
+
reduce_scatter.replace_with(fused_node)
|
| 738 |
+
reduce_scatter.erase()
|
| 739 |
+
matmul.erase()
|
| 740 |
+
|
| 741 |
+
order = {node: idx for idx, node in enumerate(graph.nodes)}
|
| 742 |
+
nodes_to_raise = sorted(
|
| 743 |
+
matmul.arg_ancestor_nodes,
|
| 744 |
+
key=lambda x: order[x],
|
| 745 |
+
)
|
| 746 |
+
for node in nodes_to_raise:
|
| 747 |
+
if order[node] > order[fused_node]:
|
| 748 |
+
fused_node.prepend(node)
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
def _get_node_to_ancestors(
|
| 752 |
+
graph: torch.fx.Graph,
|
| 753 |
+
) -> Dict[torch.fx.Node, Set[torch.fx.Node]]:
|
| 754 |
+
"""
|
| 755 |
+
Compute the ancestors for all nodes in a graph.
|
| 756 |
+
"""
|
| 757 |
+
node_to_ancestors = defaultdict(set)
|
| 758 |
+
for node in graph.nodes:
|
| 759 |
+
node_to_ancestors[node] = set(node.all_input_nodes)
|
| 760 |
+
for dep in node.all_input_nodes:
|
| 761 |
+
node_to_ancestors[node] |= node_to_ancestors[dep]
|
| 762 |
+
|
| 763 |
+
return node_to_ancestors
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
def _get_collective_to_overlappable_nodes(
|
| 767 |
+
graph: torch.fx.Graph,
|
| 768 |
+
) -> Dict[torch.fx.Node, List[torch.fx.Node]]:
|
| 769 |
+
"""
|
| 770 |
+
For each collective in the graph, find nodes that are neither ancestors nor
|
| 771 |
+
descendants of the collective.
|
| 772 |
+
"""
|
| 773 |
+
|
| 774 |
+
def is_collective(node) -> bool:
|
| 775 |
+
# Only consider all-gather and reduce-scatter in the context of
|
| 776 |
+
# micro-pipeline TP.
|
| 777 |
+
return node.target in [
|
| 778 |
+
torch.ops._c10d_functional.all_gather_into_tensor.default,
|
| 779 |
+
torch.ops._c10d_functional.reduce_scatter_tensor.default,
|
| 780 |
+
]
|
| 781 |
+
|
| 782 |
+
node_to_ancestors = _get_node_to_ancestors(graph)
|
| 783 |
+
collective_to_overlappable_nodes = defaultdict(list)
|
| 784 |
+
for node in graph.nodes:
|
| 785 |
+
if not is_collective(node):
|
| 786 |
+
continue
|
| 787 |
+
for x in graph.nodes:
|
| 788 |
+
if (
|
| 789 |
+
node not in node_to_ancestors[x]
|
| 790 |
+
and x not in node_to_ancestors[node]
|
| 791 |
+
and x.op == "call_function"
|
| 792 |
+
):
|
| 793 |
+
collective_to_overlappable_nodes[node].append(x)
|
| 794 |
+
|
| 795 |
+
return collective_to_overlappable_nodes
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
def _get_unexposed_collectives(graph: torch.fx.Graph) -> List[torch.fx.Node]:
|
| 799 |
+
"""
|
| 800 |
+
Find all unexposed collectives in the graph.
|
| 801 |
+
|
| 802 |
+
Because we don't have the runtime estimate, this function is a rough
|
| 803 |
+
estimation using the following strong/hand-wavy assumptions:
|
| 804 |
+
|
| 805 |
+
- Only a predefined set of "compute intensive" operation can hide a collective.
|
| 806 |
+
- Any "compute intensive" operation can hide exactly one collective.
|
| 807 |
+
"""
|
| 808 |
+
|
| 809 |
+
def _is_compute_intensive(node: torch.fx.Node) -> bool:
|
| 810 |
+
return node.target in [torch.ops.aten.mm.default]
|
| 811 |
+
|
| 812 |
+
collective_to_overlapping_candidates = defaultdict(list)
|
| 813 |
+
available_nodes = set()
|
| 814 |
+
collective_to_overlappable_nodes = _get_collective_to_overlappable_nodes(graph)
|
| 815 |
+
for collective, overlappable_nodes in collective_to_overlappable_nodes.items():
|
| 816 |
+
candidates = [x for x in overlappable_nodes if _is_compute_intensive(x)]
|
| 817 |
+
collective_to_overlapping_candidates[collective] = candidates
|
| 818 |
+
available_nodes |= set(candidates)
|
| 819 |
+
|
| 820 |
+
unexposed_collectives = []
|
| 821 |
+
for (
|
| 822 |
+
collective,
|
| 823 |
+
overlapping_candidates,
|
| 824 |
+
) in collective_to_overlapping_candidates.items():
|
| 825 |
+
# Each collective consumes exactly one overlapping candidate
|
| 826 |
+
for x in overlapping_candidates:
|
| 827 |
+
if x in available_nodes:
|
| 828 |
+
unexposed_collectives.append(collective)
|
| 829 |
+
available_nodes.remove(x)
|
| 830 |
+
break
|
| 831 |
+
return unexposed_collectives
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
def micro_pipeline_tp_pass(graph: torch.fx.Graph):
|
| 835 |
+
all_gathers = find_all_gather_patterns(graph)
|
| 836 |
+
reduce_scatters = find_reduce_scatter_patterns(graph)
|
| 837 |
+
|
| 838 |
+
# When a collective can be hidden through either simple overlapping or
|
| 839 |
+
# micro-pipeline TP, we prefer simple overlapping to avoid the overhead
|
| 840 |
+
# associated with decomposition. If reorder_for_compute_comm_overlap is
|
| 841 |
+
# enabled, we identify collectives that can be hidden through simple
|
| 842 |
+
# overlapping and exclude them from micro-pipeline TP candidates.
|
| 843 |
+
if config.reorder_for_compute_comm_overlap:
|
| 844 |
+
unexposed_collectives = _get_unexposed_collectives(graph)
|
| 845 |
+
all_gathers = [x for x in all_gathers if x.ag_node not in unexposed_collectives]
|
| 846 |
+
reduce_scatters = [
|
| 847 |
+
x for x in reduce_scatters if x.rs_node not in unexposed_collectives
|
| 848 |
+
]
|
| 849 |
+
|
| 850 |
+
for all_gather in all_gathers:
|
| 851 |
+
fuse_all_gather_matmul(all_gather)
|
| 852 |
+
|
| 853 |
+
for reduce_scatter in reduce_scatters:
|
| 854 |
+
fuse_matmul_reduce_scatter(reduce_scatter)
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/misc_patterns.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
from typing import Dict, Set, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch._dynamo.utils import counters
|
| 7 |
+
from torch._ops import OpOverload, OpOverloadPacket
|
| 8 |
+
|
| 9 |
+
from ..pattern_matcher import fwd_only, register_replacement
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
aten = torch.ops.aten
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@functools.lru_cache(None)
|
| 16 |
+
def _misc_patterns_init():
|
| 17 |
+
from .joint_graph import patterns as joint_graph_patterns
|
| 18 |
+
from .post_grad import pass_patterns as post_grad_patterns_all
|
| 19 |
+
|
| 20 |
+
post_grad_patterns = post_grad_patterns_all[1] # medium priority
|
| 21 |
+
|
| 22 |
+
if torch.cuda.is_available():
|
| 23 |
+
# workaround https://github.com/pytorch/pytorch/issues/97894
|
| 24 |
+
device = "cuda"
|
| 25 |
+
else:
|
| 26 |
+
device = "cpu"
|
| 27 |
+
|
| 28 |
+
# These patterns do 2 things
|
| 29 |
+
# 1. Since we know that index is completely unique, we can codegen it using
|
| 30 |
+
# stores instead of atomic adds, which is quite a bit faster.
|
| 31 |
+
# 2. Also, since we are guaranteed that they are completely within bounds,
|
| 32 |
+
# we can use unsafe indexing and skip debug asserts
|
| 33 |
+
def randperm_index_add_pattern(x, y):
|
| 34 |
+
index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]]
|
| 35 |
+
return torch.index_add(x, dim=0, source=y, index=index), index
|
| 36 |
+
|
| 37 |
+
def randperm_index_add_replacement(x, y):
|
| 38 |
+
index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]]
|
| 39 |
+
return (
|
| 40 |
+
torch.ops.aten._unsafe_index_put(
|
| 41 |
+
x, (index,), aten._unsafe_index(x, (index,)) + y, accumulate=False
|
| 42 |
+
),
|
| 43 |
+
index,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
register_replacement(
|
| 47 |
+
randperm_index_add_pattern,
|
| 48 |
+
randperm_index_add_replacement,
|
| 49 |
+
[torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)],
|
| 50 |
+
fwd_only,
|
| 51 |
+
[post_grad_patterns, joint_graph_patterns],
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def randperm_index_pattern(x, slice_shape):
|
| 55 |
+
index = torch.randperm(x.shape[0], device=x.device)[:slice_shape]
|
| 56 |
+
return torch.ops.aten.index(x, (index,)), index
|
| 57 |
+
|
| 58 |
+
def randperm_index_replacement(x, slice_shape):
|
| 59 |
+
index = torch.randperm(x.shape[0], device=x.device)[:slice_shape]
|
| 60 |
+
return torch.ops.aten._unsafe_index(x, (index,)), index
|
| 61 |
+
|
| 62 |
+
register_replacement(
|
| 63 |
+
randperm_index_pattern,
|
| 64 |
+
randperm_index_replacement,
|
| 65 |
+
[torch.empty(4, 8, device=device)],
|
| 66 |
+
fwd_only,
|
| 67 |
+
[post_grad_patterns, joint_graph_patterns],
|
| 68 |
+
scalar_workaround={"slice_shape": 42},
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class NumpyCompatNormalization:
|
| 73 |
+
numpy_compat: Dict[str, Tuple[str, ...]] = {
|
| 74 |
+
"dim": ("axis",),
|
| 75 |
+
"keepdim": ("keepdims",),
|
| 76 |
+
"input": ("x", "a", "x1"),
|
| 77 |
+
"other": ("x2",),
|
| 78 |
+
}
|
| 79 |
+
inverse_mapping: Dict[str, str]
|
| 80 |
+
cache: Dict["torch.fx.graph.Target", Set[str]]
|
| 81 |
+
|
| 82 |
+
def __init__(self) -> None:
|
| 83 |
+
self.cache = {} # callable -> tuple of replaceable args e.g. ["axis"]
|
| 84 |
+
self.inverse_mapping = {}
|
| 85 |
+
for actual_kwarg, numpy_kwargs in self.numpy_compat.items():
|
| 86 |
+
for numpy_kwarg in numpy_kwargs:
|
| 87 |
+
assert numpy_kwarg not in self.inverse_mapping
|
| 88 |
+
self.inverse_mapping[numpy_kwarg] = actual_kwarg
|
| 89 |
+
|
| 90 |
+
def __call__(self, graph: torch.fx.Graph):
|
| 91 |
+
for node in graph.nodes:
|
| 92 |
+
if node.op != "call_function":
|
| 93 |
+
continue
|
| 94 |
+
if isinstance(node.target, (OpOverload, OpOverloadPacket)):
|
| 95 |
+
# only applies to torch ops; e.g. torch.stack(axis=1) works, torch.ops.aten.stack(axis=1) doesn't.
|
| 96 |
+
continue
|
| 97 |
+
kwargs = node.kwargs
|
| 98 |
+
|
| 99 |
+
if node.target in self.cache:
|
| 100 |
+
replaceable_kwargs = self.cache[node.target]
|
| 101 |
+
else:
|
| 102 |
+
signatures = torch.fx.operator_schemas.get_signature_for_torch_op(
|
| 103 |
+
node.target
|
| 104 |
+
)
|
| 105 |
+
signatures = () if signatures is None else signatures
|
| 106 |
+
replaceable_kwargs = set()
|
| 107 |
+
for sig in signatures:
|
| 108 |
+
for param_name in sig.parameters.keys():
|
| 109 |
+
if param_name in self.numpy_compat:
|
| 110 |
+
replaceable_kwargs.update(self.numpy_compat[param_name])
|
| 111 |
+
|
| 112 |
+
self.cache[node.target] = replaceable_kwargs
|
| 113 |
+
|
| 114 |
+
if not replaceable_kwargs:
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
new_kwargs = {}
|
| 118 |
+
kwargs_changed = False
|
| 119 |
+
for k, v in kwargs.items():
|
| 120 |
+
if k in replaceable_kwargs:
|
| 121 |
+
kwargs_changed = True
|
| 122 |
+
new_kwargs[self.inverse_mapping[k]] = v
|
| 123 |
+
else:
|
| 124 |
+
new_kwargs[k] = v
|
| 125 |
+
|
| 126 |
+
if kwargs_changed:
|
| 127 |
+
node.kwargs = torch.fx.immutable_collections.immutable_dict(new_kwargs)
|
| 128 |
+
counters["inductor"]["numpy_compat_normalization"] += 1
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
numpy_compat_normalization = NumpyCompatNormalization()
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py
ADDED
|
@@ -0,0 +1,1266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
import operator
|
| 4 |
+
from functools import reduce
|
| 5 |
+
from typing import Any, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.fx.experimental.symbolic_shapes import has_free_symbols
|
| 9 |
+
|
| 10 |
+
from .. import ir
|
| 11 |
+
from ..lowering import lowerings as L
|
| 12 |
+
from ..pattern_matcher import (
|
| 13 |
+
Arg,
|
| 14 |
+
CallFunction,
|
| 15 |
+
filter_nodes,
|
| 16 |
+
get_arg_value,
|
| 17 |
+
KeywordArg,
|
| 18 |
+
MULTIPLE,
|
| 19 |
+
)
|
| 20 |
+
from ..virtualized import ops, V
|
| 21 |
+
from .freezing_patterns import register_freezing_graph_pattern
|
| 22 |
+
from .post_grad import register_lowering_pattern
|
| 23 |
+
from .quantization import (
|
| 24 |
+
_register_quantization_lowerings,
|
| 25 |
+
_register_quantization_weight_pack_pass,
|
| 26 |
+
_register_woq_lowerings,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if torch._C._has_mkldnn:
|
| 31 |
+
aten = torch.ops.aten
|
| 32 |
+
mkldnn = torch.ops.mkldnn
|
| 33 |
+
prims = torch.ops.prims
|
| 34 |
+
|
| 35 |
+
_conv_args = [Arg() for _ in range(10)]
|
| 36 |
+
_linear_args = [Arg() for _ in range(6)]
|
| 37 |
+
_conv_transpose_args = [Arg() for _ in range(11)]
|
| 38 |
+
|
| 39 |
+
def _conv_call(users=1):
|
| 40 |
+
return CallFunction(
|
| 41 |
+
mkldnn._convolution_pointwise.default, *_conv_args, _users=users
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def _linear_call(users=1):
|
| 45 |
+
return CallFunction(
|
| 46 |
+
mkldnn._linear_pointwise.default, *_linear_args, _users=users
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def _conv_transpose_call(users=1):
|
| 50 |
+
return CallFunction(
|
| 51 |
+
mkldnn._convolution_transpose_pointwise.default,
|
| 52 |
+
*_conv_transpose_args,
|
| 53 |
+
_users=users,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def _to_float(input_call, users=1):
|
| 57 |
+
return CallFunction(
|
| 58 |
+
prims.convert_element_type.default,
|
| 59 |
+
input_call,
|
| 60 |
+
KeywordArg("to_float"),
|
| 61 |
+
_users=users,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def _to_bf16(input_call):
|
| 65 |
+
return CallFunction(
|
| 66 |
+
prims.convert_element_type.default,
|
| 67 |
+
input_call,
|
| 68 |
+
KeywordArg("to_bf16"),
|
| 69 |
+
_users=1,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def _to_fp16(input_call):
|
| 73 |
+
return CallFunction(
|
| 74 |
+
prims.convert_element_type.default,
|
| 75 |
+
input_call,
|
| 76 |
+
KeywordArg("to_fp16"),
|
| 77 |
+
_users=1,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def _unary_fusion_pattern(unary_fusion, call_fn, users, lowp_dtype):
|
| 81 |
+
# only insert to_dtype if lowp_dtype is True
|
| 82 |
+
computation_call = (
|
| 83 |
+
_to_float(call_fn(), users=users) if lowp_dtype else call_fn(users=users)
|
| 84 |
+
)
|
| 85 |
+
out = unary_fusion(computation_call)
|
| 86 |
+
if lowp_dtype == torch.bfloat16:
|
| 87 |
+
return _to_bf16(out)
|
| 88 |
+
elif lowp_dtype == torch.float16:
|
| 89 |
+
return _to_fp16(out)
|
| 90 |
+
else:
|
| 91 |
+
return out
|
| 92 |
+
|
| 93 |
+
def _gelu_fusion_1(computation_call):
|
| 94 |
+
return CallFunction(
|
| 95 |
+
aten.mul,
|
| 96 |
+
CallFunction(aten.mul, computation_call, 0.5),
|
| 97 |
+
CallFunction(
|
| 98 |
+
aten.add,
|
| 99 |
+
CallFunction(
|
| 100 |
+
aten.erf,
|
| 101 |
+
CallFunction(aten.mul, computation_call, 0.7071067811865476),
|
| 102 |
+
),
|
| 103 |
+
1,
|
| 104 |
+
),
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def _gelu_fusion_2(computation_call):
|
| 108 |
+
return CallFunction(
|
| 109 |
+
aten.mul,
|
| 110 |
+
CallFunction(aten.mul, computation_call, 0.5),
|
| 111 |
+
CallFunction(
|
| 112 |
+
aten.add,
|
| 113 |
+
CallFunction(
|
| 114 |
+
aten.tanh,
|
| 115 |
+
CallFunction(
|
| 116 |
+
aten.mul,
|
| 117 |
+
CallFunction(
|
| 118 |
+
aten.add,
|
| 119 |
+
computation_call,
|
| 120 |
+
CallFunction(
|
| 121 |
+
aten.mul,
|
| 122 |
+
CallFunction(
|
| 123 |
+
aten.mul,
|
| 124 |
+
CallFunction(
|
| 125 |
+
aten.mul, computation_call, computation_call
|
| 126 |
+
),
|
| 127 |
+
computation_call,
|
| 128 |
+
),
|
| 129 |
+
0.044715,
|
| 130 |
+
),
|
| 131 |
+
),
|
| 132 |
+
0.7978845608028654,
|
| 133 |
+
),
|
| 134 |
+
),
|
| 135 |
+
1,
|
| 136 |
+
),
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
def _hardswish_fusion(computation_call):
|
| 140 |
+
return CallFunction(
|
| 141 |
+
aten.div,
|
| 142 |
+
CallFunction(
|
| 143 |
+
aten.mul,
|
| 144 |
+
computation_call,
|
| 145 |
+
CallFunction(
|
| 146 |
+
aten.clamp_max,
|
| 147 |
+
CallFunction(
|
| 148 |
+
aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0
|
| 149 |
+
),
|
| 150 |
+
6,
|
| 151 |
+
),
|
| 152 |
+
),
|
| 153 |
+
6,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def _silu_fusion(computation_call):
|
| 157 |
+
return CallFunction(
|
| 158 |
+
aten.mul, computation_call, CallFunction(aten.sigmoid, computation_call)
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def _hardsigmoid_fusion(computation_call):
|
| 162 |
+
return CallFunction(
|
| 163 |
+
aten.div,
|
| 164 |
+
CallFunction(
|
| 165 |
+
aten.clamp_max,
|
| 166 |
+
CallFunction(
|
| 167 |
+
aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0
|
| 168 |
+
),
|
| 169 |
+
6,
|
| 170 |
+
),
|
| 171 |
+
6,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
def _leaky_relu_fusion(computation_call):
|
| 175 |
+
return CallFunction(
|
| 176 |
+
aten.where,
|
| 177 |
+
CallFunction(aten.gt, computation_call, 0),
|
| 178 |
+
computation_call,
|
| 179 |
+
CallFunction(aten.mul, computation_call, KeywordArg("negative_slope")),
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def _hardtanh_fusion(computation_call):
|
| 183 |
+
return CallFunction(
|
| 184 |
+
aten.clamp_max,
|
| 185 |
+
CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")),
|
| 186 |
+
KeywordArg("max_value"),
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def _combined_fusion(computation_call, elementwise_op):
|
| 190 |
+
return CallFunction(elementwise_op, computation_call)
|
| 191 |
+
|
| 192 |
+
# binary_op(other, computation_op)
|
| 193 |
+
def _binary_fusion_v1(computation_call, binary_fn):
|
| 194 |
+
return CallFunction(binary_fn, KeywordArg("other"), computation_call)
|
| 195 |
+
|
| 196 |
+
# binary_op(computation_op, other)
|
| 197 |
+
def _binary_fusion_v2(computation_call, binary_fn):
|
| 198 |
+
return CallFunction(binary_fn, computation_call, KeywordArg("other"))
|
| 199 |
+
|
| 200 |
+
def _is_single_computation_op(computation_op, lowp_dtype=None):
|
| 201 |
+
def fn(match):
|
| 202 |
+
computation_nodes = filter_nodes(match.nodes, computation_op)
|
| 203 |
+
|
| 204 |
+
if lowp_dtype:
|
| 205 |
+
output_node_meta = match.output_node().meta.get("val")
|
| 206 |
+
if output_node_meta.dtype != lowp_dtype:
|
| 207 |
+
return False
|
| 208 |
+
|
| 209 |
+
if len(computation_nodes) < 1:
|
| 210 |
+
return False
|
| 211 |
+
if any(n.args[-3] != "none" for n in computation_nodes):
|
| 212 |
+
return False
|
| 213 |
+
return True
|
| 214 |
+
|
| 215 |
+
return fn
|
| 216 |
+
|
| 217 |
+
def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None):
|
| 218 |
+
def fn(match):
|
| 219 |
+
matched = _is_single_computation_op(computation_op, lowp_dtype)(match)
|
| 220 |
+
computation_node = filter_nodes(match.nodes, computation_op)[0]
|
| 221 |
+
if lowp_dtype:
|
| 222 |
+
conversion_dtype_nodes = filter_nodes(
|
| 223 |
+
match.nodes, prims.convert_element_type.default
|
| 224 |
+
)
|
| 225 |
+
if len(conversion_dtype_nodes) != 2:
|
| 226 |
+
return False
|
| 227 |
+
# fusion pattern is always in the form of computation_op + to_float32 + unary_op + to_bfloat16
|
| 228 |
+
if computation_node == conversion_dtype_nodes[0].args[0]:
|
| 229 |
+
to_float = conversion_dtype_nodes[0].args[1]
|
| 230 |
+
to_lp = conversion_dtype_nodes[1].args[1]
|
| 231 |
+
else:
|
| 232 |
+
to_float = conversion_dtype_nodes[1].args[1]
|
| 233 |
+
to_lp = conversion_dtype_nodes[0].args[1]
|
| 234 |
+
matched = matched and to_float == torch.float and to_lp == lowp_dtype
|
| 235 |
+
return matched
|
| 236 |
+
|
| 237 |
+
return fn
|
| 238 |
+
|
| 239 |
+
def _register_unary_fusion_lowering(
|
| 240 |
+
pattern, unary_attr, computation_op, lowp_dtype=None
|
| 241 |
+
):
|
| 242 |
+
@register_lowering_pattern(
|
| 243 |
+
pattern,
|
| 244 |
+
extra_check=_is_valid_computation_unary_fusion(computation_op, lowp_dtype),
|
| 245 |
+
)
|
| 246 |
+
def fn(match, *args, **kwargs):
|
| 247 |
+
computation_args = list(args)[:-3] + [
|
| 248 |
+
unary_attr.op_name,
|
| 249 |
+
unary_attr.scalars_attr,
|
| 250 |
+
unary_attr.algorithm_attr,
|
| 251 |
+
]
|
| 252 |
+
return L[computation_op](*computation_args)
|
| 253 |
+
|
| 254 |
+
return fn
|
| 255 |
+
|
| 256 |
+
def _register_leaky_relu_fusion_lowering(pattern, computation_op, lowp_dtype=None):
|
| 257 |
+
@register_lowering_pattern(
|
| 258 |
+
pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype)
|
| 259 |
+
)
|
| 260 |
+
def fn(match, *args, **kwargs):
|
| 261 |
+
negative_slope = kwargs.get("negative_slope")
|
| 262 |
+
if isinstance(negative_slope, ir.TensorBox):
|
| 263 |
+
matched = False
|
| 264 |
+
else: # inp is a Number
|
| 265 |
+
matched = True
|
| 266 |
+
if lowp_dtype:
|
| 267 |
+
dtype1 = kwargs.get("to_float")
|
| 268 |
+
dtype2 = (
|
| 269 |
+
kwargs.get("to_bf16")
|
| 270 |
+
if lowp_dtype == torch.bfloat16
|
| 271 |
+
else kwargs.get("to_fp16")
|
| 272 |
+
)
|
| 273 |
+
matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype
|
| 274 |
+
computation_args = list(args)
|
| 275 |
+
if matched:
|
| 276 |
+
computation_args = computation_args[:-3] + [
|
| 277 |
+
"leaky_relu",
|
| 278 |
+
[negative_slope],
|
| 279 |
+
"",
|
| 280 |
+
]
|
| 281 |
+
return L[computation_op](*computation_args)
|
| 282 |
+
else:
|
| 283 |
+
# computation_args += ["none", [], ""]
|
| 284 |
+
out = L[computation_op](*computation_args)
|
| 285 |
+
if lowp_dtype:
|
| 286 |
+
out = L[prims.convert_element_type.default](out, dtype=torch.float)
|
| 287 |
+
out = L[aten.where](
|
| 288 |
+
L[aten.gt](out, 0),
|
| 289 |
+
out,
|
| 290 |
+
L[aten.mul](out, negative_slope),
|
| 291 |
+
)
|
| 292 |
+
if lowp_dtype:
|
| 293 |
+
out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined]
|
| 294 |
+
return out
|
| 295 |
+
|
| 296 |
+
return fn
|
| 297 |
+
|
| 298 |
+
def _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype=None):
|
| 299 |
+
@register_lowering_pattern(
|
| 300 |
+
pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype)
|
| 301 |
+
)
|
| 302 |
+
def fn(match, *args, **kwargs):
|
| 303 |
+
min_value = kwargs.get("min_value")
|
| 304 |
+
max_value = kwargs.get("max_value")
|
| 305 |
+
if isinstance(min_value, ir.TensorBox) or isinstance(
|
| 306 |
+
max_value, ir.TensorBox
|
| 307 |
+
):
|
| 308 |
+
matched = False
|
| 309 |
+
else: # inp is a Number
|
| 310 |
+
assert max_value is not None
|
| 311 |
+
matched = min_value <= max_value
|
| 312 |
+
if lowp_dtype:
|
| 313 |
+
dtype1 = kwargs.get("to_float")
|
| 314 |
+
dtype2 = (
|
| 315 |
+
kwargs.get("to_bf16")
|
| 316 |
+
if lowp_dtype == torch.bfloat16
|
| 317 |
+
else kwargs.get("to_fp16")
|
| 318 |
+
)
|
| 319 |
+
matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype
|
| 320 |
+
computation_args = list(args)
|
| 321 |
+
if matched:
|
| 322 |
+
computation_args = computation_args[:-3] + [
|
| 323 |
+
"hardtanh",
|
| 324 |
+
[min_value, max_value],
|
| 325 |
+
"",
|
| 326 |
+
]
|
| 327 |
+
return L[computation_op](*computation_args)
|
| 328 |
+
else:
|
| 329 |
+
out = L[computation_op](*computation_args)
|
| 330 |
+
if lowp_dtype:
|
| 331 |
+
out = L[prims.convert_element_type.default](out, dtype=torch.float)
|
| 332 |
+
out = L[aten.clamp_max](L[aten.clamp_min](out, min_value), max_value)
|
| 333 |
+
if lowp_dtype:
|
| 334 |
+
out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined]
|
| 335 |
+
return out
|
| 336 |
+
|
| 337 |
+
return fn
|
| 338 |
+
|
| 339 |
+
_binary_attr = {
|
| 340 |
+
aten.add: "add",
|
| 341 |
+
ops.add: "add",
|
| 342 |
+
aten.sub: "sub",
|
| 343 |
+
ops.sub: "sub",
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
def _is_valid_binary(match, fn):
|
| 347 |
+
binary_nodes = filter_nodes(match.nodes, fn)
|
| 348 |
+
if len(binary_nodes) < 1:
|
| 349 |
+
return False
|
| 350 |
+
|
| 351 |
+
def get_meta_value(argument: torch.fx.node.Argument):
|
| 352 |
+
# Only torch.fx.Node is expected to have meta.
|
| 353 |
+
if isinstance(argument, torch.fx.Node):
|
| 354 |
+
return argument.meta.get("val", None)
|
| 355 |
+
return None
|
| 356 |
+
|
| 357 |
+
if any(
|
| 358 |
+
not isinstance(get_meta_value(n.args[0]), torch.Tensor)
|
| 359 |
+
or not isinstance(get_meta_value(n.args[1]), torch.Tensor)
|
| 360 |
+
for n in binary_nodes
|
| 361 |
+
):
|
| 362 |
+
return False
|
| 363 |
+
# check alpha is one.
|
| 364 |
+
if any(
|
| 365 |
+
get_arg_value(n, 2, kwarg_name="alpha") != 1.0
|
| 366 |
+
and get_arg_value(n, 2, kwarg_name="alpha") is not None
|
| 367 |
+
for n in binary_nodes
|
| 368 |
+
):
|
| 369 |
+
return False
|
| 370 |
+
if any(
|
| 371 |
+
get_meta_value(n.args[0]).size() != get_meta_value(n.args[1]).size()
|
| 372 |
+
or get_meta_value(n.args[0]).device != get_meta_value(n.args[1]).device
|
| 373 |
+
or get_meta_value(n.args[0]).dtype != get_meta_value(n.args[1]).dtype
|
| 374 |
+
for n in binary_nodes
|
| 375 |
+
):
|
| 376 |
+
return False
|
| 377 |
+
# check args[0] and args[1] is not same
|
| 378 |
+
if any(n.args[0] == n.args[1] for n in binary_nodes):
|
| 379 |
+
return False
|
| 380 |
+
return True
|
| 381 |
+
|
| 382 |
+
def _is_valid_computation_binary(computation_op, binary_op, other_index=None):
|
| 383 |
+
def fn(match):
|
| 384 |
+
if not _is_single_computation_op(computation_op)(match):
|
| 385 |
+
return False
|
| 386 |
+
if not _is_valid_binary(match, binary_op):
|
| 387 |
+
return False
|
| 388 |
+
return True
|
| 389 |
+
|
| 390 |
+
return fn
|
| 391 |
+
|
| 392 |
+
def _get_remaining_users(extra_input_node, compute_node):
|
| 393 |
+
# Think about this pattern:
|
| 394 |
+
# ReLU
|
| 395 |
+
# / \
|
| 396 |
+
# Conv1
|
| 397 |
+
# / \
|
| 398 |
+
# Conv2
|
| 399 |
+
# \ /
|
| 400 |
+
# Add
|
| 401 |
+
# Although, the extra input node (ReLU) has more than 1 users: Conv1 and Add.
|
| 402 |
+
# The Conv1 is the ancestor node of the current compute node (Conv2).
|
| 403 |
+
# This indicates that the buffer of ReLU has completed all its usage,
|
| 404 |
+
# So we can safely make changes to it now by doing Conv2->Add inplace fusion.
|
| 405 |
+
# Take above case as example:
|
| 406 |
+
# * extra_input_node: ReLU
|
| 407 |
+
# * compute_node: Conv2
|
| 408 |
+
# _get_remaining_users will return the users of extra_input_node which are not
|
| 409 |
+
# ancestor node of compute_node.
|
| 410 |
+
def _is_ancestor_node(_current_node, _ancestor_node):
|
| 411 |
+
# Check whether _ancestor_node is the ancestor node of _current_node
|
| 412 |
+
_node_list = [_current_node]
|
| 413 |
+
_visited_nodes = set()
|
| 414 |
+
while len(_node_list) != 0:
|
| 415 |
+
_current_node = _node_list.pop(0)
|
| 416 |
+
if _current_node not in _visited_nodes:
|
| 417 |
+
_visited_nodes.add(_current_node)
|
| 418 |
+
if _current_node == _ancestor_node:
|
| 419 |
+
return True
|
| 420 |
+
elif isinstance(
|
| 421 |
+
_current_node, torch.fx.Node
|
| 422 |
+
) and _current_node.op not in ["placeholder", "output", "get_attr"]:
|
| 423 |
+
for input in _current_node.all_input_nodes:
|
| 424 |
+
_node_list.append(input) # noqa: PERF402
|
| 425 |
+
return False
|
| 426 |
+
|
| 427 |
+
return [
|
| 428 |
+
user
|
| 429 |
+
for user in list(extra_input_node.users)
|
| 430 |
+
if not _is_ancestor_node(compute_node, user)
|
| 431 |
+
]
|
| 432 |
+
|
| 433 |
+
def _is_valid_computation_binary_inplace(computation_op, binary_op, other_index):
|
| 434 |
+
def fn(match):
|
| 435 |
+
if not _is_valid_computation_binary(computation_op, binary_op)(match):
|
| 436 |
+
return False
|
| 437 |
+
binary_nodes = filter_nodes(match.nodes, binary_op)
|
| 438 |
+
|
| 439 |
+
def _get_compute_node(_binary_node, _other_index):
|
| 440 |
+
assert (
|
| 441 |
+
len(_binary_node.all_input_nodes) == 2
|
| 442 |
+
), "Binary node should have 2 input nodes."
|
| 443 |
+
_compute_index = 1 if (_other_index == 0) else 0
|
| 444 |
+
return _binary_node.args[_compute_index]
|
| 445 |
+
|
| 446 |
+
def _other_input_not_inplaceable(_binary_node, _other_index):
|
| 447 |
+
_compute_node = _get_compute_node(_binary_node, _other_index)
|
| 448 |
+
return (
|
| 449 |
+
len(
|
| 450 |
+
_get_remaining_users(
|
| 451 |
+
_binary_node.args[_other_index], _compute_node
|
| 452 |
+
)
|
| 453 |
+
)
|
| 454 |
+
> 1
|
| 455 |
+
or _binary_node.args[_other_index] == _compute_node.args[0]
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
if any(_other_input_not_inplaceable(n, other_index) for n in binary_nodes):
|
| 459 |
+
return False
|
| 460 |
+
if any(
|
| 461 |
+
n.args[other_index].op in ["placeholder", "output"]
|
| 462 |
+
for n in binary_nodes
|
| 463 |
+
):
|
| 464 |
+
return False
|
| 465 |
+
return True
|
| 466 |
+
|
| 467 |
+
return fn
|
| 468 |
+
|
| 469 |
+
def _register_binary_unary_fusion_lowering(
|
| 470 |
+
pattern,
|
| 471 |
+
computation_op,
|
| 472 |
+
binary_op,
|
| 473 |
+
fusion_op,
|
| 474 |
+
unary_attr=None,
|
| 475 |
+
):
|
| 476 |
+
@register_lowering_pattern(
|
| 477 |
+
pattern, extra_check=_is_valid_computation_binary(computation_op, binary_op)
|
| 478 |
+
)
|
| 479 |
+
def fn(match, *args, **kwargs):
|
| 480 |
+
other = kwargs.get("other")
|
| 481 |
+
assert isinstance(other, ir.TensorBox)
|
| 482 |
+
binary_attr = _binary_attr[binary_op]
|
| 483 |
+
args_list = list(args)
|
| 484 |
+
computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr]
|
| 485 |
+
if len(args_list) > 6:
|
| 486 |
+
if unary_attr is not None:
|
| 487 |
+
computation_args += [
|
| 488 |
+
1.0,
|
| 489 |
+
unary_attr.op_name,
|
| 490 |
+
unary_attr.scalars_attr,
|
| 491 |
+
unary_attr.algorithm_attr,
|
| 492 |
+
]
|
| 493 |
+
else:
|
| 494 |
+
computation_args += [1.0, None, [], None]
|
| 495 |
+
return L[fusion_op](*computation_args)
|
| 496 |
+
|
| 497 |
+
return fn
|
| 498 |
+
|
| 499 |
+
def _can_be_inplace(_other):
|
| 500 |
+
if isinstance(_other.data, ir.View):
|
| 501 |
+
return _can_be_inplace(_other.data)
|
| 502 |
+
else:
|
| 503 |
+
return not (
|
| 504 |
+
isinstance(_other.data, ir.ReinterpretView)
|
| 505 |
+
or len(_other.get_inputs_that_alias_output()) > 0
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
def _register_binary_unary_maybe_inplace_fusion_lowering(
|
| 509 |
+
pattern,
|
| 510 |
+
computation_op,
|
| 511 |
+
binary_op,
|
| 512 |
+
inplace_fusion_op,
|
| 513 |
+
outplace_fusion_op,
|
| 514 |
+
unary_attr=None,
|
| 515 |
+
other_index=None,
|
| 516 |
+
):
|
| 517 |
+
@register_lowering_pattern(
|
| 518 |
+
pattern,
|
| 519 |
+
extra_check=_is_valid_computation_binary_inplace(
|
| 520 |
+
computation_op, binary_op, other_index
|
| 521 |
+
),
|
| 522 |
+
)
|
| 523 |
+
def fn(match, *args, **kwargs):
|
| 524 |
+
other = kwargs.get("other")
|
| 525 |
+
assert isinstance(other, ir.TensorBox)
|
| 526 |
+
binary_attr = _binary_attr[binary_op]
|
| 527 |
+
args_list = list(args)
|
| 528 |
+
computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr]
|
| 529 |
+
if len(args_list) > 6:
|
| 530 |
+
if unary_attr is not None:
|
| 531 |
+
computation_args += [
|
| 532 |
+
1.0,
|
| 533 |
+
unary_attr.op_name,
|
| 534 |
+
unary_attr.scalars_attr,
|
| 535 |
+
unary_attr.algorithm_attr,
|
| 536 |
+
]
|
| 537 |
+
else:
|
| 538 |
+
computation_args += [1.0, None, [], None]
|
| 539 |
+
# Make sure the other is not an alias or mutation(fx side doesn't has such info).
|
| 540 |
+
other.realize()
|
| 541 |
+
if not _can_be_inplace(other):
|
| 542 |
+
return L[outplace_fusion_op](*computation_args)
|
| 543 |
+
return L[inplace_fusion_op](*computation_args)
|
| 544 |
+
|
| 545 |
+
return fn
|
| 546 |
+
|
| 547 |
+
computation_ops = [
|
| 548 |
+
mkldnn._convolution_pointwise.default,
|
| 549 |
+
mkldnn._linear_pointwise.default,
|
| 550 |
+
mkldnn._convolution_transpose_pointwise.default,
|
| 551 |
+
]
|
| 552 |
+
|
| 553 |
+
class UnaryAttr:
|
| 554 |
+
def __init__(
|
| 555 |
+
self, op_name: str, scalars_attr=None, algorithm_attr=None
|
| 556 |
+
) -> None:
|
| 557 |
+
self.op_name = op_name
|
| 558 |
+
self.scalars_attr = scalars_attr if scalars_attr else []
|
| 559 |
+
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
|
| 560 |
+
|
| 561 |
+
def _register_unary_fusion():
|
| 562 |
+
computation_call_fns = [_conv_call, _linear_call, _conv_transpose_call]
|
| 563 |
+
|
| 564 |
+
def _unary_fusion_patterns(lowp_dtype):
|
| 565 |
+
replacement_unary_fusion_patterns = {
|
| 566 |
+
UnaryAttr("gelu", algorithm_attr="tanh"): [
|
| 567 |
+
_unary_fusion_pattern(_gelu_fusion_2, call_fn, 4, lowp_dtype)
|
| 568 |
+
for call_fn in computation_call_fns
|
| 569 |
+
],
|
| 570 |
+
UnaryAttr("gelu", algorithm_attr="none"): [
|
| 571 |
+
_unary_fusion_pattern(_gelu_fusion_1, call_fn, 2, lowp_dtype)
|
| 572 |
+
for call_fn in computation_call_fns
|
| 573 |
+
],
|
| 574 |
+
UnaryAttr("hardswish"): [
|
| 575 |
+
_unary_fusion_pattern(_hardswish_fusion, call_fn, 2, lowp_dtype)
|
| 576 |
+
for call_fn in computation_call_fns
|
| 577 |
+
],
|
| 578 |
+
UnaryAttr("hardsigmoid"): [
|
| 579 |
+
_unary_fusion_pattern(_hardsigmoid_fusion, call_fn, 1, lowp_dtype)
|
| 580 |
+
for call_fn in computation_call_fns
|
| 581 |
+
],
|
| 582 |
+
UnaryAttr("swish"): [
|
| 583 |
+
_unary_fusion_pattern(_silu_fusion, call_fn, 2, lowp_dtype)
|
| 584 |
+
for call_fn in computation_call_fns
|
| 585 |
+
],
|
| 586 |
+
}
|
| 587 |
+
if not lowp_dtype:
|
| 588 |
+
call_user1 = [call_fn(users=1) for call_fn in computation_call_fns]
|
| 589 |
+
replacement_unary_fusion_patterns.update(
|
| 590 |
+
{
|
| 591 |
+
UnaryAttr("relu"): [
|
| 592 |
+
_combined_fusion(u, aten.relu) for u in call_user1
|
| 593 |
+
],
|
| 594 |
+
UnaryAttr("sigmoid"): [
|
| 595 |
+
_combined_fusion(u, aten.sigmoid) for u in call_user1
|
| 596 |
+
],
|
| 597 |
+
UnaryAttr("tanh"): [
|
| 598 |
+
_combined_fusion(u, aten.tanh) for u in call_user1
|
| 599 |
+
],
|
| 600 |
+
}
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
return replacement_unary_fusion_patterns
|
| 604 |
+
|
| 605 |
+
for lowp_dtype in [torch.bfloat16, torch.float16, None]:
|
| 606 |
+
replace_patterns = _unary_fusion_patterns(lowp_dtype)
|
| 607 |
+
for unary_attr, patterns in replace_patterns.items():
|
| 608 |
+
_register_unary_fusion_lowering(
|
| 609 |
+
patterns[0], unary_attr, computation_ops[0], lowp_dtype
|
| 610 |
+
)
|
| 611 |
+
_register_unary_fusion_lowering(
|
| 612 |
+
patterns[1], unary_attr, computation_ops[1], lowp_dtype
|
| 613 |
+
)
|
| 614 |
+
_register_unary_fusion_lowering(
|
| 615 |
+
patterns[2], unary_attr, computation_ops[2], lowp_dtype
|
| 616 |
+
)
|
| 617 |
+
_leaky_relu_patterns = [
|
| 618 |
+
_unary_fusion_pattern(_leaky_relu_fusion, call_fn, 3, lowp_dtype)
|
| 619 |
+
for call_fn in computation_call_fns
|
| 620 |
+
]
|
| 621 |
+
for pattern, computation_op in zip(_leaky_relu_patterns, computation_ops):
|
| 622 |
+
_register_leaky_relu_fusion_lowering(
|
| 623 |
+
pattern, computation_op, lowp_dtype
|
| 624 |
+
)
|
| 625 |
+
hardtanh_patterns = [
|
| 626 |
+
_unary_fusion_pattern(_hardtanh_fusion, call_fn, 1, lowp_dtype)
|
| 627 |
+
for call_fn in computation_call_fns
|
| 628 |
+
]
|
| 629 |
+
for pattern, computation_op in zip(hardtanh_patterns, computation_ops):
|
| 630 |
+
_register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype)
|
| 631 |
+
|
| 632 |
+
def _register_inplace_fusion():
|
| 633 |
+
binary_ops = [aten.add, ops.add]
|
| 634 |
+
inplace_fusion_op = mkldnn._convolution_pointwise_.binary
|
| 635 |
+
outplace_fusion_op = mkldnn._convolution_pointwise.binary
|
| 636 |
+
conv_call = _conv_call(users=1)
|
| 637 |
+
conv_op = computation_ops[0]
|
| 638 |
+
for binary_op in binary_ops:
|
| 639 |
+
binary_v1 = _binary_fusion_v1(conv_call, binary_op)
|
| 640 |
+
binary_unary_v1 = _combined_fusion(binary_v1, aten.relu)
|
| 641 |
+
_register_binary_unary_maybe_inplace_fusion_lowering(
|
| 642 |
+
binary_unary_v1,
|
| 643 |
+
conv_op,
|
| 644 |
+
binary_op,
|
| 645 |
+
inplace_fusion_op,
|
| 646 |
+
outplace_fusion_op,
|
| 647 |
+
other_index=0,
|
| 648 |
+
unary_attr=UnaryAttr("relu"),
|
| 649 |
+
)
|
| 650 |
+
_register_binary_unary_maybe_inplace_fusion_lowering(
|
| 651 |
+
binary_v1,
|
| 652 |
+
conv_op,
|
| 653 |
+
binary_op,
|
| 654 |
+
inplace_fusion_op,
|
| 655 |
+
outplace_fusion_op,
|
| 656 |
+
other_index=0,
|
| 657 |
+
)
|
| 658 |
+
binary_v2 = _binary_fusion_v2(conv_call, binary_op)
|
| 659 |
+
binary_unary_v2 = _combined_fusion(binary_v2, aten.relu)
|
| 660 |
+
_register_binary_unary_maybe_inplace_fusion_lowering(
|
| 661 |
+
binary_unary_v2,
|
| 662 |
+
conv_op,
|
| 663 |
+
binary_op,
|
| 664 |
+
inplace_fusion_op,
|
| 665 |
+
outplace_fusion_op,
|
| 666 |
+
other_index=1,
|
| 667 |
+
unary_attr=UnaryAttr("relu"),
|
| 668 |
+
)
|
| 669 |
+
_register_binary_unary_maybe_inplace_fusion_lowering(
|
| 670 |
+
binary_v2,
|
| 671 |
+
conv_op,
|
| 672 |
+
binary_op,
|
| 673 |
+
inplace_fusion_op,
|
| 674 |
+
outplace_fusion_op,
|
| 675 |
+
other_index=1,
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
def _register_binary_fusion():
|
| 679 |
+
binary_ops = [aten.add, ops.add, aten.sub, ops.sub]
|
| 680 |
+
fusion_ops = [
|
| 681 |
+
mkldnn._convolution_pointwise.binary,
|
| 682 |
+
mkldnn._linear_pointwise.binary,
|
| 683 |
+
]
|
| 684 |
+
_computation_user_1 = [_conv_call(users=1), _linear_call(users=1)]
|
| 685 |
+
for computation_call, computation_op, fusion_op in zip(
|
| 686 |
+
_computation_user_1, computation_ops[:-1], fusion_ops
|
| 687 |
+
):
|
| 688 |
+
for binary_op in binary_ops:
|
| 689 |
+
pattern = _binary_fusion_v2(computation_call, binary_op)
|
| 690 |
+
_register_binary_unary_fusion_lowering(
|
| 691 |
+
pattern, computation_op, binary_op, fusion_op
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
for binary_op in [aten.add, ops.add]:
|
| 695 |
+
pattern = _binary_fusion_v1(computation_call, binary_op)
|
| 696 |
+
_register_binary_unary_fusion_lowering(
|
| 697 |
+
pattern, computation_op, binary_op, fusion_op
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
def _register_binary_unary_fusion():
|
| 701 |
+
binary_ops = [aten.add, ops.add, aten.sub, ops.sub]
|
| 702 |
+
fusion_ops = [mkldnn._convolution_pointwise.binary]
|
| 703 |
+
_computation_user_1 = [_conv_call(users=1)]
|
| 704 |
+
for computation_call, computation_op, fusion_op in zip(
|
| 705 |
+
_computation_user_1, computation_ops[:-1], fusion_ops
|
| 706 |
+
):
|
| 707 |
+
for binary_op in binary_ops:
|
| 708 |
+
pattern_v1 = _combined_fusion(
|
| 709 |
+
_binary_fusion_v2(computation_call, binary_op), aten.relu
|
| 710 |
+
)
|
| 711 |
+
_register_binary_unary_fusion_lowering(
|
| 712 |
+
pattern_v1,
|
| 713 |
+
computation_op,
|
| 714 |
+
binary_op,
|
| 715 |
+
fusion_op,
|
| 716 |
+
unary_attr=UnaryAttr("relu"),
|
| 717 |
+
)
|
| 718 |
+
for binary_op in [aten.add, ops.add]:
|
| 719 |
+
pattern_v2 = _combined_fusion(
|
| 720 |
+
_binary_fusion_v1(computation_call, binary_op), aten.relu
|
| 721 |
+
)
|
| 722 |
+
_register_binary_unary_fusion_lowering(
|
| 723 |
+
pattern_v2,
|
| 724 |
+
computation_op,
|
| 725 |
+
binary_op,
|
| 726 |
+
fusion_op,
|
| 727 |
+
unary_attr=UnaryAttr("relu"),
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
def _recover_linear():
|
| 731 |
+
# convert reshape+linear+reshape to a single linear for applying fusion path.
|
| 732 |
+
@register_freezing_graph_pattern(
|
| 733 |
+
CallFunction(
|
| 734 |
+
aten.reshape.default,
|
| 735 |
+
CallFunction(
|
| 736 |
+
mkldnn._linear_pointwise.default,
|
| 737 |
+
CallFunction(
|
| 738 |
+
aten.reshape.default,
|
| 739 |
+
Arg(),
|
| 740 |
+
KeywordArg("reshape_1"),
|
| 741 |
+
_users=MULTIPLE,
|
| 742 |
+
),
|
| 743 |
+
Arg(),
|
| 744 |
+
Arg(),
|
| 745 |
+
Arg(),
|
| 746 |
+
Arg(),
|
| 747 |
+
Arg(),
|
| 748 |
+
),
|
| 749 |
+
KeywordArg("reshape_2"),
|
| 750 |
+
),
|
| 751 |
+
pass_number=1,
|
| 752 |
+
)
|
| 753 |
+
def reshape_linear_reshape_pattern(match, *args, **kwargs):
|
| 754 |
+
def get_val(val):
|
| 755 |
+
return val if isinstance(val, int) else val.meta.get("val")
|
| 756 |
+
|
| 757 |
+
reshape_1 = kwargs.get("reshape_1")
|
| 758 |
+
reshape_2 = kwargs.get("reshape_2")
|
| 759 |
+
assert isinstance(reshape_1, list)
|
| 760 |
+
assert isinstance(reshape_2, list)
|
| 761 |
+
assert len(reshape_1) == 2
|
| 762 |
+
|
| 763 |
+
graph = match.graph
|
| 764 |
+
reshape_2_node = match.output_node()
|
| 765 |
+
linear_input_node = reshape_2_node.args[0].args[0].args[0]
|
| 766 |
+
# check linear's input's shape[:-1] == reshape_2[:-1]
|
| 767 |
+
# and check product(reshape_2[:-1]) == reshape_1[0]
|
| 768 |
+
can_remove_reshape = linear_input_node.meta.get("val").shape[
|
| 769 |
+
:-1
|
| 770 |
+
] == torch.Size([get_val(val) for val in reshape_2[:-1]])
|
| 771 |
+
can_remove_reshape = can_remove_reshape and (
|
| 772 |
+
reduce(
|
| 773 |
+
operator.mul,
|
| 774 |
+
[get_val(val) for val in reshape_2[:-1]],
|
| 775 |
+
)
|
| 776 |
+
== get_val(reshape_1[0])
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
if can_remove_reshape:
|
| 780 |
+
repl = graph.call_function(mkldnn._linear_pointwise.default, args)
|
| 781 |
+
repl.meta.update(reshape_2_node.meta)
|
| 782 |
+
reshape_2_node.replace_all_uses_with(repl)
|
| 783 |
+
old_linear_node = reshape_2_node.args[0]
|
| 784 |
+
reshape_1_node = old_linear_node.args[0]
|
| 785 |
+
graph.erase_node(reshape_2_node)
|
| 786 |
+
graph.erase_node(old_linear_node)
|
| 787 |
+
if len(reshape_1_node.users) == 0:
|
| 788 |
+
graph.erase_node(reshape_1_node)
|
| 789 |
+
|
| 790 |
+
def is_linear_add_bias(match):
|
| 791 |
+
add_node = match.output_node()
|
| 792 |
+
linear_node = add_node.args[0]
|
| 793 |
+
packed_weight_node = linear_node.args[1]
|
| 794 |
+
assert packed_weight_node.target == mkldnn._reorder_linear_weight
|
| 795 |
+
transpose_weight_node = packed_weight_node.args[0]
|
| 796 |
+
assert transpose_weight_node.target == aten.permute.default
|
| 797 |
+
weight_meta = transpose_weight_node.args[0].meta.get("val")
|
| 798 |
+
bias_node = add_node.args[1]
|
| 799 |
+
if isinstance(bias_node, int):
|
| 800 |
+
# we only folding bias if it is a constant
|
| 801 |
+
return False
|
| 802 |
+
bias_meta = add_node.args[1].meta.get("val")
|
| 803 |
+
if weight_meta is None or bias_meta is None:
|
| 804 |
+
return False
|
| 805 |
+
assert weight_meta.dtype in (
|
| 806 |
+
torch.bfloat16,
|
| 807 |
+
torch.float16,
|
| 808 |
+
)
|
| 809 |
+
if bias_meta.dtype != weight_meta.dtype:
|
| 810 |
+
return False
|
| 811 |
+
return (
|
| 812 |
+
linear_node.args[2] is None
|
| 813 |
+
and bias_meta.dim() == 1
|
| 814 |
+
and bias_meta.size(0) == weight_meta.size(1)
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
# convert linear+bias to a single linear for applying fusion path.
|
| 818 |
+
@register_freezing_graph_pattern(
|
| 819 |
+
CallFunction(
|
| 820 |
+
aten.add.Tensor,
|
| 821 |
+
CallFunction(mkldnn._linear_pointwise.default, *_linear_args),
|
| 822 |
+
Arg(),
|
| 823 |
+
),
|
| 824 |
+
pass_number=1,
|
| 825 |
+
extra_check=is_linear_add_bias,
|
| 826 |
+
)
|
| 827 |
+
def linear_bias_pattern(match, *args):
|
| 828 |
+
graph = match.graph
|
| 829 |
+
add_node = match.output_node()
|
| 830 |
+
linear_node = add_node.args[0]
|
| 831 |
+
new_args = list(linear_node.args)
|
| 832 |
+
new_args[2] = add_node.args[1]
|
| 833 |
+
repl = graph.call_function(
|
| 834 |
+
mkldnn._linear_pointwise.default, tuple(new_args)
|
| 835 |
+
)
|
| 836 |
+
repl.meta.update(add_node.meta)
|
| 837 |
+
add_node.replace_all_uses_with(repl)
|
| 838 |
+
match.erase_nodes()
|
| 839 |
+
|
| 840 |
+
def _is_packable_mkldnn_rnn_layer(match):
|
| 841 |
+
lstm_node = match.output_node()
|
| 842 |
+
POS_WEIGHTS = [1, 2]
|
| 843 |
+
POS_INPUTS = [0, 5, 6]
|
| 844 |
+
POS_ARGS = POS_WEIGHTS + POS_INPUTS
|
| 845 |
+
# Weights should be Constant
|
| 846 |
+
if any(
|
| 847 |
+
lstm_node.args[POS_WEIGHT].op != "get_attr" for POS_WEIGHT in POS_WEIGHTS
|
| 848 |
+
):
|
| 849 |
+
return False
|
| 850 |
+
|
| 851 |
+
# Meta info for weights and inputs should be available
|
| 852 |
+
if any(lstm_node.args[POS_ARG].meta.get("val") is None for POS_ARG in POS_ARGS):
|
| 853 |
+
return False
|
| 854 |
+
|
| 855 |
+
# Check device
|
| 856 |
+
if any(
|
| 857 |
+
lstm_node.args[POS_ARG].meta.get("val").device.type != "cpu"
|
| 858 |
+
for POS_ARG in POS_ARGS
|
| 859 |
+
):
|
| 860 |
+
return False
|
| 861 |
+
|
| 862 |
+
# Check dtype
|
| 863 |
+
if any(
|
| 864 |
+
lstm_node.args[POS_ARG].meta.get("val").dtype == torch.bfloat16
|
| 865 |
+
and not mkldnn._is_mkldnn_bf16_supported()
|
| 866 |
+
for POS_ARG in POS_ARGS
|
| 867 |
+
):
|
| 868 |
+
return False
|
| 869 |
+
if any(
|
| 870 |
+
lstm_node.args[POS_ARG].meta.get("val").dtype == torch.float16
|
| 871 |
+
and not mkldnn._is_mkldnn_fp16_supported()
|
| 872 |
+
for POS_ARG in POS_ARGS
|
| 873 |
+
):
|
| 874 |
+
return False
|
| 875 |
+
|
| 876 |
+
return True
|
| 877 |
+
|
| 878 |
+
def _is_packable_convolution(match):
|
| 879 |
+
"""
|
| 880 |
+
Check if the node is supported for MKLDNN convolution.
|
| 881 |
+
"""
|
| 882 |
+
conv_node = match.output_node()
|
| 883 |
+
input_meta_value = conv_node.args[0].meta.get("val")
|
| 884 |
+
weight_meta_value = conv_node.args[1].meta.get("val")
|
| 885 |
+
if input_meta_value is None or weight_meta_value is None:
|
| 886 |
+
return False
|
| 887 |
+
input_size = input_meta_value.shape
|
| 888 |
+
if conv_node.args[1].op != "get_attr":
|
| 889 |
+
return False
|
| 890 |
+
for meta_value in [input_meta_value, weight_meta_value]:
|
| 891 |
+
if (
|
| 892 |
+
meta_value is None
|
| 893 |
+
or meta_value.device.type != "cpu"
|
| 894 |
+
or (meta_value.dim() != 4 and meta_value.dim() != 5)
|
| 895 |
+
):
|
| 896 |
+
return False
|
| 897 |
+
if (
|
| 898 |
+
input_meta_value.dtype == torch.bfloat16
|
| 899 |
+
or weight_meta_value.dtype == torch.bfloat16
|
| 900 |
+
):
|
| 901 |
+
if not mkldnn._is_mkldnn_bf16_supported():
|
| 902 |
+
return False
|
| 903 |
+
if (
|
| 904 |
+
input_meta_value.dtype == torch.float16
|
| 905 |
+
or weight_meta_value.dtype == torch.float16
|
| 906 |
+
):
|
| 907 |
+
if not mkldnn._is_mkldnn_fp16_supported():
|
| 908 |
+
return False
|
| 909 |
+
is_transposed = conv_node.args[-3]
|
| 910 |
+
if is_transposed:
|
| 911 |
+
# TODO: Support dynamic shape case for MKLDNN conv transpose.
|
| 912 |
+
if has_free_symbols(input_size):
|
| 913 |
+
return False
|
| 914 |
+
groups = conv_node.args[-1]
|
| 915 |
+
in_channels = weight_meta_value.size(0)
|
| 916 |
+
# doesn't support group_depthwise_conv_transpose.
|
| 917 |
+
if groups > 1 and groups == in_channels:
|
| 918 |
+
return False
|
| 919 |
+
# Port from: aten/src/ATen/native/Convolution.cpp:is_output_padding_big
|
| 920 |
+
output_paddings = conv_node.args[-2]
|
| 921 |
+
strides = conv_node.args[3]
|
| 922 |
+
if any(
|
| 923 |
+
output_padding >= stride
|
| 924 |
+
for output_padding, stride in zip(output_paddings, strides)
|
| 925 |
+
):
|
| 926 |
+
return False
|
| 927 |
+
return True
|
| 928 |
+
|
| 929 |
+
def _is_packable_linear(match):
|
| 930 |
+
"""
|
| 931 |
+
Check if the node is supported for MKLDNN linear.
|
| 932 |
+
"""
|
| 933 |
+
linear_node = match.output_node()
|
| 934 |
+
# mkldnn linear only supports beta=1or0 and alpha=1
|
| 935 |
+
if linear_node.target == aten.addmm.default:
|
| 936 |
+
alpha = linear_node.kwargs.get("alpha", 1.0)
|
| 937 |
+
beta = linear_node.kwargs.get("beta", 1.0)
|
| 938 |
+
if (beta != 0.0 and beta != 1.0) or alpha != 1.0:
|
| 939 |
+
return False
|
| 940 |
+
# weight_idx is 1 for aten.mm and is 2 for aten.addmm
|
| 941 |
+
weight_idx = 2 if linear_node.target == aten.addmm.default else 1
|
| 942 |
+
if linear_node.args[weight_idx].op != "get_attr":
|
| 943 |
+
return False
|
| 944 |
+
input_meta_value = linear_node.args[weight_idx - 1].meta.get("val")
|
| 945 |
+
weight_meta_value = linear_node.args[weight_idx].meta.get("val")
|
| 946 |
+
if input_meta_value is None or weight_meta_value is None:
|
| 947 |
+
return False
|
| 948 |
+
batch_size = input_meta_value.shape[0]
|
| 949 |
+
if (
|
| 950 |
+
input_meta_value.dtype == torch.float64
|
| 951 |
+
or weight_meta_value.dtype == torch.float64
|
| 952 |
+
):
|
| 953 |
+
return False
|
| 954 |
+
is_lp_weight = weight_meta_value.dtype in (
|
| 955 |
+
torch.bfloat16,
|
| 956 |
+
torch.float16,
|
| 957 |
+
)
|
| 958 |
+
# on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol.
|
| 959 |
+
# on aarch64, use mkldnn op for fp32 as well if acl is enabled
|
| 960 |
+
if (
|
| 961 |
+
not is_lp_weight
|
| 962 |
+
and not mkldnn._is_mkldnn_acl_supported()
|
| 963 |
+
and ((not torch._C.has_mkl) or has_free_symbols(batch_size))
|
| 964 |
+
):
|
| 965 |
+
return False
|
| 966 |
+
for meta_value in [input_meta_value, weight_meta_value]:
|
| 967 |
+
if (
|
| 968 |
+
meta_value is None
|
| 969 |
+
or meta_value.device.type != "cpu"
|
| 970 |
+
or meta_value.dim() != 2
|
| 971 |
+
):
|
| 972 |
+
return False
|
| 973 |
+
if weight_idx == 2:
|
| 974 |
+
bias_meta_value = linear_node.args[0].meta.get("val")
|
| 975 |
+
if (
|
| 976 |
+
bias_meta_value is None
|
| 977 |
+
or meta_value.device.type != "cpu"
|
| 978 |
+
or bias_meta_value.dim() != 1
|
| 979 |
+
or bias_meta_value.size(0) != weight_meta_value.size(1)
|
| 980 |
+
):
|
| 981 |
+
return False
|
| 982 |
+
|
| 983 |
+
if (
|
| 984 |
+
input_meta_value.dtype == torch.bfloat16
|
| 985 |
+
or weight_meta_value.dtype == torch.bfloat16
|
| 986 |
+
):
|
| 987 |
+
if not mkldnn._is_mkldnn_bf16_supported():
|
| 988 |
+
return False
|
| 989 |
+
if (
|
| 990 |
+
input_meta_value.dtype == torch.float16
|
| 991 |
+
or weight_meta_value.dtype == torch.float16
|
| 992 |
+
):
|
| 993 |
+
if not mkldnn._is_mkldnn_fp16_supported():
|
| 994 |
+
return False
|
| 995 |
+
return True
|
| 996 |
+
|
| 997 |
+
_aten_conv_args = (
|
| 998 |
+
Arg(),
|
| 999 |
+
Arg(),
|
| 1000 |
+
Arg(),
|
| 1001 |
+
Arg(),
|
| 1002 |
+
Arg(),
|
| 1003 |
+
Arg(),
|
| 1004 |
+
KeywordArg("is_transposed"),
|
| 1005 |
+
Arg(),
|
| 1006 |
+
Arg(),
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
_aten_mkldnn_rnn_layer_args = (
|
| 1010 |
+
Arg(), # input
|
| 1011 |
+
Arg(), # weight0
|
| 1012 |
+
Arg(), # weight1
|
| 1013 |
+
Arg(), # weight2
|
| 1014 |
+
Arg(), # weight3
|
| 1015 |
+
Arg(), # hx_
|
| 1016 |
+
Arg(), # cx_
|
| 1017 |
+
KeywordArg("reverse"), # reverse
|
| 1018 |
+
Arg(), # batch_sizes
|
| 1019 |
+
Arg(), # mode
|
| 1020 |
+
Arg(), # hidden_size
|
| 1021 |
+
Arg(), # num_layers
|
| 1022 |
+
Arg(), # has_biases
|
| 1023 |
+
Arg(), # bidirectional
|
| 1024 |
+
Arg(), # batch_first
|
| 1025 |
+
Arg(), # train
|
| 1026 |
+
)
|
| 1027 |
+
|
| 1028 |
+
def _register_weight_pack_pass():
|
| 1029 |
+
@register_freezing_graph_pattern(
|
| 1030 |
+
CallFunction(aten.convolution.default, *_aten_conv_args),
|
| 1031 |
+
extra_check=_is_packable_convolution,
|
| 1032 |
+
)
|
| 1033 |
+
def convolution(match, *args, **kwargs):
|
| 1034 |
+
is_transposed = kwargs.get("is_transposed")
|
| 1035 |
+
assert isinstance(is_transposed, bool)
|
| 1036 |
+
graph = match.graph
|
| 1037 |
+
conv_node = match.output_node()
|
| 1038 |
+
input_size = conv_node.args[0].meta.get("val").shape
|
| 1039 |
+
with graph.inserting_before(conv_node):
|
| 1040 |
+
constant_args = [args[4], args[3], args[5], args[-1]]
|
| 1041 |
+
packed_weight_op = mkldnn._reorder_convolution_weight
|
| 1042 |
+
packed_conv_op = mkldnn._convolution_pointwise.default
|
| 1043 |
+
if is_transposed:
|
| 1044 |
+
constant_args.insert(1, args[-2]) # output_padding
|
| 1045 |
+
packed_weight_op = mkldnn._reorder_convolution_transpose_weight
|
| 1046 |
+
packed_conv_op = mkldnn._convolution_transpose_pointwise.default
|
| 1047 |
+
if not has_free_symbols(input_size):
|
| 1048 |
+
packed_weight_inputs = (
|
| 1049 |
+
(args[1],) + tuple(constant_args) + (input_size,)
|
| 1050 |
+
)
|
| 1051 |
+
packed_weight_node = graph.create_node(
|
| 1052 |
+
"call_function", packed_weight_op, args=packed_weight_inputs
|
| 1053 |
+
)
|
| 1054 |
+
else:
|
| 1055 |
+
assert not is_transposed
|
| 1056 |
+
# For dynamic shape case, we need to pack weight in runtime.
|
| 1057 |
+
packed_weight_node = args[1]
|
| 1058 |
+
packed_conv_inputs = (
|
| 1059 |
+
(args[0], packed_weight_node, args[2])
|
| 1060 |
+
+ tuple(constant_args)
|
| 1061 |
+
+ ("none", [], "")
|
| 1062 |
+
)
|
| 1063 |
+
packed_conv_node = graph.create_node(
|
| 1064 |
+
"call_function", packed_conv_op, tuple(packed_conv_inputs)
|
| 1065 |
+
)
|
| 1066 |
+
conv_node.replace_all_uses_with(packed_conv_node)
|
| 1067 |
+
packed_conv_node.meta.update(conv_node.meta)
|
| 1068 |
+
graph.erase_node(conv_node)
|
| 1069 |
+
|
| 1070 |
+
@register_freezing_graph_pattern(
|
| 1071 |
+
CallFunction(aten.mkldnn_rnn_layer.default, *_aten_mkldnn_rnn_layer_args),
|
| 1072 |
+
extra_check=_is_packable_mkldnn_rnn_layer,
|
| 1073 |
+
)
|
| 1074 |
+
def mkldnn_rnn_layer(match, *args, **kwargs):
|
| 1075 |
+
def get_item(graph, node, index):
|
| 1076 |
+
return graph.call_function(operator.getitem, (node, index))
|
| 1077 |
+
|
| 1078 |
+
graph = match.graph
|
| 1079 |
+
lstm_node = match.output_node()
|
| 1080 |
+
input = args[0]
|
| 1081 |
+
weight0, weight1 = args[1:3]
|
| 1082 |
+
reverse = kwargs.get("reverse")
|
| 1083 |
+
packed_lstm_op = aten.mkldnn_rnn_layer.default
|
| 1084 |
+
hidden_size = args[9]
|
| 1085 |
+
has_biases = args[11]
|
| 1086 |
+
batch_first = args[13]
|
| 1087 |
+
with graph.inserting_before(lstm_node):
|
| 1088 |
+
packed_weight_op = mkldnn._reorder_mkldnn_rnn_layer_weight.default
|
| 1089 |
+
packed_weight_inputs = (
|
| 1090 |
+
weight0,
|
| 1091 |
+
weight1,
|
| 1092 |
+
hidden_size,
|
| 1093 |
+
reverse,
|
| 1094 |
+
has_biases,
|
| 1095 |
+
batch_first,
|
| 1096 |
+
)
|
| 1097 |
+
packed_weight_node = graph.create_node(
|
| 1098 |
+
"call_function", packed_weight_op, packed_weight_inputs, {}, "name"
|
| 1099 |
+
)
|
| 1100 |
+
packed_weight_items = [
|
| 1101 |
+
get_item(graph, packed_weight_node, i) for i in range(2)
|
| 1102 |
+
]
|
| 1103 |
+
pack_lstm_inputs = (
|
| 1104 |
+
args[0],
|
| 1105 |
+
*packed_weight_items,
|
| 1106 |
+
args[3],
|
| 1107 |
+
args[4],
|
| 1108 |
+
args[5],
|
| 1109 |
+
args[6],
|
| 1110 |
+
reverse,
|
| 1111 |
+
*args[7:],
|
| 1112 |
+
)
|
| 1113 |
+
|
| 1114 |
+
packed_lstm_node = graph.create_node(
|
| 1115 |
+
"call_function", packed_lstm_op, args=pack_lstm_inputs
|
| 1116 |
+
)
|
| 1117 |
+
lstm_node.replace_all_uses_with(packed_lstm_node)
|
| 1118 |
+
packed_lstm_node.meta.update(lstm_node.meta)
|
| 1119 |
+
graph.erase_node(lstm_node)
|
| 1120 |
+
|
| 1121 |
+
@register_freezing_graph_pattern(
|
| 1122 |
+
CallFunction(
|
| 1123 |
+
aten.addmm.default,
|
| 1124 |
+
Arg(),
|
| 1125 |
+
Arg(),
|
| 1126 |
+
Arg(),
|
| 1127 |
+
beta=KeywordArg("beta"),
|
| 1128 |
+
alpha=KeywordArg("alpha"),
|
| 1129 |
+
),
|
| 1130 |
+
extra_check=_is_packable_linear,
|
| 1131 |
+
)
|
| 1132 |
+
@register_freezing_graph_pattern(
|
| 1133 |
+
CallFunction(aten.mm.default, Arg(), Arg()),
|
| 1134 |
+
extra_check=_is_packable_linear,
|
| 1135 |
+
)
|
| 1136 |
+
def linear(match, *args, **kwargs):
|
| 1137 |
+
graph = match.graph
|
| 1138 |
+
linear_node = match.output_node()
|
| 1139 |
+
input = args[0] if linear_node.target == aten.mm.default else args[1]
|
| 1140 |
+
bias = (
|
| 1141 |
+
None
|
| 1142 |
+
if linear_node.target == aten.mm.default
|
| 1143 |
+
or (
|
| 1144 |
+
linear_node.target == aten.addmm.default
|
| 1145 |
+
and linear_node.kwargs.get("beta", 1.0) == 0.0
|
| 1146 |
+
)
|
| 1147 |
+
else args[0]
|
| 1148 |
+
)
|
| 1149 |
+
weight = args[1] if linear_node.target == aten.mm.default else args[2]
|
| 1150 |
+
with graph.inserting_before(linear_node):
|
| 1151 |
+
transpose_weight_node = graph.create_node(
|
| 1152 |
+
"call_function", aten.permute.default, (weight, (1, 0))
|
| 1153 |
+
)
|
| 1154 |
+
weight_dtype = weight.meta.get("val").dtype
|
| 1155 |
+
is_lp_weight = weight_dtype in (
|
| 1156 |
+
torch.bfloat16,
|
| 1157 |
+
torch.float16,
|
| 1158 |
+
)
|
| 1159 |
+
batch_size = input.meta.get("val").shape[0]
|
| 1160 |
+
if has_free_symbols(batch_size):
|
| 1161 |
+
assert (
|
| 1162 |
+
is_lp_weight or mkldnn._is_mkldnn_acl_supported()
|
| 1163 |
+
), f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
|
| 1164 |
+
# For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance.
|
| 1165 |
+
packed_weight_inputs = (
|
| 1166 |
+
transpose_weight_node,
|
| 1167 |
+
batch_size.node.shape_env.size_hint(batch_size.node.expr)
|
| 1168 |
+
if has_free_symbols(batch_size)
|
| 1169 |
+
else batch_size,
|
| 1170 |
+
)
|
| 1171 |
+
# MKL packed matrix can't be copied to a different address because the internal implementation
|
| 1172 |
+
# depends on the alignment of internally-stored metadata.
|
| 1173 |
+
# In aot mode, we need to firstly save the packed weight, when loading it,
|
| 1174 |
+
# it will be in a different address which doesn't work.
|
| 1175 |
+
# Disable MKL prepack linear in AOT mode
|
| 1176 |
+
packed_weight_op = (
|
| 1177 |
+
mkldnn._reorder_linear_weight
|
| 1178 |
+
if (
|
| 1179 |
+
is_lp_weight
|
| 1180 |
+
or mkldnn._is_mkldnn_acl_supported()
|
| 1181 |
+
or V.aot_compilation is True
|
| 1182 |
+
)
|
| 1183 |
+
else torch.ops.mkl._mkl_reorder_linear_weight
|
| 1184 |
+
)
|
| 1185 |
+
packed_weight_node = graph.create_node(
|
| 1186 |
+
"call_function", packed_weight_op, args=packed_weight_inputs
|
| 1187 |
+
)
|
| 1188 |
+
|
| 1189 |
+
packed_linear_inputs: Tuple[Any, ...] = (input, packed_weight_node)
|
| 1190 |
+
if (
|
| 1191 |
+
is_lp_weight
|
| 1192 |
+
or mkldnn._is_mkldnn_acl_supported()
|
| 1193 |
+
or V.aot_compilation is True
|
| 1194 |
+
):
|
| 1195 |
+
packed_linear_inputs += (bias, "none", [], "")
|
| 1196 |
+
packed_linear_op = mkldnn._linear_pointwise.default
|
| 1197 |
+
else:
|
| 1198 |
+
packed_linear_inputs += (transpose_weight_node, bias, batch_size)
|
| 1199 |
+
packed_linear_op = torch.ops.mkl._mkl_linear
|
| 1200 |
+
packed_linear_node = graph.create_node(
|
| 1201 |
+
"call_function", packed_linear_op, packed_linear_inputs
|
| 1202 |
+
)
|
| 1203 |
+
linear_node.replace_all_uses_with(packed_linear_node)
|
| 1204 |
+
packed_linear_node.meta.update(linear_node.meta)
|
| 1205 |
+
graph.erase_node(linear_node)
|
| 1206 |
+
|
| 1207 |
+
def _eliminate_duplicate_packed_nodes(gm):
|
| 1208 |
+
"""
|
| 1209 |
+
Combine packed weight nodes with the same inputs to reduce memory usage.
|
| 1210 |
+
for example:
|
| 1211 |
+
class Model(nn.Module):
|
| 1212 |
+
def __init__(self) -> None:
|
| 1213 |
+
super().__init__()
|
| 1214 |
+
self.linear = nn.Linear(32, 32, bias=True)
|
| 1215 |
+
|
| 1216 |
+
def forward(self, x):
|
| 1217 |
+
return self.linear(self.linear(x))
|
| 1218 |
+
|
| 1219 |
+
the above's packed weight nodes are duplicate if two linear calls have same input size.
|
| 1220 |
+
"""
|
| 1221 |
+
if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()):
|
| 1222 |
+
return gm
|
| 1223 |
+
|
| 1224 |
+
packed_weight_ops = [
|
| 1225 |
+
torch._C._nn.mkldnn_reorder_conv2d_weight,
|
| 1226 |
+
torch._C._nn.mkldnn_reorder_conv3d_weight,
|
| 1227 |
+
mkldnn._reorder_convolution_transpose_weight,
|
| 1228 |
+
mkldnn._reorder_linear_weight,
|
| 1229 |
+
mkldnn._reorder_mkldnn_rnn_layer_weight,
|
| 1230 |
+
]
|
| 1231 |
+
if torch._C.has_mkl:
|
| 1232 |
+
packed_weight_ops.append(torch.ops.mkl._mkl_reorder_linear_weight)
|
| 1233 |
+
|
| 1234 |
+
for node in gm.graph.nodes:
|
| 1235 |
+
if node.target in packed_weight_ops and len(node.args[0].users) > 1:
|
| 1236 |
+
for user_node in list(node.args[0].users.keys()):
|
| 1237 |
+
if (
|
| 1238 |
+
user_node.target == node.target
|
| 1239 |
+
and user_node != node
|
| 1240 |
+
and user_node.args == node.args
|
| 1241 |
+
):
|
| 1242 |
+
user_node.replace_all_uses_with(node)
|
| 1243 |
+
gm.graph.erase_node(user_node)
|
| 1244 |
+
|
| 1245 |
+
@functools.lru_cache(None)
|
| 1246 |
+
def _mkldnn_fusion_init():
|
| 1247 |
+
# TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now.
|
| 1248 |
+
# Otherwise even the matmul or innerproduct can not be accelerated with acl
|
| 1249 |
+
if (
|
| 1250 |
+
torch.backends.mkldnn.enabled
|
| 1251 |
+
and torch.backends.mkldnn.is_available()
|
| 1252 |
+
and not torch.ops.mkldnn._is_mkldnn_acl_supported()
|
| 1253 |
+
):
|
| 1254 |
+
_register_unary_fusion()
|
| 1255 |
+
_register_inplace_fusion()
|
| 1256 |
+
_register_binary_unary_fusion()
|
| 1257 |
+
_register_binary_fusion()
|
| 1258 |
+
_register_quantization_lowerings()
|
| 1259 |
+
_register_woq_lowerings()
|
| 1260 |
+
|
| 1261 |
+
@functools.lru_cache(None)
|
| 1262 |
+
def _mkldnn_weight_pack_init():
|
| 1263 |
+
if torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available():
|
| 1264 |
+
_register_weight_pack_pass()
|
| 1265 |
+
_recover_linear()
|
| 1266 |
+
_register_quantization_weight_pack_pass()
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/numeric_utils.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import gc
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import traceback
|
| 7 |
+
|
| 8 |
+
import numpy
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.optim as optim
|
| 12 |
+
|
| 13 |
+
from .. import config
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
logger: logging.Logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
MAIN_RANDOM_SEED = 1337
|
| 19 |
+
|
| 20 |
+
# Set the CUBLAS_WORKSPACE_CONFIG environment variable
|
| 21 |
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# If the two forward functions involve any non-deterministic operations,
|
| 25 |
+
# such as certain types of parallelism or asynchronous execution,
|
| 26 |
+
# this can also lead to different outputs.
|
| 27 |
+
def set_deterministic() -> None:
|
| 28 |
+
"""Make torch manual seed deterministic."""
|
| 29 |
+
|
| 30 |
+
torch.manual_seed(MAIN_RANDOM_SEED)
|
| 31 |
+
random.seed(MAIN_RANDOM_SEED)
|
| 32 |
+
numpy.random.seed(MAIN_RANDOM_SEED)
|
| 33 |
+
torch.use_deterministic_algorithms(True)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def clean_memory() -> None:
|
| 37 |
+
"""Clean memory to avoid OOM."""
|
| 38 |
+
gc.collect()
|
| 39 |
+
torch.cuda.empty_cache()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# We compare the numerical results before and after pre/post grad fx passes
|
| 43 |
+
# transformation to make sure the numerical results are the same.
|
| 44 |
+
def compare_dict_tensors(dict_base, dict_control, precision):
|
| 45 |
+
if len(set(dict_base.keys())) != len(set(dict_control.keys())):
|
| 46 |
+
logger.warning("Mismatch keys found before and after pre/post grad fx passes.")
|
| 47 |
+
logger.debug("keys before pre/post grad fx passes %s", dict_base.keys())
|
| 48 |
+
logger.debug("keys after pre/post grad fx passes %s", dict_control.keys())
|
| 49 |
+
return False
|
| 50 |
+
is_allclose = True
|
| 51 |
+
for key in dict_base.keys():
|
| 52 |
+
if key not in dict_control:
|
| 53 |
+
logger.warning(
|
| 54 |
+
"Mismatch parameter name %s does not exist after pre/post grad fx passes",
|
| 55 |
+
key,
|
| 56 |
+
)
|
| 57 |
+
# Some parameters have `None`, and not every param has a valid .grad field, we skip them
|
| 58 |
+
if dict_base[key] is None or dict_control[key] is None:
|
| 59 |
+
continue
|
| 60 |
+
if not torch.allclose(
|
| 61 |
+
dict_base[key],
|
| 62 |
+
dict_control[key],
|
| 63 |
+
rtol=precision,
|
| 64 |
+
atol=precision,
|
| 65 |
+
equal_nan=True,
|
| 66 |
+
):
|
| 67 |
+
logger.warning(
|
| 68 |
+
"Mismatch parameter values found before and after pre/post grad fx passes."
|
| 69 |
+
)
|
| 70 |
+
logger.debug("value before pre/post grad fx passes %s", dict_base[key])
|
| 71 |
+
logger.debug("value after pre/post grad fx passes %s", dict_control[key])
|
| 72 |
+
is_allclose = False
|
| 73 |
+
return is_allclose
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def compare_tuple_tensors(tuple_base, tuple_control, precision):
|
| 77 |
+
if len(tuple_base) != len(tuple_control):
|
| 78 |
+
logger.warning(
|
| 79 |
+
"Mismatch fw output length. before transformation: %s, after transformation: %s",
|
| 80 |
+
len(tuple_base),
|
| 81 |
+
len(tuple_control),
|
| 82 |
+
)
|
| 83 |
+
return False
|
| 84 |
+
is_allclose = True
|
| 85 |
+
for i in range(len(tuple_base)):
|
| 86 |
+
# Some parameters have `None`, we skip them
|
| 87 |
+
if tuple_base[i] is None or tuple_control[i] is None:
|
| 88 |
+
continue
|
| 89 |
+
if not torch.allclose(
|
| 90 |
+
tuple_base[i],
|
| 91 |
+
tuple_control[i],
|
| 92 |
+
rtol=precision,
|
| 93 |
+
atol=precision,
|
| 94 |
+
equal_nan=True,
|
| 95 |
+
):
|
| 96 |
+
logger.debug(
|
| 97 |
+
"forward output before pre/post grad fx passes %s", tuple_base[i]
|
| 98 |
+
)
|
| 99 |
+
logger.debug(
|
| 100 |
+
"forward output after pre/post grad fx passes %s", tuple_control[i]
|
| 101 |
+
)
|
| 102 |
+
is_allclose = False
|
| 103 |
+
return is_allclose
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def compare_parameters(model_base, model_control, precision):
|
| 107 |
+
return compare_dict_tensors(
|
| 108 |
+
dict(model_base.named_parameters()),
|
| 109 |
+
dict(model_control.named_parameters()),
|
| 110 |
+
precision,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def compare_forward_output(pred_base, pred_control, precision):
|
| 115 |
+
return compare_tuple_tensors(
|
| 116 |
+
pred_base,
|
| 117 |
+
pred_control,
|
| 118 |
+
precision,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def compare_gradients(model_base, model_control, precision):
|
| 123 |
+
grad_base = {key: param.grad for key, param in model_base.named_parameters()}
|
| 124 |
+
grad_pt2 = {key: param.grad for key, param in model_control.named_parameters()}
|
| 125 |
+
return compare_dict_tensors(
|
| 126 |
+
grad_base,
|
| 127 |
+
grad_pt2,
|
| 128 |
+
precision,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def run_model(
|
| 133 |
+
model_base, model_control, model_input, num_iterations=10, precision=1e-4
|
| 134 |
+
):
|
| 135 |
+
clean_memory()
|
| 136 |
+
for i in range(num_iterations):
|
| 137 |
+
logger.info("start %s iteration", i)
|
| 138 |
+
set_deterministic()
|
| 139 |
+
pred_base = model_base(*model_input)
|
| 140 |
+
set_deterministic()
|
| 141 |
+
pred_control = model_control(*model_input)
|
| 142 |
+
|
| 143 |
+
res = compare_parameters(model_base, model_control, precision)
|
| 144 |
+
logger.info("compare parameters. Numerical result : %s", res)
|
| 145 |
+
|
| 146 |
+
res = compare_forward_output(pred_base, pred_control, precision)
|
| 147 |
+
logger.info("compare loss/predict. Numerical result : %s", res)
|
| 148 |
+
# tensor may not have a grad_fn
|
| 149 |
+
try:
|
| 150 |
+
_ = pred_base[0].sum().backward(retain_graph=True)
|
| 151 |
+
_ = pred_control[0].sum().backward(retain_graph=True)
|
| 152 |
+
res = compare_gradients(model_base, model_control, precision)
|
| 153 |
+
logger.info("compare param grad. Numerical result : %s", res)
|
| 154 |
+
except Exception:
|
| 155 |
+
logger.exception("Exception when comparing gradients")
|
| 156 |
+
traceback.print_exc()
|
| 157 |
+
|
| 158 |
+
if config.fx_passes_numeric_check["requires_optimizer"]:
|
| 159 |
+
try:
|
| 160 |
+
optimizer_base = optim.SGD(
|
| 161 |
+
[param for name, param in model_base.named_parameters()], lr=0.01
|
| 162 |
+
)
|
| 163 |
+
optimizer_base.step()
|
| 164 |
+
|
| 165 |
+
optimizer_control = optim.SGD(
|
| 166 |
+
[param for name, param in model_control.named_parameters()], lr=0.01
|
| 167 |
+
)
|
| 168 |
+
optimizer_control.step()
|
| 169 |
+
|
| 170 |
+
res = compare_parameters(model_base, model_control, precision)
|
| 171 |
+
logger.info(
|
| 172 |
+
"compare parameters with optimizer added. Numerical result : %s",
|
| 173 |
+
res,
|
| 174 |
+
)
|
| 175 |
+
except Exception as e:
|
| 176 |
+
logger.exception(
|
| 177 |
+
"Exception when optimizer is added to check parameter names"
|
| 178 |
+
)
|
| 179 |
+
traceback.print_exc()
|
| 180 |
+
else:
|
| 181 |
+
logger.warning(
|
| 182 |
+
"no parameter with optimizer to compare with length %s before transformation"
|
| 183 |
+
" and the length %s after transformation",
|
| 184 |
+
len(dict(model_base.named_parameters())),
|
| 185 |
+
len(dict(model_control.named_parameters())),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def numeric_check_if_enabled(
|
| 190 |
+
gm_before_fx_passes,
|
| 191 |
+
gm_after_fx_passes,
|
| 192 |
+
example_inputs,
|
| 193 |
+
num_iterations,
|
| 194 |
+
precision,
|
| 195 |
+
):
|
| 196 |
+
# need to topo-sort graphmodule before we run the model,
|
| 197 |
+
# otherwise it may fail as refer before def
|
| 198 |
+
# fail silently in order not to block the model run
|
| 199 |
+
try:
|
| 200 |
+
with torch.autograd.set_detect_anomaly(True):
|
| 201 |
+
run_model(
|
| 202 |
+
gm_before_fx_passes,
|
| 203 |
+
gm_after_fx_passes,
|
| 204 |
+
example_inputs,
|
| 205 |
+
num_iterations=num_iterations,
|
| 206 |
+
precision=precision,
|
| 207 |
+
)
|
| 208 |
+
except Exception as e:
|
| 209 |
+
logger.warning(
|
| 210 |
+
"Runtime numeric check failed in pre grad fx passes with error: %s", e
|
| 211 |
+
)
|
| 212 |
+
traceback.print_exc()
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pad_mm.py
ADDED
|
@@ -0,0 +1,881 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
import itertools
|
| 4 |
+
import operator
|
| 5 |
+
import typing
|
| 6 |
+
from typing import Callable, List, Optional, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor.runtime.runtime_utils
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
from torch._dynamo.utils import counters
|
| 12 |
+
from torch._inductor import utils
|
| 13 |
+
from torch._inductor.autoheuristic.autoheuristic import (
|
| 14 |
+
AHContext,
|
| 15 |
+
AutoHeuristic,
|
| 16 |
+
LocalFeedback,
|
| 17 |
+
)
|
| 18 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 19 |
+
context_add_strides,
|
| 20 |
+
context_add_using_tf32,
|
| 21 |
+
pad_mm_operations,
|
| 22 |
+
pad_mm_precondition,
|
| 23 |
+
)
|
| 24 |
+
from torch._subclasses.fake_tensor import FakeTensor
|
| 25 |
+
from torch.utils._mode_utils import no_dispatch
|
| 26 |
+
|
| 27 |
+
from ...utils._triton import has_triton
|
| 28 |
+
from ..pattern_matcher import (
|
| 29 |
+
fwd_only,
|
| 30 |
+
gen_register_replacement,
|
| 31 |
+
joint_fwd_bwd,
|
| 32 |
+
Match,
|
| 33 |
+
ReplaceFn,
|
| 34 |
+
SearchFn,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
aten = torch.ops.aten
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# This flag is only used for testing purpose.
|
| 42 |
+
# Changing it to True will ignore comparing do_bench times
|
| 43 |
+
# between original pattern and padded one.
|
| 44 |
+
_skip_do_bench_times = False
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def fetch_fake_tensors(match, kwarg_names) -> List[Tensor]:
|
| 48 |
+
kwargs = match.kwargs
|
| 49 |
+
return [kwargs[name].meta["val"] for name in kwarg_names]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def unwrap_fake_args(*arg_names):
|
| 53 |
+
def decorator(func):
|
| 54 |
+
def wrapper(match):
|
| 55 |
+
fake_tensors = fetch_fake_tensors(match, arg_names)
|
| 56 |
+
return func(*fake_tensors)
|
| 57 |
+
|
| 58 |
+
return wrapper
|
| 59 |
+
|
| 60 |
+
return decorator
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_alignment_size(x: Tensor) -> int:
|
| 64 |
+
return get_alignment_size_dtype(x.dtype)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_alignment_size_dtype(dtype: torch.dtype) -> int:
|
| 68 |
+
if dtype == torch.float16 or dtype == torch.half or dtype == torch.bfloat16:
|
| 69 |
+
return 8
|
| 70 |
+
elif dtype == torch.float32 or dtype == torch.float:
|
| 71 |
+
return 4
|
| 72 |
+
else:
|
| 73 |
+
return 0
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def check_device(a: Tensor, b: Tensor) -> bool:
|
| 77 |
+
return a.is_cuda and b.is_cuda
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def check_dtype(a: Tensor, b: Tensor) -> bool:
|
| 81 |
+
return a.is_floating_point() and b.is_floating_point()
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def should_pad_common(
|
| 85 |
+
mat1: Tensor, mat2: Tensor, input: Optional[Tensor] = None
|
| 86 |
+
) -> bool:
|
| 87 |
+
# It's fine we have symbolic shapes or strides as long as they
|
| 88 |
+
# have hints. Later, we will make sure we only pad non-symbolic dimensions.
|
| 89 |
+
def valid_shape_and_stride(t: Optional[Tensor]) -> bool:
|
| 90 |
+
if t is None:
|
| 91 |
+
return True
|
| 92 |
+
|
| 93 |
+
symbolic_cnt = 0
|
| 94 |
+
for x in t.size():
|
| 95 |
+
if isinstance(x, int):
|
| 96 |
+
continue
|
| 97 |
+
elif utils.is_symbolic(x):
|
| 98 |
+
if not x.node.has_hint():
|
| 99 |
+
return False
|
| 100 |
+
symbolic_cnt += 1
|
| 101 |
+
else:
|
| 102 |
+
return False
|
| 103 |
+
# filter out cases where all dimentions are symbolic
|
| 104 |
+
if symbolic_cnt == len(t.size()):
|
| 105 |
+
return False
|
| 106 |
+
return all(
|
| 107 |
+
isinstance(x, int) or (utils.is_symbolic(x) and x.node.has_hint())
|
| 108 |
+
for x in t.stride()
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
return (
|
| 112 |
+
torch._inductor.config.shape_padding
|
| 113 |
+
and check_device(mat1, mat2)
|
| 114 |
+
and check_dtype(mat1, mat2)
|
| 115 |
+
and all(valid_shape_and_stride(t) for t in (mat1, mat2, input))
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def get_padded_length(x: Union[int, torch.SymInt], alignment_size) -> int:
|
| 120 |
+
# we don't pad x if it is symbolic
|
| 121 |
+
if isinstance(x, torch.SymInt) or alignment_size == 0 or x % alignment_size == 0:
|
| 122 |
+
return 0
|
| 123 |
+
|
| 124 |
+
# ignore dim that can be squeezed away
|
| 125 |
+
if x == 1:
|
| 126 |
+
return 0
|
| 127 |
+
|
| 128 |
+
return int((x // alignment_size + 1) * alignment_size) - x
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def pad_dim(x: Tensor, padded_length: int, dim: int) -> Tensor:
|
| 132 |
+
if padded_length == 0:
|
| 133 |
+
return x
|
| 134 |
+
pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :])
|
| 135 |
+
return torch.cat([x, pad], dim=dim)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def addmm_pattern(
|
| 139 |
+
input: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float
|
| 140 |
+
) -> Tensor:
|
| 141 |
+
return aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def should_pad_addmm(match: Match) -> bool:
|
| 145 |
+
mat1, mat2, input = fetch_fake_tensors(match, ("mat1", "mat2", "input"))
|
| 146 |
+
return should_pad_common(mat1, mat2, input) and should_pad_bench(
|
| 147 |
+
match, mat1, mat2, torch.ops.aten.addmm, input=input
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def pad_addmm(
|
| 152 |
+
input: Optional[Tensor],
|
| 153 |
+
mat1: Tensor,
|
| 154 |
+
mat2: Tensor,
|
| 155 |
+
m_padded_length: int,
|
| 156 |
+
k_padded_length: int,
|
| 157 |
+
n_padded_length: int,
|
| 158 |
+
beta=1.0,
|
| 159 |
+
alpha=1.0,
|
| 160 |
+
mat1_pre_padded: bool = False,
|
| 161 |
+
mat2_pre_padded: bool = False,
|
| 162 |
+
):
|
| 163 |
+
# for paddings, dim order is reversed for some reasons
|
| 164 |
+
# and for every dim, we need to specify left and right padding
|
| 165 |
+
if not mat1_pre_padded:
|
| 166 |
+
mat1 = pad_mat1(
|
| 167 |
+
mat1, m_padded_length=m_padded_length, k_padded_length=k_padded_length
|
| 168 |
+
)
|
| 169 |
+
if not mat2_pre_padded:
|
| 170 |
+
mat2 = pad_mat2(
|
| 171 |
+
mat2, k_padded_length=k_padded_length, n_padded_length=n_padded_length
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# the add broadcasts, so we only pad if the dimension != 1
|
| 175 |
+
if input is not None:
|
| 176 |
+
if n_padded_length != 0:
|
| 177 |
+
if input.dim() == 2 and input.shape[1] != 1:
|
| 178 |
+
input = pad_dim(input, n_padded_length, 1)
|
| 179 |
+
elif input.dim() == 1 and input.shape[0] != 1:
|
| 180 |
+
input = pad_dim(input, n_padded_length, 0)
|
| 181 |
+
if m_padded_length != 0 and input.dim() == 2 and input.shape[0] != 1:
|
| 182 |
+
input = pad_dim(input, m_padded_length, 0)
|
| 183 |
+
|
| 184 |
+
res = aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)
|
| 185 |
+
|
| 186 |
+
if m_padded_length != 0:
|
| 187 |
+
res = res[:-m_padded_length, :]
|
| 188 |
+
if n_padded_length != 0:
|
| 189 |
+
res = res[:, :-n_padded_length]
|
| 190 |
+
return res
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def addmm_replace(
|
| 194 |
+
input: Optional[Tensor], mat1: Tensor, mat2: Tensor, beta=1.0, alpha=1.0
|
| 195 |
+
) -> Tensor:
|
| 196 |
+
k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
|
| 197 |
+
n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
|
| 198 |
+
m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
|
| 199 |
+
return pad_addmm(
|
| 200 |
+
input,
|
| 201 |
+
mat1,
|
| 202 |
+
mat2,
|
| 203 |
+
m_padded_length,
|
| 204 |
+
k_padded_length,
|
| 205 |
+
n_padded_length,
|
| 206 |
+
beta,
|
| 207 |
+
alpha,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool:
|
| 212 |
+
denominator = M * K + N * K + M * N
|
| 213 |
+
if denominator == 0:
|
| 214 |
+
return False
|
| 215 |
+
arithmetic_intensity = (M * N * K) / denominator
|
| 216 |
+
|
| 217 |
+
# we have experienced some large perf hits in this case, even in bandwidth bound regimes
|
| 218 |
+
if (
|
| 219 |
+
dtype is torch.bfloat16
|
| 220 |
+
and K > M
|
| 221 |
+
and K > N
|
| 222 |
+
and torch.cuda.get_device_capability() < (9, 0)
|
| 223 |
+
): # doesnt repro on h100s:
|
| 224 |
+
return True
|
| 225 |
+
|
| 226 |
+
# Fails with AMD
|
| 227 |
+
try:
|
| 228 |
+
machine_balance = (
|
| 229 |
+
1000 * utils.get_device_tflops(dtype)
|
| 230 |
+
) / utils.get_gpu_dram_gbps()
|
| 231 |
+
except Exception:
|
| 232 |
+
return True
|
| 233 |
+
|
| 234 |
+
# dram_gbps might be underestimating bandwidth because of cache.
|
| 235 |
+
# if we estimate machine balance too low we might miss some speedups,
|
| 236 |
+
# if we extimate too high there will be unnecessary compilation time increase.
|
| 237 |
+
# TODO - finetune coefficient here. As a reference point, Triton mm model assumes
|
| 238 |
+
# 80% of reads are in cache and cache is 4x faster than dram_gbps
|
| 239 |
+
machine_balance = machine_balance * 0.5
|
| 240 |
+
|
| 241 |
+
return arithmetic_intensity > machine_balance
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@functools.lru_cache(None)
|
| 245 |
+
def get_pad_cache():
|
| 246 |
+
return torch._inductor.codecache.LocalCache()
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def get_cached_should_pad(key: str) -> bool:
|
| 250 |
+
return get_pad_cache().lookup(key)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def set_cached_should_pad(key: str, value: bool):
|
| 254 |
+
return get_pad_cache().set_value(key, value=value)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def get_cached_base_mm_benchmark_time(key: str) -> float:
|
| 258 |
+
return get_pad_cache().lookup(key)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def set_cached_base_mm_benchmark_time(key: str, value: float):
|
| 262 |
+
return get_pad_cache().set_value(key, value=value)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def should_pad_bench_key(
|
| 266 |
+
match,
|
| 267 |
+
mat1: Tensor,
|
| 268 |
+
mat2: Tensor,
|
| 269 |
+
op,
|
| 270 |
+
input: Optional[Tensor] = None,
|
| 271 |
+
is_base_time_key=False,
|
| 272 |
+
) -> str:
|
| 273 |
+
def tensor_key(t):
|
| 274 |
+
return (t.shape, t.stride(), t.dtype)
|
| 275 |
+
|
| 276 |
+
tf32_key = (
|
| 277 |
+
None if mat1.dtype != torch.float32 else torch.backends.cuda.matmul.allow_tf32
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
def fmt_pad(name):
|
| 281 |
+
if is_base_time_key:
|
| 282 |
+
return None
|
| 283 |
+
return f"exclude_pad:{should_exclude_padding_time(match, name)}"
|
| 284 |
+
|
| 285 |
+
key = (
|
| 286 |
+
tensor_key(mat1),
|
| 287 |
+
tensor_key(mat2),
|
| 288 |
+
fmt_pad("mat1"),
|
| 289 |
+
fmt_pad("mat2"),
|
| 290 |
+
op,
|
| 291 |
+
input if input is None else tensor_key(input),
|
| 292 |
+
tf32_key,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
key = str(key)
|
| 296 |
+
if is_base_time_key:
|
| 297 |
+
key = f"base mm time: {key}"
|
| 298 |
+
return key
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def get_non_view_def(node):
|
| 302 |
+
if node.op == operator.getitem:
|
| 303 |
+
return get_non_view_def(node.args[0])
|
| 304 |
+
|
| 305 |
+
if (
|
| 306 |
+
node.op == "call_function"
|
| 307 |
+
and isinstance(node.target, torch._ops.OpOverload)
|
| 308 |
+
and utils.is_view(node.target)
|
| 309 |
+
):
|
| 310 |
+
return get_non_view_def(node.all_input_nodes[0])
|
| 311 |
+
|
| 312 |
+
return node
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def should_exclude_padding_time(match, arg_name):
|
| 316 |
+
node_def = get_non_view_def(match.kwargs[arg_name])
|
| 317 |
+
|
| 318 |
+
# constant padding converts tensors to contiguous so even if the input tensor
|
| 319 |
+
# can be planned layout transform is not free. TODO - way to pad and preserve layout ?
|
| 320 |
+
if not fetch_fake_tensors(match, (arg_name,))[0].is_contiguous():
|
| 321 |
+
return False
|
| 322 |
+
|
| 323 |
+
# TODO - see issue https://githpub.com/pytorch/pytorch/issues/128889
|
| 324 |
+
# We would only able to completely plan these out if we were only doing
|
| 325 |
+
# first dimension padding. non-first we would still need a copy
|
| 326 |
+
# because these outputs are fixed dense.
|
| 327 |
+
cannot_plan_output = [
|
| 328 |
+
aten.mm.default,
|
| 329 |
+
aten.convolution.default,
|
| 330 |
+
aten.convolution_backward.default,
|
| 331 |
+
aten.bmm.default,
|
| 332 |
+
aten.addmm.default,
|
| 333 |
+
aten._scaled_dot_product_flash_attention.default,
|
| 334 |
+
aten._scaled_dot_product_efficient_attention.default,
|
| 335 |
+
]
|
| 336 |
+
|
| 337 |
+
if node_def.target in cannot_plan_output:
|
| 338 |
+
return False
|
| 339 |
+
|
| 340 |
+
if (
|
| 341 |
+
node_def.target == aten.cat.default
|
| 342 |
+
and len(node_def.all_input_nodes)
|
| 343 |
+
> torch._inductor.config.max_pointwise_cat_inputs
|
| 344 |
+
):
|
| 345 |
+
return False
|
| 346 |
+
|
| 347 |
+
# optimistically assume we should be able to memory plan away
|
| 348 |
+
# all non inputs
|
| 349 |
+
return node_def.op != "placeholder"
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def should_pad(key: str, ori_time, pad_time) -> bool:
|
| 353 |
+
multiplier = 1.1
|
| 354 |
+
# Shape padding introduces additional memory ops. Based on microbenchmarks, 1.1x represents a reasonable
|
| 355 |
+
# tradeoff between performance improvement from shape padding and overhead from additional memory ops
|
| 356 |
+
# TODO: Build a learned model which would be better than this heuristic
|
| 357 |
+
if "shape_padding_multiplier" in torch._inductor.config.post_grad_fusion_options:
|
| 358 |
+
multiplier = torch._inductor.config.post_grad_fusion_options[
|
| 359 |
+
"shape_padding_multiplier"
|
| 360 |
+
].get("value", 1.1)
|
| 361 |
+
counters["inductor"]["shape_padding_multiplier"] += 1
|
| 362 |
+
should_pad = _skip_do_bench_times or ori_time > pad_time * multiplier
|
| 363 |
+
set_cached_should_pad(key, should_pad)
|
| 364 |
+
return should_pad
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def should_pad_bench(
|
| 368 |
+
match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None
|
| 369 |
+
) -> bool:
|
| 370 |
+
do_bench = functools.partial(
|
| 371 |
+
torch._inductor.runtime.benchmarking.benchmarker.benchmark_gpu,
|
| 372 |
+
warmup=5,
|
| 373 |
+
)
|
| 374 |
+
m_padded_length = 0
|
| 375 |
+
n_padded_length = 0
|
| 376 |
+
batchsize = 1
|
| 377 |
+
with no_dispatch():
|
| 378 |
+
if op is torch.ops.aten.mm or op is torch.ops.aten.addmm:
|
| 379 |
+
m = mat1.shape[0]
|
| 380 |
+
k = mat1.shape[1]
|
| 381 |
+
n = mat2.shape[1]
|
| 382 |
+
k_padded_length = get_padded_length(k, get_alignment_size(mat1))
|
| 383 |
+
n_padded_length = get_padded_length(n, get_alignment_size(mat2))
|
| 384 |
+
m_padded_length = get_padded_length(m, get_alignment_size(mat1))
|
| 385 |
+
elif op is torch.ops.aten.bmm:
|
| 386 |
+
batchsize = mat1.shape[0]
|
| 387 |
+
m = mat1.shape[1]
|
| 388 |
+
k = mat1.shape[2]
|
| 389 |
+
n = mat2.shape[2]
|
| 390 |
+
k_padded_length = get_padded_length(k, get_alignment_size(mat1))
|
| 391 |
+
m_padded_length = get_padded_length(m, get_alignment_size(mat1))
|
| 392 |
+
n_padded_length = get_padded_length(n, get_alignment_size(mat2))
|
| 393 |
+
else:
|
| 394 |
+
return False
|
| 395 |
+
|
| 396 |
+
if m_padded_length == k_padded_length == n_padded_length == 0:
|
| 397 |
+
return False
|
| 398 |
+
|
| 399 |
+
def realize_symbols(ds):
|
| 400 |
+
return [d if isinstance(d, int) else d.node.hint for d in ds]
|
| 401 |
+
|
| 402 |
+
if any(
|
| 403 |
+
dim == 0
|
| 404 |
+
for dim in itertools.chain(
|
| 405 |
+
realize_symbols(mat1.shape), realize_symbols(mat2.shape)
|
| 406 |
+
)
|
| 407 |
+
):
|
| 408 |
+
return False
|
| 409 |
+
|
| 410 |
+
if torch._inductor.config.force_shape_pad:
|
| 411 |
+
return True
|
| 412 |
+
|
| 413 |
+
if not has_triton():
|
| 414 |
+
return False
|
| 415 |
+
|
| 416 |
+
if not is_mm_compute_bound(m, k, n, mat1.dtype):
|
| 417 |
+
return False
|
| 418 |
+
|
| 419 |
+
# We don't want to look up the cache for cases that are trivially false
|
| 420 |
+
# since it does file io
|
| 421 |
+
key = should_pad_bench_key(match, mat1, mat2, op, input)
|
| 422 |
+
|
| 423 |
+
cached_pad = get_cached_should_pad(key)
|
| 424 |
+
if cached_pad is not None:
|
| 425 |
+
return cached_pad
|
| 426 |
+
|
| 427 |
+
def realize_tensor(t):
|
| 428 |
+
if isinstance(t, FakeTensor):
|
| 429 |
+
size_hints = realize_symbols(t.size())
|
| 430 |
+
stride_hint = realize_symbols(t.stride())
|
| 431 |
+
real_size = (
|
| 432 |
+
sum((d - 1) * s for d, s in zip(size_hints, stride_hint)) + 1
|
| 433 |
+
)
|
| 434 |
+
real_t = torch.randn(real_size, dtype=t.dtype, device=t.device)
|
| 435 |
+
return torch.as_strided(real_t, size_hints, stride_hint)
|
| 436 |
+
else:
|
| 437 |
+
return torch.randn_like(t)
|
| 438 |
+
|
| 439 |
+
mat1 = realize_tensor(mat1)
|
| 440 |
+
mat2 = realize_tensor(mat2)
|
| 441 |
+
|
| 442 |
+
# since we key on whether or not the inputs can be memory planned, set cache for the
|
| 443 |
+
# original time which is unaffected by whether or not the input can be planned
|
| 444 |
+
ori_time_key = should_pad_bench_key(
|
| 445 |
+
match, mat1, mat2, op, input, is_base_time_key=True
|
| 446 |
+
)
|
| 447 |
+
ori_time = get_cached_base_mm_benchmark_time(ori_time_key)
|
| 448 |
+
if ori_time is None and op is torch.ops.aten.addmm and input is not None:
|
| 449 |
+
# realize bias for addmm
|
| 450 |
+
input = realize_tensor(input)
|
| 451 |
+
|
| 452 |
+
mat1_pad = mat1
|
| 453 |
+
mat2_pad = mat2
|
| 454 |
+
|
| 455 |
+
is_bmm = op is torch.ops.aten.bmm
|
| 456 |
+
|
| 457 |
+
mat1_pre_padded = should_exclude_padding_time(match, "mat1")
|
| 458 |
+
fns = []
|
| 459 |
+
if mat1_pre_padded and (m_padded_length or k_padded_length):
|
| 460 |
+
mat1_pad = pad_mat1(
|
| 461 |
+
mat1_pad,
|
| 462 |
+
m_padded_length=m_padded_length,
|
| 463 |
+
k_padded_length=k_padded_length,
|
| 464 |
+
is_bmm=is_bmm,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
def write_pad():
|
| 468 |
+
if is_bmm:
|
| 469 |
+
mat1_pad[:, -m_padded_length:, -k_padded_length:].fill_(0)
|
| 470 |
+
else:
|
| 471 |
+
mat1_pad[-m_padded_length:, -k_padded_length:].fill_(0)
|
| 472 |
+
|
| 473 |
+
fns.append(write_pad)
|
| 474 |
+
|
| 475 |
+
mat2_pre_padded = should_exclude_padding_time(match, "mat2")
|
| 476 |
+
if mat2_pre_padded and (k_padded_length or n_padded_length):
|
| 477 |
+
mat2_pad = pad_mat2(
|
| 478 |
+
mat2_pad,
|
| 479 |
+
k_padded_length=k_padded_length,
|
| 480 |
+
n_padded_length=n_padded_length,
|
| 481 |
+
is_bmm=is_bmm,
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
def write_pad():
|
| 485 |
+
if is_bmm:
|
| 486 |
+
mat2_pad[:, -k_padded_length:, -n_padded_length:].fill_(0)
|
| 487 |
+
else:
|
| 488 |
+
mat2_pad[-k_padded_length:, -n_padded_length:].fill_(0)
|
| 489 |
+
|
| 490 |
+
fns.append(write_pad)
|
| 491 |
+
|
| 492 |
+
if op is torch.ops.aten.addmm:
|
| 493 |
+
input_pad = None
|
| 494 |
+
if input is not None and input.is_cuda:
|
| 495 |
+
input_pad = torch.randn_like(input)
|
| 496 |
+
fns.append(
|
| 497 |
+
lambda: pad_addmm(
|
| 498 |
+
input_pad,
|
| 499 |
+
mat1_pad,
|
| 500 |
+
mat2_pad,
|
| 501 |
+
m_padded_length,
|
| 502 |
+
k_padded_length,
|
| 503 |
+
n_padded_length,
|
| 504 |
+
mat1_pre_padded=mat1_pre_padded,
|
| 505 |
+
mat2_pre_padded=mat2_pre_padded,
|
| 506 |
+
)
|
| 507 |
+
)
|
| 508 |
+
elif op is torch.ops.aten.mm:
|
| 509 |
+
fns.append(
|
| 510 |
+
lambda: pad_mm(
|
| 511 |
+
mat1_pad,
|
| 512 |
+
mat2_pad,
|
| 513 |
+
m_padded_length,
|
| 514 |
+
k_padded_length,
|
| 515 |
+
n_padded_length,
|
| 516 |
+
mat1_pre_padded=mat1_pre_padded,
|
| 517 |
+
mat2_pre_padded=mat2_pre_padded,
|
| 518 |
+
)
|
| 519 |
+
)
|
| 520 |
+
else:
|
| 521 |
+
fns.append(
|
| 522 |
+
lambda: pad_bmm(
|
| 523 |
+
mat1_pad,
|
| 524 |
+
mat2_pad,
|
| 525 |
+
m_padded_length,
|
| 526 |
+
k_padded_length,
|
| 527 |
+
n_padded_length,
|
| 528 |
+
mat1_pre_padded=mat1_pre_padded,
|
| 529 |
+
mat2_pre_padded=mat2_pre_padded,
|
| 530 |
+
)
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
def orig_bench_fn():
|
| 534 |
+
if op is torch.ops.aten.bmm or op is torch.ops.aten.mm:
|
| 535 |
+
op(mat1, mat2)
|
| 536 |
+
else:
|
| 537 |
+
op(input, mat1, mat2)
|
| 538 |
+
|
| 539 |
+
def pad_bench_fn():
|
| 540 |
+
for fn in fns:
|
| 541 |
+
fn()
|
| 542 |
+
|
| 543 |
+
if (
|
| 544 |
+
torch._inductor.config.run_autoheuristic("pad_mm")
|
| 545 |
+
and op is torch.ops.aten.mm
|
| 546 |
+
):
|
| 547 |
+
ah_should_pad = run_autoheuristic(
|
| 548 |
+
mat1,
|
| 549 |
+
mat2,
|
| 550 |
+
orig_bench_fn,
|
| 551 |
+
pad_bench_fn,
|
| 552 |
+
m_padded_length,
|
| 553 |
+
k_padded_length,
|
| 554 |
+
n_padded_length,
|
| 555 |
+
do_bench,
|
| 556 |
+
mat1_pre_padded,
|
| 557 |
+
mat2_pre_padded,
|
| 558 |
+
ori_time,
|
| 559 |
+
ori_time_key,
|
| 560 |
+
key,
|
| 561 |
+
)
|
| 562 |
+
if ah_should_pad is not None:
|
| 563 |
+
return ah_should_pad
|
| 564 |
+
|
| 565 |
+
if ori_time is None:
|
| 566 |
+
ori_time = do_bench(orig_bench_fn)
|
| 567 |
+
set_cached_base_mm_benchmark_time(ori_time_key, ori_time)
|
| 568 |
+
|
| 569 |
+
pad_time = do_bench(pad_bench_fn)
|
| 570 |
+
return should_pad(key, ori_time, pad_time)
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
def get_context(
|
| 574 |
+
mat1: Tensor,
|
| 575 |
+
mat2: Tensor,
|
| 576 |
+
mat1_pre_padded: bool,
|
| 577 |
+
mat2_pre_padded: bool,
|
| 578 |
+
m_padded_length: int,
|
| 579 |
+
k_padded_length: int,
|
| 580 |
+
n_padded_length: int,
|
| 581 |
+
):
|
| 582 |
+
context = AHContext()
|
| 583 |
+
|
| 584 |
+
context.add_feature("m", mat1.shape[0])
|
| 585 |
+
context.add_feature("k", mat1.shape[1])
|
| 586 |
+
context.add_feature("n", mat2.shape[1])
|
| 587 |
+
|
| 588 |
+
context_add_strides(context, "mat1", mat1.stride())
|
| 589 |
+
context_add_strides(context, "mat2", mat2.stride())
|
| 590 |
+
|
| 591 |
+
context.add_feature("m_padded_length", m_padded_length)
|
| 592 |
+
context.add_feature("k_padded_length", k_padded_length)
|
| 593 |
+
context.add_feature("n_padded_length", n_padded_length)
|
| 594 |
+
|
| 595 |
+
context.add_feature("mat1_align_size", get_alignment_size(mat1))
|
| 596 |
+
context.add_feature("mat2_align_size", get_alignment_size(mat2))
|
| 597 |
+
|
| 598 |
+
context.add_feature("mat1_dtype", mat1.dtype, is_categorical=True)
|
| 599 |
+
context.add_feature("mat2_dtype", mat2.dtype, is_categorical=True)
|
| 600 |
+
|
| 601 |
+
context.add_feature("prepadded_mat1", mat1_pre_padded, is_categorical=True)
|
| 602 |
+
context.add_feature("prepadded_mat2", mat2_pre_padded, is_categorical=True)
|
| 603 |
+
|
| 604 |
+
context_add_using_tf32(context, mat1.dtype)
|
| 605 |
+
return context
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def run_autoheuristic(
|
| 609 |
+
mat1: Tensor,
|
| 610 |
+
mat2: Tensor,
|
| 611 |
+
orig_bench_fn: Callable[[], None],
|
| 612 |
+
pad_bench_fn: Callable[[], None],
|
| 613 |
+
m_padded_length: int,
|
| 614 |
+
k_padded_length: int,
|
| 615 |
+
n_padded_length: int,
|
| 616 |
+
do_bench,
|
| 617 |
+
mat1_pre_padded: bool,
|
| 618 |
+
mat2_pre_padded: bool,
|
| 619 |
+
ori_time,
|
| 620 |
+
ori_time_key: str,
|
| 621 |
+
key: str,
|
| 622 |
+
) -> Optional[bool]:
|
| 623 |
+
def feedback_fn(choice: str):
|
| 624 |
+
if choice == orig_choice:
|
| 625 |
+
return do_bench(orig_bench_fn)
|
| 626 |
+
elif choice == pad_choice:
|
| 627 |
+
return do_bench(pad_bench_fn)
|
| 628 |
+
return None
|
| 629 |
+
|
| 630 |
+
def fallback() -> str:
|
| 631 |
+
return "autotune"
|
| 632 |
+
|
| 633 |
+
orig_choice = "orig"
|
| 634 |
+
pad_choice = "pad"
|
| 635 |
+
choices = [orig_choice, pad_choice]
|
| 636 |
+
feedback = LocalFeedback(feedback_fn)
|
| 637 |
+
context = get_context(
|
| 638 |
+
mat1,
|
| 639 |
+
mat2,
|
| 640 |
+
mat1_pre_padded,
|
| 641 |
+
mat2_pre_padded,
|
| 642 |
+
m_padded_length,
|
| 643 |
+
k_padded_length,
|
| 644 |
+
n_padded_length,
|
| 645 |
+
)
|
| 646 |
+
name = "pad_mm"
|
| 647 |
+
autoheuristic = AutoHeuristic(
|
| 648 |
+
fallback=fallback,
|
| 649 |
+
choices=choices,
|
| 650 |
+
feedback=feedback,
|
| 651 |
+
context=context,
|
| 652 |
+
name=name,
|
| 653 |
+
augment_context=pad_mm_operations(),
|
| 654 |
+
precondition=pad_mm_precondition,
|
| 655 |
+
)
|
| 656 |
+
choice = autoheuristic.get_choice()
|
| 657 |
+
choice2should_pad = {orig_choice: False, pad_choice: True, "autotune": None}
|
| 658 |
+
ah_should_pad = choice2should_pad.get(choice, None)
|
| 659 |
+
|
| 660 |
+
if torch._inductor.config.collect_autoheuristic(name):
|
| 661 |
+
ah_ori_time = autoheuristic.get_collected_feedback(orig_choice)
|
| 662 |
+
ah_pad_time = autoheuristic.get_collected_feedback(pad_choice)
|
| 663 |
+
|
| 664 |
+
# if precondition is not satisifed, autoheuristic does not collect data
|
| 665 |
+
if ah_ori_time is not None and ah_pad_time is not None:
|
| 666 |
+
if ori_time is None:
|
| 667 |
+
set_cached_base_mm_benchmark_time(ori_time_key, ah_ori_time)
|
| 668 |
+
return should_pad(key, ah_ori_time, ah_pad_time)
|
| 669 |
+
if ah_should_pad is not None:
|
| 670 |
+
set_cached_should_pad(key, ah_should_pad)
|
| 671 |
+
return ah_should_pad
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
def mm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor:
|
| 675 |
+
return aten.mm(mat1, mat2)
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
def should_pad_mm(match: Match) -> bool:
|
| 679 |
+
mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2"))
|
| 680 |
+
return should_pad_common(mat1, mat2) and should_pad_bench(
|
| 681 |
+
match, mat1, mat2, torch.ops.aten.mm
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
def pad_mat1(mat1, *, m_padded_length, k_padded_length, is_bmm=False):
|
| 686 |
+
if m_padded_length == 0 and k_padded_length == 0:
|
| 687 |
+
return mat1
|
| 688 |
+
elif k_padded_length != 0 and m_padded_length != 0:
|
| 689 |
+
# dim order is reversed for constant_pad_nd, for every dim we specify right and left padding
|
| 690 |
+
pad_arg = [0, k_padded_length, 0, m_padded_length]
|
| 691 |
+
if is_bmm:
|
| 692 |
+
pad_arg.extend((0, 0))
|
| 693 |
+
return aten.constant_pad_nd(mat1, pad_arg)
|
| 694 |
+
elif m_padded_length != 0:
|
| 695 |
+
return pad_dim(mat1, m_padded_length, 0 if not is_bmm else 1)
|
| 696 |
+
else:
|
| 697 |
+
assert k_padded_length != 0
|
| 698 |
+
return pad_dim(mat1, k_padded_length, 1 if not is_bmm else 2)
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
def pad_mat2(mat2, *, k_padded_length, n_padded_length, is_bmm=False):
|
| 702 |
+
if k_padded_length == 0 and n_padded_length == 0:
|
| 703 |
+
return mat2
|
| 704 |
+
elif k_padded_length != 0 and n_padded_length != 0:
|
| 705 |
+
# dim order is reversed for constant_pad_nd, for every dim we specify right and left padding
|
| 706 |
+
pad_arg = [0, n_padded_length, 0, k_padded_length]
|
| 707 |
+
if is_bmm:
|
| 708 |
+
pad_arg.extend((0, 0))
|
| 709 |
+
return aten.constant_pad_nd(mat2, pad_arg)
|
| 710 |
+
elif k_padded_length != 0:
|
| 711 |
+
return pad_dim(mat2, k_padded_length, 0 if not is_bmm else 1)
|
| 712 |
+
else:
|
| 713 |
+
assert n_padded_length != 0
|
| 714 |
+
return pad_dim(mat2, n_padded_length, 1 if not is_bmm else 2)
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
def pad_mm(
|
| 718 |
+
mat1: Tensor,
|
| 719 |
+
mat2: Tensor,
|
| 720 |
+
m_padded_length: int,
|
| 721 |
+
k_padded_length: int,
|
| 722 |
+
n_padded_length: int,
|
| 723 |
+
mat1_pre_padded: bool = False,
|
| 724 |
+
mat2_pre_padded: bool = False,
|
| 725 |
+
) -> Tensor:
|
| 726 |
+
if not mat1_pre_padded:
|
| 727 |
+
mat1 = pad_mat1(
|
| 728 |
+
mat1, m_padded_length=m_padded_length, k_padded_length=k_padded_length
|
| 729 |
+
)
|
| 730 |
+
if not mat2_pre_padded:
|
| 731 |
+
mat2 = pad_mat2(
|
| 732 |
+
mat2, k_padded_length=k_padded_length, n_padded_length=n_padded_length
|
| 733 |
+
)
|
| 734 |
+
res = aten.mm(mat1, mat2)
|
| 735 |
+
if m_padded_length != 0:
|
| 736 |
+
res = res[:-m_padded_length, :]
|
| 737 |
+
if n_padded_length != 0:
|
| 738 |
+
res = res[:, :-n_padded_length]
|
| 739 |
+
return res
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
def mm_replace(mat1: Tensor, mat2: Tensor) -> Tensor:
|
| 743 |
+
k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
|
| 744 |
+
m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
|
| 745 |
+
n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
|
| 746 |
+
return pad_mm(
|
| 747 |
+
mat1,
|
| 748 |
+
mat2,
|
| 749 |
+
m_padded_length,
|
| 750 |
+
k_padded_length,
|
| 751 |
+
n_padded_length,
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
def bmm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor:
|
| 756 |
+
return aten.bmm(mat1, mat2)
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
def should_pad_bmm(match: Match) -> bool:
|
| 760 |
+
mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2"))
|
| 761 |
+
return should_pad_common(mat1, mat2) and should_pad_bench(
|
| 762 |
+
match, mat1, mat2, torch.ops.aten.bmm
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
def pad_bmm(
|
| 767 |
+
mat1: Tensor,
|
| 768 |
+
mat2: Tensor,
|
| 769 |
+
m_padded_length: int,
|
| 770 |
+
k_padded_length: int,
|
| 771 |
+
n_padded_length: int,
|
| 772 |
+
mat1_pre_padded: bool = False,
|
| 773 |
+
mat2_pre_padded: bool = False,
|
| 774 |
+
) -> Tensor:
|
| 775 |
+
if not mat1_pre_padded:
|
| 776 |
+
mat1 = pad_mat1(
|
| 777 |
+
mat1,
|
| 778 |
+
m_padded_length=m_padded_length,
|
| 779 |
+
k_padded_length=k_padded_length,
|
| 780 |
+
is_bmm=True,
|
| 781 |
+
)
|
| 782 |
+
if not mat2_pre_padded:
|
| 783 |
+
mat2 = pad_mat2(
|
| 784 |
+
mat2,
|
| 785 |
+
k_padded_length=k_padded_length,
|
| 786 |
+
n_padded_length=n_padded_length,
|
| 787 |
+
is_bmm=True,
|
| 788 |
+
)
|
| 789 |
+
res = aten.bmm(mat1, mat2)
|
| 790 |
+
if m_padded_length != 0:
|
| 791 |
+
res = res[:, :-m_padded_length, :]
|
| 792 |
+
if n_padded_length != 0:
|
| 793 |
+
res = res[:, :, :-n_padded_length]
|
| 794 |
+
return res
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
def bmm_replace(mat1: Tensor, mat2: Tensor) -> Tensor:
|
| 798 |
+
k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1))
|
| 799 |
+
n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2))
|
| 800 |
+
m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
|
| 801 |
+
return pad_bmm(
|
| 802 |
+
mat1,
|
| 803 |
+
mat2,
|
| 804 |
+
m_padded_length,
|
| 805 |
+
k_padded_length,
|
| 806 |
+
n_padded_length,
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
@functools.lru_cache(None)
|
| 811 |
+
def _pad_mm_init():
|
| 812 |
+
from .joint_graph import patterns
|
| 813 |
+
|
| 814 |
+
if torch.cuda.is_available():
|
| 815 |
+
# workaround https://github.com/pytorch/pytorch/issues/97894
|
| 816 |
+
device = "cuda"
|
| 817 |
+
else:
|
| 818 |
+
device = "cpu"
|
| 819 |
+
|
| 820 |
+
# sizes/values dont actually matter for initial trace
|
| 821 |
+
# once we get a possible match we re-trace with the actual values and verify the match still holds
|
| 822 |
+
|
| 823 |
+
dim2a = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True)
|
| 824 |
+
dim2b = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True)
|
| 825 |
+
|
| 826 |
+
dim3a = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True)
|
| 827 |
+
dim3b = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True)
|
| 828 |
+
|
| 829 |
+
dim1a = functools.partial(torch.empty, (4), device=device, requires_grad=True)
|
| 830 |
+
|
| 831 |
+
# workaround https://github.com/pytorch/pytorch/issues/97894
|
| 832 |
+
# 0.113377 is a "magic" value that lets us recover the lost input arg relationship
|
| 833 |
+
rep = {"beta": 0.213377, "alpha": 0.113377}
|
| 834 |
+
|
| 835 |
+
for pattern, replacement, args, workaround, extra_check in [
|
| 836 |
+
(
|
| 837 |
+
typing.cast(SearchFn, mm_pattern),
|
| 838 |
+
typing.cast(ReplaceFn, mm_replace),
|
| 839 |
+
[dim2a(), dim2b()],
|
| 840 |
+
{},
|
| 841 |
+
should_pad_mm,
|
| 842 |
+
),
|
| 843 |
+
(
|
| 844 |
+
typing.cast(SearchFn, bmm_pattern),
|
| 845 |
+
typing.cast(ReplaceFn, bmm_replace),
|
| 846 |
+
[dim3a(), dim3b()],
|
| 847 |
+
{},
|
| 848 |
+
should_pad_bmm,
|
| 849 |
+
),
|
| 850 |
+
(
|
| 851 |
+
typing.cast(SearchFn, addmm_pattern),
|
| 852 |
+
typing.cast(ReplaceFn, addmm_replace),
|
| 853 |
+
[dim1a(), dim2a(), dim2b()],
|
| 854 |
+
rep,
|
| 855 |
+
should_pad_addmm,
|
| 856 |
+
),
|
| 857 |
+
]:
|
| 858 |
+
assert isinstance(workaround, dict) # mypy is unable to infer the type properly
|
| 859 |
+
name = pattern.__name__
|
| 860 |
+
|
| 861 |
+
gen_register_replacement(
|
| 862 |
+
f"{name}_training",
|
| 863 |
+
pattern,
|
| 864 |
+
replacement,
|
| 865 |
+
args,
|
| 866 |
+
joint_fwd_bwd,
|
| 867 |
+
patterns,
|
| 868 |
+
extra_check=extra_check,
|
| 869 |
+
scalar_workaround=workaround,
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
gen_register_replacement(
|
| 873 |
+
f"{name}_inference",
|
| 874 |
+
pattern,
|
| 875 |
+
replacement,
|
| 876 |
+
args,
|
| 877 |
+
fwd_only,
|
| 878 |
+
patterns,
|
| 879 |
+
extra_check=extra_check,
|
| 880 |
+
scalar_workaround=workaround,
|
| 881 |
+
)
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/post_grad.py
ADDED
|
@@ -0,0 +1,1318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-decorators
|
| 2 |
+
# mypy: allow-untyped-defs
|
| 3 |
+
import functools
|
| 4 |
+
import itertools
|
| 5 |
+
import logging
|
| 6 |
+
import operator
|
| 7 |
+
from collections import Counter, defaultdict
|
| 8 |
+
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch._inductor as inductor
|
| 12 |
+
import torch.utils._pytree as pytree
|
| 13 |
+
from torch import fx
|
| 14 |
+
from torch._decomp import register_decomposition
|
| 15 |
+
from torch._dynamo.utils import counters, optimus_scuba_log
|
| 16 |
+
from torch._inductor import comms
|
| 17 |
+
from torch._inductor.virtualized import ops
|
| 18 |
+
from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype
|
| 19 |
+
from torch._utils_internal import upload_graph
|
| 20 |
+
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
|
| 21 |
+
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
| 22 |
+
|
| 23 |
+
from .. import config, ir, pattern_matcher
|
| 24 |
+
from ..codegen.common import BackendFeature, has_backend_feature
|
| 25 |
+
from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage
|
| 26 |
+
from ..lowering import lowerings as L
|
| 27 |
+
from ..pattern_matcher import (
|
| 28 |
+
_return_true,
|
| 29 |
+
Arg,
|
| 30 |
+
CallFunction,
|
| 31 |
+
CallFunctionVarArgs,
|
| 32 |
+
filter_nodes,
|
| 33 |
+
get_arg_value,
|
| 34 |
+
get_mutation_region_id,
|
| 35 |
+
Ignored,
|
| 36 |
+
init_once_fakemode,
|
| 37 |
+
KeywordArg,
|
| 38 |
+
ListOf,
|
| 39 |
+
Match,
|
| 40 |
+
MULTIPLE,
|
| 41 |
+
PatternMatcherPass,
|
| 42 |
+
register_graph_pattern,
|
| 43 |
+
stable_topological_sort,
|
| 44 |
+
)
|
| 45 |
+
from ..utils import decode_device, get_gpu_type, is_pointwise_use
|
| 46 |
+
from ..virtualized import V
|
| 47 |
+
from .b2b_gemm import B2B_GEMM_PASS
|
| 48 |
+
from .ddp_fusion import fuse_ddp_communication
|
| 49 |
+
from .group_batch_fusion import group_batch_fusion_passes, POST_GRAD_FUSIONS
|
| 50 |
+
from .micro_pipeline_tp import micro_pipeline_tp_pass
|
| 51 |
+
from .pre_grad import is_same_dict, save_inductor_dict
|
| 52 |
+
from .reinplace import reinplace_inplaceable_ops
|
| 53 |
+
from .split_cat import POST_GRAD_PATTERNS
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if TYPE_CHECKING:
|
| 57 |
+
from sympy import Expr
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
log = logging.getLogger(__name__)
|
| 61 |
+
aten = torch.ops.aten
|
| 62 |
+
prims = torch.ops.prims
|
| 63 |
+
|
| 64 |
+
# First pass_patterns[0] are applied, then [1], then [2]
|
| 65 |
+
pass_patterns = [
|
| 66 |
+
PatternMatcherPass(),
|
| 67 |
+
PatternMatcherPass(),
|
| 68 |
+
PatternMatcherPass(),
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
| 73 |
+
"""
|
| 74 |
+
Passes that run on after grad. This is called once on the forwards
|
| 75 |
+
graph and once on the backwards graph.
|
| 76 |
+
|
| 77 |
+
The IR here has been normalized and functionalized.
|
| 78 |
+
"""
|
| 79 |
+
if config.dce:
|
| 80 |
+
# has some issues with mutation in inference mode
|
| 81 |
+
gm.graph.eliminate_dead_code()
|
| 82 |
+
|
| 83 |
+
if is_inference and config.reorder_for_locality:
|
| 84 |
+
reorder_for_locality(gm.graph)
|
| 85 |
+
|
| 86 |
+
fake_tensor_updater = FakeTensorUpdater(gm.graph)
|
| 87 |
+
|
| 88 |
+
if config.post_grad_custom_pre_pass is not None:
|
| 89 |
+
with GraphTransformObserver(
|
| 90 |
+
gm, "post_grad_custom_pre_pass", config.trace.log_url_for_graph_xform
|
| 91 |
+
):
|
| 92 |
+
config.post_grad_custom_pre_pass(gm.graph)
|
| 93 |
+
|
| 94 |
+
if config.pattern_matcher:
|
| 95 |
+
lazy_init()
|
| 96 |
+
optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph)
|
| 97 |
+
group_batch_fusion_passes(gm.graph, pre_grad=False)
|
| 98 |
+
remove_noop_ops(gm.graph)
|
| 99 |
+
for patterns in pass_patterns:
|
| 100 |
+
patterns.apply(gm.graph) # type: ignore[arg-type]
|
| 101 |
+
for pass_name in config.post_grad_fusion_options:
|
| 102 |
+
# skip all patterns for group batch fusions
|
| 103 |
+
if pass_name in POST_GRAD_FUSIONS:
|
| 104 |
+
continue
|
| 105 |
+
pattern_matcher_pass = POST_GRAD_PATTERNS[pass_name]
|
| 106 |
+
inductor_before_change = save_inductor_dict(
|
| 107 |
+
[pattern_matcher_pass.pass_name]
|
| 108 |
+
)
|
| 109 |
+
pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
|
| 110 |
+
if not is_same_dict(counters["inductor"], inductor_before_change):
|
| 111 |
+
optimus_scuba_log[
|
| 112 |
+
f"{pattern_matcher_pass.pass_name}_post_grad"
|
| 113 |
+
] = upload_graph(gm.graph)
|
| 114 |
+
if config.b2b_gemm_pass:
|
| 115 |
+
B2B_GEMM_PASS.apply(gm.graph) # type: ignore[arg-type]
|
| 116 |
+
|
| 117 |
+
if config._micro_pipeline_tp:
|
| 118 |
+
micro_pipeline_tp_pass(gm.graph)
|
| 119 |
+
|
| 120 |
+
if config._fuse_ddp_communication:
|
| 121 |
+
fuse_ddp_communication(
|
| 122 |
+
gm.graph,
|
| 123 |
+
config._fuse_ddp_communication_passes,
|
| 124 |
+
config._fuse_ddp_bucket_size,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
if config.post_grad_custom_post_pass is not None:
|
| 128 |
+
with GraphTransformObserver(
|
| 129 |
+
gm, "post_grad_custom_post_pass", config.trace.log_url_for_graph_xform
|
| 130 |
+
):
|
| 131 |
+
config.post_grad_custom_post_pass(gm.graph)
|
| 132 |
+
|
| 133 |
+
stable_topological_sort(gm.graph)
|
| 134 |
+
|
| 135 |
+
move_constructors_to_gpu(gm.graph)
|
| 136 |
+
|
| 137 |
+
fake_tensor_updater.incremental_update()
|
| 138 |
+
|
| 139 |
+
# Keep these last, since they introduces mutation. Look at
|
| 140 |
+
# ./fx_passes/README.md for a discussion of mutation invariants.
|
| 141 |
+
reinplace_inplaceable_ops(gm.graph)
|
| 142 |
+
decompose_auto_functionalized(gm.graph)
|
| 143 |
+
|
| 144 |
+
comms.reinplace_fsdp_all_gather(gm.graph)
|
| 145 |
+
|
| 146 |
+
gm.recompile()
|
| 147 |
+
optimus_scuba_log["after_recompile_post_grad"] = upload_graph(gm.graph)
|
| 148 |
+
gm.graph.lint()
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
@init_once_fakemode
|
| 152 |
+
def lazy_init():
|
| 153 |
+
if torch._C._has_mkldnn:
|
| 154 |
+
from . import decompose_mem_bound_mm # noqa: F401
|
| 155 |
+
from .mkldnn_fusion import _mkldnn_fusion_init
|
| 156 |
+
|
| 157 |
+
_mkldnn_fusion_init()
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def reorder_for_locality(graph: torch.fx.Graph):
|
| 161 |
+
def visit(other_node):
|
| 162 |
+
if (
|
| 163 |
+
other_node.op == "call_function"
|
| 164 |
+
and other_node.target != operator.getitem
|
| 165 |
+
and all((n in seen_nodes) for n in other_node.users)
|
| 166 |
+
and get_mutation_region_id(graph, node)
|
| 167 |
+
== get_mutation_region_id(graph, other_node)
|
| 168 |
+
):
|
| 169 |
+
# move node's producers right before it
|
| 170 |
+
node.prepend(other_node)
|
| 171 |
+
|
| 172 |
+
seen_nodes = set()
|
| 173 |
+
|
| 174 |
+
# only reorder nodes before the first copy_ in the graph.
|
| 175 |
+
# copy_ will appear at the end of functionalized graphs when there is mutation on inputs,
|
| 176 |
+
# and this reordering doesnt work well with mutation
|
| 177 |
+
first_copy = next(
|
| 178 |
+
iter(graph.find_nodes(op="call_function", target=torch.ops.aten.copy_.default)),
|
| 179 |
+
None,
|
| 180 |
+
)
|
| 181 |
+
past_mutating_epilogue = True if first_copy is None else False
|
| 182 |
+
|
| 183 |
+
for node in reversed(graph.nodes):
|
| 184 |
+
seen_nodes.add(node)
|
| 185 |
+
if not past_mutating_epilogue:
|
| 186 |
+
past_mutating_epilogue = node is first_copy
|
| 187 |
+
continue
|
| 188 |
+
|
| 189 |
+
torch.fx.map_arg((node.args, node.kwargs), visit)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def register_lowering_pattern(pattern, extra_check=_return_true, pass_number=1):
|
| 193 |
+
"""
|
| 194 |
+
Register an aten to inductor IR replacement pattern
|
| 195 |
+
"""
|
| 196 |
+
return pattern_matcher.register_lowering_pattern(
|
| 197 |
+
pattern, extra_check, pass_dict=pass_patterns[pass_number]
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
################################################################################
|
| 202 |
+
# Actual patterns below this point.
|
| 203 |
+
# Priority of patterns is:
|
| 204 |
+
# - later output nodes first
|
| 205 |
+
# - order patterns are defined in
|
| 206 |
+
################################################################################
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def is_valid_mm_plus_mm(match: Match):
|
| 210 |
+
*b1, m1, k1 = match.kwargs["mat1"].meta.get("tensor_meta").shape
|
| 211 |
+
*b2, k2, n1 = match.kwargs["mat2"].meta.get("tensor_meta").shape
|
| 212 |
+
if k1 != k2:
|
| 213 |
+
return False
|
| 214 |
+
|
| 215 |
+
*b1, m2, k3 = match.kwargs["mat3"].meta.get("tensor_meta").shape
|
| 216 |
+
*b2, k4, n2 = match.kwargs["mat4"].meta.get("tensor_meta").shape
|
| 217 |
+
if k3 != k4:
|
| 218 |
+
return False
|
| 219 |
+
|
| 220 |
+
if m1 != m2 or n1 != n2:
|
| 221 |
+
return False
|
| 222 |
+
|
| 223 |
+
return True
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def scatter_upon_const_tensor_extra_check(m):
|
| 227 |
+
if not config.optimize_scatter_upon_const_tensor:
|
| 228 |
+
return False
|
| 229 |
+
full_shape = m.kwargs["shape"]
|
| 230 |
+
selector = m.kwargs["selector"]
|
| 231 |
+
dim = m.kwargs["dim"]
|
| 232 |
+
if dim < 0:
|
| 233 |
+
dim += len(full_shape)
|
| 234 |
+
|
| 235 |
+
selector_ft = selector.meta["val"]
|
| 236 |
+
assert selector_ft.dim() == len(full_shape)
|
| 237 |
+
|
| 238 |
+
for idx, select_sz, full_sz in zip(
|
| 239 |
+
itertools.count(), selector_ft.shape, full_shape
|
| 240 |
+
):
|
| 241 |
+
if idx == dim:
|
| 242 |
+
continue
|
| 243 |
+
|
| 244 |
+
# TODO: the pattern can be updated to support the case that index tensor
|
| 245 |
+
# is shorter. But that will need a more complex condition expression
|
| 246 |
+
# especially for multi-dimensional tensors.
|
| 247 |
+
# Skip it for now.
|
| 248 |
+
if isinstance(full_sz, fx.Node):
|
| 249 |
+
full_sz = full_sz.meta["val"]
|
| 250 |
+
if select_sz < full_sz:
|
| 251 |
+
return False
|
| 252 |
+
|
| 253 |
+
# Actually we can support small size larger than 1. It would be a bit
|
| 254 |
+
# tedius. E.g., we load all the index values (not many) and compare
|
| 255 |
+
# them with the position in tensor to decide what value to return.
|
| 256 |
+
return selector_ft.size(dim) == 1
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
@register_lowering_pattern(
|
| 260 |
+
CallFunction(
|
| 261 |
+
aten.scatter.value,
|
| 262 |
+
CallFunction(
|
| 263 |
+
aten.full,
|
| 264 |
+
KeywordArg("shape"),
|
| 265 |
+
KeywordArg("background_val"),
|
| 266 |
+
dtype=KeywordArg("dtype"),
|
| 267 |
+
),
|
| 268 |
+
KeywordArg("dim"),
|
| 269 |
+
KeywordArg("selector"),
|
| 270 |
+
KeywordArg("val"), # scalar value
|
| 271 |
+
),
|
| 272 |
+
extra_check=scatter_upon_const_tensor_extra_check,
|
| 273 |
+
)
|
| 274 |
+
def scatter_upon_const_tensor(
|
| 275 |
+
match: Match, shape, background_val, dtype, dim, selector, val
|
| 276 |
+
):
|
| 277 |
+
"""
|
| 278 |
+
Match the pattern of full+scatter into a pointwise.
|
| 279 |
+
|
| 280 |
+
TODO: Right now the scatter value must be a scalar. But we could support it
|
| 281 |
+
when it is a tensor as well.
|
| 282 |
+
"""
|
| 283 |
+
from torch._inductor import metrics
|
| 284 |
+
|
| 285 |
+
metrics.num_matches_for_scatter_upon_const_tensor += 1
|
| 286 |
+
|
| 287 |
+
selector_loader = selector.make_loader()
|
| 288 |
+
|
| 289 |
+
def inner_fn(idx):
|
| 290 |
+
selector_idx = list(idx)
|
| 291 |
+
selector_idx[dim] = 0
|
| 292 |
+
|
| 293 |
+
selector = selector_loader(selector_idx)
|
| 294 |
+
return ops.where(
|
| 295 |
+
selector == ops.index_expr(idx[dim], torch.int64),
|
| 296 |
+
ops.constant(val, dtype),
|
| 297 |
+
ops.constant(background_val, dtype),
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
return ir.Pointwise.create(
|
| 301 |
+
device=selector.get_device(),
|
| 302 |
+
dtype=dtype,
|
| 303 |
+
inner_fn=inner_fn,
|
| 304 |
+
ranges=shape,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
@register_lowering_pattern(
|
| 309 |
+
CallFunction(
|
| 310 |
+
aten.add,
|
| 311 |
+
CallFunction(aten.mm, KeywordArg("mat1"), KeywordArg("mat2")),
|
| 312 |
+
CallFunction(aten.mm, KeywordArg("mat3"), KeywordArg("mat4")),
|
| 313 |
+
),
|
| 314 |
+
extra_check=is_valid_mm_plus_mm,
|
| 315 |
+
)
|
| 316 |
+
def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4):
|
| 317 |
+
return inductor.kernel.mm_plus_mm.tuned_mm_plus_mm(mat1, mat2, mat3, mat4)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def cuda_and_enabled_mixed_mm(match):
|
| 321 |
+
return (
|
| 322 |
+
(config.use_mixed_mm or config.mixed_mm_choice != "default")
|
| 323 |
+
and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False)
|
| 324 |
+
and (
|
| 325 |
+
match.kwargs["mat2_dtype"].itemsize
|
| 326 |
+
> match.kwargs["mat2"].meta.get("val").dtype.itemsize
|
| 327 |
+
)
|
| 328 |
+
and has_backend_feature("cuda", BackendFeature.TRITON_TEMPLATES)
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def cuda_and_enabled_mixed_mm_and_not_int8(match):
|
| 333 |
+
return (
|
| 334 |
+
cuda_and_enabled_mixed_mm(match)
|
| 335 |
+
and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False)
|
| 336 |
+
and getattr(match.kwargs["mat2"].meta.get("val"), "dtype", torch.int8)
|
| 337 |
+
!= torch.int8
|
| 338 |
+
) # bitshift numerics in triton and pytorch don't match for torch.int8
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
"""
|
| 342 |
+
this is intended to be used to unpack a [K,N] int4 tensor from a [K/2, N] uint4x2 tensor
|
| 343 |
+
(where the int4 and uint4x2 are represented with int8 and uint8 respectively)
|
| 344 |
+
where every other row of the int4 is packed with the row above it as:
|
| 345 |
+
uint4x2[k,n] = (8+int4[2*k,n])+(8+int4[2*k+1,n])<<4
|
| 346 |
+
|
| 347 |
+
unpack formulas:
|
| 348 |
+
int4[2*k,n]=(uint4x2[k,n] & 0xF) - 8
|
| 349 |
+
int4[2*k+1,n]=(uint4x2[k,n] >> 4) - 8
|
| 350 |
+
|
| 351 |
+
thus matching on unpack formula:
|
| 352 |
+
torch.mm(mat1, torch.cat((mat2 & 0xF, mat2>>4),1).reshape(mat2_mm_shape).to(mat2_dtype).sub(8))
|
| 353 |
+
|
| 354 |
+
note: although the unpack formula in pytorch and the triton kernel is designed for a uint8 mat2, the behavior
|
| 355 |
+
of the kernel matches the pytorch formula for all dtypes except torch.int8
|
| 356 |
+
where the bitwise numerics in triton do not match those in pytorch.
|
| 357 |
+
"""
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
@register_lowering_pattern(
|
| 361 |
+
CallFunction(
|
| 362 |
+
aten.mm.default,
|
| 363 |
+
KeywordArg("mat1"),
|
| 364 |
+
CallFunction(
|
| 365 |
+
aten.sub.Tensor,
|
| 366 |
+
CallFunction(
|
| 367 |
+
prims.convert_element_type.default,
|
| 368 |
+
CallFunction(
|
| 369 |
+
aten.reshape.default,
|
| 370 |
+
CallFunction(
|
| 371 |
+
aten.cat.default,
|
| 372 |
+
ListOf(
|
| 373 |
+
CallFunction(
|
| 374 |
+
aten.bitwise_and.Scalar,
|
| 375 |
+
KeywordArg("mat2"),
|
| 376 |
+
0xF,
|
| 377 |
+
),
|
| 378 |
+
# CallFunction(
|
| 379 |
+
# aten.__rshift__.Scalar,
|
| 380 |
+
# KeywordArg("mat2"),
|
| 381 |
+
# 4,
|
| 382 |
+
# ),
|
| 383 |
+
True,
|
| 384 |
+
),
|
| 385 |
+
1,
|
| 386 |
+
),
|
| 387 |
+
KeywordArg("mat2_mm_shape"),
|
| 388 |
+
),
|
| 389 |
+
KeywordArg("mat2_dtype"),
|
| 390 |
+
),
|
| 391 |
+
8,
|
| 392 |
+
),
|
| 393 |
+
),
|
| 394 |
+
extra_check=cuda_and_enabled_mixed_mm_and_not_int8,
|
| 395 |
+
)
|
| 396 |
+
def uint4x2_mixed_mm(match: Match, mat1, mat2, mat2_mm_shape, mat2_dtype):
|
| 397 |
+
return inductor.kernel.unpack_mixed_mm.tuned_uint4x2_mixed_mm(
|
| 398 |
+
mat1, mat2, mat2_mm_shape, mat2_dtype
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
"""
|
| 403 |
+
torch.mm(mat1, mat2.to(mat2_dtype))
|
| 404 |
+
"""
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
@register_lowering_pattern(
|
| 408 |
+
CallFunction(
|
| 409 |
+
aten.mm,
|
| 410 |
+
KeywordArg("mat1"),
|
| 411 |
+
CallFunction(
|
| 412 |
+
prims.convert_element_type.default,
|
| 413 |
+
KeywordArg("mat2"),
|
| 414 |
+
KeywordArg("mat2_dtype"),
|
| 415 |
+
),
|
| 416 |
+
),
|
| 417 |
+
extra_check=cuda_and_enabled_mixed_mm,
|
| 418 |
+
)
|
| 419 |
+
def mixed_mm(match: Match, mat1, mat2, mat2_dtype):
|
| 420 |
+
return inductor.kernel.mm.tuned_mixed_mm(mat1, mat2, mat2_dtype)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
@register_graph_pattern(
|
| 424 |
+
CallFunction(
|
| 425 |
+
aten.cumsum.default,
|
| 426 |
+
CallFunction(
|
| 427 |
+
torch.ops.aten.full.default,
|
| 428 |
+
KeywordArg("shape"),
|
| 429 |
+
KeywordArg("fill_value"),
|
| 430 |
+
dtype=KeywordArg("dtype"),
|
| 431 |
+
layout=Ignored(),
|
| 432 |
+
device=KeywordArg("device"),
|
| 433 |
+
pin_memory=False,
|
| 434 |
+
_users=MULTIPLE,
|
| 435 |
+
),
|
| 436 |
+
KeywordArg("dim"),
|
| 437 |
+
_users=MULTIPLE,
|
| 438 |
+
),
|
| 439 |
+
pass_dict=pass_patterns[1],
|
| 440 |
+
)
|
| 441 |
+
def pointless_cumsum_replacement(match: Match, shape, fill_value, device, dtype, dim):
|
| 442 |
+
"""Based on a pattern in OPTForCausalLM"""
|
| 443 |
+
|
| 444 |
+
if is_integer_dtype(dtype) or is_boolean_dtype(dtype):
|
| 445 |
+
# cumsum promotes all integral types to int64
|
| 446 |
+
dtype = torch.int64
|
| 447 |
+
|
| 448 |
+
def repl(*shape):
|
| 449 |
+
dim_size = shape[dim]
|
| 450 |
+
idx = torch.arange(1, dim_size + 1, device=device, dtype=dtype)
|
| 451 |
+
|
| 452 |
+
inter_shape = [1] * len(shape)
|
| 453 |
+
inter_shape[dim] = dim_size
|
| 454 |
+
return (idx * fill_value).view(inter_shape).expand(shape)
|
| 455 |
+
|
| 456 |
+
# only replace the output node, not all nodes
|
| 457 |
+
match.nodes = [match.output_node()]
|
| 458 |
+
match.replace_by_example(repl, list(shape))
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def shape_of_mm(a, b):
|
| 462 |
+
m, _ = a.get_size()
|
| 463 |
+
_, n = b.get_size()
|
| 464 |
+
return [m, n]
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
@register_lowering_pattern(
|
| 468 |
+
CallFunction(aten.cat, ListOf(CallFunction(aten.mm, Arg(), Arg())), Arg()),
|
| 469 |
+
)
|
| 470 |
+
def cat_mm(match, inputs, dim):
|
| 471 |
+
return cat_tuned_op(match, inputs, dim, op=L[aten.mm], shape_of=shape_of_mm)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
@register_lowering_pattern(
|
| 475 |
+
CallFunction(
|
| 476 |
+
aten.cat, ListOf(CallFunction(aten.addmm, Arg(), Arg(), Arg())), Arg()
|
| 477 |
+
),
|
| 478 |
+
)
|
| 479 |
+
def cat_addmm(match, inputs, dim):
|
| 480 |
+
def shape_of(bias, a, b):
|
| 481 |
+
m, _ = a.get_size()
|
| 482 |
+
_, n = b.get_size()
|
| 483 |
+
return [m, n]
|
| 484 |
+
|
| 485 |
+
return cat_tuned_op(match, inputs, dim, op=L[aten.addmm], shape_of=shape_of)
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def cat_tuned_op(match, inputs, dim, *, op, shape_of):
|
| 489 |
+
"""
|
| 490 |
+
Memory planning to remove cat. We can't use the stock memory
|
| 491 |
+
planner since autotuning matmuls needs to know the output layout.
|
| 492 |
+
"""
|
| 493 |
+
if len(inputs) == 1:
|
| 494 |
+
return op(*inputs[0])
|
| 495 |
+
|
| 496 |
+
# TODO(jansel): rewrite this as a bmm?
|
| 497 |
+
if dim < 0:
|
| 498 |
+
dim += len(shape_of(*inputs[0]))
|
| 499 |
+
assert dim in (0, 1)
|
| 500 |
+
notdim = 1 - dim
|
| 501 |
+
|
| 502 |
+
new_size: Optional[Union[List[Expr], List[int]]] = None
|
| 503 |
+
offsets_start = []
|
| 504 |
+
offsets_end = []
|
| 505 |
+
|
| 506 |
+
# compute output sizes
|
| 507 |
+
for i in range(len(inputs)):
|
| 508 |
+
shape = shape_of(*inputs[i])
|
| 509 |
+
if new_size is None:
|
| 510 |
+
new_size = shape
|
| 511 |
+
else:
|
| 512 |
+
new_size[notdim] = V.graph.sizevars.guard_equals( # type: ignore[call-overload]
|
| 513 |
+
shape[notdim], new_size[notdim]
|
| 514 |
+
)
|
| 515 |
+
new_size[dim] += shape[dim]
|
| 516 |
+
offsets_start.append(new_size[dim] - shape[dim])
|
| 517 |
+
offsets_end.append(new_size[dim])
|
| 518 |
+
|
| 519 |
+
assert new_size is not None
|
| 520 |
+
dtype = functools.reduce(
|
| 521 |
+
torch.promote_types,
|
| 522 |
+
[x.get_dtype() for x in itertools.chain.from_iterable(inputs)],
|
| 523 |
+
)
|
| 524 |
+
device = inputs[0][0].get_device()
|
| 525 |
+
kernel = ir.ConcatKernel(
|
| 526 |
+
name=None,
|
| 527 |
+
layout=ir.FixedLayout(device, dtype, new_size),
|
| 528 |
+
inputs=[],
|
| 529 |
+
)
|
| 530 |
+
kernel_tensor = ir.TensorBox.create(kernel)
|
| 531 |
+
|
| 532 |
+
for i in range(len(inputs)):
|
| 533 |
+
dst = ir.SliceView.create(kernel_tensor, dim, offsets_start[i], offsets_end[i])
|
| 534 |
+
src = op(*inputs[i], layout=dst.get_layout()).data.data
|
| 535 |
+
assert isinstance(src, (ir.ExternKernelOut, ir.TemplateBuffer))
|
| 536 |
+
src.layout = ir.NonOwningLayout(dst)
|
| 537 |
+
kernel.inputs.append(src)
|
| 538 |
+
|
| 539 |
+
kernel.name = V.graph.register_buffer(kernel)
|
| 540 |
+
kernel.inputs = ir.ConcatKernel.unwrap_storage(kernel.inputs)
|
| 541 |
+
V.graph.register_operation(kernel)
|
| 542 |
+
return kernel_tensor
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
_cat_1 = CallFunction(aten.cat, Arg(), 1, _users=2)
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
@register_lowering_pattern(
|
| 549 |
+
CallFunction(
|
| 550 |
+
aten.cat,
|
| 551 |
+
[
|
| 552 |
+
_cat_1,
|
| 553 |
+
CallFunction(
|
| 554 |
+
aten.slice,
|
| 555 |
+
_cat_1,
|
| 556 |
+
1,
|
| 557 |
+
0,
|
| 558 |
+
KeywordArg("size"),
|
| 559 |
+
),
|
| 560 |
+
],
|
| 561 |
+
1,
|
| 562 |
+
)
|
| 563 |
+
)
|
| 564 |
+
def cat_slice_cat(match, cat_input, size, dim=1):
|
| 565 |
+
"""
|
| 566 |
+
This is an example of a more complex pattern where cat_1 is used
|
| 567 |
+
multiple times inside the pattern. We fold 2 calls to cat into one.
|
| 568 |
+
|
| 569 |
+
Matches:
|
| 570 |
+
cat_1: f32[1024, 4077] = torch.ops.aten.cat.default([add_26, primals_217], 1)
|
| 571 |
+
slice_1: f32[1024, 4077] = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
|
| 572 |
+
slice_2: f32[1024, 19] = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19)
|
| 573 |
+
cat_2: f32[1024, 4096] = torch.ops.aten.cat.default([cat_1, slice_2], 1)
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
Rewrite to:
|
| 577 |
+
slice_2 = torch.ops.aten.slice.Tensor(add_26, 1, 0, 19)
|
| 578 |
+
cat_2 = torch.ops.aten.cat.default([add_26, primals_217, slice2], 1)
|
| 579 |
+
"""
|
| 580 |
+
first, *rest = cat_input
|
| 581 |
+
# Optimization is optional, because we can just not fold the cat
|
| 582 |
+
# size should be within first.get_size()[dim] such that the optimization is valid.
|
| 583 |
+
# For negative `end`, we currently fallback to not optimizing.
|
| 584 |
+
if size >= 0 and V.graph.sizevars.statically_known_leq(size, first.get_size()[dim]):
|
| 585 |
+
# fold 2 cats into 1 cat
|
| 586 |
+
return L[aten.cat](
|
| 587 |
+
[
|
| 588 |
+
first,
|
| 589 |
+
*rest,
|
| 590 |
+
L[aten.slice](first, dim, 0, size),
|
| 591 |
+
],
|
| 592 |
+
dim,
|
| 593 |
+
)
|
| 594 |
+
else:
|
| 595 |
+
# don't expect to hit this case, just fall back
|
| 596 |
+
tmp = L[aten.cat](cat_input, dim)
|
| 597 |
+
return L[aten.cat](
|
| 598 |
+
[
|
| 599 |
+
tmp,
|
| 600 |
+
L[aten.slice](tmp, dim, 0, size),
|
| 601 |
+
],
|
| 602 |
+
dim,
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def is_valid_splitwithsizes_cat(match):
|
| 607 |
+
split_nodes = filter_nodes(match.nodes, aten.split_with_sizes)
|
| 608 |
+
cat_nodes = filter_nodes(match.nodes, aten.cat)
|
| 609 |
+
get_item_nodes = filter_nodes(match.nodes, operator.getitem)
|
| 610 |
+
if len(split_nodes) != 1 or len(cat_nodes) != 1:
|
| 611 |
+
return False
|
| 612 |
+
split_node, cat_node = split_nodes[0], cat_nodes[0]
|
| 613 |
+
# The dim of split and cat should match for passthrough
|
| 614 |
+
if get_arg_value(split_node, 2, "dim") != get_arg_value(cat_node, 1, "dim"):
|
| 615 |
+
return False
|
| 616 |
+
get_item_args = {
|
| 617 |
+
get_arg_value(get_item_node, 1) for get_item_node in get_item_nodes
|
| 618 |
+
}
|
| 619 |
+
assert None not in get_item_args
|
| 620 |
+
split_sizes = get_arg_value(split_node, 1, "split_sizes")
|
| 621 |
+
# All parts of split should be included in the cat
|
| 622 |
+
if get_item_args != set(range(len(split_sizes))):
|
| 623 |
+
return False
|
| 624 |
+
# The order of get_item_args should same with cat_node used.
|
| 625 |
+
# For example, if the split_node like split_with_sizes(input, [2, 2, 3], 1),
|
| 626 |
+
# the cat node should be like cat([get_item(0), get_item(1), get_item(2)], 1).
|
| 627 |
+
cat_items_args_order = [
|
| 628 |
+
get_arg_value(item_node, 1) for item_node in get_arg_value(cat_node, 0)
|
| 629 |
+
]
|
| 630 |
+
if cat_items_args_order != list(range(len(split_sizes))):
|
| 631 |
+
return False
|
| 632 |
+
|
| 633 |
+
return True
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
def same_meta(node1: torch.fx.Node, node2: torch.fx.Node):
|
| 637 |
+
"""True if two nodes have the same metadata"""
|
| 638 |
+
val1 = node1.meta.get("val")
|
| 639 |
+
val2 = node2.meta.get("val")
|
| 640 |
+
return (
|
| 641 |
+
val1 is not None
|
| 642 |
+
and val2 is not None
|
| 643 |
+
and statically_known_true(sym_eq(val1.size(), val2.size()))
|
| 644 |
+
and val1.layout == val2.layout
|
| 645 |
+
and val1.dtype == val2.dtype
|
| 646 |
+
and val1.device == val2.device
|
| 647 |
+
and (
|
| 648 |
+
val1.layout != torch.strided
|
| 649 |
+
or statically_known_true(sym_eq(val1.stride(), val2.stride()))
|
| 650 |
+
)
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
noop_registry: Dict[Any, Any] = {}
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
def register_noop_decomp(targets, nop_arg=0):
|
| 658 |
+
def register_fun(cond):
|
| 659 |
+
register_decomposition(targets, registry=noop_registry, unsafe=True)(
|
| 660 |
+
(cond, nop_arg) # type: ignore[arg-type]
|
| 661 |
+
)
|
| 662 |
+
return cond
|
| 663 |
+
|
| 664 |
+
return register_fun
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
@register_noop_decomp(aten.slice)
|
| 668 |
+
def slice_noop(self, dim=0, start=None, end=None, step=1):
|
| 669 |
+
if start is None or end is None:
|
| 670 |
+
return False
|
| 671 |
+
if (
|
| 672 |
+
statically_known_true(sym_eq(start, 0))
|
| 673 |
+
and statically_known_true(end >= 2**63 - 1)
|
| 674 |
+
and statically_known_true(sym_eq(step, 1))
|
| 675 |
+
):
|
| 676 |
+
return True
|
| 677 |
+
return False
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
@register_noop_decomp(aten.slice_scatter, 1)
|
| 681 |
+
def slice_scatter_noop(self, src, dim=0, start=None, end=None, step=1):
|
| 682 |
+
if start is None:
|
| 683 |
+
start = 0
|
| 684 |
+
if end is None:
|
| 685 |
+
end = 2**63 - 1
|
| 686 |
+
if start == 0 and end >= 2**63 - 1 and step == 1:
|
| 687 |
+
return True
|
| 688 |
+
return False
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
@register_noop_decomp(aten.repeat)
|
| 692 |
+
def repeat_noop(self, repeats):
|
| 693 |
+
return all(r == 1 for r in repeats)
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
@register_noop_decomp(aten.constant_pad_nd)
|
| 697 |
+
def constant_pad_nd(x, padding, fill_value=0):
|
| 698 |
+
return all(p == 0 for p in padding)
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
@register_noop_decomp(torch.ops.prims.convert_element_type)
|
| 702 |
+
def convert_element_type_noop(x, dtype: torch.dtype):
|
| 703 |
+
return x.dtype == dtype
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
@register_noop_decomp(torch.ops.prims.device_put)
|
| 707 |
+
def device_put_noop(x, device):
|
| 708 |
+
return x.device == decode_device(device)
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
@register_noop_decomp([aten.ceil, aten.floor, aten.round, aten.trunc])
|
| 712 |
+
def int_noop(x):
|
| 713 |
+
return is_integer_dtype(x.dtype)
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
@register_noop_decomp([aten.pow])
|
| 717 |
+
def pow_noop(a, b):
|
| 718 |
+
return isinstance(b, int) and b == 1
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
@register_noop_decomp([aten.cat], lambda args: args[0][0])
|
| 722 |
+
def cat_noop(inputs, dim=0):
|
| 723 |
+
return len(inputs) == 1
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
@register_noop_decomp(aten.view)
|
| 727 |
+
def view_noop(arg, size):
|
| 728 |
+
return arg.shape == size
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
# Note, we also always have a check for identical metadata, which is why these
|
| 732 |
+
# are safe
|
| 733 |
+
@register_noop_decomp([aten.copy], nop_arg=1)
|
| 734 |
+
@register_noop_decomp([aten.alias, aten.clone])
|
| 735 |
+
def true_noop(*args, **kwargs):
|
| 736 |
+
return True
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
def remove_noop_ops(graph: torch.fx.Graph):
|
| 740 |
+
"""
|
| 741 |
+
Removes both operations that are essentially aten.clone and operations that are essentially aten.alias from the graph.
|
| 742 |
+
"""
|
| 743 |
+
inputs = set()
|
| 744 |
+
input_storages = set()
|
| 745 |
+
output_storages = set()
|
| 746 |
+
|
| 747 |
+
for node in graph.find_nodes(op="placeholder"):
|
| 748 |
+
inputs.add(node)
|
| 749 |
+
input_storages.add(get_node_storage(node))
|
| 750 |
+
|
| 751 |
+
output_node = next(iter(reversed(graph.nodes)))
|
| 752 |
+
assert output_node.op == "output"
|
| 753 |
+
outputs = output_node.args[0]
|
| 754 |
+
if not isinstance(outputs, (list, tuple)):
|
| 755 |
+
# nested subgraphs can have singleton outputs
|
| 756 |
+
outputs = (outputs,)
|
| 757 |
+
for out in outputs:
|
| 758 |
+
if isinstance(out, torch.fx.Node):
|
| 759 |
+
output_storages.add(get_node_storage(out))
|
| 760 |
+
|
| 761 |
+
for node in graph.nodes:
|
| 762 |
+
if node.target in noop_registry:
|
| 763 |
+
cond, src_index = noop_registry[node.target]
|
| 764 |
+
if isinstance(src_index, int):
|
| 765 |
+
src = node.args[src_index]
|
| 766 |
+
else:
|
| 767 |
+
src = src_index(node.args)
|
| 768 |
+
if not isinstance(src, torch.fx.Node):
|
| 769 |
+
continue
|
| 770 |
+
# Don't introduce new aliasing between inputs and outputs.
|
| 771 |
+
# See fx_passes/README.md for a discussion of why this is
|
| 772 |
+
# necessary.
|
| 773 |
+
node_storage = get_node_storage(node)
|
| 774 |
+
src_storage = get_node_storage(src)
|
| 775 |
+
node_is_view = node_storage == src_storage
|
| 776 |
+
if (
|
| 777 |
+
not node_is_view
|
| 778 |
+
and node_storage in output_storages
|
| 779 |
+
and (src_storage in input_storages or src_storage in output_storages)
|
| 780 |
+
):
|
| 781 |
+
continue
|
| 782 |
+
|
| 783 |
+
# Even if input and outputs are expected to alias,
|
| 784 |
+
# don't make "node is src" True
|
| 785 |
+
if (
|
| 786 |
+
node_is_view
|
| 787 |
+
and node in output_node.args
|
| 788 |
+
and (src in inputs or src in output_node.args)
|
| 789 |
+
):
|
| 790 |
+
continue
|
| 791 |
+
|
| 792 |
+
is_valid, args, kwargs = get_fake_args_kwargs(node)
|
| 793 |
+
if not is_valid:
|
| 794 |
+
continue
|
| 795 |
+
if same_meta(node, src) and cond(*args, **kwargs):
|
| 796 |
+
node.replace_all_uses_with(src)
|
| 797 |
+
graph.erase_node(node)
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
def decompose_auto_functionalized(graph):
|
| 801 |
+
"""Decomposes auto_functionalized and triton_kernel_wrapper_functional
|
| 802 |
+
nodes into clones and the underlying mutation node.
|
| 803 |
+
|
| 804 |
+
We assume that the reinplacing pass runs before this; the reinplacing pass
|
| 805 |
+
tells us (via rewriting the arguments or .meta to those nodes) which
|
| 806 |
+
Tensors we should clone and which Tensors are safe to reinplace.
|
| 807 |
+
"""
|
| 808 |
+
graph_pass = PatternMatcherPass()
|
| 809 |
+
|
| 810 |
+
@register_graph_pattern(
|
| 811 |
+
CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized),
|
| 812 |
+
pass_dict=graph_pass,
|
| 813 |
+
)
|
| 814 |
+
def _(match: Match, *args, **kwargs):
|
| 815 |
+
from torch._higher_order_ops.auto_functionalize import auto_functionalized_dense
|
| 816 |
+
|
| 817 |
+
only_clone_these_tensors = tuple(
|
| 818 |
+
match.nodes[0].meta.get("only_clone_these_tensors", [])
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
flat_args, spec = pytree.tree_flatten((args, kwargs))
|
| 822 |
+
|
| 823 |
+
# NB: we combine (args, kwargs) into flat args for replacing.
|
| 824 |
+
# This is replace_by_example uses make_fx which does not support
|
| 825 |
+
# tracing a function with kwargs.
|
| 826 |
+
def decomp(*flat_args):
|
| 827 |
+
args, kwargs = pytree.tree_unflatten(flat_args, spec)
|
| 828 |
+
return auto_functionalized_dense(*args, only_clone_these_tensors, **kwargs)
|
| 829 |
+
|
| 830 |
+
match.replace_by_example(decomp, flat_args, run_functional_passes=False)
|
| 831 |
+
|
| 832 |
+
@register_graph_pattern(
|
| 833 |
+
CallFunctionVarArgs(torch.ops.higher_order.triton_kernel_wrapper_functional),
|
| 834 |
+
pass_dict=graph_pass,
|
| 835 |
+
)
|
| 836 |
+
def _(match: Match, *args, **kwargs):
|
| 837 |
+
from torch._higher_order_ops.triton_kernel_wrap import (
|
| 838 |
+
triton_kernel_wrapper_functional_dense,
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
flat_args, spec = pytree.tree_flatten((args, kwargs))
|
| 842 |
+
|
| 843 |
+
# NB: we combine (args, kwargs) into flat args for replacing.
|
| 844 |
+
# This is replace_by_example uses make_fx which does not support
|
| 845 |
+
# tracing a function with kwargs.
|
| 846 |
+
def decomp(*flat_args):
|
| 847 |
+
args, kwargs = pytree.tree_unflatten(flat_args, spec)
|
| 848 |
+
return (triton_kernel_wrapper_functional_dense(*args, **kwargs),)
|
| 849 |
+
|
| 850 |
+
match.replace_by_example(decomp, flat_args, run_functional_passes=False)
|
| 851 |
+
|
| 852 |
+
@register_graph_pattern(
|
| 853 |
+
CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized_v2),
|
| 854 |
+
pass_dict=graph_pass,
|
| 855 |
+
)
|
| 856 |
+
def _(match: Match, *args, **kwargs):
|
| 857 |
+
from torch._higher_order_ops.auto_functionalize import (
|
| 858 |
+
auto_functionalized_v2_dense,
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
only_clone_these_bases = tuple(
|
| 862 |
+
match.nodes[0].meta.get("only_clone_these_tensors", [])
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
flat_args, spec = pytree.tree_flatten((args, kwargs))
|
| 866 |
+
|
| 867 |
+
# NB: we combine (args, kwargs) into flat args for replacing.
|
| 868 |
+
# This is replace_by_example uses make_fx which does not support
|
| 869 |
+
# tracing a function with kwargs.
|
| 870 |
+
def decomp(*flat_args):
|
| 871 |
+
args, kwargs = pytree.tree_unflatten(flat_args, spec)
|
| 872 |
+
return auto_functionalized_v2_dense(*args, only_clone_these_bases, **kwargs)
|
| 873 |
+
|
| 874 |
+
match.replace_by_example(decomp, flat_args, run_functional_passes=False)
|
| 875 |
+
|
| 876 |
+
graph_pass.apply(graph)
|
| 877 |
+
|
| 878 |
+
for node in graph.find_nodes(
|
| 879 |
+
op="call_function", target=torch.ops.higher_order.auto_functionalized
|
| 880 |
+
):
|
| 881 |
+
raise AssertionError("auto_functionalized was not removed")
|
| 882 |
+
|
| 883 |
+
for node in graph.find_nodes(
|
| 884 |
+
op="call_function", target=torch.ops.higher_order.auto_functionalized_v2
|
| 885 |
+
):
|
| 886 |
+
raise AssertionError("auto_functionalized_v2 was not removed")
|
| 887 |
+
|
| 888 |
+
for node in graph.find_nodes(
|
| 889 |
+
op="call_function",
|
| 890 |
+
target=torch.ops.higher_order.triton_kernel_wrapper_functional,
|
| 891 |
+
):
|
| 892 |
+
raise AssertionError("triton_kernel_wrapper_functional was not removed")
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
@register_lowering_pattern(
|
| 896 |
+
CallFunction(
|
| 897 |
+
aten.cat,
|
| 898 |
+
ListOf(
|
| 899 |
+
CallFunction(
|
| 900 |
+
operator.getitem,
|
| 901 |
+
CallFunction(
|
| 902 |
+
aten.split_with_sizes,
|
| 903 |
+
KeywordArg("input_"),
|
| 904 |
+
Ignored(),
|
| 905 |
+
Ignored(),
|
| 906 |
+
_users=MULTIPLE,
|
| 907 |
+
),
|
| 908 |
+
Ignored(),
|
| 909 |
+
),
|
| 910 |
+
),
|
| 911 |
+
Ignored(),
|
| 912 |
+
),
|
| 913 |
+
pass_number=2,
|
| 914 |
+
extra_check=is_valid_splitwithsizes_cat,
|
| 915 |
+
)
|
| 916 |
+
def splitwithsizes_cat_replace(match, input_):
|
| 917 |
+
return input_
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
def is_valid_cat_splitwithsizes(match):
|
| 921 |
+
cat_nodes = filter_nodes(match.nodes, aten.cat)
|
| 922 |
+
split_nodes = filter_nodes(match.nodes, aten.split_with_sizes)
|
| 923 |
+
if len(split_nodes) != 1 or len(cat_nodes) != 1:
|
| 924 |
+
return False
|
| 925 |
+
split_node, cat_node = split_nodes[0], cat_nodes[0]
|
| 926 |
+
|
| 927 |
+
# the cat node has other users: can't eliminate
|
| 928 |
+
if len(cat_node.users) > 1:
|
| 929 |
+
return False
|
| 930 |
+
|
| 931 |
+
# the dim of the cat and split should match
|
| 932 |
+
dim = get_arg_value(split_node, 2, "dim")
|
| 933 |
+
if dim != get_arg_value(cat_node, 1, "dim"):
|
| 934 |
+
return False
|
| 935 |
+
|
| 936 |
+
cat_inputs = list(get_arg_value(cat_node, 0))
|
| 937 |
+
split_sizes = get_arg_value(split_node, 1, "split_sizes")
|
| 938 |
+
# the number of input tensors in cat and the
|
| 939 |
+
# length of the split sizes should match
|
| 940 |
+
if len(cat_inputs) != len(split_sizes):
|
| 941 |
+
return False
|
| 942 |
+
|
| 943 |
+
for cat_input, split_size in zip(cat_inputs, split_sizes):
|
| 944 |
+
# each cat input tensor's size along dim
|
| 945 |
+
# should match the corresponding split size
|
| 946 |
+
if "val" not in cat_input.meta:
|
| 947 |
+
return False
|
| 948 |
+
cat_input_size = cat_input.meta["val"].size(dim)
|
| 949 |
+
if cat_input_size != split_size:
|
| 950 |
+
return False
|
| 951 |
+
|
| 952 |
+
return True
|
| 953 |
+
|
| 954 |
+
|
| 955 |
+
@register_lowering_pattern(
|
| 956 |
+
CallFunction(
|
| 957 |
+
aten.split_with_sizes,
|
| 958 |
+
CallFunction(
|
| 959 |
+
aten.cat,
|
| 960 |
+
KeywordArg("input_"),
|
| 961 |
+
Ignored(),
|
| 962 |
+
_users=MULTIPLE,
|
| 963 |
+
),
|
| 964 |
+
Ignored(),
|
| 965 |
+
Ignored(),
|
| 966 |
+
),
|
| 967 |
+
pass_number=2,
|
| 968 |
+
extra_check=is_valid_cat_splitwithsizes,
|
| 969 |
+
)
|
| 970 |
+
def cat_splitwithsizes_replace(match, input_):
|
| 971 |
+
return input_
|
| 972 |
+
|
| 973 |
+
|
| 974 |
+
def view_to_reshape(gm):
|
| 975 |
+
"""
|
| 976 |
+
Replace view ops in the GraphModule to reshape ops.
|
| 977 |
+
"""
|
| 978 |
+
for nd in gm.graph.find_nodes(
|
| 979 |
+
op="call_function", target=torch.ops.aten.view.default
|
| 980 |
+
):
|
| 981 |
+
nd.target = torch.ops.aten.reshape.default
|
| 982 |
+
|
| 983 |
+
|
| 984 |
+
def should_prefer_unfused_addmm(match):
|
| 985 |
+
inp = match.kwargs["inp"]
|
| 986 |
+
if not inp.meta["val"].is_cuda:
|
| 987 |
+
return False
|
| 988 |
+
|
| 989 |
+
output = match.output_node()
|
| 990 |
+
return all(is_pointwise_use(use) for use in output.users)
|
| 991 |
+
|
| 992 |
+
|
| 993 |
+
@register_graph_pattern(
|
| 994 |
+
CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()),
|
| 995 |
+
pass_dict=pass_patterns[2],
|
| 996 |
+
extra_check=should_prefer_unfused_addmm,
|
| 997 |
+
)
|
| 998 |
+
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp):
|
| 999 |
+
def repl(inp, x1, x2):
|
| 1000 |
+
return x1 @ x2 + inp
|
| 1001 |
+
|
| 1002 |
+
match.replace_by_example(repl, [inp, mat1, mat2])
|
| 1003 |
+
|
| 1004 |
+
|
| 1005 |
+
def is_valid_addmm_fusion(match):
|
| 1006 |
+
mat1, mat2 = match.args
|
| 1007 |
+
inp = match.kwargs["inp"]
|
| 1008 |
+
|
| 1009 |
+
if not (
|
| 1010 |
+
isinstance(inp, torch.fx.Node) and isinstance(inp.meta["val"], torch.Tensor)
|
| 1011 |
+
):
|
| 1012 |
+
return False # Input is a number
|
| 1013 |
+
|
| 1014 |
+
in_shape = inp.meta["val"].shape
|
| 1015 |
+
mm_shape = mat1.meta["val"].shape[0], mat2.meta["val"].shape[1]
|
| 1016 |
+
matched = is_expandable_to(in_shape, mm_shape)
|
| 1017 |
+
if not matched:
|
| 1018 |
+
return False # Shape mismatch
|
| 1019 |
+
|
| 1020 |
+
return not should_prefer_unfused_addmm(match)
|
| 1021 |
+
|
| 1022 |
+
|
| 1023 |
+
@register_graph_pattern(
|
| 1024 |
+
CallFunction(
|
| 1025 |
+
aten.add,
|
| 1026 |
+
CallFunction(aten.mm, Arg(), Arg()),
|
| 1027 |
+
KeywordArg("inp"),
|
| 1028 |
+
),
|
| 1029 |
+
pass_dict=pass_patterns[2],
|
| 1030 |
+
extra_check=is_valid_addmm_fusion,
|
| 1031 |
+
)
|
| 1032 |
+
@register_graph_pattern(
|
| 1033 |
+
CallFunction(
|
| 1034 |
+
aten.add,
|
| 1035 |
+
KeywordArg("inp"),
|
| 1036 |
+
CallFunction(aten.mm, Arg(), Arg()),
|
| 1037 |
+
),
|
| 1038 |
+
pass_dict=pass_patterns[2],
|
| 1039 |
+
extra_check=is_valid_addmm_fusion,
|
| 1040 |
+
)
|
| 1041 |
+
def addmm(match, mat1, mat2, *, inp):
|
| 1042 |
+
def repl(inp, mat1, mat2):
|
| 1043 |
+
return aten.addmm(inp, mat1, mat2)
|
| 1044 |
+
|
| 1045 |
+
match.replace_by_example(repl, [inp, mat1, mat2])
|
| 1046 |
+
|
| 1047 |
+
|
| 1048 |
+
def check_shape_cuda_and_fused_int_mm_mul_enabled(match):
|
| 1049 |
+
return (
|
| 1050 |
+
config.force_fuse_int_mm_with_mul
|
| 1051 |
+
and len(getattr(match.args[2].meta.get("val"), "shape", [])) == 2
|
| 1052 |
+
and getattr(match.args[2].meta.get("val"), "is_cuda", False)
|
| 1053 |
+
)
|
| 1054 |
+
|
| 1055 |
+
|
| 1056 |
+
@register_lowering_pattern(
|
| 1057 |
+
CallFunction(
|
| 1058 |
+
prims.convert_element_type.default,
|
| 1059 |
+
CallFunction(
|
| 1060 |
+
aten.mul,
|
| 1061 |
+
CallFunction(
|
| 1062 |
+
aten._int_mm,
|
| 1063 |
+
Arg(),
|
| 1064 |
+
Arg(),
|
| 1065 |
+
),
|
| 1066 |
+
Arg(),
|
| 1067 |
+
),
|
| 1068 |
+
Arg(),
|
| 1069 |
+
),
|
| 1070 |
+
check_shape_cuda_and_fused_int_mm_mul_enabled,
|
| 1071 |
+
)
|
| 1072 |
+
@register_lowering_pattern(
|
| 1073 |
+
CallFunction(
|
| 1074 |
+
aten.mul,
|
| 1075 |
+
CallFunction(
|
| 1076 |
+
aten._int_mm,
|
| 1077 |
+
Arg(),
|
| 1078 |
+
Arg(),
|
| 1079 |
+
),
|
| 1080 |
+
Arg(),
|
| 1081 |
+
),
|
| 1082 |
+
check_shape_cuda_and_fused_int_mm_mul_enabled,
|
| 1083 |
+
)
|
| 1084 |
+
def fused_int_mm_mul(match: Match, mat1, mat2, mat3, out_dtype=None):
|
| 1085 |
+
return inductor.kernel.mm.tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype)
|
| 1086 |
+
|
| 1087 |
+
|
| 1088 |
+
def is_index_put_and_requires_h2d_sync_for_gpu_value(node):
|
| 1089 |
+
from torch.fx.operator_schemas import normalize_function
|
| 1090 |
+
|
| 1091 |
+
if node.target not in [
|
| 1092 |
+
torch.ops.aten.index_put.default,
|
| 1093 |
+
torch.ops.aten.index_put_.default,
|
| 1094 |
+
]:
|
| 1095 |
+
return False
|
| 1096 |
+
# Inductor falls back to aten.index_put_.
|
| 1097 |
+
# index_put_ will will call nonzero() and perform a H2D sync if
|
| 1098 |
+
# any of its indices are bool/byte tensors
|
| 1099 |
+
# However, it will short-circuit this H2D sync and run mask_fill_
|
| 1100 |
+
# if the value we are putting is a cpu scalar.
|
| 1101 |
+
# Therefore, when inductor sees an index_put_ with byte tensor indices,
|
| 1102 |
+
# it should *not* convert the cpu scalar value into a gpu tensor.
|
| 1103 |
+
args_, kwargs_ = normalize_function(node.target, node.args, node.kwargs) # type: ignore[misc]
|
| 1104 |
+
any_byte_bool_indices = False
|
| 1105 |
+
indices = args_[1]
|
| 1106 |
+
for i in indices:
|
| 1107 |
+
if i is not None and i.meta["val"].dtype in [torch.bool, torch.int8]:
|
| 1108 |
+
any_byte_bool_indices = True
|
| 1109 |
+
|
| 1110 |
+
val = args_[2].meta["val"]
|
| 1111 |
+
val_is_cpu_scalar = val.device.type == "cpu" and val.numel() == 1
|
| 1112 |
+
# If both these conditions hold, then converting the val
|
| 1113 |
+
# to a gpu tensor will incur a H2D sync when inductor calls aten.index_put_
|
| 1114 |
+
return any_byte_bool_indices and val_is_cpu_scalar
|
| 1115 |
+
|
| 1116 |
+
|
| 1117 |
+
class ConstructorMoverPass:
|
| 1118 |
+
def __init__(self, target: str, allow_outputs: bool = False) -> None:
|
| 1119 |
+
"""
|
| 1120 |
+
Move constructors from cpu to the target_device.
|
| 1121 |
+
|
| 1122 |
+
Sweeps through the module, looking for constructor nodes that can be moved
|
| 1123 |
+
to the target_device.
|
| 1124 |
+
|
| 1125 |
+
A constructor node can be moved to the target_device iff all of its users
|
| 1126 |
+
can also be moved (tested by cannot_be_moved). Otherwise, all dependent
|
| 1127 |
+
constructor nodes won't be moved.
|
| 1128 |
+
|
| 1129 |
+
- target: target device type
|
| 1130 |
+
- allow_outputs: allow outputs to be moved
|
| 1131 |
+
"""
|
| 1132 |
+
|
| 1133 |
+
self.target = target
|
| 1134 |
+
self.allow_outputs = allow_outputs
|
| 1135 |
+
|
| 1136 |
+
assert isinstance(target, str), (
|
| 1137 |
+
"target should be a string representing the device type. "
|
| 1138 |
+
f"Got: {type(target).__name__}"
|
| 1139 |
+
)
|
| 1140 |
+
|
| 1141 |
+
def allow_cpu_device(self, node: fx.Node) -> bool:
|
| 1142 |
+
"""
|
| 1143 |
+
Returns whether a node that returns a tensor on the target device may have
|
| 1144 |
+
cpu tensors as input.
|
| 1145 |
+
"""
|
| 1146 |
+
return node.target in (
|
| 1147 |
+
torch.ops.aten.index.Tensor,
|
| 1148 |
+
torch.ops.aten.index_put.default,
|
| 1149 |
+
torch.ops.aten.index_put_.default,
|
| 1150 |
+
torch.ops.aten.copy.default,
|
| 1151 |
+
torch.ops.aten.copy_.default,
|
| 1152 |
+
torch.ops.aten.slice_scatter.default,
|
| 1153 |
+
)
|
| 1154 |
+
|
| 1155 |
+
def cannot_be_moved(self, node: fx.Node) -> bool:
|
| 1156 |
+
"""
|
| 1157 |
+
Returns whether a node can be moved to the target device.
|
| 1158 |
+
|
| 1159 |
+
If this function returns False, it means that this node and all of its users
|
| 1160 |
+
won't be moved into the target device.
|
| 1161 |
+
"""
|
| 1162 |
+
if node.target == "output":
|
| 1163 |
+
return not self.allow_outputs
|
| 1164 |
+
|
| 1165 |
+
if not (
|
| 1166 |
+
isinstance(node.target, torch._ops.OpOverload)
|
| 1167 |
+
and node.target.namespace in ("prims", "aten")
|
| 1168 |
+
):
|
| 1169 |
+
return True
|
| 1170 |
+
if is_index_put_and_requires_h2d_sync_for_gpu_value(node):
|
| 1171 |
+
return True
|
| 1172 |
+
|
| 1173 |
+
return False
|
| 1174 |
+
|
| 1175 |
+
def get_node_device(self, node: fx.Node) -> Optional[torch.device]:
|
| 1176 |
+
"""
|
| 1177 |
+
Get the device of a node.
|
| 1178 |
+
"""
|
| 1179 |
+
ten = node.meta.get("val")
|
| 1180 |
+
return None if not isinstance(ten, torch.Tensor) else ten.device
|
| 1181 |
+
|
| 1182 |
+
def get_cpu_indeg_count(self, graph: fx.Graph) -> Dict[fx.Node, int]:
|
| 1183 |
+
"""
|
| 1184 |
+
Get the number of cpu inputs to a node
|
| 1185 |
+
"""
|
| 1186 |
+
cpu_indeg: Dict[fx.Node, int] = Counter()
|
| 1187 |
+
|
| 1188 |
+
for node in graph.nodes:
|
| 1189 |
+
cpu_count = 0
|
| 1190 |
+
|
| 1191 |
+
def add_cpu_inp(node):
|
| 1192 |
+
nonlocal cpu_count
|
| 1193 |
+
device = self.get_node_device(node)
|
| 1194 |
+
cpu_count += device is not None and device.type == "cpu"
|
| 1195 |
+
|
| 1196 |
+
pytree.tree_map_only(fx.Node, add_cpu_inp, (node.args, node.kwargs))
|
| 1197 |
+
|
| 1198 |
+
if cpu_count:
|
| 1199 |
+
cpu_indeg[node] = cpu_count
|
| 1200 |
+
|
| 1201 |
+
return cpu_indeg
|
| 1202 |
+
|
| 1203 |
+
def __call__(self, graph: fx.Graph) -> None:
|
| 1204 |
+
target_devices = set()
|
| 1205 |
+
constructors = []
|
| 1206 |
+
|
| 1207 |
+
for node in graph.nodes:
|
| 1208 |
+
device = self.get_node_device(node)
|
| 1209 |
+
if device and device.type == self.target:
|
| 1210 |
+
target_devices.add(device)
|
| 1211 |
+
|
| 1212 |
+
if not (
|
| 1213 |
+
isinstance(node.target, torch._ops.OpOverload)
|
| 1214 |
+
and node.target.namespace in ("prims", "aten")
|
| 1215 |
+
):
|
| 1216 |
+
continue
|
| 1217 |
+
|
| 1218 |
+
if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target):
|
| 1219 |
+
continue
|
| 1220 |
+
|
| 1221 |
+
if not node.kwargs.get("device") == torch.device("cpu"):
|
| 1222 |
+
continue
|
| 1223 |
+
|
| 1224 |
+
constructors.append(node)
|
| 1225 |
+
|
| 1226 |
+
# not handling multiple target devices initially
|
| 1227 |
+
if not constructors or len(target_devices) != 1:
|
| 1228 |
+
return
|
| 1229 |
+
|
| 1230 |
+
movable_constructors = self.find_movable_constructors(graph, constructors)
|
| 1231 |
+
|
| 1232 |
+
for node in movable_constructors:
|
| 1233 |
+
kwargs = node.kwargs.copy()
|
| 1234 |
+
kwargs["device"] = next(iter(target_devices))
|
| 1235 |
+
node.kwargs = kwargs
|
| 1236 |
+
|
| 1237 |
+
def find_movable_constructors(
|
| 1238 |
+
self, graph: fx.Graph, constructors: List[fx.Node]
|
| 1239 |
+
) -> Set[fx.Node]:
|
| 1240 |
+
"""
|
| 1241 |
+
Starting from the cpu constructors, iterate through the graph and test that all of their
|
| 1242 |
+
downstream uses can safely be moved to cpu.
|
| 1243 |
+
"""
|
| 1244 |
+
cpu_indeg: Dict[fx.Node, int] = self.get_cpu_indeg_count(graph)
|
| 1245 |
+
|
| 1246 |
+
# which constructors cannot be moved to gpu
|
| 1247 |
+
cannot_move_to_gpu: Set[fx.Node] = set()
|
| 1248 |
+
|
| 1249 |
+
# For any node in the graph, which constructors does it have a dependency on
|
| 1250 |
+
constructor_dependencies: Dict[fx.Node, Set[fx.Node]] = defaultdict(set)
|
| 1251 |
+
|
| 1252 |
+
# if a cpu node has a dependency on two different cpu constructors,
|
| 1253 |
+
# then if either constructor cannot be moved to gpu, the other cannot as well.
|
| 1254 |
+
# In this case any node with a dependency on one will have a dependency on the other
|
| 1255 |
+
equal_constructor_sets: Dict[fx.Node, Set[fx.Node]] = {
|
| 1256 |
+
c: {c} for c in constructors
|
| 1257 |
+
}
|
| 1258 |
+
|
| 1259 |
+
def make_dependencies_equivalent(
|
| 1260 |
+
set1: Set[fx.Node], set2: Set[fx.Node]
|
| 1261 |
+
) -> Set[fx.Node]:
|
| 1262 |
+
# could use union find but not worth complexity here
|
| 1263 |
+
set1.update(set2)
|
| 1264 |
+
for obj in set1:
|
| 1265 |
+
equal_constructor_sets[obj] = set1
|
| 1266 |
+
return set1
|
| 1267 |
+
|
| 1268 |
+
queue: List[fx.Node] = list(constructors)
|
| 1269 |
+
|
| 1270 |
+
for c in queue:
|
| 1271 |
+
constructor_dependencies[c].add(c)
|
| 1272 |
+
|
| 1273 |
+
while queue:
|
| 1274 |
+
node = queue.pop()
|
| 1275 |
+
dependencies = constructor_dependencies[node]
|
| 1276 |
+
|
| 1277 |
+
for user in node.users:
|
| 1278 |
+
if self.cannot_be_moved(user):
|
| 1279 |
+
cannot_move_to_gpu.update(dependencies)
|
| 1280 |
+
break
|
| 1281 |
+
|
| 1282 |
+
# this node was used on a op which takes in multiple devices and output a gpu
|
| 1283 |
+
# tensor. we can convert its cpu input to gpu without making further changes
|
| 1284 |
+
node_device = self.get_node_device(user)
|
| 1285 |
+
if (
|
| 1286 |
+
self.allow_cpu_device(user)
|
| 1287 |
+
and node_device
|
| 1288 |
+
and node_device.type == self.target
|
| 1289 |
+
):
|
| 1290 |
+
del cpu_indeg[user]
|
| 1291 |
+
else:
|
| 1292 |
+
# otherwise, we should continue look at its downstream uses
|
| 1293 |
+
cpu_indeg[user] -= 1
|
| 1294 |
+
if cpu_indeg[user] == 0:
|
| 1295 |
+
del cpu_indeg[user]
|
| 1296 |
+
queue.append(user)
|
| 1297 |
+
|
| 1298 |
+
unioned_set = make_dependencies_equivalent(
|
| 1299 |
+
dependencies, constructor_dependencies[user]
|
| 1300 |
+
)
|
| 1301 |
+
constructor_dependencies[user] = unioned_set
|
| 1302 |
+
|
| 1303 |
+
for node in cpu_indeg:
|
| 1304 |
+
if constructor_dependencies[node]:
|
| 1305 |
+
cannot_move_to_gpu.update(constructor_dependencies[node])
|
| 1306 |
+
|
| 1307 |
+
all_cannot_move_to_gpu = cannot_move_to_gpu.copy()
|
| 1308 |
+
for constructor in cannot_move_to_gpu:
|
| 1309 |
+
all_cannot_move_to_gpu.update(equal_constructor_sets[constructor])
|
| 1310 |
+
|
| 1311 |
+
return set(constructors) - all_cannot_move_to_gpu
|
| 1312 |
+
|
| 1313 |
+
|
| 1314 |
+
def move_constructors_to_gpu(graph: fx.Graph) -> None:
|
| 1315 |
+
"""
|
| 1316 |
+
Moves intermediary tensors which are constructed on the cpu to gpu when safe
|
| 1317 |
+
"""
|
| 1318 |
+
ConstructorMoverPass(get_gpu_type())(graph)
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pre_grad.py
ADDED
|
@@ -0,0 +1,800 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import copy
|
| 3 |
+
import itertools
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Dict, Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch._dynamo.utils import counters, detect_fake_mode, optimus_scuba_log
|
| 10 |
+
from torch._utils_internal import upload_graph
|
| 11 |
+
from torch.fx.experimental.optimization import (
|
| 12 |
+
matches_module_pattern,
|
| 13 |
+
replace_node_module,
|
| 14 |
+
)
|
| 15 |
+
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
| 16 |
+
from torch.fx.passes.shape_prop import ShapeProp
|
| 17 |
+
from torch.nn import functional as F
|
| 18 |
+
from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights
|
| 19 |
+
|
| 20 |
+
from .. import config
|
| 21 |
+
from ..fx_utils import matches_module_function_pattern
|
| 22 |
+
from ..pattern_matcher import (
|
| 23 |
+
init_once_fakemode,
|
| 24 |
+
PatternMatcherPass,
|
| 25 |
+
stable_topological_sort,
|
| 26 |
+
)
|
| 27 |
+
from ..utils import is_cpu_device, pass_execution_and_save
|
| 28 |
+
from .group_batch_fusion import group_batch_fusion_passes, PRE_GRAD_FUSIONS
|
| 29 |
+
from .misc_patterns import numpy_compat_normalization
|
| 30 |
+
from .split_cat import PRE_GRAD_PATTERNS
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
log = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
efficient_conv_bn_eval_pass = PatternMatcherPass(
|
| 36 |
+
pass_name="efficient_conv_bn_eval_pass"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
fuse_split_linear_add_pass = PatternMatcherPass(
|
| 40 |
+
pass_name="fuse_split_linear_add_pass",
|
| 41 |
+
)
|
| 42 |
+
fuse_chunk_squeeze_cat_pass = PatternMatcherPass(
|
| 43 |
+
pass_name="fuse_chunk_squeeze_cat_pass",
|
| 44 |
+
)
|
| 45 |
+
remove_reshape_pass = PatternMatcherPass(
|
| 46 |
+
pass_name="remove_reshape_pass",
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# based on predispatch aten IR
|
| 50 |
+
normalization_pass_aten = PatternMatcherPass()
|
| 51 |
+
merge_splits_pass_aten = PatternMatcherPass()
|
| 52 |
+
split_cat_pass_aten = PatternMatcherPass()
|
| 53 |
+
unbind_stack_pass_aten = PatternMatcherPass()
|
| 54 |
+
merge_getitem_cat_pass_aten = PatternMatcherPass()
|
| 55 |
+
merge_stack_tahn_unbind_pass_aten = PatternMatcherPass()
|
| 56 |
+
mutate_cat_pass_aten = PatternMatcherPass()
|
| 57 |
+
remove_split_with_size_one_pass_aten = PatternMatcherPass()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def save_inductor_dict(pass_to_compare=None):
|
| 61 |
+
if not pass_to_compare:
|
| 62 |
+
pass_to_compare = list(config.pre_grad_fusion_options.keys()) + list(
|
| 63 |
+
config.post_grad_fusion_options.keys()
|
| 64 |
+
)
|
| 65 |
+
return {p: dict(counters["inductor"]).get(p, 0) for p in pass_to_compare}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def is_same_dict(inductor_dict, optimus_dict):
|
| 69 |
+
for pass_name, count in optimus_dict.items():
|
| 70 |
+
if count != dict(inductor_dict).get(pass_name, 0):
|
| 71 |
+
return False
|
| 72 |
+
return True
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def normalize_node_kwargs_pass(graph):
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def fuse_parallel_linear_pass(graph):
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def remove_split_ops(graph, shape_prop):
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def fuse_chunk_reshape_unsqueeze_concat_pass(graph):
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def fuse_chunk_reshape_concat_pass(graph):
|
| 92 |
+
return None
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def remove_noop_pass(graph):
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def stack_to_unsqueeze_pass(graph):
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@init_once_fakemode
|
| 104 |
+
def lazy_init():
|
| 105 |
+
from . import efficient_conv_bn_eval, split_cat # noqa: F401 # noqa: F401
|
| 106 |
+
|
| 107 |
+
if config.is_fbcode():
|
| 108 |
+
from . import fb # type: ignore[attr-defined] # noqa: F401
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs=None):
|
| 112 |
+
"""
|
| 113 |
+
Apply passes on the input FX graph using Torch IR.
|
| 114 |
+
|
| 115 |
+
WARNING:
|
| 116 |
+
The IR before grad is not functional or normalized, so it is harder
|
| 117 |
+
to write passes on this IR. Passes must be safe with respect to
|
| 118 |
+
aliasing and mutation and need to handle all possible arg schemas.
|
| 119 |
+
|
| 120 |
+
Consider adding a new pass to post_grad.py or joint_graph.py which
|
| 121 |
+
are after functionalization and normalization.
|
| 122 |
+
"""
|
| 123 |
+
if config.pattern_matcher:
|
| 124 |
+
lazy_init()
|
| 125 |
+
if hasattr(
|
| 126 |
+
config, "fx_passes_numeric_check"
|
| 127 |
+
) and config.fx_passes_numeric_check.get("pre_grad", False):
|
| 128 |
+
gm_before_fx_passes = gm.__copy__()
|
| 129 |
+
# explicitly run with predispatch atenIR based passes
|
| 130 |
+
if config.is_predispatch:
|
| 131 |
+
|
| 132 |
+
def shape_prop(mod) -> None:
|
| 133 |
+
ShapeProp(
|
| 134 |
+
gm=mod,
|
| 135 |
+
# pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode`
|
| 136 |
+
fake_mode=detect_fake_mode(example_inputs),
|
| 137 |
+
).propagate(*example_inputs)
|
| 138 |
+
|
| 139 |
+
# normalization pass
|
| 140 |
+
pass_execution_and_save(
|
| 141 |
+
normalization_pass_aten.apply,
|
| 142 |
+
gm,
|
| 143 |
+
example_inputs,
|
| 144 |
+
"[Pre grad(predispatch IR)]Apply normalization pass",
|
| 145 |
+
)
|
| 146 |
+
# normalize kwargs, must be called as the first pass
|
| 147 |
+
pass_execution_and_save(
|
| 148 |
+
normalize_node_kwargs_pass,
|
| 149 |
+
gm,
|
| 150 |
+
example_inputs,
|
| 151 |
+
"[Pre grad(predispatch IR)]Apply normalize_node_kwargs_pass",
|
| 152 |
+
)
|
| 153 |
+
pass_execution_and_save(
|
| 154 |
+
remove_noop_pass,
|
| 155 |
+
gm,
|
| 156 |
+
example_inputs,
|
| 157 |
+
"[Pre grad(predispatch IR)]Apply remove_noop pass",
|
| 158 |
+
)
|
| 159 |
+
pass_execution_and_save(
|
| 160 |
+
fuse_chunk_reshape_concat_pass,
|
| 161 |
+
gm,
|
| 162 |
+
example_inputs,
|
| 163 |
+
"[Pre grad(predispatch IR)] Apply fuse_chunk_reshape_concat_pass",
|
| 164 |
+
)
|
| 165 |
+
pass_execution_and_save(
|
| 166 |
+
group_batch_fusion_passes,
|
| 167 |
+
gm,
|
| 168 |
+
example_inputs,
|
| 169 |
+
"[Pre grad(predispatch IR)] Apply group_batch_fusion",
|
| 170 |
+
)
|
| 171 |
+
pass_execution_and_save(
|
| 172 |
+
normalize_node_kwargs_pass,
|
| 173 |
+
gm,
|
| 174 |
+
example_inputs,
|
| 175 |
+
"[Pre grad(predispatch IR)]Apply normalize_node_kwargs_pass",
|
| 176 |
+
)
|
| 177 |
+
pass_execution_and_save(
|
| 178 |
+
fuse_chunk_squeeze_cat_pass.apply,
|
| 179 |
+
gm,
|
| 180 |
+
example_inputs,
|
| 181 |
+
"[Pre grad(predispatch IR)] Apply fuse_chunk_squeeze_cat_pass",
|
| 182 |
+
)
|
| 183 |
+
pass_execution_and_save(
|
| 184 |
+
fuse_split_linear_add_pass.apply,
|
| 185 |
+
gm,
|
| 186 |
+
example_inputs,
|
| 187 |
+
"[Pre grad(predispatch IR)] Apply fuse_split_linear_add_pass",
|
| 188 |
+
)
|
| 189 |
+
pass_execution_and_save(
|
| 190 |
+
remove_reshape_pass.apply,
|
| 191 |
+
gm,
|
| 192 |
+
example_inputs,
|
| 193 |
+
"[Pre grad(predispatch IR)] Apply remove_reshape_pass",
|
| 194 |
+
)
|
| 195 |
+
pass_execution_and_save(
|
| 196 |
+
fuse_parallel_linear_pass,
|
| 197 |
+
gm,
|
| 198 |
+
example_inputs,
|
| 199 |
+
"[Pre grad(predispatch IR)] Apply fuse_parallel_linear_pass",
|
| 200 |
+
)
|
| 201 |
+
pass_execution_and_save(
|
| 202 |
+
lambda graph: remove_split_ops(graph.owning_module, shape_prop),
|
| 203 |
+
gm,
|
| 204 |
+
example_inputs,
|
| 205 |
+
"[Pre grad(predispatch IR)] Apply remove_split_ops",
|
| 206 |
+
)
|
| 207 |
+
# run before fuse_chunk_reshape_unsqueeze_concat_pass
|
| 208 |
+
pass_execution_and_save(
|
| 209 |
+
stack_to_unsqueeze_pass,
|
| 210 |
+
gm,
|
| 211 |
+
example_inputs,
|
| 212 |
+
"[Pre grad(predispatch IR)] Apply stack_to_unsqueeze_pass",
|
| 213 |
+
)
|
| 214 |
+
pass_execution_and_save(
|
| 215 |
+
fuse_chunk_reshape_unsqueeze_concat_pass,
|
| 216 |
+
gm,
|
| 217 |
+
example_inputs,
|
| 218 |
+
"[Pre grad(predispatch IR)] Apply fuse_chunk_reshape_unsqueeze_concat_pass",
|
| 219 |
+
)
|
| 220 |
+
# Remove noops at the end, which may be generated other passes.
|
| 221 |
+
pass_execution_and_save(
|
| 222 |
+
remove_noop_pass,
|
| 223 |
+
gm,
|
| 224 |
+
example_inputs,
|
| 225 |
+
"[Pre grad(predispatch IR)]Apply remove_noop pass",
|
| 226 |
+
)
|
| 227 |
+
shape_prop(gm)
|
| 228 |
+
|
| 229 |
+
else:
|
| 230 |
+
# We only log the graph with changes to avoid the excessive compilation time
|
| 231 |
+
# https://fb.workplace.com/groups/257735836456307/permalink/633533465543207/
|
| 232 |
+
if example_inputs is not None:
|
| 233 |
+
gm = fuse_fx(gm, example_inputs)
|
| 234 |
+
numpy_compat_normalization(gm.graph)
|
| 235 |
+
optimus_scuba_log["before_recompile_pre_grad"] = upload_graph(gm.graph)
|
| 236 |
+
group_batch_fusion_passes(gm.graph, pre_grad=True)
|
| 237 |
+
for pass_name in config.pre_grad_fusion_options:
|
| 238 |
+
# skip all patterns for group batch fusions
|
| 239 |
+
if pass_name in PRE_GRAD_FUSIONS:
|
| 240 |
+
continue
|
| 241 |
+
pattern_matcher_pass = PRE_GRAD_PATTERNS[pass_name]
|
| 242 |
+
inductor_before_change = save_inductor_dict(
|
| 243 |
+
[pattern_matcher_pass.pass_name]
|
| 244 |
+
)
|
| 245 |
+
# we support run same pattern multiple times, the default is to run only once
|
| 246 |
+
counter = config.pre_grad_fusion_options[pass_name].get("counter", 1)
|
| 247 |
+
for _ in range(counter):
|
| 248 |
+
pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
|
| 249 |
+
if not is_same_dict(counters["inductor"], inductor_before_change):
|
| 250 |
+
optimus_scuba_log[
|
| 251 |
+
f"{pattern_matcher_pass.pass_name}_pre_grad"
|
| 252 |
+
] = upload_graph(gm.graph)
|
| 253 |
+
# TODO: move efficient_conv_bn_eval_pass to the fusions dict too.
|
| 254 |
+
efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type]
|
| 255 |
+
|
| 256 |
+
if config.pre_grad_custom_pass is not None:
|
| 257 |
+
with GraphTransformObserver(
|
| 258 |
+
gm, "pre_grad_custom_pass", config.trace.log_url_for_graph_xform
|
| 259 |
+
):
|
| 260 |
+
config.pre_grad_custom_pass(gm.graph)
|
| 261 |
+
stable_topological_sort(gm.graph)
|
| 262 |
+
|
| 263 |
+
from .quantization import quant_lift_up
|
| 264 |
+
|
| 265 |
+
quant_lift_up(gm)
|
| 266 |
+
|
| 267 |
+
gm.graph.lint()
|
| 268 |
+
gm.recompile()
|
| 269 |
+
optimus_scuba_log["after_recompile_pre_grad"] = upload_graph(gm.graph)
|
| 270 |
+
|
| 271 |
+
if (
|
| 272 |
+
config.pattern_matcher
|
| 273 |
+
and hasattr(config, "fx_passes_numeric_check")
|
| 274 |
+
and config.fx_passes_numeric_check.get("pre_grad", False)
|
| 275 |
+
and example_inputs is not None
|
| 276 |
+
):
|
| 277 |
+
from .numeric_utils import numeric_check_if_enabled
|
| 278 |
+
|
| 279 |
+
gm_after_fx_passes = gm.__copy__()
|
| 280 |
+
numeric_check_if_enabled(
|
| 281 |
+
gm_before_fx_passes, # type: ignore[possibly-undefined]
|
| 282 |
+
gm_after_fx_passes,
|
| 283 |
+
example_inputs,
|
| 284 |
+
config.fx_passes_numeric_check.get("num_iterations", 1),
|
| 285 |
+
config.fx_passes_numeric_check.get("precision", 1e-4),
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
return gm
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def fuse_fx(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule:
|
| 292 |
+
is_cpu = is_cpu_device(example_inputs)
|
| 293 |
+
# pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode`
|
| 294 |
+
fake_mode = detect_fake_mode(example_inputs)
|
| 295 |
+
|
| 296 |
+
gm = sink_cat_after_pointwise(gm)
|
| 297 |
+
if config.permute_fusion and not is_cpu:
|
| 298 |
+
# For linear permute fusion, we need to check input info to identify
|
| 299 |
+
# and perform proper permutation/transpose
|
| 300 |
+
ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs)
|
| 301 |
+
with GraphTransformObserver(
|
| 302 |
+
gm, "linear_permute_fusion", config.trace.log_url_for_graph_xform
|
| 303 |
+
):
|
| 304 |
+
gm = linear_permute_fusion(gm)
|
| 305 |
+
with GraphTransformObserver(
|
| 306 |
+
gm, "permute_linear_fusion", config.trace.log_url_for_graph_xform
|
| 307 |
+
):
|
| 308 |
+
gm = permute_linear_fusion(gm)
|
| 309 |
+
with GraphTransformObserver(
|
| 310 |
+
gm, "permute_matmul_fusion", config.trace.log_url_for_graph_xform
|
| 311 |
+
):
|
| 312 |
+
gm = permute_matmul_fusion(gm)
|
| 313 |
+
|
| 314 |
+
# make sure the autograd is disabled.
|
| 315 |
+
if torch.is_grad_enabled() or not is_cpu:
|
| 316 |
+
return gm
|
| 317 |
+
if config.freezing:
|
| 318 |
+
with GraphTransformObserver(
|
| 319 |
+
gm, "remove_identity", config.trace.log_url_for_graph_xform
|
| 320 |
+
):
|
| 321 |
+
gm = remove_identity(gm)
|
| 322 |
+
with GraphTransformObserver(
|
| 323 |
+
gm, "fuse_conv_bn", config.trace.log_url_for_graph_xform
|
| 324 |
+
):
|
| 325 |
+
gm = fuse_conv_bn(gm)
|
| 326 |
+
return gm
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def fetch_attr(target: str, mod):
|
| 330 |
+
target_atoms = target.split(".")
|
| 331 |
+
attr_itr = mod
|
| 332 |
+
for i, atom in enumerate(target_atoms):
|
| 333 |
+
if not hasattr(attr_itr, atom):
|
| 334 |
+
raise RuntimeError(
|
| 335 |
+
f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
|
| 336 |
+
)
|
| 337 |
+
attr_itr = getattr(attr_itr, atom)
|
| 338 |
+
return attr_itr
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def remove_identity(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
| 342 |
+
"""
|
| 343 |
+
Removes all identity layers from the module.
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
class IdentityRemover(torch.fx.Transformer):
|
| 347 |
+
def call_module(self, target, args, kwargs):
|
| 348 |
+
if isinstance(self.submodules[target], nn.Identity):
|
| 349 |
+
assert len(args) == 1
|
| 350 |
+
return args[0]
|
| 351 |
+
else:
|
| 352 |
+
return super().call_module(target, args, kwargs)
|
| 353 |
+
|
| 354 |
+
return IdentityRemover(gm).transform()
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False) -> torch.fx.GraphModule:
|
| 358 |
+
"""
|
| 359 |
+
Fuses Convolution/BN layers for inference purposes.
|
| 360 |
+
"""
|
| 361 |
+
modules_patterns = [
|
| 362 |
+
(torch.nn.Conv1d, torch.nn.BatchNorm1d),
|
| 363 |
+
(torch.nn.Conv2d, torch.nn.BatchNorm2d),
|
| 364 |
+
(torch.nn.Conv3d, torch.nn.BatchNorm3d),
|
| 365 |
+
]
|
| 366 |
+
module_function_patterns = [
|
| 367 |
+
(torch.nn.Conv1d, F.batch_norm),
|
| 368 |
+
(torch.nn.Conv2d, F.batch_norm),
|
| 369 |
+
(torch.nn.Conv3d, F.batch_norm),
|
| 370 |
+
]
|
| 371 |
+
modules = dict(gm.named_modules())
|
| 372 |
+
|
| 373 |
+
class ConvBNFusion:
|
| 374 |
+
def __init__(
|
| 375 |
+
self,
|
| 376 |
+
bn_node,
|
| 377 |
+
conv_module,
|
| 378 |
+
bn_module=None, # For BN Module
|
| 379 |
+
bn_running_mean=None, # For Functional BN
|
| 380 |
+
bn_running_var=None,
|
| 381 |
+
bn_eps=None,
|
| 382 |
+
bn_weight=None,
|
| 383 |
+
bn_bias=None,
|
| 384 |
+
) -> None:
|
| 385 |
+
self.bn_nodes = [
|
| 386 |
+
bn_node,
|
| 387 |
+
]
|
| 388 |
+
self.conv_module = conv_module
|
| 389 |
+
self.bn_module = bn_module
|
| 390 |
+
self.bn_running_mean = bn_running_mean
|
| 391 |
+
self.bn_running_var = bn_running_var
|
| 392 |
+
self.bn_eps = bn_eps
|
| 393 |
+
self.bn_weight = bn_weight
|
| 394 |
+
self.bn_bias = bn_bias
|
| 395 |
+
self.fusion_enabled = True
|
| 396 |
+
|
| 397 |
+
def add_bn_node(self, bn_node):
|
| 398 |
+
self.bn_nodes.append(bn_node)
|
| 399 |
+
|
| 400 |
+
def disable_fusion(self):
|
| 401 |
+
self.fusion_enabled = False
|
| 402 |
+
|
| 403 |
+
def is_fusion_enabled(self):
|
| 404 |
+
return self.fusion_enabled
|
| 405 |
+
|
| 406 |
+
conv_bn_to_fuse: Dict[int, ConvBNFusion] = {}
|
| 407 |
+
for pattern in modules_patterns:
|
| 408 |
+
conv_bn_to_fuse.clear()
|
| 409 |
+
for node in gm.graph.nodes:
|
| 410 |
+
if matches_module_pattern(pattern, node, modules):
|
| 411 |
+
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
|
| 412 |
+
continue
|
| 413 |
+
conv = modules[node.args[0].target]
|
| 414 |
+
bn = modules[node.target]
|
| 415 |
+
eval_mode = all(not n.training for n in [conv, bn])
|
| 416 |
+
if not eval_mode:
|
| 417 |
+
continue
|
| 418 |
+
if not bn.track_running_stats:
|
| 419 |
+
continue
|
| 420 |
+
|
| 421 |
+
# Do hash based on the module name of conv
|
| 422 |
+
hash_id = hash(node.args[0].target)
|
| 423 |
+
if hash_id not in conv_bn_to_fuse:
|
| 424 |
+
conv_bn_to_fuse[hash_id] = ConvBNFusion(node, conv, bn)
|
| 425 |
+
else:
|
| 426 |
+
if bn == conv_bn_to_fuse[hash_id].bn_module:
|
| 427 |
+
# Do fusion if same bn module
|
| 428 |
+
conv_bn_to_fuse[hash_id].add_bn_node(node)
|
| 429 |
+
else:
|
| 430 |
+
# Disable the conv bn folding if conv shared by different bn
|
| 431 |
+
conv_bn_to_fuse[hash_id].disable_fusion()
|
| 432 |
+
|
| 433 |
+
for conv_bn_fusion in conv_bn_to_fuse.values():
|
| 434 |
+
if conv_bn_fusion.is_fusion_enabled():
|
| 435 |
+
bn_nodes = conv_bn_fusion.bn_nodes
|
| 436 |
+
conv = conv_bn_fusion.conv_module
|
| 437 |
+
bn = conv_bn_fusion.bn_module
|
| 438 |
+
|
| 439 |
+
fused_conv = fuse_conv_bn_eval(conv, bn)
|
| 440 |
+
for bn_node in bn_nodes:
|
| 441 |
+
replace_node_module(bn_node.args[0], modules, fused_conv)
|
| 442 |
+
bn_node.replace_all_uses_with(bn_node.args[0])
|
| 443 |
+
gm.graph.erase_node(bn_node)
|
| 444 |
+
|
| 445 |
+
gm.graph.lint()
|
| 446 |
+
for pattern in module_function_patterns:
|
| 447 |
+
conv_bn_to_fuse.clear()
|
| 448 |
+
for node in gm.graph.nodes:
|
| 449 |
+
if matches_module_function_pattern(pattern, node, modules):
|
| 450 |
+
# TODO: support kwargs.
|
| 451 |
+
if len(node.args) != 8:
|
| 452 |
+
continue
|
| 453 |
+
conv = modules[node.args[0].target]
|
| 454 |
+
bn_training = node.args[5]
|
| 455 |
+
bn_eps = node.args[7]
|
| 456 |
+
if conv.training or bn_training:
|
| 457 |
+
continue
|
| 458 |
+
if type(bn_eps) is not float:
|
| 459 |
+
continue
|
| 460 |
+
|
| 461 |
+
def _used_by_same_conv_module(users):
|
| 462 |
+
conv_module_name = users[0].args[0].target
|
| 463 |
+
return all(
|
| 464 |
+
conv_module_name == user.args[0].target for user in users
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
bn_args_is_constant = all(
|
| 468 |
+
n.op == "get_attr"
|
| 469 |
+
and (len(n.users) == 1 or _used_by_same_conv_module(list(n.users)))
|
| 470 |
+
for n in node.args[1:5]
|
| 471 |
+
)
|
| 472 |
+
if not bn_args_is_constant:
|
| 473 |
+
continue
|
| 474 |
+
bn_running_mean = fetch_attr(node.args[1].target, gm)
|
| 475 |
+
bn_running_var = fetch_attr(node.args[2].target, gm)
|
| 476 |
+
bn_weight = fetch_attr(node.args[3].target, gm)
|
| 477 |
+
bn_bias = fetch_attr(node.args[4].target, gm)
|
| 478 |
+
if bn_running_mean is None or bn_running_var is None:
|
| 479 |
+
continue
|
| 480 |
+
|
| 481 |
+
# Do hash based on the module name of conv
|
| 482 |
+
hash_id = hash(node.args[0].target)
|
| 483 |
+
if hash_id not in conv_bn_to_fuse:
|
| 484 |
+
conv_bn_to_fuse[hash_id] = ConvBNFusion(
|
| 485 |
+
node,
|
| 486 |
+
conv,
|
| 487 |
+
bn_running_mean=bn_running_mean,
|
| 488 |
+
bn_running_var=bn_running_var,
|
| 489 |
+
bn_eps=bn_eps,
|
| 490 |
+
bn_weight=bn_weight,
|
| 491 |
+
bn_bias=bn_bias,
|
| 492 |
+
)
|
| 493 |
+
else:
|
| 494 |
+
if (
|
| 495 |
+
hash(bn_running_mean)
|
| 496 |
+
== hash(conv_bn_to_fuse[hash_id].bn_running_mean)
|
| 497 |
+
and hash(bn_running_var)
|
| 498 |
+
== hash(conv_bn_to_fuse[hash_id].bn_running_var)
|
| 499 |
+
and torch.allclose(
|
| 500 |
+
torch.tensor(bn_eps),
|
| 501 |
+
torch.tensor(conv_bn_to_fuse[hash_id].bn_eps),
|
| 502 |
+
)
|
| 503 |
+
and hash(bn_weight) == hash(conv_bn_to_fuse[hash_id].bn_weight)
|
| 504 |
+
and hash(bn_bias) == hash(conv_bn_to_fuse[hash_id].bn_bias)
|
| 505 |
+
):
|
| 506 |
+
# Do fusion if same functional bn
|
| 507 |
+
conv_bn_to_fuse[hash_id].add_bn_node(node)
|
| 508 |
+
else:
|
| 509 |
+
# Disable the conv bn folding if conv shared by different bn
|
| 510 |
+
conv_bn_to_fuse[hash_id].disable_fusion()
|
| 511 |
+
|
| 512 |
+
for conv_bn_fusion in conv_bn_to_fuse.values():
|
| 513 |
+
if conv_bn_fusion.is_fusion_enabled():
|
| 514 |
+
bn_nodes = conv_bn_fusion.bn_nodes
|
| 515 |
+
conv = conv_bn_fusion.conv_module
|
| 516 |
+
bn_running_mean = conv_bn_fusion.bn_running_mean
|
| 517 |
+
bn_running_var = conv_bn_fusion.bn_running_var
|
| 518 |
+
bn_eps = conv_bn_fusion.bn_eps
|
| 519 |
+
bn_weight = conv_bn_fusion.bn_weight
|
| 520 |
+
bn_bias = conv_bn_fusion.bn_bias
|
| 521 |
+
|
| 522 |
+
fused_conv = copy.deepcopy(conv)
|
| 523 |
+
fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights(
|
| 524 |
+
fused_conv.weight,
|
| 525 |
+
fused_conv.bias,
|
| 526 |
+
bn_running_mean,
|
| 527 |
+
bn_running_var,
|
| 528 |
+
bn_eps,
|
| 529 |
+
bn_weight,
|
| 530 |
+
bn_bias,
|
| 531 |
+
)
|
| 532 |
+
for bn_node in bn_nodes:
|
| 533 |
+
replace_node_module(bn_node.args[0], modules, fused_conv)
|
| 534 |
+
bn_node.replace_all_uses_with(bn_node.args[0])
|
| 535 |
+
gm.graph.erase_node(bn_node)
|
| 536 |
+
gm.graph.lint()
|
| 537 |
+
gm.recompile()
|
| 538 |
+
|
| 539 |
+
return gm
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
class NormalizedLinearNode:
|
| 543 |
+
def __init__(self, node: torch.fx.Node) -> None:
|
| 544 |
+
assert node.op == "call_function"
|
| 545 |
+
assert node.target in [torch.nn.functional.linear]
|
| 546 |
+
self.node: torch.fx.Node = node
|
| 547 |
+
|
| 548 |
+
def get_input(self) -> torch.fx.Node:
|
| 549 |
+
if len(self.node.args) > 0:
|
| 550 |
+
return self.node.args[0] # type: ignore[return-value]
|
| 551 |
+
else:
|
| 552 |
+
return self.node.kwargs["input"] # type: ignore[return-value]
|
| 553 |
+
|
| 554 |
+
def get_weight(self) -> torch.fx.Node:
|
| 555 |
+
if len(self.node.args) > 1:
|
| 556 |
+
return self.node.args[1] # type: ignore[return-value]
|
| 557 |
+
else:
|
| 558 |
+
return self.node.kwargs["weight"] # type: ignore[return-value]
|
| 559 |
+
|
| 560 |
+
def get_bias(self) -> torch.fx.Node:
|
| 561 |
+
if len(self.node.args) > 2:
|
| 562 |
+
return self.node.args[2] # type: ignore[return-value]
|
| 563 |
+
else:
|
| 564 |
+
return self.node.kwargs["bias"] if "bias" in self.node.kwargs else None # type: ignore[return-value]
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
class NormalizedMatmulNode:
|
| 568 |
+
def __init__(self, node: torch.fx.Node) -> None:
|
| 569 |
+
assert node.op == "call_function"
|
| 570 |
+
assert node.target in [torch.bmm, torch.matmul]
|
| 571 |
+
self.node: torch.fx.Node = node
|
| 572 |
+
|
| 573 |
+
def get_input(self) -> torch.fx.Node:
|
| 574 |
+
if len(self.node.args) > 0:
|
| 575 |
+
return self.node.args[0] # type: ignore[return-value]
|
| 576 |
+
else:
|
| 577 |
+
return self.node.kwargs["input"] # type: ignore[return-value]
|
| 578 |
+
|
| 579 |
+
def get_other(self) -> torch.fx.Node:
|
| 580 |
+
if len(self.node.args) > 1:
|
| 581 |
+
return self.node.args[1] # type: ignore[return-value]
|
| 582 |
+
else:
|
| 583 |
+
return self.node.kwargs["other"] # type: ignore[return-value]
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def check_permute(node: torch.fx.Node) -> bool:
|
| 587 |
+
ranks = len(node.meta["tensor_meta"].shape)
|
| 588 |
+
if len(node.args) > 3:
|
| 589 |
+
permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] # type: ignore[operator]
|
| 590 |
+
elif (
|
| 591 |
+
"permutation" in node.kwargs
|
| 592 |
+
and node.kwargs["permutation"] is not None
|
| 593 |
+
and len(node.kwargs["permutation"]) > 2 # type: ignore[arg-type]
|
| 594 |
+
):
|
| 595 |
+
permutation = [i % ranks for i in node.kwargs["permutation"]] # type: ignore[union-attr]
|
| 596 |
+
else:
|
| 597 |
+
return False
|
| 598 |
+
allowed_permutation = list(range(ranks))
|
| 599 |
+
allowed_permutation[-1] = ranks - 2
|
| 600 |
+
allowed_permutation[-2] = ranks - 1
|
| 601 |
+
return permutation == allowed_permutation
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def sink_cat_after_pointwise(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
| 605 |
+
def one_user(node):
|
| 606 |
+
users = list(node.users)
|
| 607 |
+
return users[0] if len(users) == 1 else None
|
| 608 |
+
|
| 609 |
+
def is_view(node):
|
| 610 |
+
view = {"view"}
|
| 611 |
+
return node.op == "call_method" and node.target in view
|
| 612 |
+
|
| 613 |
+
def is_pointwise_unary(node):
|
| 614 |
+
pointwise = {torch.relu, torch.tanh, "relu", "tanh"}
|
| 615 |
+
return node.op in {"call_function", "call_method"} and node.target in pointwise
|
| 616 |
+
|
| 617 |
+
g = module.graph
|
| 618 |
+
for node in g.nodes:
|
| 619 |
+
if node.op != "call_function" or node.target != torch.cat:
|
| 620 |
+
continue
|
| 621 |
+
|
| 622 |
+
cat_or_view = node
|
| 623 |
+
while True:
|
| 624 |
+
user = one_user(cat_or_view)
|
| 625 |
+
if not user or not is_view(user):
|
| 626 |
+
break
|
| 627 |
+
cat_or_view = user
|
| 628 |
+
|
| 629 |
+
if user and is_pointwise_unary(user):
|
| 630 |
+
with g.inserting_before(node):
|
| 631 |
+
|
| 632 |
+
def cat_args(tensors, dim=0):
|
| 633 |
+
return tensors, dim
|
| 634 |
+
|
| 635 |
+
tensors, dim = cat_args(*node.args, **node.kwargs)
|
| 636 |
+
new_kwargs = {
|
| 637 |
+
name: val for name, val in user.kwargs.items() if name != "input"
|
| 638 |
+
}
|
| 639 |
+
new_tensors = [
|
| 640 |
+
g.create_node(user.op, user.target, args=(arg,), kwargs=new_kwargs)
|
| 641 |
+
for arg in tensors
|
| 642 |
+
]
|
| 643 |
+
new_cat = g.create_node(
|
| 644 |
+
"call_function", torch.cat, args=(new_tensors, dim)
|
| 645 |
+
)
|
| 646 |
+
user.replace_all_uses_with(cat_or_view)
|
| 647 |
+
node.replace_all_uses_with(new_cat)
|
| 648 |
+
g.erase_node(user)
|
| 649 |
+
g.erase_node(node)
|
| 650 |
+
g.lint()
|
| 651 |
+
module.recompile()
|
| 652 |
+
return module
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
| 656 |
+
for node in module.graph.find_nodes(op="call_method", target="permute"):
|
| 657 |
+
if check_permute(node):
|
| 658 |
+
if len(node.args) > 0:
|
| 659 |
+
input_node = node.args[0]
|
| 660 |
+
else:
|
| 661 |
+
input_node = node.kwargs["input"]
|
| 662 |
+
if (
|
| 663 |
+
input_node.op == "call_function"
|
| 664 |
+
and input_node.target == torch.nn.functional.linear
|
| 665 |
+
):
|
| 666 |
+
normalized = NormalizedLinearNode(input_node)
|
| 667 |
+
input = normalized.get_input()
|
| 668 |
+
weight = normalized.get_weight()
|
| 669 |
+
bias = normalized.get_bias()
|
| 670 |
+
with module.graph.inserting_before(node):
|
| 671 |
+
fused_node = module.graph.call_function(
|
| 672 |
+
linear_transpose, args=(input, weight, bias)
|
| 673 |
+
)
|
| 674 |
+
node.replace_all_uses_with(fused_node)
|
| 675 |
+
module.graph.erase_node(node)
|
| 676 |
+
if len(input_node.users) == 0:
|
| 677 |
+
module.graph.erase_node(input_node)
|
| 678 |
+
|
| 679 |
+
module.graph.lint()
|
| 680 |
+
module.recompile()
|
| 681 |
+
return module
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
# Y1 = X * W^T + bias
|
| 685 |
+
# Y2 = Y1.permute(0, 2, 1)
|
| 686 |
+
# ---->
|
| 687 |
+
# Y2 = (W * X^T + bias.unsqueeze(-1))^T
|
| 688 |
+
def linear_transpose(
|
| 689 |
+
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
|
| 690 |
+
) -> torch.Tensor:
|
| 691 |
+
if bias is None:
|
| 692 |
+
return torch.matmul(weight, input.transpose(-1, -2))
|
| 693 |
+
return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1)
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
| 697 |
+
for node in module.graph.find_nodes(
|
| 698 |
+
op="call_function", target=torch.nn.functional.linear
|
| 699 |
+
):
|
| 700 |
+
if len(node.args) > 0:
|
| 701 |
+
input_node = node.args[0]
|
| 702 |
+
else:
|
| 703 |
+
input_node = node.kwargs["input"]
|
| 704 |
+
if (
|
| 705 |
+
input_node.op == "call_method"
|
| 706 |
+
and input_node.target == "permute"
|
| 707 |
+
and check_permute(input_node)
|
| 708 |
+
):
|
| 709 |
+
normalized = NormalizedLinearNode(node)
|
| 710 |
+
if len(input_node.args) > 0:
|
| 711 |
+
input = input_node.args[0]
|
| 712 |
+
else:
|
| 713 |
+
input = input_node.kwargs["input"]
|
| 714 |
+
weight = normalized.get_weight()
|
| 715 |
+
bias = normalized.get_bias()
|
| 716 |
+
with module.graph.inserting_before(node):
|
| 717 |
+
fused_node = module.graph.call_function(
|
| 718 |
+
transpose_linear, args=(input, weight, bias)
|
| 719 |
+
)
|
| 720 |
+
node.replace_all_uses_with(fused_node)
|
| 721 |
+
module.graph.erase_node(node)
|
| 722 |
+
if len(input_node.users) == 0:
|
| 723 |
+
module.graph.erase_node(input_node)
|
| 724 |
+
|
| 725 |
+
module.graph.lint()
|
| 726 |
+
module.recompile()
|
| 727 |
+
return module
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
| 731 |
+
for node in itertools.chain(
|
| 732 |
+
module.graph.find_nodes(op="call_function", target=torch.bmm),
|
| 733 |
+
module.graph.find_nodes(op="call_function", target=torch.matmul),
|
| 734 |
+
):
|
| 735 |
+
normalized = NormalizedMatmulNode(node)
|
| 736 |
+
input_A_node = normalized.get_input()
|
| 737 |
+
input_B_node = normalized.get_other()
|
| 738 |
+
input_A = input_A_node
|
| 739 |
+
input_B = input_B_node
|
| 740 |
+
Atrans = Btrans = False
|
| 741 |
+
if (
|
| 742 |
+
input_A_node.op == "call_method"
|
| 743 |
+
and input_A_node.target == "permute"
|
| 744 |
+
and check_permute(input_A_node)
|
| 745 |
+
):
|
| 746 |
+
Atrans = True
|
| 747 |
+
if len(input_A_node.args) > 0:
|
| 748 |
+
input_A = input_A_node.args[0] # type: ignore[assignment]
|
| 749 |
+
else:
|
| 750 |
+
input_A = input_A_node.kwargs["input"] # type: ignore[assignment]
|
| 751 |
+
|
| 752 |
+
if (
|
| 753 |
+
input_B_node.op == "call_method"
|
| 754 |
+
and input_B_node.target == "permute"
|
| 755 |
+
and check_permute(input_B_node)
|
| 756 |
+
):
|
| 757 |
+
Btrans = True
|
| 758 |
+
if len(input_B_node.args) > 0:
|
| 759 |
+
input_B = input_B_node.args[0] # type: ignore[assignment]
|
| 760 |
+
else:
|
| 761 |
+
input_B = input_B_node.kwargs["input"] # type: ignore[assignment]
|
| 762 |
+
|
| 763 |
+
if Atrans or Btrans:
|
| 764 |
+
with module.graph.inserting_before(node):
|
| 765 |
+
fused_node = module.graph.call_function(
|
| 766 |
+
transpose_matmul,
|
| 767 |
+
args=(input_A, input_B, Atrans, Btrans),
|
| 768 |
+
)
|
| 769 |
+
node.replace_all_uses_with(fused_node)
|
| 770 |
+
module.graph.erase_node(node)
|
| 771 |
+
if Atrans and len(input_A_node.users) == 0:
|
| 772 |
+
module.graph.erase_node(input_A_node)
|
| 773 |
+
if Btrans and len(input_B_node.users) == 0:
|
| 774 |
+
module.graph.erase_node(input_B_node)
|
| 775 |
+
|
| 776 |
+
module.graph.lint()
|
| 777 |
+
module.recompile()
|
| 778 |
+
return module
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
# X1 = X.permute(0, 2, 1)
|
| 782 |
+
# Y1 = X1 * W1^T + bias1
|
| 783 |
+
# ---->
|
| 784 |
+
# Y2 = X1.transpose(-1, -2) * W1^T + bias1
|
| 785 |
+
def transpose_linear(
|
| 786 |
+
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
|
| 787 |
+
) -> torch.Tensor:
|
| 788 |
+
if bias is None:
|
| 789 |
+
return torch.matmul(input.transpose(-1, -2), weight.t())
|
| 790 |
+
return torch.matmul(input.transpose(-1, -2), weight.t()) + bias
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
def transpose_matmul(
|
| 794 |
+
A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool
|
| 795 |
+
) -> torch.Tensor:
|
| 796 |
+
if Atrans:
|
| 797 |
+
A = A.transpose(-1, -2)
|
| 798 |
+
if Btrans:
|
| 799 |
+
B = B.transpose(-1, -2)
|
| 800 |
+
return torch.matmul(A, B)
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/quantization.py
ADDED
|
@@ -0,0 +1,2589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-decorators
|
| 2 |
+
# mypy: allow-untyped-defs
|
| 3 |
+
import copy
|
| 4 |
+
import functools
|
| 5 |
+
import itertools
|
| 6 |
+
import math
|
| 7 |
+
import operator
|
| 8 |
+
from typing import Any, Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch._dynamo.utils import counters
|
| 12 |
+
from torch.fx.experimental.symbolic_shapes import has_free_symbols
|
| 13 |
+
from torch.fx.node import map_arg
|
| 14 |
+
|
| 15 |
+
from ..lowering import lowerings as L, require_channels_last
|
| 16 |
+
from ..pattern_matcher import Arg, CallFunction, filter_nodes, KeywordArg, ListOf, Match
|
| 17 |
+
from ..utils import pad_listlike
|
| 18 |
+
from .freezing_patterns import register_freezing_graph_pattern
|
| 19 |
+
from .post_grad import register_lowering_pattern
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
aten = torch.ops.aten
|
| 23 |
+
prims = torch.ops.prims
|
| 24 |
+
quantized_decomposed = torch.ops.quantized_decomposed
|
| 25 |
+
quantized = torch.ops.quantized
|
| 26 |
+
|
| 27 |
+
# Only for per tensor quant since permute may changes the channel idx
|
| 28 |
+
_PER_TENSOR_QUANTIZE_OPS = [
|
| 29 |
+
quantized_decomposed.quantize_per_tensor.default,
|
| 30 |
+
quantized_decomposed.quantize_per_tensor.tensor,
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
_VIEW_OPS = [
|
| 34 |
+
aten.transpose.int,
|
| 35 |
+
aten.permute.default,
|
| 36 |
+
aten.view.default,
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
"""
|
| 40 |
+
The quantization.py file primarily incorporates passes related to quantization fusion
|
| 41 |
+
in inductor, includes:
|
| 42 |
+
1. Dequant Promotion;
|
| 43 |
+
2. Conv/GEMM weight prepack with oneDNN Library;
|
| 44 |
+
3. Conv/GEMM quantization fusion with output quant node (if have);
|
| 45 |
+
4. Other pointwise operators' quantization fusion like: qmaxpool2d, qcat and more;
|
| 46 |
+
|
| 47 |
+
It also involves int8-mixed-fp32 and int8-mixed-bf16 quantization. The main difference
|
| 48 |
+
of patterns for int8-mixed-bf16, comparing with int8-mixed-fp32, is
|
| 49 |
+
1. There is to(dtype=torch.bfloat16) node at the inputs of activation and weight for Conv/GEMM.
|
| 50 |
+
2. There is to(dtype=torch.float32) node at the outputs of Conv/GEMM before inputs to next quant node.
|
| 51 |
+
Refer to: https://github.com/pytorch/pytorch/issues/111640 for detail design of int8-mixed-bf16
|
| 52 |
+
quantization.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _get_pattern_output_dtype(match: Match):
|
| 57 |
+
"""
|
| 58 |
+
Get the pattern's output dtype from node's meta
|
| 59 |
+
Assume only 1 output node in this matched pattern.
|
| 60 |
+
"""
|
| 61 |
+
pattern_output_nodes = match.output_nodes()
|
| 62 |
+
assert len(pattern_output_nodes) == 1
|
| 63 |
+
output_node = pattern_output_nodes[0]
|
| 64 |
+
assert isinstance(output_node, torch.fx.Node)
|
| 65 |
+
output_dtype = output_node.meta["val"].dtype
|
| 66 |
+
assert output_dtype in [torch.uint8, torch.float32, torch.bfloat16]
|
| 67 |
+
return output_dtype
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _may_generate_pattern_with_dtype_convert(
|
| 71 |
+
pattern, dtype=Arg(), with_dtype_convert=True, users=1
|
| 72 |
+
):
|
| 73 |
+
if with_dtype_convert:
|
| 74 |
+
return CallFunction(
|
| 75 |
+
prims.convert_element_type.default,
|
| 76 |
+
pattern,
|
| 77 |
+
dtype,
|
| 78 |
+
_users=users,
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
return pattern
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _may_generate_pattern_with_reshape(pattern, reshape_size=Arg(), with_reshape=True):
|
| 85 |
+
if with_reshape:
|
| 86 |
+
return CallFunction(
|
| 87 |
+
torch.ops.aten.reshape.default,
|
| 88 |
+
pattern,
|
| 89 |
+
reshape_size,
|
| 90 |
+
)
|
| 91 |
+
else:
|
| 92 |
+
return pattern
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _generate_linear_t_pattern(
|
| 96 |
+
_dequant_per_channel_pattern,
|
| 97 |
+
dtype,
|
| 98 |
+
):
|
| 99 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 100 |
+
t_pattern = CallFunction(
|
| 101 |
+
aten.permute.default,
|
| 102 |
+
_may_generate_pattern_with_dtype_convert(
|
| 103 |
+
_dequant_per_channel_pattern,
|
| 104 |
+
KeywordArg("autocast_wgt_dtype"),
|
| 105 |
+
dtype == torch.bfloat16,
|
| 106 |
+
),
|
| 107 |
+
KeywordArg("permute_axes"),
|
| 108 |
+
)
|
| 109 |
+
return t_pattern
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16):
|
| 113 |
+
# only insert to_dtype if is_bf16 is True
|
| 114 |
+
computation_call = _may_generate_pattern_with_dtype_convert(
|
| 115 |
+
call_fn, dtype=KeywordArg("to_float"), with_dtype_convert=is_bf16, users=users
|
| 116 |
+
)
|
| 117 |
+
return unary_fusion(computation_call)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False):
|
| 121 |
+
dequantize_per_tensor_activation_pattern = CallFunction(
|
| 122 |
+
quantized_decomposed.dequantize_per_tensor.tensor
|
| 123 |
+
if is_tensor_overload
|
| 124 |
+
else quantized_decomposed.dequantize_per_tensor.default,
|
| 125 |
+
KeywordArg("x"),
|
| 126 |
+
KeywordArg("x_scale"),
|
| 127 |
+
KeywordArg("x_zp"),
|
| 128 |
+
KeywordArg("x_quant_min"),
|
| 129 |
+
KeywordArg("x_quant_max"),
|
| 130 |
+
KeywordArg("x_dq_dtype"),
|
| 131 |
+
)
|
| 132 |
+
return dequantize_per_tensor_activation_pattern
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
dequantize_per_channel_weight_pattern = CallFunction(
|
| 136 |
+
quantized_decomposed.dequantize_per_channel.default,
|
| 137 |
+
KeywordArg("q_weight"),
|
| 138 |
+
KeywordArg("w_scale"),
|
| 139 |
+
KeywordArg("w_zp"),
|
| 140 |
+
KeywordArg("w_axis"),
|
| 141 |
+
KeywordArg("w_quant_min"),
|
| 142 |
+
KeywordArg("w_quant_max"),
|
| 143 |
+
KeywordArg("w_dtype"),
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
dequantize_per_channel_to_bf16_weight_pattern = (
|
| 147 |
+
_may_generate_pattern_with_dtype_convert(
|
| 148 |
+
dequantize_per_channel_weight_pattern,
|
| 149 |
+
KeywordArg("autocast_wgt_dtype"),
|
| 150 |
+
)
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
dequantize_per_channel_clone_weight_pattern = CallFunction(
|
| 154 |
+
aten.clone.default,
|
| 155 |
+
dequantize_per_channel_weight_pattern,
|
| 156 |
+
memory_format=KeywordArg("memory_format"),
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction(
|
| 160 |
+
aten.clone.default,
|
| 161 |
+
dequantize_per_channel_to_bf16_weight_pattern,
|
| 162 |
+
memory_format=KeywordArg("memory_format"),
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def get_dequantize_qconv_pt2e_pattern(users=1):
|
| 167 |
+
return CallFunction(
|
| 168 |
+
torch.ops.onednn.qconv2d_pointwise.default,
|
| 169 |
+
KeywordArg("x"),
|
| 170 |
+
KeywordArg("x_scale"), # x_scale
|
| 171 |
+
KeywordArg("x_zp"), # x_zp
|
| 172 |
+
KeywordArg("packed_weight"), # packed_weight
|
| 173 |
+
KeywordArg("w_scale"), # w_scale
|
| 174 |
+
KeywordArg("w_zp"), # w_zp
|
| 175 |
+
KeywordArg("b"), # bias
|
| 176 |
+
KeywordArg("stride"),
|
| 177 |
+
KeywordArg("padding"),
|
| 178 |
+
KeywordArg("dilation"),
|
| 179 |
+
KeywordArg("groups"),
|
| 180 |
+
KeywordArg("output_scale"), # output_scale = 1.0
|
| 181 |
+
KeywordArg("output_zero_point"), # output_zero_point = 0
|
| 182 |
+
KeywordArg("output_dtype"), # output_dtype = None
|
| 183 |
+
KeywordArg("attr"), # attr = "none"
|
| 184 |
+
Arg(), # scalars
|
| 185 |
+
Arg(), # algorithm
|
| 186 |
+
_users=users,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors, users=1):
|
| 191 |
+
qlinear_op = (
|
| 192 |
+
torch.ops.onednn.qlinear_pointwise.tensor
|
| 193 |
+
if x_scale_zp_are_tensors
|
| 194 |
+
else torch.ops.onednn.qlinear_pointwise.default
|
| 195 |
+
)
|
| 196 |
+
return CallFunction(
|
| 197 |
+
qlinear_op,
|
| 198 |
+
KeywordArg("x"),
|
| 199 |
+
KeywordArg("x_scale"),
|
| 200 |
+
KeywordArg("x_zp"),
|
| 201 |
+
KeywordArg("packed_weight"),
|
| 202 |
+
KeywordArg("w_scale"),
|
| 203 |
+
KeywordArg("w_zp"),
|
| 204 |
+
KeywordArg("b"),
|
| 205 |
+
KeywordArg("output_scale"),
|
| 206 |
+
KeywordArg("output_zero_point"),
|
| 207 |
+
KeywordArg("output_dtype"),
|
| 208 |
+
KeywordArg("postop_name"),
|
| 209 |
+
KeywordArg("postop_args"),
|
| 210 |
+
KeywordArg("postop_algorithm"),
|
| 211 |
+
_users=users,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
dequantize_accum_pattern = CallFunction(
|
| 216 |
+
quantized_decomposed.dequantize_per_tensor.default,
|
| 217 |
+
KeywordArg("accum"),
|
| 218 |
+
KeywordArg("accum_scale"),
|
| 219 |
+
KeywordArg("accum_zp"),
|
| 220 |
+
Arg(),
|
| 221 |
+
Arg(),
|
| 222 |
+
KeywordArg("accum_dq_dtype"),
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def generate_pattern_with_binary(
|
| 227 |
+
binary_post_op,
|
| 228 |
+
computation_call,
|
| 229 |
+
extra_input_pattern,
|
| 230 |
+
dtype_convert=False,
|
| 231 |
+
swap_inputs=False,
|
| 232 |
+
):
|
| 233 |
+
binary_pattern = (
|
| 234 |
+
CallFunction(
|
| 235 |
+
binary_post_op,
|
| 236 |
+
extra_input_pattern,
|
| 237 |
+
computation_call,
|
| 238 |
+
)
|
| 239 |
+
if swap_inputs
|
| 240 |
+
else CallFunction(
|
| 241 |
+
binary_post_op,
|
| 242 |
+
computation_call,
|
| 243 |
+
extra_input_pattern,
|
| 244 |
+
)
|
| 245 |
+
)
|
| 246 |
+
return _may_generate_pattern_with_dtype_convert(
|
| 247 |
+
binary_pattern,
|
| 248 |
+
KeywordArg("convert_dtype_after_inplace_add"),
|
| 249 |
+
dtype_convert,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def generate_pattern_with_unary(computation_call, unary_post_op):
|
| 254 |
+
if unary_post_op is not None:
|
| 255 |
+
return CallFunction(
|
| 256 |
+
unary_post_op,
|
| 257 |
+
computation_call,
|
| 258 |
+
)
|
| 259 |
+
return computation_call
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False):
|
| 263 |
+
quantized_op_output_pattern_pt2e = CallFunction(
|
| 264 |
+
quantized_decomposed.quantize_per_tensor.default,
|
| 265 |
+
_may_generate_pattern_with_dtype_convert(
|
| 266 |
+
computation_call,
|
| 267 |
+
Arg(),
|
| 268 |
+
with_dtype_convert,
|
| 269 |
+
),
|
| 270 |
+
KeywordArg("o_inv_scale"),
|
| 271 |
+
KeywordArg("o_zp"),
|
| 272 |
+
KeywordArg("o_qmin"),
|
| 273 |
+
KeywordArg("o_qmax"),
|
| 274 |
+
KeywordArg("o_dtype"),
|
| 275 |
+
)
|
| 276 |
+
return quantized_op_output_pattern_pt2e
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_value):
|
| 280 |
+
if kwarg_name in check_node.kwargs:
|
| 281 |
+
actual_value = check_node.kwargs[kwarg_name]
|
| 282 |
+
return actual_value == expected_value
|
| 283 |
+
else:
|
| 284 |
+
assert len(check_node.args) >= (args_index + 1)
|
| 285 |
+
actual_value = check_node.args[args_index]
|
| 286 |
+
return actual_value == expected_value
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def _is_valid_quantized_conv2d_optimization_pattern():
|
| 290 |
+
def fn(match):
|
| 291 |
+
output_dtype = _get_pattern_output_dtype(match)
|
| 292 |
+
if output_dtype in [torch.float32, torch.bfloat16]:
|
| 293 |
+
# Only keep matched pattern with same output_dtype
|
| 294 |
+
qconv_node_after_weight_prepack = filter_nodes(
|
| 295 |
+
match.nodes, torch.ops.onednn.qconv2d_pointwise
|
| 296 |
+
)[0]
|
| 297 |
+
return _check_node_kwarg_arg_value(
|
| 298 |
+
qconv_node_after_weight_prepack, "output_dtype", 13, output_dtype
|
| 299 |
+
)
|
| 300 |
+
return True
|
| 301 |
+
|
| 302 |
+
return fn
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _register_quantized_conv_lowering(
|
| 306 |
+
pattern,
|
| 307 |
+
pass_number,
|
| 308 |
+
computation_op,
|
| 309 |
+
unary_attr,
|
| 310 |
+
):
|
| 311 |
+
@register_lowering_pattern(
|
| 312 |
+
pattern,
|
| 313 |
+
extra_check=_is_valid_quantized_conv2d_optimization_pattern(),
|
| 314 |
+
pass_number=pass_number,
|
| 315 |
+
)
|
| 316 |
+
def qconv(match: Match, *args, **kwargs):
|
| 317 |
+
# Activation QParams
|
| 318 |
+
x, x_scale, x_zp = (
|
| 319 |
+
kwargs["x"],
|
| 320 |
+
kwargs["x_scale"],
|
| 321 |
+
kwargs["x_zp"],
|
| 322 |
+
)
|
| 323 |
+
# Weight QParams
|
| 324 |
+
packed_weight, w_scale, w_zp = (
|
| 325 |
+
kwargs["packed_weight"],
|
| 326 |
+
kwargs["w_scale"],
|
| 327 |
+
kwargs["w_zp"],
|
| 328 |
+
)
|
| 329 |
+
# Conv Params
|
| 330 |
+
b, stride, padding, dilation, groups = (
|
| 331 |
+
kwargs["b"],
|
| 332 |
+
kwargs["stride"],
|
| 333 |
+
kwargs["padding"],
|
| 334 |
+
kwargs["dilation"],
|
| 335 |
+
kwargs["groups"],
|
| 336 |
+
)
|
| 337 |
+
output_dtype = _get_pattern_output_dtype(match)
|
| 338 |
+
assert output_dtype in [torch.uint8, torch.float32, torch.bfloat16]
|
| 339 |
+
# Output QParams
|
| 340 |
+
o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
|
| 341 |
+
o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
|
| 342 |
+
assert (
|
| 343 |
+
kwargs["attr"] == "none"
|
| 344 |
+
) # Expected no post op fused in weight prepack phase
|
| 345 |
+
if unary_attr.op_name == "hardtanh":
|
| 346 |
+
min_value = kwargs.get("min_value")
|
| 347 |
+
max_value = kwargs.get("max_value")
|
| 348 |
+
unary_attr.scalars_attr = [min_value, max_value]
|
| 349 |
+
|
| 350 |
+
computation_args = (
|
| 351 |
+
x,
|
| 352 |
+
x_scale,
|
| 353 |
+
x_zp,
|
| 354 |
+
packed_weight,
|
| 355 |
+
w_scale,
|
| 356 |
+
w_zp,
|
| 357 |
+
b,
|
| 358 |
+
stride,
|
| 359 |
+
padding,
|
| 360 |
+
dilation,
|
| 361 |
+
groups,
|
| 362 |
+
o_inv_scale,
|
| 363 |
+
o_zero_point,
|
| 364 |
+
output_dtype,
|
| 365 |
+
unary_attr.op_name,
|
| 366 |
+
unary_attr.scalars_attr,
|
| 367 |
+
unary_attr.algorithm_attr,
|
| 368 |
+
)
|
| 369 |
+
counters["inductor"]["qconv2d_unary_matcher_count"] += 1
|
| 370 |
+
counters["inductor"]["qconv2d_unary_matcher_nodes"] += len(match.nodes)
|
| 371 |
+
return L[computation_op](*computation_args)
|
| 372 |
+
|
| 373 |
+
return qconv
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def _is_valid_quantized_linear_optimization_pattern():
|
| 377 |
+
def fn(match):
|
| 378 |
+
output_dtype = _get_pattern_output_dtype(match)
|
| 379 |
+
if output_dtype in [torch.float32, torch.bfloat16]:
|
| 380 |
+
# Only keep matched pattern with same output_dtype
|
| 381 |
+
qlinear_node_after_weight_prepack = filter_nodes(
|
| 382 |
+
match.nodes, torch.ops.onednn.qlinear_pointwise
|
| 383 |
+
)[0]
|
| 384 |
+
return _check_node_kwarg_arg_value(
|
| 385 |
+
qlinear_node_after_weight_prepack, "output_dtype", 9, output_dtype
|
| 386 |
+
)
|
| 387 |
+
return True
|
| 388 |
+
|
| 389 |
+
return fn
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def _register_quantized_linear_lowering(
|
| 393 |
+
pattern,
|
| 394 |
+
pass_number,
|
| 395 |
+
computation_op,
|
| 396 |
+
unary_attr,
|
| 397 |
+
):
|
| 398 |
+
@register_lowering_pattern(
|
| 399 |
+
pattern,
|
| 400 |
+
extra_check=_is_valid_quantized_linear_optimization_pattern(),
|
| 401 |
+
pass_number=pass_number,
|
| 402 |
+
)
|
| 403 |
+
def qlinear(match: Match, *args, **kwargs):
|
| 404 |
+
output_dtype = _get_pattern_output_dtype(match)
|
| 405 |
+
# Activation QParams
|
| 406 |
+
x, x_scale, x_zp = (
|
| 407 |
+
kwargs["x"],
|
| 408 |
+
kwargs["x_scale"],
|
| 409 |
+
kwargs["x_zp"],
|
| 410 |
+
)
|
| 411 |
+
# Weight QParams
|
| 412 |
+
packed_weight, w_scale, w_zp = (
|
| 413 |
+
kwargs["packed_weight"],
|
| 414 |
+
kwargs["w_scale"],
|
| 415 |
+
kwargs["w_zp"],
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
# bias
|
| 419 |
+
b = kwargs["b"] if "b" in kwargs else None
|
| 420 |
+
|
| 421 |
+
# Output QParams
|
| 422 |
+
o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
|
| 423 |
+
o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
|
| 424 |
+
assert (
|
| 425 |
+
kwargs["postop_name"] == "none"
|
| 426 |
+
) # Expected no post op fused in weight prepack phase
|
| 427 |
+
|
| 428 |
+
computation_args = (
|
| 429 |
+
x,
|
| 430 |
+
x_scale,
|
| 431 |
+
x_zp,
|
| 432 |
+
packed_weight,
|
| 433 |
+
w_scale,
|
| 434 |
+
w_zp,
|
| 435 |
+
b,
|
| 436 |
+
o_inv_scale,
|
| 437 |
+
o_zero_point,
|
| 438 |
+
output_dtype,
|
| 439 |
+
unary_attr.op_name,
|
| 440 |
+
unary_attr.scalars_attr,
|
| 441 |
+
unary_attr.algorithm_attr,
|
| 442 |
+
)
|
| 443 |
+
counters["inductor"]["qlinear_unary_matcher_count"] += 1
|
| 444 |
+
counters["inductor"]["qlinear_unary_matcher_nodes"] += len(match.nodes)
|
| 445 |
+
return L[computation_op](*computation_args)
|
| 446 |
+
|
| 447 |
+
return qlinear
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def _register_quantized_linear_binary_lowering(
|
| 451 |
+
pattern,
|
| 452 |
+
pass_number,
|
| 453 |
+
computation_op,
|
| 454 |
+
binary_unary_attr,
|
| 455 |
+
):
|
| 456 |
+
@register_lowering_pattern(
|
| 457 |
+
pattern,
|
| 458 |
+
extra_check=_is_valid_qlinear_binary_optimization_pattern(),
|
| 459 |
+
pass_number=pass_number,
|
| 460 |
+
)
|
| 461 |
+
def qlinear_binary(match: Match, *args, **kwargs):
|
| 462 |
+
output_dtype = _get_pattern_output_dtype(match)
|
| 463 |
+
assert output_dtype is not None
|
| 464 |
+
# Activation QParams
|
| 465 |
+
x, x_scale, x_zp = (
|
| 466 |
+
kwargs["x"],
|
| 467 |
+
kwargs["x_scale"],
|
| 468 |
+
kwargs["x_zp"],
|
| 469 |
+
)
|
| 470 |
+
x2 = (
|
| 471 |
+
kwargs["accum"]
|
| 472 |
+
if binary_unary_attr.binary_op_name == "sum"
|
| 473 |
+
else kwargs["other"]
|
| 474 |
+
)
|
| 475 |
+
x2_scale = 1.0
|
| 476 |
+
x2_zp = 0
|
| 477 |
+
# Weight QParams
|
| 478 |
+
packed_weight, w_scale, w_zp = (
|
| 479 |
+
kwargs["packed_weight"],
|
| 480 |
+
kwargs["w_scale"],
|
| 481 |
+
kwargs["w_zp"],
|
| 482 |
+
)
|
| 483 |
+
# bias
|
| 484 |
+
b = kwargs["b"] if "b" in kwargs else None
|
| 485 |
+
# Output QParams
|
| 486 |
+
o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
|
| 487 |
+
o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
|
| 488 |
+
|
| 489 |
+
x2.realize()
|
| 490 |
+
from .mkldnn_fusion import _can_be_inplace
|
| 491 |
+
|
| 492 |
+
binary_op_name = binary_unary_attr.binary_op_name
|
| 493 |
+
|
| 494 |
+
if binary_op_name == "sum" and not _can_be_inplace(x2):
|
| 495 |
+
# When we enable the GEMM Template, the output of QLinear
|
| 496 |
+
# will be reshaped from 2D back to 3D if the input is 3D.
|
| 497 |
+
# This causes _can_be_inplace(x2) to return False if x2 happens
|
| 498 |
+
# to be the output of QLinear in this scenario.
|
| 499 |
+
# Change the post op from sum to binary add for this case.
|
| 500 |
+
# Refer to test case:
|
| 501 |
+
# test_mkldnn_pattern_matcher.py::test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2
|
| 502 |
+
binary_op_name = "add"
|
| 503 |
+
|
| 504 |
+
computation_args = (
|
| 505 |
+
x,
|
| 506 |
+
x_scale,
|
| 507 |
+
x_zp,
|
| 508 |
+
packed_weight,
|
| 509 |
+
w_scale,
|
| 510 |
+
w_zp,
|
| 511 |
+
x2,
|
| 512 |
+
b,
|
| 513 |
+
o_inv_scale,
|
| 514 |
+
o_zero_point,
|
| 515 |
+
output_dtype,
|
| 516 |
+
x2_scale,
|
| 517 |
+
x2_zp,
|
| 518 |
+
binary_op_name,
|
| 519 |
+
binary_unary_attr.alpha,
|
| 520 |
+
binary_unary_attr.unary_op_name,
|
| 521 |
+
binary_unary_attr.scalars_attr,
|
| 522 |
+
binary_unary_attr.algorithm_attr,
|
| 523 |
+
)
|
| 524 |
+
counters["inductor"]["qlinear_binary_matcher_count"] += 1
|
| 525 |
+
counters["inductor"]["qlinear_binary_matcher_nodes"] += len(match.nodes)
|
| 526 |
+
return L[computation_op](*computation_args)
|
| 527 |
+
|
| 528 |
+
return qlinear_binary
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def _is_valid_qconv_binary_optimization_pattern():
|
| 532 |
+
return _is_valid_quantized_op_binary_optimization_pattern(
|
| 533 |
+
torch.ops.onednn.qconv2d_pointwise
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def _is_valid_qlinear_binary_optimization_pattern():
|
| 538 |
+
return _is_valid_quantized_op_binary_optimization_pattern(
|
| 539 |
+
torch.ops.onednn.qlinear_pointwise,
|
| 540 |
+
# we don't insert q-dq for extra input due to accuracy issues
|
| 541 |
+
extra_input_from_dequant=False,
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def _is_valid_quantized_op_binary_optimization_pattern(
|
| 546 |
+
qop, extra_input_from_dequant=True
|
| 547 |
+
):
|
| 548 |
+
# Check if it's a valid Binary Pattern for qconv2d and qlinear:
|
| 549 |
+
# * qop_pointwise should only has one users
|
| 550 |
+
# * If extra_input_from_dequant is True, extra input of binary node should come from dequant pattern
|
| 551 |
+
# * the two inputs of binary node should have attribute "meta" and should be tensors
|
| 552 |
+
# * the two inputs of binary node should have the same shape
|
| 553 |
+
# * All users of the extra input in this pattern should be
|
| 554 |
+
# ancestor nodes of the compute node, except for the binary node
|
| 555 |
+
# connected to the compute node.
|
| 556 |
+
def fn(match):
|
| 557 |
+
output_dtype = _get_pattern_output_dtype(match)
|
| 558 |
+
compute_node = filter_nodes(match.nodes, qop)[0]
|
| 559 |
+
# qop_pointwise should only have one user
|
| 560 |
+
if len(compute_node.users) != 1:
|
| 561 |
+
return False
|
| 562 |
+
binary_node_inputs = next(iter(compute_node.users)).args
|
| 563 |
+
assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs"
|
| 564 |
+
if output_dtype in [torch.float32, torch.bfloat16]:
|
| 565 |
+
extra_input_of_binary_node = None
|
| 566 |
+
for arg in binary_node_inputs:
|
| 567 |
+
if arg != compute_node:
|
| 568 |
+
extra_input_of_binary_node = arg
|
| 569 |
+
break
|
| 570 |
+
assert extra_input_of_binary_node is not None
|
| 571 |
+
# Extra input of binary node comes from dequant pattern
|
| 572 |
+
if extra_input_from_dequant and (
|
| 573 |
+
(not isinstance(extra_input_of_binary_node, torch.fx.Node))
|
| 574 |
+
or (
|
| 575 |
+
extra_input_of_binary_node.target
|
| 576 |
+
!= quantized_decomposed.dequantize_per_tensor.default
|
| 577 |
+
)
|
| 578 |
+
):
|
| 579 |
+
return False
|
| 580 |
+
|
| 581 |
+
# the two inputs of binary node should have attribute "meta" and should be tensors
|
| 582 |
+
if not (
|
| 583 |
+
hasattr(binary_node_inputs[0], "meta")
|
| 584 |
+
and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
|
| 585 |
+
) or not (
|
| 586 |
+
hasattr(binary_node_inputs[1], "meta")
|
| 587 |
+
and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
|
| 588 |
+
):
|
| 589 |
+
return False
|
| 590 |
+
# the two inputs of binary node should have the same shape
|
| 591 |
+
if (
|
| 592 |
+
binary_node_inputs[0].meta["val"].size() # type: ignore[union-attr]
|
| 593 |
+
!= binary_node_inputs[1].meta["val"].size() # type: ignore[union-attr]
|
| 594 |
+
):
|
| 595 |
+
return False
|
| 596 |
+
|
| 597 |
+
# All users of the extra input in this pattern should be
|
| 598 |
+
# ancestor nodes of the compute node, except for the binary node
|
| 599 |
+
# connected to the compute node.
|
| 600 |
+
|
| 601 |
+
from .mkldnn_fusion import _get_remaining_users
|
| 602 |
+
|
| 603 |
+
extra_input_of_pattern = (
|
| 604 |
+
match.kwargs["other"]
|
| 605 |
+
if "other" in match.kwargs
|
| 606 |
+
else (
|
| 607 |
+
match.kwargs["accum"]
|
| 608 |
+
if output_dtype == torch.uint8 or (not extra_input_from_dequant)
|
| 609 |
+
else match.kwargs["accum_after_dequant"]
|
| 610 |
+
)
|
| 611 |
+
)
|
| 612 |
+
if (
|
| 613 |
+
len(_get_remaining_users(extra_input_of_pattern, compute_node)) > 1
|
| 614 |
+
or extra_input_of_pattern == compute_node.args[0]
|
| 615 |
+
):
|
| 616 |
+
return False
|
| 617 |
+
return True
|
| 618 |
+
|
| 619 |
+
return fn
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
def _register_quantized_conv_binary_lowering(
|
| 623 |
+
pattern,
|
| 624 |
+
pass_number,
|
| 625 |
+
computation_op,
|
| 626 |
+
binary_unary_attr,
|
| 627 |
+
):
|
| 628 |
+
@register_lowering_pattern(
|
| 629 |
+
pattern,
|
| 630 |
+
extra_check=_is_valid_qconv_binary_optimization_pattern(),
|
| 631 |
+
pass_number=pass_number,
|
| 632 |
+
)
|
| 633 |
+
def qconv_binary(match: Match, *args, **kwargs):
|
| 634 |
+
output_dtype = _get_pattern_output_dtype(match)
|
| 635 |
+
assert output_dtype is not None
|
| 636 |
+
x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"]
|
| 637 |
+
accum = (
|
| 638 |
+
kwargs["accum"]
|
| 639 |
+
if output_dtype == torch.uint8
|
| 640 |
+
else kwargs["accum_after_dequant"]
|
| 641 |
+
)
|
| 642 |
+
accum_scale = kwargs["accum_scale"] if output_dtype == torch.uint8 else 1.0
|
| 643 |
+
accum_zp = kwargs["accum_zp"] if output_dtype == torch.uint8 else 0
|
| 644 |
+
packed_weight, w_scale, w_zp = (
|
| 645 |
+
kwargs["packed_weight"],
|
| 646 |
+
kwargs["w_scale"],
|
| 647 |
+
kwargs["w_zp"],
|
| 648 |
+
)
|
| 649 |
+
b, stride, padding, dilation, groups = (
|
| 650 |
+
kwargs["b"],
|
| 651 |
+
kwargs["stride"],
|
| 652 |
+
kwargs["padding"],
|
| 653 |
+
kwargs["dilation"],
|
| 654 |
+
kwargs["groups"],
|
| 655 |
+
)
|
| 656 |
+
# Output QParams
|
| 657 |
+
o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
|
| 658 |
+
o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
|
| 659 |
+
|
| 660 |
+
accum.realize()
|
| 661 |
+
from .mkldnn_fusion import _can_be_inplace
|
| 662 |
+
|
| 663 |
+
assert _can_be_inplace(
|
| 664 |
+
accum
|
| 665 |
+
), "QConv Binary Inplace Fusion requires accum is not an alias or mutation."
|
| 666 |
+
|
| 667 |
+
computation_args = (
|
| 668 |
+
x,
|
| 669 |
+
x_scale,
|
| 670 |
+
x_zp,
|
| 671 |
+
accum,
|
| 672 |
+
accum_scale,
|
| 673 |
+
accum_zp,
|
| 674 |
+
packed_weight,
|
| 675 |
+
w_scale,
|
| 676 |
+
w_zp,
|
| 677 |
+
b,
|
| 678 |
+
stride,
|
| 679 |
+
padding,
|
| 680 |
+
dilation,
|
| 681 |
+
groups,
|
| 682 |
+
o_inv_scale,
|
| 683 |
+
o_zero_point,
|
| 684 |
+
output_dtype,
|
| 685 |
+
binary_unary_attr.binary_op_name,
|
| 686 |
+
binary_unary_attr.alpha,
|
| 687 |
+
binary_unary_attr.unary_op_name,
|
| 688 |
+
binary_unary_attr.scalars_attr,
|
| 689 |
+
binary_unary_attr.algorithm_attr,
|
| 690 |
+
)
|
| 691 |
+
counters["inductor"]["qconv2d_binary_matcher_count"] += 1
|
| 692 |
+
counters["inductor"]["qconv2d_binary_matcher_nodes"] += len(match.nodes)
|
| 693 |
+
return L[computation_op](*computation_args)
|
| 694 |
+
|
| 695 |
+
return qconv_binary
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
def _register_quantization_unary_fusion():
|
| 699 |
+
from .mkldnn_fusion import (
|
| 700 |
+
_gelu_fusion_1 as _gelu_fusion_erf,
|
| 701 |
+
_gelu_fusion_2 as _gelu_fusion_tanh,
|
| 702 |
+
_hardswish_fusion,
|
| 703 |
+
_hardtanh_fusion,
|
| 704 |
+
_silu_fusion,
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
class UnaryAttr:
|
| 708 |
+
def __init__(
|
| 709 |
+
self, op_name: str, scalars_attr=None, algorithm_attr=None
|
| 710 |
+
) -> None:
|
| 711 |
+
self.op_name = op_name
|
| 712 |
+
self.scalars_attr = scalars_attr if scalars_attr else []
|
| 713 |
+
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
|
| 714 |
+
|
| 715 |
+
for original_pattern_output_dtype in [torch.float32, torch.bfloat16]:
|
| 716 |
+
# QConv2d
|
| 717 |
+
# Priority 1 to match: QConv2d Unary pattern with int8 output
|
| 718 |
+
# If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly.
|
| 719 |
+
# For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant
|
| 720 |
+
is_bf16 = original_pattern_output_dtype == torch.bfloat16
|
| 721 |
+
conv_unary_replace_patterns = {
|
| 722 |
+
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
|
| 723 |
+
get_dequantize_qconv_pt2e_pattern(1),
|
| 724 |
+
),
|
| 725 |
+
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
|
| 726 |
+
generate_pattern_with_unary(
|
| 727 |
+
get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
|
| 728 |
+
),
|
| 729 |
+
),
|
| 730 |
+
UnaryAttr("hardtanh", [], ""): generate_pattern_with_output_quant(
|
| 731 |
+
_unary_fusion_pattern(
|
| 732 |
+
_hardtanh_fusion,
|
| 733 |
+
get_dequantize_qconv_pt2e_pattern(1),
|
| 734 |
+
1,
|
| 735 |
+
is_bf16,
|
| 736 |
+
),
|
| 737 |
+
with_dtype_convert=is_bf16,
|
| 738 |
+
),
|
| 739 |
+
UnaryAttr("hardswish", [], ""): generate_pattern_with_output_quant(
|
| 740 |
+
_unary_fusion_pattern(
|
| 741 |
+
_hardswish_fusion,
|
| 742 |
+
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
|
| 743 |
+
2,
|
| 744 |
+
is_bf16,
|
| 745 |
+
),
|
| 746 |
+
with_dtype_convert=is_bf16,
|
| 747 |
+
),
|
| 748 |
+
UnaryAttr("swish", [], ""): generate_pattern_with_output_quant(
|
| 749 |
+
_unary_fusion_pattern(
|
| 750 |
+
_silu_fusion,
|
| 751 |
+
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
|
| 752 |
+
2,
|
| 753 |
+
is_bf16,
|
| 754 |
+
),
|
| 755 |
+
with_dtype_convert=is_bf16,
|
| 756 |
+
),
|
| 757 |
+
}
|
| 758 |
+
|
| 759 |
+
for unary_attr, patterns in conv_unary_replace_patterns.items():
|
| 760 |
+
# Register qconv2d pattern for ExternKernel Lowering
|
| 761 |
+
_register_quantized_conv_lowering(
|
| 762 |
+
patterns,
|
| 763 |
+
1, # pass_number
|
| 764 |
+
torch.ops.onednn.qconv2d_pointwise, # computation_op
|
| 765 |
+
unary_attr, # unary_attr
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
# Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output
|
| 769 |
+
conv_unary_replace_float_out_patterns = {
|
| 770 |
+
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
|
| 771 |
+
get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
|
| 772 |
+
),
|
| 773 |
+
UnaryAttr("hardtanh", [], ""): _may_generate_pattern_with_dtype_convert(
|
| 774 |
+
_unary_fusion_pattern(
|
| 775 |
+
_hardtanh_fusion,
|
| 776 |
+
get_dequantize_qconv_pt2e_pattern(1),
|
| 777 |
+
1,
|
| 778 |
+
is_bf16,
|
| 779 |
+
),
|
| 780 |
+
Arg(),
|
| 781 |
+
is_bf16,
|
| 782 |
+
),
|
| 783 |
+
UnaryAttr("hardswish", [], ""): _may_generate_pattern_with_dtype_convert(
|
| 784 |
+
_unary_fusion_pattern(
|
| 785 |
+
_hardswish_fusion,
|
| 786 |
+
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
|
| 787 |
+
2,
|
| 788 |
+
is_bf16,
|
| 789 |
+
),
|
| 790 |
+
Arg(),
|
| 791 |
+
is_bf16,
|
| 792 |
+
),
|
| 793 |
+
UnaryAttr("swish", [], ""): _may_generate_pattern_with_dtype_convert(
|
| 794 |
+
_unary_fusion_pattern(
|
| 795 |
+
_silu_fusion,
|
| 796 |
+
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
|
| 797 |
+
2,
|
| 798 |
+
is_bf16,
|
| 799 |
+
),
|
| 800 |
+
Arg(),
|
| 801 |
+
is_bf16,
|
| 802 |
+
),
|
| 803 |
+
}
|
| 804 |
+
|
| 805 |
+
for unary_attr, patterns in conv_unary_replace_float_out_patterns.items():
|
| 806 |
+
# Register qconv2d pattern for ExternKernel Lowering
|
| 807 |
+
_register_quantized_conv_lowering(
|
| 808 |
+
patterns,
|
| 809 |
+
2, # pass_number
|
| 810 |
+
torch.ops.onednn.qconv2d_pointwise, # computation_op
|
| 811 |
+
unary_attr, # unary_attr
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
# QLinear
|
| 815 |
+
for x_scale_zp_are_tensors in (False, True):
|
| 816 |
+
qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors)
|
| 817 |
+
# Priority 1 to match: QLinear Unary pattern with int8 output
|
| 818 |
+
linear_unary_replace_patterns = {
|
| 819 |
+
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
|
| 820 |
+
qlinear_pattern,
|
| 821 |
+
),
|
| 822 |
+
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
|
| 823 |
+
generate_pattern_with_unary(qlinear_pattern, aten.relu.default),
|
| 824 |
+
),
|
| 825 |
+
UnaryAttr("gelu", [], "none"): generate_pattern_with_output_quant(
|
| 826 |
+
_unary_fusion_pattern(
|
| 827 |
+
_gelu_fusion_erf,
|
| 828 |
+
get_qlinear_pt2e_pattern(
|
| 829 |
+
x_scale_zp_are_tensors, 1 if is_bf16 else 2
|
| 830 |
+
),
|
| 831 |
+
2,
|
| 832 |
+
is_bf16,
|
| 833 |
+
),
|
| 834 |
+
with_dtype_convert=is_bf16,
|
| 835 |
+
),
|
| 836 |
+
UnaryAttr("gelu", [], "tanh"): generate_pattern_with_output_quant(
|
| 837 |
+
_unary_fusion_pattern(
|
| 838 |
+
_gelu_fusion_tanh,
|
| 839 |
+
get_qlinear_pt2e_pattern(
|
| 840 |
+
x_scale_zp_are_tensors, 1 if is_bf16 else 4
|
| 841 |
+
),
|
| 842 |
+
4,
|
| 843 |
+
is_bf16,
|
| 844 |
+
),
|
| 845 |
+
with_dtype_convert=is_bf16,
|
| 846 |
+
),
|
| 847 |
+
}
|
| 848 |
+
|
| 849 |
+
for unary_attr, patterns in linear_unary_replace_patterns.items():
|
| 850 |
+
_register_quantized_linear_lowering(
|
| 851 |
+
patterns,
|
| 852 |
+
1, # pass_number
|
| 853 |
+
torch.ops.onednn.qlinear_pointwise, # computation_op
|
| 854 |
+
unary_attr, # unary_attr
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
# Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
|
| 858 |
+
linear_unary_replace_float_out_patterns = {
|
| 859 |
+
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
|
| 860 |
+
qlinear_pattern, aten.relu.default
|
| 861 |
+
),
|
| 862 |
+
UnaryAttr("gelu", [], "none"): _may_generate_pattern_with_dtype_convert(
|
| 863 |
+
_unary_fusion_pattern(
|
| 864 |
+
_gelu_fusion_erf,
|
| 865 |
+
get_qlinear_pt2e_pattern(
|
| 866 |
+
x_scale_zp_are_tensors, 1 if is_bf16 else 2
|
| 867 |
+
),
|
| 868 |
+
2,
|
| 869 |
+
is_bf16,
|
| 870 |
+
),
|
| 871 |
+
Arg(),
|
| 872 |
+
is_bf16,
|
| 873 |
+
),
|
| 874 |
+
UnaryAttr("gelu", [], "tanh"): _may_generate_pattern_with_dtype_convert(
|
| 875 |
+
_unary_fusion_pattern(
|
| 876 |
+
_gelu_fusion_tanh,
|
| 877 |
+
get_qlinear_pt2e_pattern(
|
| 878 |
+
x_scale_zp_are_tensors, 1 if is_bf16 else 4
|
| 879 |
+
),
|
| 880 |
+
4,
|
| 881 |
+
is_bf16,
|
| 882 |
+
),
|
| 883 |
+
Arg(),
|
| 884 |
+
is_bf16,
|
| 885 |
+
),
|
| 886 |
+
}
|
| 887 |
+
|
| 888 |
+
for unary_attr, patterns in linear_unary_replace_float_out_patterns.items():
|
| 889 |
+
_register_quantized_linear_lowering(
|
| 890 |
+
patterns,
|
| 891 |
+
2, # pass_number
|
| 892 |
+
torch.ops.onednn.qlinear_pointwise, # computation_op
|
| 893 |
+
unary_attr, # unary_attr
|
| 894 |
+
)
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
def _register_quantization_binary_fusion():
|
| 898 |
+
class BinaryUnaryAttr:
|
| 899 |
+
def __init__(
|
| 900 |
+
self,
|
| 901 |
+
binary_op_name: str,
|
| 902 |
+
alpha=None,
|
| 903 |
+
unary_op_name: str = "none",
|
| 904 |
+
scalars_attr=None,
|
| 905 |
+
algorithm_attr=None,
|
| 906 |
+
) -> None:
|
| 907 |
+
self.binary_op_name = binary_op_name
|
| 908 |
+
self.alpha = alpha if alpha else 1.0
|
| 909 |
+
self.unary_op_name = unary_op_name
|
| 910 |
+
self.scalars_attr = scalars_attr if scalars_attr else []
|
| 911 |
+
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
|
| 912 |
+
|
| 913 |
+
for int8_mixed_bf16_with_inplace_add in [False, True]:
|
| 914 |
+
# Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output
|
| 915 |
+
binary_replace_patterns = {
|
| 916 |
+
BinaryUnaryAttr(
|
| 917 |
+
"sum", 1.0, "none", [], ""
|
| 918 |
+
): generate_pattern_with_output_quant(
|
| 919 |
+
generate_pattern_with_binary(
|
| 920 |
+
aten.add.Tensor,
|
| 921 |
+
get_dequantize_qconv_pt2e_pattern(1),
|
| 922 |
+
dequantize_accum_pattern,
|
| 923 |
+
int8_mixed_bf16_with_inplace_add,
|
| 924 |
+
),
|
| 925 |
+
),
|
| 926 |
+
BinaryUnaryAttr(
|
| 927 |
+
"sum", 1.0, "relu", [], ""
|
| 928 |
+
): generate_pattern_with_output_quant(
|
| 929 |
+
generate_pattern_with_unary(
|
| 930 |
+
generate_pattern_with_binary(
|
| 931 |
+
aten.add.Tensor,
|
| 932 |
+
get_dequantize_qconv_pt2e_pattern(1),
|
| 933 |
+
dequantize_accum_pattern,
|
| 934 |
+
int8_mixed_bf16_with_inplace_add,
|
| 935 |
+
),
|
| 936 |
+
aten.relu.default,
|
| 937 |
+
),
|
| 938 |
+
),
|
| 939 |
+
}
|
| 940 |
+
|
| 941 |
+
for binary_unary_attr, patterns in binary_replace_patterns.items():
|
| 942 |
+
_register_quantized_conv_binary_lowering(
|
| 943 |
+
patterns,
|
| 944 |
+
0, # pass_number
|
| 945 |
+
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
|
| 946 |
+
binary_unary_attr, # binary_unary_attr
|
| 947 |
+
)
|
| 948 |
+
|
| 949 |
+
# Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output
|
| 950 |
+
binary_replace_float_out_patterns = {
|
| 951 |
+
BinaryUnaryAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary(
|
| 952 |
+
generate_pattern_with_binary(
|
| 953 |
+
aten.add.Tensor,
|
| 954 |
+
get_dequantize_qconv_pt2e_pattern(1),
|
| 955 |
+
KeywordArg("accum_after_dequant"),
|
| 956 |
+
int8_mixed_bf16_with_inplace_add,
|
| 957 |
+
),
|
| 958 |
+
aten.relu.default,
|
| 959 |
+
),
|
| 960 |
+
}
|
| 961 |
+
|
| 962 |
+
for (
|
| 963 |
+
binary_unary_attr,
|
| 964 |
+
patterns,
|
| 965 |
+
) in binary_replace_float_out_patterns.items():
|
| 966 |
+
if int8_mixed_bf16_with_inplace_add:
|
| 967 |
+
_register_quantized_conv_binary_lowering(
|
| 968 |
+
patterns,
|
| 969 |
+
0, # pass_number
|
| 970 |
+
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
|
| 971 |
+
binary_unary_attr, # binary_unary_attr
|
| 972 |
+
)
|
| 973 |
+
else:
|
| 974 |
+
_register_quantized_conv_binary_lowering(
|
| 975 |
+
patterns,
|
| 976 |
+
1, # pass_number
|
| 977 |
+
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
|
| 978 |
+
binary_unary_attr, # binary_unary_attr
|
| 979 |
+
)
|
| 980 |
+
|
| 981 |
+
# Priority 3: QConv2d Binary pattern with fp32/bfloat16 output
|
| 982 |
+
binary_replace_float_out_patterns = {
|
| 983 |
+
BinaryUnaryAttr("sum", 1.0, "none", [], ""): generate_pattern_with_binary(
|
| 984 |
+
aten.add.Tensor,
|
| 985 |
+
get_dequantize_qconv_pt2e_pattern(1),
|
| 986 |
+
KeywordArg("accum_after_dequant"),
|
| 987 |
+
int8_mixed_bf16_with_inplace_add,
|
| 988 |
+
),
|
| 989 |
+
}
|
| 990 |
+
|
| 991 |
+
for (
|
| 992 |
+
binary_unary_attr,
|
| 993 |
+
patterns,
|
| 994 |
+
) in binary_replace_float_out_patterns.items():
|
| 995 |
+
_register_quantized_conv_binary_lowering(
|
| 996 |
+
patterns,
|
| 997 |
+
1 if int8_mixed_bf16_with_inplace_add else 2, # pass_number
|
| 998 |
+
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
|
| 999 |
+
binary_unary_attr, # binary_unary_attr
|
| 1000 |
+
)
|
| 1001 |
+
|
| 1002 |
+
# QLinear
|
| 1003 |
+
r"""
|
| 1004 |
+
Supported linear-binary(-unary) patterns
|
| 1005 |
+
|
| 1006 |
+
linear(X) extra input
|
| 1007 |
+
\ /
|
| 1008 |
+
Add
|
| 1009 |
+
|
|
| 1010 |
+
Optional(relu)
|
| 1011 |
+
|
|
| 1012 |
+
Y
|
| 1013 |
+
|
| 1014 |
+
1. int8-mixed-fp32
|
| 1015 |
+
+---+---------------+-----------+------------------------------+---------+
|
| 1016 |
+
| # | Add type | Quant out | Pattern | Post op |
|
| 1017 |
+
+---+---------------+-----------+------------------------------+---------+
|
| 1018 |
+
| 1 | In-/out-place | Yes | linear + fp32 -> (relu) -> q | add |
|
| 1019 |
+
+---+---------------+-----------+------------------------------+---------+
|
| 1020 |
+
| 2 | In-/out-place | No | linear + fp32 -> (relu) | sum |
|
| 1021 |
+
+---+---------------+-----------+------------------------------+---------+
|
| 1022 |
+
|
| 1023 |
+
2. int8-mixed-bf16
|
| 1024 |
+
+---+----------+---------------+-----------+-----------------------------------------+---------+
|
| 1025 |
+
| # | X2 dtype | Add type | Quant out | Pattern | Post op |
|
| 1026 |
+
+---+----------+---------------+-----------+-----------------------------------------+---------+
|
| 1027 |
+
| 1 | BF16 | In-/out-place | Yes | linear + bf16 -> (relu) -> q | add |
|
| 1028 |
+
+---+----------+---------------+-----------+-----------------------------------------+---------+
|
| 1029 |
+
| 2 | BF16 | In-/out-place | No | linear + bf16 -> (relu) | sum |
|
| 1030 |
+
+---+----------+---------------+-----------+-----------------------------------------+---------+
|
| 1031 |
+
| 3 | FP32 | Out-place | Yes | linear + fp32 -> (relu) -> q | add |
|
| 1032 |
+
| | | In-place right| | | |
|
| 1033 |
+
+---+----------+---------------+-----------+-----------------------------------------+---------+
|
| 1034 |
+
| 4 | FP32 | Out-place | No | linear + fp32 -> (relu) | sum |
|
| 1035 |
+
| | | In-place right| | | |
|
| 1036 |
+
+---+----------+---------------+-----------+-----------------------------------------+---------+
|
| 1037 |
+
| 5 | FP32 | In-place left | Yes | linear + fp32 -> to_bf16 -> (relu) -> q | add |
|
| 1038 |
+
+---+----------+---------------+-----------+-----------------------------------------+---------+
|
| 1039 |
+
| 6 | FP32 | In-place left | No | linear + fp32 -> to_bf16 -> (relu) | add |
|
| 1040 |
+
+---+----------+---------------+-----------+-----------------------------------------+---------+
|
| 1041 |
+
|
| 1042 |
+
Note
|
| 1043 |
+
(1) The positions of linear and the extra input can be swapped.
|
| 1044 |
+
(2) we don't insert q-dq before the extra input of linear-add by recipe. But if q-dq is found at the
|
| 1045 |
+
extra input, we don't match that pattern because we cannot match all these patterns in 3 passes.
|
| 1046 |
+
"""
|
| 1047 |
+
for x_scale_zp_are_tensors in (False, True):
|
| 1048 |
+
qlinear_binary_op = (
|
| 1049 |
+
torch.ops.onednn.qlinear_pointwise.binary_tensor
|
| 1050 |
+
if x_scale_zp_are_tensors
|
| 1051 |
+
else torch.ops.onednn.qlinear_pointwise.binary
|
| 1052 |
+
)
|
| 1053 |
+
unary_postop_list = ["none", "relu"]
|
| 1054 |
+
unary_postop_dict = {
|
| 1055 |
+
"none": None,
|
| 1056 |
+
"relu": aten.relu.default,
|
| 1057 |
+
}
|
| 1058 |
+
convert_dtype_after_binary_list = [False, True]
|
| 1059 |
+
|
| 1060 |
+
# Priority 1 to match: QLinear Binary or Binary-Unary pattern with int8 output
|
| 1061 |
+
# Covers case (1) of int8-mixed-fp32 and case (1)(3)(5) of int8-mixed-bf16,
|
| 1062 |
+
# totally 3 patterns (2 are identical)
|
| 1063 |
+
swap_binary_inputs_list = [False, True]
|
| 1064 |
+
int8_mixed_bf16_list = [False, True]
|
| 1065 |
+
combinations = itertools.product(
|
| 1066 |
+
unary_postop_list,
|
| 1067 |
+
int8_mixed_bf16_list,
|
| 1068 |
+
swap_binary_inputs_list,
|
| 1069 |
+
convert_dtype_after_binary_list,
|
| 1070 |
+
)
|
| 1071 |
+
qlinear_binary_replace_patterns = {}
|
| 1072 |
+
for unary_op, int8_mixed_bf16, swap_inputs, cvt_dtype_binary in combinations:
|
| 1073 |
+
if not int8_mixed_bf16 and cvt_dtype_binary:
|
| 1074 |
+
# No convert node after binary node if dtypes are all fp32
|
| 1075 |
+
continue
|
| 1076 |
+
qlinear_binary_replace_patterns.update(
|
| 1077 |
+
{
|
| 1078 |
+
BinaryUnaryAttr(
|
| 1079 |
+
"add", 1.0, unary_op, [], ""
|
| 1080 |
+
): generate_pattern_with_output_quant(
|
| 1081 |
+
generate_pattern_with_unary(
|
| 1082 |
+
generate_pattern_with_binary(
|
| 1083 |
+
aten.add.Tensor,
|
| 1084 |
+
get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
|
| 1085 |
+
KeywordArg("other"),
|
| 1086 |
+
# If fp32 extra input is inplace added to bf16 linear output,
|
| 1087 |
+
# a to_bf16 node is inserted after binary
|
| 1088 |
+
dtype_convert=cvt_dtype_binary,
|
| 1089 |
+
swap_inputs=swap_inputs,
|
| 1090 |
+
),
|
| 1091 |
+
unary_postop_dict[unary_op],
|
| 1092 |
+
),
|
| 1093 |
+
)
|
| 1094 |
+
}
|
| 1095 |
+
)
|
| 1096 |
+
for binary_unary_attr, patterns in qlinear_binary_replace_patterns.items():
|
| 1097 |
+
_register_quantized_linear_binary_lowering(
|
| 1098 |
+
patterns,
|
| 1099 |
+
0, # pass_number
|
| 1100 |
+
qlinear_binary_op, # computation_op
|
| 1101 |
+
binary_unary_attr, # binary_unary_attr
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
# Priority 2.1 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output
|
| 1105 |
+
# Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16,
|
| 1106 |
+
# totally 2 patterns (2 are identical)
|
| 1107 |
+
binary_replace_float_out_patterns = {}
|
| 1108 |
+
for swap_binary_inputs in swap_binary_inputs_list:
|
| 1109 |
+
binary_replace_float_out_patterns.update(
|
| 1110 |
+
{
|
| 1111 |
+
BinaryUnaryAttr(
|
| 1112 |
+
"sum", 1.0, "relu", [], ""
|
| 1113 |
+
): generate_pattern_with_unary(
|
| 1114 |
+
generate_pattern_with_binary(
|
| 1115 |
+
aten.add.Tensor,
|
| 1116 |
+
get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
|
| 1117 |
+
KeywordArg("accum"),
|
| 1118 |
+
dtype_convert=False,
|
| 1119 |
+
swap_inputs=swap_binary_inputs,
|
| 1120 |
+
),
|
| 1121 |
+
aten.relu.default,
|
| 1122 |
+
),
|
| 1123 |
+
}
|
| 1124 |
+
)
|
| 1125 |
+
for (
|
| 1126 |
+
binary_unary_attr,
|
| 1127 |
+
patterns,
|
| 1128 |
+
) in binary_replace_float_out_patterns.items():
|
| 1129 |
+
_register_quantized_linear_binary_lowering(
|
| 1130 |
+
patterns,
|
| 1131 |
+
1, # pass_number
|
| 1132 |
+
qlinear_binary_op, # computation_op
|
| 1133 |
+
binary_unary_attr,
|
| 1134 |
+
)
|
| 1135 |
+
# Priority 2.2 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output
|
| 1136 |
+
# Covers case (6) of int8-mixed-bf16
|
| 1137 |
+
binary_replace_float_out_patterns = {}
|
| 1138 |
+
for swap_binary_inputs in swap_binary_inputs_list:
|
| 1139 |
+
binary_replace_float_out_patterns.update(
|
| 1140 |
+
{
|
| 1141 |
+
BinaryUnaryAttr(
|
| 1142 |
+
"add", 1.0, "relu", [], ""
|
| 1143 |
+
): generate_pattern_with_unary(
|
| 1144 |
+
generate_pattern_with_binary(
|
| 1145 |
+
aten.add.Tensor,
|
| 1146 |
+
get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
|
| 1147 |
+
KeywordArg("other"),
|
| 1148 |
+
dtype_convert=True,
|
| 1149 |
+
swap_inputs=swap_binary_inputs,
|
| 1150 |
+
),
|
| 1151 |
+
aten.relu.default,
|
| 1152 |
+
),
|
| 1153 |
+
}
|
| 1154 |
+
)
|
| 1155 |
+
for (
|
| 1156 |
+
binary_unary_attr,
|
| 1157 |
+
patterns,
|
| 1158 |
+
) in binary_replace_float_out_patterns.items():
|
| 1159 |
+
_register_quantized_linear_binary_lowering(
|
| 1160 |
+
patterns,
|
| 1161 |
+
1, # pass_number
|
| 1162 |
+
qlinear_binary_op, # computation_op
|
| 1163 |
+
binary_unary_attr,
|
| 1164 |
+
)
|
| 1165 |
+
|
| 1166 |
+
# Priority 3.1: QLinear Binary pattern with fp32/bfloat16 output
|
| 1167 |
+
# Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16,
|
| 1168 |
+
# totally 2 patterns (2 are identical)
|
| 1169 |
+
binary_replace_float_out_patterns = {}
|
| 1170 |
+
for swap_binary_inputs in swap_binary_inputs_list:
|
| 1171 |
+
binary_replace_float_out_patterns.update(
|
| 1172 |
+
{
|
| 1173 |
+
BinaryUnaryAttr(
|
| 1174 |
+
"sum", 1.0, "none", [], ""
|
| 1175 |
+
): generate_pattern_with_binary(
|
| 1176 |
+
aten.add.Tensor,
|
| 1177 |
+
get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
|
| 1178 |
+
KeywordArg("accum"),
|
| 1179 |
+
dtype_convert=False,
|
| 1180 |
+
swap_inputs=swap_binary_inputs,
|
| 1181 |
+
),
|
| 1182 |
+
}
|
| 1183 |
+
)
|
| 1184 |
+
for (
|
| 1185 |
+
binary_unary_attr,
|
| 1186 |
+
patterns,
|
| 1187 |
+
) in binary_replace_float_out_patterns.items():
|
| 1188 |
+
_register_quantized_linear_binary_lowering(
|
| 1189 |
+
patterns,
|
| 1190 |
+
2, # pass_number
|
| 1191 |
+
qlinear_binary_op, # computation_op
|
| 1192 |
+
binary_unary_attr,
|
| 1193 |
+
)
|
| 1194 |
+
# Priority 3.2: QLinear Binary pattern with fp32/bfloat16 output
|
| 1195 |
+
# Covers (6) of int8-mixed-bf16
|
| 1196 |
+
binary_replace_float_out_patterns = {}
|
| 1197 |
+
for swap_binary_inputs in swap_binary_inputs_list:
|
| 1198 |
+
binary_replace_float_out_patterns.update(
|
| 1199 |
+
{
|
| 1200 |
+
BinaryUnaryAttr(
|
| 1201 |
+
"add", 1.0, "none", [], ""
|
| 1202 |
+
): generate_pattern_with_binary(
|
| 1203 |
+
aten.add.Tensor,
|
| 1204 |
+
get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
|
| 1205 |
+
KeywordArg("other"),
|
| 1206 |
+
dtype_convert=True,
|
| 1207 |
+
swap_inputs=swap_binary_inputs,
|
| 1208 |
+
),
|
| 1209 |
+
}
|
| 1210 |
+
)
|
| 1211 |
+
for (
|
| 1212 |
+
binary_unary_attr,
|
| 1213 |
+
patterns,
|
| 1214 |
+
) in binary_replace_float_out_patterns.items():
|
| 1215 |
+
_register_quantized_linear_binary_lowering(
|
| 1216 |
+
patterns,
|
| 1217 |
+
2, # pass_number
|
| 1218 |
+
qlinear_binary_op, # computation_op
|
| 1219 |
+
binary_unary_attr,
|
| 1220 |
+
)
|
| 1221 |
+
|
| 1222 |
+
|
| 1223 |
+
def _is_valid_quantized_maxpool2d_optimization_pattern():
|
| 1224 |
+
def fn(match):
|
| 1225 |
+
# Only match the pattern which max_pool2d_with_indices returns value
|
| 1226 |
+
# instead of indices.
|
| 1227 |
+
get_item_node = filter_nodes(match.nodes, operator.getitem)[0]
|
| 1228 |
+
return get_item_node.args[1] == 0
|
| 1229 |
+
|
| 1230 |
+
return fn
|
| 1231 |
+
|
| 1232 |
+
|
| 1233 |
+
def _register_quantized_maxpool2d_lowering(
|
| 1234 |
+
pattern,
|
| 1235 |
+
computation_op,
|
| 1236 |
+
):
|
| 1237 |
+
@register_lowering_pattern(
|
| 1238 |
+
pattern,
|
| 1239 |
+
extra_check=_is_valid_quantized_maxpool2d_optimization_pattern(),
|
| 1240 |
+
)
|
| 1241 |
+
def qmaxpool2d(match: Match, *args, **kwargs):
|
| 1242 |
+
x = kwargs["x"]
|
| 1243 |
+
kernel_size = kwargs["kernel_size"]
|
| 1244 |
+
stride = kwargs["stride"] if ("stride" in kwargs) else None
|
| 1245 |
+
padding = kwargs["padding"] if ("padding" in kwargs) else 0
|
| 1246 |
+
dilation = kwargs["dilation"] if ("dilation" in kwargs) else 1
|
| 1247 |
+
ceil_mode = kwargs["ceil_mode"] if ("ceil_mode" in kwargs) else False
|
| 1248 |
+
|
| 1249 |
+
if padding == 0:
|
| 1250 |
+
padding = [0, 0]
|
| 1251 |
+
if dilation == 1:
|
| 1252 |
+
dilation = [1, 1]
|
| 1253 |
+
if not stride:
|
| 1254 |
+
stride = kernel_size
|
| 1255 |
+
kernel_size = pad_listlike(kernel_size, 2)
|
| 1256 |
+
stride = pad_listlike(stride, 2)
|
| 1257 |
+
padding = pad_listlike(padding, 2)
|
| 1258 |
+
dilation = pad_listlike(dilation, 2)
|
| 1259 |
+
|
| 1260 |
+
assert len(kernel_size) == 2
|
| 1261 |
+
assert len(stride) == 2
|
| 1262 |
+
assert len(padding) == 2
|
| 1263 |
+
assert len(dilation) == 2
|
| 1264 |
+
|
| 1265 |
+
computation_args = (
|
| 1266 |
+
x,
|
| 1267 |
+
kernel_size,
|
| 1268 |
+
stride,
|
| 1269 |
+
padding,
|
| 1270 |
+
dilation,
|
| 1271 |
+
ceil_mode,
|
| 1272 |
+
)
|
| 1273 |
+
computation_args, _ = require_channels_last(computation_op, *computation_args)
|
| 1274 |
+
counters["inductor"]["qmaxpool2d_matcher_count"] += 1
|
| 1275 |
+
counters["inductor"]["qmaxpool2d_matcher_nodes"] += len(match.nodes)
|
| 1276 |
+
return L[computation_op](*computation_args)
|
| 1277 |
+
|
| 1278 |
+
return qmaxpool2d
|
| 1279 |
+
|
| 1280 |
+
|
| 1281 |
+
def _register_quantization_maxpool2d():
|
| 1282 |
+
# Currently, the default parameters are not in FX Graph generated by Dynamo export.
|
| 1283 |
+
# So, if user defines nn.MaxPool2d with different assignment of default parameter,
|
| 1284 |
+
# it will generate graph with different number of input nodes and hence
|
| 1285 |
+
# different pattern to be matched.
|
| 1286 |
+
# Refer to the issue: https://github.com/pytorch/pytorch/issues/105901
|
| 1287 |
+
max_pool2d_args_list = [
|
| 1288 |
+
[
|
| 1289 |
+
KeywordArg("stride"),
|
| 1290 |
+
],
|
| 1291 |
+
[
|
| 1292 |
+
KeywordArg("stride"),
|
| 1293 |
+
KeywordArg("padding"),
|
| 1294 |
+
],
|
| 1295 |
+
[
|
| 1296 |
+
KeywordArg("stride"),
|
| 1297 |
+
KeywordArg("padding"),
|
| 1298 |
+
KeywordArg("dilation"),
|
| 1299 |
+
],
|
| 1300 |
+
[
|
| 1301 |
+
KeywordArg("stride"),
|
| 1302 |
+
KeywordArg("padding"),
|
| 1303 |
+
KeywordArg("dilation"),
|
| 1304 |
+
KeywordArg("ceil_mode"),
|
| 1305 |
+
],
|
| 1306 |
+
]
|
| 1307 |
+
for max_pool2d_args in max_pool2d_args_list:
|
| 1308 |
+
dequantize_maxpool2d_pattern = CallFunction(
|
| 1309 |
+
aten.max_pool2d_with_indices.default,
|
| 1310 |
+
get_dequantize_per_tensor_activation_pattern(),
|
| 1311 |
+
KeywordArg("kernel_size"),
|
| 1312 |
+
*max_pool2d_args,
|
| 1313 |
+
)
|
| 1314 |
+
dequantize_lowmem_maxpool2d_pattern = CallFunction(
|
| 1315 |
+
prims._low_memory_max_pool2d_with_offsets.default,
|
| 1316 |
+
get_dequantize_per_tensor_activation_pattern(),
|
| 1317 |
+
KeywordArg("kernel_size"),
|
| 1318 |
+
*max_pool2d_args,
|
| 1319 |
+
KeywordArg("offset_dtype"),
|
| 1320 |
+
)
|
| 1321 |
+
dequantize_maxpool2d_get_item_pattern = CallFunction(
|
| 1322 |
+
operator.getitem,
|
| 1323 |
+
dequantize_maxpool2d_pattern,
|
| 1324 |
+
Arg(),
|
| 1325 |
+
)
|
| 1326 |
+
dequantize_lowmem_maxpool2d_get_item_pattern = CallFunction(
|
| 1327 |
+
operator.getitem,
|
| 1328 |
+
dequantize_lowmem_maxpool2d_pattern,
|
| 1329 |
+
Arg(),
|
| 1330 |
+
)
|
| 1331 |
+
_register_quantized_maxpool2d_lowering(
|
| 1332 |
+
generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern),
|
| 1333 |
+
quantized.max_pool2d.default,
|
| 1334 |
+
)
|
| 1335 |
+
_register_quantized_maxpool2d_lowering(
|
| 1336 |
+
generate_pattern_with_output_quant(
|
| 1337 |
+
dequantize_lowmem_maxpool2d_get_item_pattern
|
| 1338 |
+
),
|
| 1339 |
+
quantized.max_pool2d.default,
|
| 1340 |
+
)
|
| 1341 |
+
|
| 1342 |
+
|
| 1343 |
+
def _is_input_output_same_scale_zp(check_node):
|
| 1344 |
+
def fn(match):
|
| 1345 |
+
# Ensure all the inputs and output has same scale and zero point
|
| 1346 |
+
# Step 1: Check inputs/output zero point
|
| 1347 |
+
# Get dequant nodes at input
|
| 1348 |
+
dequant_nodes = filter_nodes(
|
| 1349 |
+
match.nodes, quantized_decomposed.dequantize_per_tensor.default
|
| 1350 |
+
)
|
| 1351 |
+
zero_points = [node.args[2] for node in dequant_nodes]
|
| 1352 |
+
# Get quant nodes at output
|
| 1353 |
+
quant_nodes = filter_nodes(
|
| 1354 |
+
match.nodes, quantized_decomposed.quantize_per_tensor.default
|
| 1355 |
+
)
|
| 1356 |
+
assert len(quant_nodes) == 1, "expect only 1 add node at output quant pattern"
|
| 1357 |
+
zero_points.append(quant_nodes[0].args[2])
|
| 1358 |
+
if not all(zero_point == zero_points[0] for zero_point in zero_points):
|
| 1359 |
+
return False
|
| 1360 |
+
|
| 1361 |
+
# Step 2: Check inputs/output scale
|
| 1362 |
+
scales = [node.args[1] for node in dequant_nodes]
|
| 1363 |
+
scales.append(quant_nodes[0].args[1])
|
| 1364 |
+
if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales): # type: ignore[arg-type]
|
| 1365 |
+
return False
|
| 1366 |
+
|
| 1367 |
+
return True
|
| 1368 |
+
|
| 1369 |
+
return fn
|
| 1370 |
+
|
| 1371 |
+
|
| 1372 |
+
def _register_quantized_cat_lowering(
|
| 1373 |
+
pattern,
|
| 1374 |
+
computation_op,
|
| 1375 |
+
):
|
| 1376 |
+
@register_lowering_pattern(
|
| 1377 |
+
pattern,
|
| 1378 |
+
extra_check=_is_input_output_same_scale_zp(aten.cat.default),
|
| 1379 |
+
)
|
| 1380 |
+
def qcat(match: Match, inputs, dim, **kwargs):
|
| 1381 |
+
# inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...]
|
| 1382 |
+
uint8_inputs = [input[0] for input in inputs]
|
| 1383 |
+
counters["inductor"]["qcat_matcher_count"] += 1
|
| 1384 |
+
counters["inductor"]["qcat_matcher_nodes"] += len(match.nodes)
|
| 1385 |
+
return L[computation_op](uint8_inputs, dim)
|
| 1386 |
+
|
| 1387 |
+
return qcat
|
| 1388 |
+
|
| 1389 |
+
|
| 1390 |
+
_raw_dequantize_per_tensor_activation_pattern = CallFunction(
|
| 1391 |
+
quantized_decomposed.dequantize_per_tensor.default,
|
| 1392 |
+
Arg(),
|
| 1393 |
+
Arg(),
|
| 1394 |
+
Arg(),
|
| 1395 |
+
Arg(),
|
| 1396 |
+
Arg(),
|
| 1397 |
+
Arg(),
|
| 1398 |
+
)
|
| 1399 |
+
|
| 1400 |
+
|
| 1401 |
+
def _register_quantization_cat():
|
| 1402 |
+
dequantize_cat_pattern = CallFunction(
|
| 1403 |
+
aten.cat.default,
|
| 1404 |
+
ListOf(_raw_dequantize_per_tensor_activation_pattern),
|
| 1405 |
+
KeywordArg("dim"),
|
| 1406 |
+
)
|
| 1407 |
+
_register_quantized_cat_lowering(
|
| 1408 |
+
generate_pattern_with_output_quant(dequantize_cat_pattern),
|
| 1409 |
+
aten.cat,
|
| 1410 |
+
)
|
| 1411 |
+
|
| 1412 |
+
|
| 1413 |
+
def _register_quantized_reshape_lowering(
|
| 1414 |
+
pattern,
|
| 1415 |
+
computation_op,
|
| 1416 |
+
):
|
| 1417 |
+
@register_lowering_pattern(
|
| 1418 |
+
pattern,
|
| 1419 |
+
extra_check=_is_input_output_same_scale_zp(aten.reshape.default),
|
| 1420 |
+
)
|
| 1421 |
+
def qreshape(match: Match, *args, **kwargs):
|
| 1422 |
+
qx = kwargs["x"]
|
| 1423 |
+
shape = kwargs["shape"]
|
| 1424 |
+
counters["inductor"]["qreshape_matcher_count"] += 1
|
| 1425 |
+
counters["inductor"]["qreshape_matcher_nodes"] += len(match.nodes)
|
| 1426 |
+
return L[computation_op](qx, shape)
|
| 1427 |
+
|
| 1428 |
+
return qreshape
|
| 1429 |
+
|
| 1430 |
+
|
| 1431 |
+
def _register_quantization_reshape():
|
| 1432 |
+
dequantize_reshape_pattern = CallFunction(
|
| 1433 |
+
torch.ops.aten.reshape.default,
|
| 1434 |
+
get_dequantize_per_tensor_activation_pattern(),
|
| 1435 |
+
KeywordArg("shape"),
|
| 1436 |
+
)
|
| 1437 |
+
_register_quantized_reshape_lowering(
|
| 1438 |
+
generate_pattern_with_output_quant(dequantize_reshape_pattern),
|
| 1439 |
+
aten.reshape,
|
| 1440 |
+
)
|
| 1441 |
+
|
| 1442 |
+
|
| 1443 |
+
def _is_valid_woq_optimization_pattern():
|
| 1444 |
+
def fn(match):
|
| 1445 |
+
assert all(k in match.kwargs for k in ("x", "weight", "scales"))
|
| 1446 |
+
x = match.kwargs["x"].meta["val"]
|
| 1447 |
+
weight = match.kwargs["weight"].meta["val"]
|
| 1448 |
+
scales = match.kwargs["scales"].meta["val"]
|
| 1449 |
+
return (
|
| 1450 |
+
# For now, we only support woq mm kernels
|
| 1451 |
+
# with x.type=bfloat16 and w.type=int8
|
| 1452 |
+
x.dtype == torch.bfloat16
|
| 1453 |
+
and weight.dtype == torch.int8
|
| 1454 |
+
and scales.dtype == torch.bfloat16
|
| 1455 |
+
# _weight_int8pack_mm kernel only supports cpu now
|
| 1456 |
+
# TODO: add cuda kernel support instead of calling mul+sum
|
| 1457 |
+
and x.device.type == "cpu"
|
| 1458 |
+
and x.device == weight.device
|
| 1459 |
+
and x.device == scales.device
|
| 1460 |
+
)
|
| 1461 |
+
|
| 1462 |
+
return fn
|
| 1463 |
+
|
| 1464 |
+
|
| 1465 |
+
def _register_woq_lowering(pattern, computation_woq, computation_reshape):
|
| 1466 |
+
@register_lowering_pattern(
|
| 1467 |
+
pattern,
|
| 1468 |
+
extra_check=_is_valid_woq_optimization_pattern(),
|
| 1469 |
+
)
|
| 1470 |
+
def woq(match: Match, *args, **kwargs):
|
| 1471 |
+
x = kwargs["x"]
|
| 1472 |
+
weight = kwargs["weight"]
|
| 1473 |
+
scales = kwargs["scales"]
|
| 1474 |
+
counters["inductor"]["woq_matcher_count"] += 1
|
| 1475 |
+
counters["inductor"]["woq_matcher_nodes"] += len(match.nodes)
|
| 1476 |
+
out_features = weight.get_size()[0]
|
| 1477 |
+
origin_x_size = x.get_size()
|
| 1478 |
+
x_shape = [-1, origin_x_size[-1]]
|
| 1479 |
+
out_shape = origin_x_size[:-1] + [
|
| 1480 |
+
out_features,
|
| 1481 |
+
]
|
| 1482 |
+
func1 = L[computation_reshape](x, x_shape)
|
| 1483 |
+
func2 = L[computation_woq](func1, weight, scales)
|
| 1484 |
+
return L[computation_reshape](func2, out_shape)
|
| 1485 |
+
|
| 1486 |
+
return woq
|
| 1487 |
+
|
| 1488 |
+
|
| 1489 |
+
def _register_woq_mm_int8_pattern1():
|
| 1490 |
+
# F.linear(x, weight.to(dtype=x.dtype)) * scales
|
| 1491 |
+
# case of dispatching to mm, with x reshape
|
| 1492 |
+
_woq_pattern = CallFunction(
|
| 1493 |
+
aten.mul.Tensor,
|
| 1494 |
+
CallFunction(
|
| 1495 |
+
aten.reshape.default,
|
| 1496 |
+
CallFunction(
|
| 1497 |
+
aten.mm.default,
|
| 1498 |
+
CallFunction(aten.reshape.default, KeywordArg("x"), Arg()),
|
| 1499 |
+
CallFunction(
|
| 1500 |
+
aten.permute.default,
|
| 1501 |
+
CallFunction(
|
| 1502 |
+
prims.convert_element_type.default, KeywordArg("weight"), Arg()
|
| 1503 |
+
),
|
| 1504 |
+
Arg(),
|
| 1505 |
+
),
|
| 1506 |
+
),
|
| 1507 |
+
Arg(),
|
| 1508 |
+
),
|
| 1509 |
+
KeywordArg("scales"),
|
| 1510 |
+
)
|
| 1511 |
+
_register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
|
| 1512 |
+
|
| 1513 |
+
|
| 1514 |
+
def _register_woq_mm_int8_pattern2():
|
| 1515 |
+
# F.linear(x, weight.to(dtype=x.dtype)) * scales
|
| 1516 |
+
# case of dispatching to mm, w/o x reshape
|
| 1517 |
+
_woq_pattern = CallFunction(
|
| 1518 |
+
aten.mul.Tensor,
|
| 1519 |
+
CallFunction(
|
| 1520 |
+
aten.reshape.default,
|
| 1521 |
+
CallFunction(
|
| 1522 |
+
aten.mm.default,
|
| 1523 |
+
KeywordArg("x"),
|
| 1524 |
+
CallFunction(
|
| 1525 |
+
aten.permute.default,
|
| 1526 |
+
CallFunction(
|
| 1527 |
+
prims.convert_element_type.default, KeywordArg("weight"), Arg()
|
| 1528 |
+
),
|
| 1529 |
+
Arg(),
|
| 1530 |
+
),
|
| 1531 |
+
),
|
| 1532 |
+
Arg(),
|
| 1533 |
+
),
|
| 1534 |
+
KeywordArg("scales"),
|
| 1535 |
+
)
|
| 1536 |
+
_register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
|
| 1537 |
+
|
| 1538 |
+
|
| 1539 |
+
def _register_woq_mm_int8_pattern3():
|
| 1540 |
+
# F.linear(x, weight.to(dtype=x.dtype)) * scales
|
| 1541 |
+
# case of dispatching to bmm
|
| 1542 |
+
_woq_pattern = CallFunction(
|
| 1543 |
+
aten.mul.Tensor,
|
| 1544 |
+
CallFunction(
|
| 1545 |
+
aten.bmm.default,
|
| 1546 |
+
CallFunction(aten.expand.default, KeywordArg("x"), Arg()),
|
| 1547 |
+
CallFunction(
|
| 1548 |
+
aten.expand.default,
|
| 1549 |
+
CallFunction(
|
| 1550 |
+
aten.permute.default,
|
| 1551 |
+
CallFunction(
|
| 1552 |
+
prims.convert_element_type.default, KeywordArg("weight"), Arg()
|
| 1553 |
+
),
|
| 1554 |
+
Arg(),
|
| 1555 |
+
),
|
| 1556 |
+
Arg(),
|
| 1557 |
+
),
|
| 1558 |
+
),
|
| 1559 |
+
KeywordArg("scales"),
|
| 1560 |
+
)
|
| 1561 |
+
_register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
|
| 1562 |
+
|
| 1563 |
+
|
| 1564 |
+
def _register_quantization_lowerings():
|
| 1565 |
+
_register_quantization_unary_fusion()
|
| 1566 |
+
_register_quantization_binary_fusion()
|
| 1567 |
+
_register_quantization_maxpool2d()
|
| 1568 |
+
_register_quantization_cat()
|
| 1569 |
+
_register_quantization_reshape()
|
| 1570 |
+
|
| 1571 |
+
|
| 1572 |
+
def _register_woq_lowerings():
|
| 1573 |
+
_register_woq_mm_int8_pattern1()
|
| 1574 |
+
_register_woq_mm_int8_pattern2()
|
| 1575 |
+
_register_woq_mm_int8_pattern3()
|
| 1576 |
+
|
| 1577 |
+
|
| 1578 |
+
def _is_valid_dequant_promotion_pattern(dtype=torch.float32):
|
| 1579 |
+
def _inner(match):
|
| 1580 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 1581 |
+
dequant_pattern_end_node = match.output_node()
|
| 1582 |
+
if dequant_pattern_end_node.target not in [
|
| 1583 |
+
quantized_decomposed.dequantize_per_tensor.default,
|
| 1584 |
+
quantized_decomposed.dequantize_per_tensor.tensor,
|
| 1585 |
+
prims.convert_element_type.default,
|
| 1586 |
+
aten.reshape.default,
|
| 1587 |
+
]:
|
| 1588 |
+
return False
|
| 1589 |
+
|
| 1590 |
+
if dequant_pattern_end_node.target is aten.reshape.default:
|
| 1591 |
+
dequant_node = (
|
| 1592 |
+
dequant_pattern_end_node.args[
|
| 1593 |
+
0
|
| 1594 |
+
] # pattern: linear <- reshape <- dequant
|
| 1595 |
+
if dtype == torch.float32
|
| 1596 |
+
else dequant_pattern_end_node.args[0].args[
|
| 1597 |
+
0
|
| 1598 |
+
] # pattern: linear <- reshape <- to_bf16 <- dequant
|
| 1599 |
+
)
|
| 1600 |
+
else:
|
| 1601 |
+
dequant_node = (
|
| 1602 |
+
dequant_pattern_end_node # pattern: linear <- dequant
|
| 1603 |
+
if dtype == torch.float32
|
| 1604 |
+
else dequant_pattern_end_node.args[
|
| 1605 |
+
0
|
| 1606 |
+
] # pattern: linear <- to_bf16 <- dequant
|
| 1607 |
+
)
|
| 1608 |
+
|
| 1609 |
+
if (
|
| 1610 |
+
dequant_node.target
|
| 1611 |
+
in [
|
| 1612 |
+
quantized_decomposed.dequantize_per_tensor.default,
|
| 1613 |
+
quantized_decomposed.dequantize_per_tensor.tensor,
|
| 1614 |
+
]
|
| 1615 |
+
and len(list(dequant_pattern_end_node.users)) > 1
|
| 1616 |
+
):
|
| 1617 |
+
# If dequant pattern has more than 1 users, then do dequant promoted
|
| 1618 |
+
return True
|
| 1619 |
+
return False
|
| 1620 |
+
|
| 1621 |
+
return _inner
|
| 1622 |
+
|
| 1623 |
+
|
| 1624 |
+
def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
|
| 1625 |
+
@register_freezing_graph_pattern(
|
| 1626 |
+
pattern,
|
| 1627 |
+
extra_check=_is_valid_dequant_promotion_pattern(dtype),
|
| 1628 |
+
pass_number=pass_number,
|
| 1629 |
+
)
|
| 1630 |
+
def dequant_promotion(match: Match, *args, **kwargs):
|
| 1631 |
+
# Dequant_promotion will transform
|
| 1632 |
+
# graph 1:
|
| 1633 |
+
# quant
|
| 1634 |
+
# + - - - | - - - +
|
| 1635 |
+
# | dequant |
|
| 1636 |
+
# | / \ |
|
| 1637 |
+
# | node1 node2 |
|
| 1638 |
+
# + - | - - - | - +
|
| 1639 |
+
# quant quant
|
| 1640 |
+
# into:
|
| 1641 |
+
# graph 2:
|
| 1642 |
+
# quant
|
| 1643 |
+
# + - - / - \ - - +
|
| 1644 |
+
# |dequant dequant|
|
| 1645 |
+
# | | | |
|
| 1646 |
+
# | node1 node2 |
|
| 1647 |
+
# + - | - - - | - +
|
| 1648 |
+
# quant quant
|
| 1649 |
+
# In graph 1, the dequant node is shared by node1 and node2,
|
| 1650 |
+
# as a result, neither node1 nor node2 could form an int8
|
| 1651 |
+
# fusion pattern.
|
| 1652 |
+
# After this transformation, the graph 2 could hit the int8
|
| 1653 |
+
# fusion pattern: dequant-node-quant, respectively for
|
| 1654 |
+
# node1 and node2.
|
| 1655 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 1656 |
+
|
| 1657 |
+
def clone_to_new_node(graph, source_node, user_node):
|
| 1658 |
+
# Clone the source_node to a new node
|
| 1659 |
+
# Replace user_node's input from source_node to new_node
|
| 1660 |
+
assert (
|
| 1661 |
+
source_node.op == "call_function"
|
| 1662 |
+
), "clone_to_new_node only support node.op call_function"
|
| 1663 |
+
with graph.inserting_before(user_node):
|
| 1664 |
+
new_node = graph.call_function(
|
| 1665 |
+
source_node.target,
|
| 1666 |
+
args=source_node.args,
|
| 1667 |
+
kwargs=source_node.kwargs,
|
| 1668 |
+
)
|
| 1669 |
+
new_node.meta = copy.copy(source_node.meta)
|
| 1670 |
+
user_node.replace_input_with(source_node, new_node)
|
| 1671 |
+
return new_node
|
| 1672 |
+
|
| 1673 |
+
# Find the start node and end node of a dequant pattern
|
| 1674 |
+
# * End node should be the match.output_node()
|
| 1675 |
+
# * Start node should be the node of dequantize_per_tensor
|
| 1676 |
+
dequant_pattern_end_node = match.output_node()
|
| 1677 |
+
assert dequant_pattern_end_node.target in [
|
| 1678 |
+
quantized_decomposed.dequantize_per_tensor.default,
|
| 1679 |
+
quantized_decomposed.dequantize_per_tensor.tensor,
|
| 1680 |
+
prims.convert_element_type.default,
|
| 1681 |
+
aten.reshape.default,
|
| 1682 |
+
]
|
| 1683 |
+
|
| 1684 |
+
# For a dequant pattern, we should expect see the node list as:
|
| 1685 |
+
# * OPT(aten.reshape.default)
|
| 1686 |
+
# * OPT(prims.convert_element_type.default) (to_bf16)
|
| 1687 |
+
# * dequantize_per_tensor
|
| 1688 |
+
def _find_first_node_in_dequant_pattern(_node):
|
| 1689 |
+
if _node.target in [
|
| 1690 |
+
quantized_decomposed.dequantize_per_tensor.default,
|
| 1691 |
+
quantized_decomposed.dequantize_per_tensor.tensor,
|
| 1692 |
+
]:
|
| 1693 |
+
# For a dequant pattern, we expect the start node is a dequantize_per_tensor node
|
| 1694 |
+
return _node
|
| 1695 |
+
else:
|
| 1696 |
+
assert (
|
| 1697 |
+
len(_node.args) >= 1
|
| 1698 |
+
), "In in dequant pattern, each node should have more than 1 arg."
|
| 1699 |
+
return _find_first_node_in_dequant_pattern(_node.args[0])
|
| 1700 |
+
|
| 1701 |
+
dequant_pattern_start_node = _find_first_node_in_dequant_pattern(
|
| 1702 |
+
dequant_pattern_end_node
|
| 1703 |
+
)
|
| 1704 |
+
|
| 1705 |
+
assert dequant_pattern_start_node.target in [
|
| 1706 |
+
quantized_decomposed.dequantize_per_tensor.default,
|
| 1707 |
+
quantized_decomposed.dequantize_per_tensor.tensor,
|
| 1708 |
+
]
|
| 1709 |
+
|
| 1710 |
+
# Clone the dequant pattern for each user node
|
| 1711 |
+
graph = match.graph
|
| 1712 |
+
user_node_list = list(dequant_pattern_end_node.users)
|
| 1713 |
+
for user_node in user_node_list[1:]:
|
| 1714 |
+
_source_node = dequant_pattern_end_node
|
| 1715 |
+
_user_node = user_node
|
| 1716 |
+
while _source_node != dequant_pattern_start_node.args[0]:
|
| 1717 |
+
_user_node = clone_to_new_node(graph, _source_node, _user_node)
|
| 1718 |
+
_source_node = _source_node.args[0] # type: ignore[assignment]
|
| 1719 |
+
|
| 1720 |
+
counters["inductor"]["dequant_promotion_matcher_count"] += 1
|
| 1721 |
+
counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes)
|
| 1722 |
+
|
| 1723 |
+
|
| 1724 |
+
def _is_valid_dequant_conv2d_pattern(dtype):
|
| 1725 |
+
def _inner(match):
|
| 1726 |
+
# Here we do some further check to ensure:
|
| 1727 |
+
# 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now.
|
| 1728 |
+
# 2. The dequant pattern has only 1 user of conv2d node.
|
| 1729 |
+
# If these conditions don't meet, we will not
|
| 1730 |
+
# insert weight prepack node into the matched pattern.
|
| 1731 |
+
conv_node = match.output_node()
|
| 1732 |
+
assert conv_node.target is aten.convolution.default
|
| 1733 |
+
input_meta_value = conv_node.args[0].meta.get("val")
|
| 1734 |
+
weight_meta_value = conv_node.args[1].meta.get("val")
|
| 1735 |
+
for meta_value in [input_meta_value, weight_meta_value]:
|
| 1736 |
+
if (
|
| 1737 |
+
meta_value is None
|
| 1738 |
+
or meta_value.device.type != "cpu"
|
| 1739 |
+
or meta_value.dim() != 4
|
| 1740 |
+
):
|
| 1741 |
+
# Only support conv2d now
|
| 1742 |
+
return False
|
| 1743 |
+
|
| 1744 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 1745 |
+
|
| 1746 |
+
if dtype == torch.float32:
|
| 1747 |
+
dequant_node = conv_node.args[0]
|
| 1748 |
+
else:
|
| 1749 |
+
convert_to_bf16 = conv_node.args[0]
|
| 1750 |
+
dequant_node = convert_to_bf16.args[0]
|
| 1751 |
+
|
| 1752 |
+
if len(list(dequant_node.users)) != 1:
|
| 1753 |
+
# Ensure the dequant pattern only has 1 user
|
| 1754 |
+
# since we will delete the dequant pattern here
|
| 1755 |
+
return False
|
| 1756 |
+
return True
|
| 1757 |
+
|
| 1758 |
+
return _inner
|
| 1759 |
+
|
| 1760 |
+
|
| 1761 |
+
def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32):
|
| 1762 |
+
@register_freezing_graph_pattern(
|
| 1763 |
+
pattern,
|
| 1764 |
+
extra_check=_is_valid_dequant_conv2d_pattern(dtype),
|
| 1765 |
+
pass_number=pass_number,
|
| 1766 |
+
)
|
| 1767 |
+
def qconv_weight_prepack(match: Match, *args, **kwargs):
|
| 1768 |
+
"""
|
| 1769 |
+
Match the pattern:
|
| 1770 |
+
int8 activation
|
| 1771 |
+
|
|
| 1772 |
+
dequant_per_tensor
|
| 1773 |
+
|
|
| 1774 |
+
Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight
|
| 1775 |
+
|
| 1776 |
+
Insert weight prepack node and change the pattern to:
|
| 1777 |
+
int8 activation
|
| 1778 |
+
|
|
| 1779 |
+
onednn.qconv2d_pointwise <- onednn.qconv_prepack <- int8_weight
|
| 1780 |
+
"""
|
| 1781 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 1782 |
+
conv_node = match.output_node()
|
| 1783 |
+
assert conv_node.target is aten.convolution.default
|
| 1784 |
+
if dtype == torch.float32:
|
| 1785 |
+
dequant_node = conv_node.args[0]
|
| 1786 |
+
else:
|
| 1787 |
+
convert_to_bf16 = conv_node.args[0]
|
| 1788 |
+
dequant_node = convert_to_bf16.args[0] # type: ignore[union-attr]
|
| 1789 |
+
has_clone_to_channel_last_node_in_pattern = (
|
| 1790 |
+
conv_node.args[1].target is aten.clone.default # type: ignore[union-attr]
|
| 1791 |
+
)
|
| 1792 |
+
clone_node = (
|
| 1793 |
+
conv_node.args[1] if has_clone_to_channel_last_node_in_pattern else None
|
| 1794 |
+
)
|
| 1795 |
+
|
| 1796 |
+
if dtype == torch.float32:
|
| 1797 |
+
dequant_per_channel = (
|
| 1798 |
+
clone_node.args[0] # type: ignore[union-attr]
|
| 1799 |
+
if has_clone_to_channel_last_node_in_pattern
|
| 1800 |
+
else conv_node.args[1]
|
| 1801 |
+
)
|
| 1802 |
+
else:
|
| 1803 |
+
weight_to_bf16_node = (
|
| 1804 |
+
clone_node.args[0] # type: ignore[union-attr]
|
| 1805 |
+
if has_clone_to_channel_last_node_in_pattern
|
| 1806 |
+
else conv_node.args[1]
|
| 1807 |
+
)
|
| 1808 |
+
dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr]
|
| 1809 |
+
|
| 1810 |
+
assert (
|
| 1811 |
+
dequant_per_channel.target # type: ignore[union-attr]
|
| 1812 |
+
is quantized_decomposed.dequantize_per_channel.default
|
| 1813 |
+
)
|
| 1814 |
+
|
| 1815 |
+
# Activation QParams
|
| 1816 |
+
qx, x_zp, x_scale = (
|
| 1817 |
+
kwargs["x"],
|
| 1818 |
+
kwargs["x_zp"],
|
| 1819 |
+
kwargs["x_scale"],
|
| 1820 |
+
)
|
| 1821 |
+
|
| 1822 |
+
# Weight QParams
|
| 1823 |
+
qw, w_scale, w_zp = (
|
| 1824 |
+
kwargs["q_weight"],
|
| 1825 |
+
kwargs["w_scale"],
|
| 1826 |
+
kwargs["w_zp"],
|
| 1827 |
+
)
|
| 1828 |
+
|
| 1829 |
+
# Conv Params
|
| 1830 |
+
bias, stride, padding, dilation, groups = (
|
| 1831 |
+
kwargs["b"],
|
| 1832 |
+
kwargs["stride"],
|
| 1833 |
+
kwargs["padding"],
|
| 1834 |
+
kwargs["dilation"],
|
| 1835 |
+
kwargs["groups"],
|
| 1836 |
+
)
|
| 1837 |
+
|
| 1838 |
+
x_shape = qx.meta.get("tensor_meta").shape
|
| 1839 |
+
if has_free_symbols(x_shape):
|
| 1840 |
+
# For dynamic shape case, we can't get activation shape ahead of runtime.
|
| 1841 |
+
x_shape = None
|
| 1842 |
+
graph = match.graph
|
| 1843 |
+
with graph.inserting_before(conv_node):
|
| 1844 |
+
# Insert weight prepack node and the QConv node
|
| 1845 |
+
packed_weight_inputs = (
|
| 1846 |
+
qw,
|
| 1847 |
+
w_scale,
|
| 1848 |
+
x_scale,
|
| 1849 |
+
x_zp,
|
| 1850 |
+
stride,
|
| 1851 |
+
padding,
|
| 1852 |
+
dilation,
|
| 1853 |
+
groups,
|
| 1854 |
+
x_shape,
|
| 1855 |
+
)
|
| 1856 |
+
packed_weight_op = torch.ops.onednn.qconv_prepack
|
| 1857 |
+
prepack_weight_node = graph.call_function(
|
| 1858 |
+
packed_weight_op, args=packed_weight_inputs
|
| 1859 |
+
)
|
| 1860 |
+
|
| 1861 |
+
new_args: Tuple[Any, ...] = (
|
| 1862 |
+
qx,
|
| 1863 |
+
x_scale,
|
| 1864 |
+
x_zp,
|
| 1865 |
+
prepack_weight_node,
|
| 1866 |
+
w_scale,
|
| 1867 |
+
w_zp,
|
| 1868 |
+
bias,
|
| 1869 |
+
stride,
|
| 1870 |
+
padding,
|
| 1871 |
+
dilation,
|
| 1872 |
+
groups,
|
| 1873 |
+
1.0, # output_scale
|
| 1874 |
+
0, # output_zero_point
|
| 1875 |
+
dtype, # output_dtype
|
| 1876 |
+
"none", # attr
|
| 1877 |
+
[], # scalars
|
| 1878 |
+
"", # algorithm
|
| 1879 |
+
)
|
| 1880 |
+
new_conv_node = graph.call_function(
|
| 1881 |
+
torch.ops.onednn.qconv2d_pointwise.default, args=new_args
|
| 1882 |
+
)
|
| 1883 |
+
conv_node.replace_all_uses_with(new_conv_node)
|
| 1884 |
+
new_conv_node.meta.update(conv_node.meta)
|
| 1885 |
+
|
| 1886 |
+
# Erase the original conv node
|
| 1887 |
+
graph.erase_node(conv_node)
|
| 1888 |
+
# Erase the dequant pattern
|
| 1889 |
+
if dtype == torch.bfloat16:
|
| 1890 |
+
graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined, arg-type]
|
| 1891 |
+
graph.erase_node(dequant_node) # type: ignore[arg-type]
|
| 1892 |
+
# Erase the dequant per channel pattern
|
| 1893 |
+
if clone_node is not None:
|
| 1894 |
+
graph.erase_node(clone_node) # type: ignore[arg-type]
|
| 1895 |
+
if dtype == torch.bfloat16:
|
| 1896 |
+
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type]
|
| 1897 |
+
graph.erase_node(dequant_per_channel) # type: ignore[arg-type]
|
| 1898 |
+
counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1
|
| 1899 |
+
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len(
|
| 1900 |
+
match.nodes
|
| 1901 |
+
)
|
| 1902 |
+
|
| 1903 |
+
|
| 1904 |
+
def _generate_dequant_convolution_node_pattern(
|
| 1905 |
+
_dequant_per_channel_pattern, dtype=torch.float32
|
| 1906 |
+
):
|
| 1907 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 1908 |
+
dequant_convolution_node_pattern = CallFunction(
|
| 1909 |
+
aten.convolution.default,
|
| 1910 |
+
_may_generate_pattern_with_dtype_convert(
|
| 1911 |
+
get_dequantize_per_tensor_activation_pattern(),
|
| 1912 |
+
KeywordArg("autocast_act_dtype"),
|
| 1913 |
+
dtype == torch.bfloat16,
|
| 1914 |
+
),
|
| 1915 |
+
_dequant_per_channel_pattern,
|
| 1916 |
+
KeywordArg("b"),
|
| 1917 |
+
KeywordArg("stride"),
|
| 1918 |
+
KeywordArg("padding"),
|
| 1919 |
+
KeywordArg("dilation"),
|
| 1920 |
+
KeywordArg("is_transposed"),
|
| 1921 |
+
KeywordArg("out_padding"),
|
| 1922 |
+
KeywordArg("groups"),
|
| 1923 |
+
)
|
| 1924 |
+
return dequant_convolution_node_pattern
|
| 1925 |
+
|
| 1926 |
+
|
| 1927 |
+
def _generate_qconv_weight_prepack_patterns(dtype=torch.float32):
|
| 1928 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 1929 |
+
return (
|
| 1930 |
+
_generate_dequant_convolution_node_pattern(
|
| 1931 |
+
dequantize_per_channel_weight_pattern
|
| 1932 |
+
if dtype == torch.float32
|
| 1933 |
+
else dequantize_per_channel_to_bf16_weight_pattern,
|
| 1934 |
+
dtype,
|
| 1935 |
+
),
|
| 1936 |
+
# There is another pattern due to the pass of convert_conv_weights_to_channels_last
|
| 1937 |
+
# https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362.
|
| 1938 |
+
# Depend on some heuristics, it may or may not insert to(channel_last) node
|
| 1939 |
+
# between convolution and dequant_per_channel node
|
| 1940 |
+
_generate_dequant_convolution_node_pattern(
|
| 1941 |
+
dequantize_per_channel_clone_weight_pattern
|
| 1942 |
+
if dtype == torch.float32
|
| 1943 |
+
else dequantize_per_channel_to_bf16_clone_weight_pattern,
|
| 1944 |
+
dtype,
|
| 1945 |
+
),
|
| 1946 |
+
)
|
| 1947 |
+
|
| 1948 |
+
|
| 1949 |
+
def _get_linear_node(match, input_dim_exceeds_two, input_contiguous):
|
| 1950 |
+
output_reshape_node = None
|
| 1951 |
+
if input_dim_exceeds_two:
|
| 1952 |
+
if input_contiguous:
|
| 1953 |
+
output_reshape_node = match.output_node()
|
| 1954 |
+
assert output_reshape_node.target is aten.reshape.default
|
| 1955 |
+
linear_node = output_reshape_node.args[0]
|
| 1956 |
+
else:
|
| 1957 |
+
linear_nodes = filter_nodes(match.nodes, aten.bmm.default)
|
| 1958 |
+
assert len(linear_nodes) == 1
|
| 1959 |
+
linear_node = linear_nodes[0]
|
| 1960 |
+
else:
|
| 1961 |
+
linear_node = match.output_node()
|
| 1962 |
+
|
| 1963 |
+
assert linear_node.target in (
|
| 1964 |
+
aten.addmm.default,
|
| 1965 |
+
aten.mm.default,
|
| 1966 |
+
aten.bmm.default,
|
| 1967 |
+
)
|
| 1968 |
+
return linear_node, output_reshape_node
|
| 1969 |
+
|
| 1970 |
+
|
| 1971 |
+
def _get_linear_dq_node(
|
| 1972 |
+
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
|
| 1973 |
+
):
|
| 1974 |
+
act_reshape_node = None
|
| 1975 |
+
activation_to_bf16_node = None
|
| 1976 |
+
act_expand_node = None
|
| 1977 |
+
if input_dim_exceeds_two:
|
| 1978 |
+
if input_contiguous:
|
| 1979 |
+
act_reshape_node = linear_node.args[input_index]
|
| 1980 |
+
assert act_reshape_node.target is aten.reshape.default
|
| 1981 |
+
if dtype == torch.float32:
|
| 1982 |
+
# pattern: linear -> reshape -> dequant
|
| 1983 |
+
dequant_node = act_reshape_node.args[0]
|
| 1984 |
+
else:
|
| 1985 |
+
# pattern: linear -> reshape -> to_bf16 -> dequant
|
| 1986 |
+
activation_to_bf16_node = act_reshape_node.args[0]
|
| 1987 |
+
dequant_node = activation_to_bf16_node.args[0]
|
| 1988 |
+
else:
|
| 1989 |
+
# bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous
|
| 1990 |
+
act_expand_node = linear_node.args[input_index]
|
| 1991 |
+
assert act_expand_node.target is aten.expand.default
|
| 1992 |
+
if dtype == torch.float32:
|
| 1993 |
+
dequant_node = act_expand_node.args[0]
|
| 1994 |
+
else:
|
| 1995 |
+
activation_to_bf16_node = act_expand_node.args[0]
|
| 1996 |
+
dequant_node = activation_to_bf16_node.args[0]
|
| 1997 |
+
else:
|
| 1998 |
+
if dtype == torch.float32:
|
| 1999 |
+
# pattern: linear -> dequant
|
| 2000 |
+
dequant_node = linear_node.args[input_index]
|
| 2001 |
+
else:
|
| 2002 |
+
# pattern: linear -> to_bf16 -> dequant
|
| 2003 |
+
activation_to_bf16_node = linear_node.args[input_index]
|
| 2004 |
+
dequant_node = activation_to_bf16_node.args[0]
|
| 2005 |
+
return dequant_node, act_reshape_node, activation_to_bf16_node, act_expand_node
|
| 2006 |
+
|
| 2007 |
+
|
| 2008 |
+
def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contiguous):
|
| 2009 |
+
def _inner(match):
|
| 2010 |
+
# Check dequant pattern has only 1 user.
|
| 2011 |
+
(
|
| 2012 |
+
linear_node,
|
| 2013 |
+
_,
|
| 2014 |
+
) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous)
|
| 2015 |
+
|
| 2016 |
+
input_index = 1 if linear_node.target is aten.addmm.default else 0
|
| 2017 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 2018 |
+
(
|
| 2019 |
+
dequant_node,
|
| 2020 |
+
_,
|
| 2021 |
+
_,
|
| 2022 |
+
_,
|
| 2023 |
+
) = _get_linear_dq_node(
|
| 2024 |
+
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
|
| 2025 |
+
)
|
| 2026 |
+
|
| 2027 |
+
assert dequant_node.target in [
|
| 2028 |
+
quantized_decomposed.dequantize_per_tensor.default,
|
| 2029 |
+
quantized_decomposed.dequantize_per_tensor.tensor,
|
| 2030 |
+
]
|
| 2031 |
+
|
| 2032 |
+
if len(list(dequant_node.users)) != 1:
|
| 2033 |
+
# Ensure the dequant pattern only has 1 user
|
| 2034 |
+
# since we will delete the dequant pattern here
|
| 2035 |
+
return False
|
| 2036 |
+
|
| 2037 |
+
# Extra check for bmm pattern
|
| 2038 |
+
if input_dim_exceeds_two and not input_contiguous:
|
| 2039 |
+
# Check for act
|
| 2040 |
+
# Act expand size should be exactly same as act size
|
| 2041 |
+
act_expand_size = match.kwargs["act_expand_size"]
|
| 2042 |
+
act_node = match.kwargs["x"]
|
| 2043 |
+
if not (
|
| 2044 |
+
hasattr(act_node, "meta")
|
| 2045 |
+
and isinstance(act_node.meta.get("val", None), torch.Tensor)
|
| 2046 |
+
and (act_node.meta["val"].size() == torch.Size(act_expand_size))
|
| 2047 |
+
):
|
| 2048 |
+
return False
|
| 2049 |
+
|
| 2050 |
+
# Check for wgt
|
| 2051 |
+
# wgt permute dims should be [1, 0]
|
| 2052 |
+
wgt_permute_dims = match.kwargs["permute_axes"]
|
| 2053 |
+
if wgt_permute_dims != [1, 0]:
|
| 2054 |
+
return False
|
| 2055 |
+
|
| 2056 |
+
# Check below wgt size items:
|
| 2057 |
+
# wgt before expand should with dim 2
|
| 2058 |
+
# Expand size should with dim 3
|
| 2059 |
+
# Expand size[0] should same as act size[0]
|
| 2060 |
+
# Expand size[1] should same as wgt size[1]
|
| 2061 |
+
# Expand size[2] should same as wgt size[0]
|
| 2062 |
+
qweight_node = match.kwargs["q_weight"]
|
| 2063 |
+
wgt_expand_size = match.kwargs["wgt_expand_size"]
|
| 2064 |
+
if not (
|
| 2065 |
+
hasattr(qweight_node, "meta")
|
| 2066 |
+
and isinstance(qweight_node.meta.get("val", None), torch.Tensor)
|
| 2067 |
+
and len(qweight_node.meta["val"].size()) == 2
|
| 2068 |
+
and len(wgt_expand_size) == 3
|
| 2069 |
+
and wgt_expand_size[0] == act_node.meta["val"].size()[0]
|
| 2070 |
+
and wgt_expand_size[1] == qweight_node.meta["val"].size()[1]
|
| 2071 |
+
and wgt_expand_size[2] == qweight_node.meta["val"].size()[0]
|
| 2072 |
+
):
|
| 2073 |
+
return False
|
| 2074 |
+
|
| 2075 |
+
return True
|
| 2076 |
+
|
| 2077 |
+
return _inner
|
| 2078 |
+
|
| 2079 |
+
|
| 2080 |
+
def _register_qlinear_weight_prepack_pass(
|
| 2081 |
+
pattern,
|
| 2082 |
+
pass_number,
|
| 2083 |
+
dtype=torch.float32,
|
| 2084 |
+
input_dim_exceeds_two=False,
|
| 2085 |
+
input_contiguous=True,
|
| 2086 |
+
):
|
| 2087 |
+
@register_freezing_graph_pattern(
|
| 2088 |
+
pattern,
|
| 2089 |
+
extra_check=_is_valid_dequant_linear_pattern(
|
| 2090 |
+
dtype, input_dim_exceeds_two, input_contiguous
|
| 2091 |
+
),
|
| 2092 |
+
pass_number=pass_number,
|
| 2093 |
+
)
|
| 2094 |
+
def qlinear_weight_prepack(match: Match, *args, **kwargs):
|
| 2095 |
+
"""
|
| 2096 |
+
Match the pattern:
|
| 2097 |
+
int8 activation
|
| 2098 |
+
|
|
| 2099 |
+
dequant_per_tensor
|
| 2100 |
+
|
|
| 2101 |
+
mm/addmm <- t <- dequant_per_channel <- int8_weight
|
| 2102 |
+
|
| 2103 |
+
Insert weight prepack node and change the pattern to:
|
| 2104 |
+
int8 activation
|
| 2105 |
+
|
|
| 2106 |
+
onednn.qlinear_pointwise <- onednn.qlinear_prepack <- int8_weight
|
| 2107 |
+
"""
|
| 2108 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 2109 |
+
(
|
| 2110 |
+
linear_node,
|
| 2111 |
+
output_reshape_node,
|
| 2112 |
+
) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous)
|
| 2113 |
+
input_index = 1 if linear_node.target is aten.addmm.default else 0
|
| 2114 |
+
weight_index = input_index + 1
|
| 2115 |
+
|
| 2116 |
+
(
|
| 2117 |
+
dequant_node,
|
| 2118 |
+
act_reshape_node,
|
| 2119 |
+
activation_to_bf16_node,
|
| 2120 |
+
act_expand_node,
|
| 2121 |
+
) = _get_linear_dq_node(
|
| 2122 |
+
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
|
| 2123 |
+
)
|
| 2124 |
+
|
| 2125 |
+
if input_dim_exceeds_two and not input_contiguous:
|
| 2126 |
+
wgt_expand_node = linear_node.args[weight_index]
|
| 2127 |
+
assert wgt_expand_node.target is aten.expand.default
|
| 2128 |
+
t_node = wgt_expand_node.args[0]
|
| 2129 |
+
else:
|
| 2130 |
+
t_node = linear_node.args[weight_index]
|
| 2131 |
+
|
| 2132 |
+
if dtype == torch.float32:
|
| 2133 |
+
dequant_per_channel = t_node.args[0]
|
| 2134 |
+
else:
|
| 2135 |
+
weight_to_bf16_node = t_node.args[0]
|
| 2136 |
+
dequant_per_channel = weight_to_bf16_node.args[0]
|
| 2137 |
+
assert (
|
| 2138 |
+
dequant_per_channel.target
|
| 2139 |
+
is quantized_decomposed.dequantize_per_channel.default
|
| 2140 |
+
)
|
| 2141 |
+
|
| 2142 |
+
# Activation QParams
|
| 2143 |
+
qx, x_zp, x_scale = (
|
| 2144 |
+
kwargs["x"],
|
| 2145 |
+
kwargs["x_zp"],
|
| 2146 |
+
kwargs["x_scale"],
|
| 2147 |
+
)
|
| 2148 |
+
|
| 2149 |
+
# Weight QParams
|
| 2150 |
+
qw, w_scale, w_zp = (
|
| 2151 |
+
kwargs["q_weight"],
|
| 2152 |
+
kwargs["w_scale"],
|
| 2153 |
+
kwargs["w_zp"],
|
| 2154 |
+
)
|
| 2155 |
+
|
| 2156 |
+
# Params
|
| 2157 |
+
bias = kwargs["b"] if "b" in kwargs else None
|
| 2158 |
+
|
| 2159 |
+
x_shape = qx.meta.get("tensor_meta").shape
|
| 2160 |
+
if has_free_symbols(x_shape):
|
| 2161 |
+
# For dynamic shape case, we can't get activation shape ahead of runtime.
|
| 2162 |
+
x_shape = None
|
| 2163 |
+
graph = match.graph
|
| 2164 |
+
with graph.inserting_before(linear_node):
|
| 2165 |
+
# Insert weight prepack node and the qlinear node
|
| 2166 |
+
packed_weight_inputs = (
|
| 2167 |
+
qw,
|
| 2168 |
+
x_shape,
|
| 2169 |
+
)
|
| 2170 |
+
packed_weight_op = torch.ops.onednn.qlinear_prepack
|
| 2171 |
+
prepack_weight_node = graph.call_function(
|
| 2172 |
+
packed_weight_op, args=packed_weight_inputs
|
| 2173 |
+
)
|
| 2174 |
+
|
| 2175 |
+
new_args: Tuple[Any, ...] = (
|
| 2176 |
+
qx,
|
| 2177 |
+
x_scale,
|
| 2178 |
+
x_zp,
|
| 2179 |
+
prepack_weight_node,
|
| 2180 |
+
w_scale,
|
| 2181 |
+
w_zp,
|
| 2182 |
+
bias,
|
| 2183 |
+
1.0, # output_scale
|
| 2184 |
+
0, # output_zero_point
|
| 2185 |
+
dtype, # output_dtype
|
| 2186 |
+
"none", # post op name
|
| 2187 |
+
[], # post op args
|
| 2188 |
+
"", # post op algorithm
|
| 2189 |
+
)
|
| 2190 |
+
Node = torch.fx.node.Node
|
| 2191 |
+
if isinstance(x_scale, Node) and isinstance(x_zp, Node):
|
| 2192 |
+
new_linear_node = graph.call_function(
|
| 2193 |
+
torch.ops.onednn.qlinear_pointwise.tensor, args=new_args
|
| 2194 |
+
)
|
| 2195 |
+
else:
|
| 2196 |
+
new_linear_node = graph.call_function(
|
| 2197 |
+
torch.ops.onednn.qlinear_pointwise.default, args=new_args
|
| 2198 |
+
)
|
| 2199 |
+
if input_dim_exceeds_two:
|
| 2200 |
+
if input_contiguous:
|
| 2201 |
+
output_reshape_node.replace_all_uses_with(new_linear_node)
|
| 2202 |
+
new_linear_node.meta.update(output_reshape_node.meta)
|
| 2203 |
+
else:
|
| 2204 |
+
if bias:
|
| 2205 |
+
output_add_node_for_bias = match.output_node()
|
| 2206 |
+
assert output_add_node_for_bias.target is aten.add.Tensor
|
| 2207 |
+
output_add_node_for_bias.replace_all_uses_with(new_linear_node)
|
| 2208 |
+
new_linear_node.meta.update(output_add_node_for_bias.meta)
|
| 2209 |
+
else:
|
| 2210 |
+
linear_node.replace_all_uses_with(new_linear_node)
|
| 2211 |
+
new_linear_node.meta.update(linear_node.meta)
|
| 2212 |
+
else:
|
| 2213 |
+
linear_node.replace_all_uses_with(new_linear_node)
|
| 2214 |
+
new_linear_node.meta.update(linear_node.meta)
|
| 2215 |
+
|
| 2216 |
+
# Erase the original linear node
|
| 2217 |
+
if input_dim_exceeds_two:
|
| 2218 |
+
if input_contiguous:
|
| 2219 |
+
graph.erase_node(output_reshape_node)
|
| 2220 |
+
elif not input_contiguous and bias:
|
| 2221 |
+
graph.erase_node(output_add_node_for_bias) # type: ignore[possibly-undefined]
|
| 2222 |
+
graph.erase_node(linear_node)
|
| 2223 |
+
if input_dim_exceeds_two:
|
| 2224 |
+
if input_contiguous:
|
| 2225 |
+
graph.erase_node(act_reshape_node)
|
| 2226 |
+
else:
|
| 2227 |
+
graph.erase_node(act_expand_node)
|
| 2228 |
+
graph.erase_node(wgt_expand_node) # type: ignore[possibly-undefined]
|
| 2229 |
+
if dtype == torch.bfloat16:
|
| 2230 |
+
graph.erase_node(activation_to_bf16_node)
|
| 2231 |
+
# Erase the dequant pattern
|
| 2232 |
+
graph.erase_node(dequant_node)
|
| 2233 |
+
# Erase the dequant per channel pattern
|
| 2234 |
+
graph.erase_node(t_node)
|
| 2235 |
+
if dtype == torch.bfloat16:
|
| 2236 |
+
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
|
| 2237 |
+
graph.erase_node(dequant_per_channel)
|
| 2238 |
+
|
| 2239 |
+
counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1
|
| 2240 |
+
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len(
|
| 2241 |
+
match.nodes
|
| 2242 |
+
)
|
| 2243 |
+
|
| 2244 |
+
|
| 2245 |
+
def _generate_dequant_linear_node_pattern(
|
| 2246 |
+
_dequant_per_channel_pattern,
|
| 2247 |
+
dtype=torch.float32,
|
| 2248 |
+
input_dim_exceeds_two=False,
|
| 2249 |
+
is_tensor_overload=False,
|
| 2250 |
+
):
|
| 2251 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 2252 |
+
t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
|
| 2253 |
+
dequant_linear_bias_pattern = _may_generate_pattern_with_reshape(
|
| 2254 |
+
CallFunction(
|
| 2255 |
+
aten.addmm.default,
|
| 2256 |
+
KeywordArg("b"),
|
| 2257 |
+
_may_generate_pattern_with_reshape(
|
| 2258 |
+
_may_generate_pattern_with_dtype_convert(
|
| 2259 |
+
get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
|
| 2260 |
+
KeywordArg("autocast_act_dtype"),
|
| 2261 |
+
dtype == torch.bfloat16,
|
| 2262 |
+
),
|
| 2263 |
+
KeywordArg("act_reshape_size"),
|
| 2264 |
+
input_dim_exceeds_two,
|
| 2265 |
+
),
|
| 2266 |
+
t_pattern,
|
| 2267 |
+
),
|
| 2268 |
+
KeywordArg("output_reshape_size"),
|
| 2269 |
+
input_dim_exceeds_two,
|
| 2270 |
+
)
|
| 2271 |
+
dequant_linear_no_bias_pattern = _may_generate_pattern_with_reshape(
|
| 2272 |
+
CallFunction(
|
| 2273 |
+
aten.mm.default,
|
| 2274 |
+
_may_generate_pattern_with_reshape(
|
| 2275 |
+
_may_generate_pattern_with_dtype_convert(
|
| 2276 |
+
get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
|
| 2277 |
+
KeywordArg("autocast_act_dtype"),
|
| 2278 |
+
dtype == torch.bfloat16,
|
| 2279 |
+
),
|
| 2280 |
+
KeywordArg("act_reshape_size"),
|
| 2281 |
+
input_dim_exceeds_two,
|
| 2282 |
+
),
|
| 2283 |
+
t_pattern,
|
| 2284 |
+
),
|
| 2285 |
+
KeywordArg("output_reshape_size"),
|
| 2286 |
+
input_dim_exceeds_two,
|
| 2287 |
+
)
|
| 2288 |
+
return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern
|
| 2289 |
+
|
| 2290 |
+
|
| 2291 |
+
def _generate_dequant_bmm_node_pattern(
|
| 2292 |
+
_dequant_per_channel_pattern,
|
| 2293 |
+
dtype=torch.float32,
|
| 2294 |
+
with_bias=False,
|
| 2295 |
+
is_tensor_overload=False,
|
| 2296 |
+
):
|
| 2297 |
+
# When activation of linear dim exceed 2 and not contiguous
|
| 2298 |
+
t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
|
| 2299 |
+
|
| 2300 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 2301 |
+
dequant_bmm_pattern = CallFunction(
|
| 2302 |
+
aten.bmm.default,
|
| 2303 |
+
CallFunction(
|
| 2304 |
+
aten.expand.default,
|
| 2305 |
+
_may_generate_pattern_with_dtype_convert(
|
| 2306 |
+
get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
|
| 2307 |
+
KeywordArg("autocast_act_dtype"),
|
| 2308 |
+
dtype == torch.bfloat16,
|
| 2309 |
+
),
|
| 2310 |
+
KeywordArg("act_expand_size"),
|
| 2311 |
+
),
|
| 2312 |
+
CallFunction(
|
| 2313 |
+
aten.expand.default,
|
| 2314 |
+
t_pattern,
|
| 2315 |
+
KeywordArg("wgt_expand_size"),
|
| 2316 |
+
),
|
| 2317 |
+
)
|
| 2318 |
+
|
| 2319 |
+
def _generate_pattern_with_output_add(_dequant_bmm_pattern, _with_bias):
|
| 2320 |
+
if _with_bias:
|
| 2321 |
+
return CallFunction(
|
| 2322 |
+
aten.add.Tensor,
|
| 2323 |
+
_dequant_bmm_pattern,
|
| 2324 |
+
KeywordArg("b"),
|
| 2325 |
+
)
|
| 2326 |
+
else:
|
| 2327 |
+
return _dequant_bmm_pattern
|
| 2328 |
+
|
| 2329 |
+
return _generate_pattern_with_output_add(dequant_bmm_pattern, with_bias)
|
| 2330 |
+
|
| 2331 |
+
|
| 2332 |
+
def _generate_qlinear_weight_prepack_patterns(
|
| 2333 |
+
dtype=torch.float32,
|
| 2334 |
+
input_dim_exceeds_two=False,
|
| 2335 |
+
input_contiguous=True,
|
| 2336 |
+
with_bias=False,
|
| 2337 |
+
is_tensor_overload=False,
|
| 2338 |
+
):
|
| 2339 |
+
if input_dim_exceeds_two and not input_contiguous:
|
| 2340 |
+
return _generate_dequant_bmm_node_pattern(
|
| 2341 |
+
dequantize_per_channel_weight_pattern,
|
| 2342 |
+
dtype,
|
| 2343 |
+
with_bias,
|
| 2344 |
+
is_tensor_overload,
|
| 2345 |
+
)
|
| 2346 |
+
else:
|
| 2347 |
+
return _generate_dequant_linear_node_pattern(
|
| 2348 |
+
dequantize_per_channel_weight_pattern,
|
| 2349 |
+
dtype,
|
| 2350 |
+
input_dim_exceeds_two,
|
| 2351 |
+
is_tensor_overload,
|
| 2352 |
+
)
|
| 2353 |
+
|
| 2354 |
+
|
| 2355 |
+
def _register_dequant_promotion():
|
| 2356 |
+
dequant_pattern_cases = itertools.product(
|
| 2357 |
+
[torch.float32, torch.bfloat16], [True, False], [True, False]
|
| 2358 |
+
)
|
| 2359 |
+
for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases:
|
| 2360 |
+
# 4 dequantization patterns will be matched based on the dtype and input dimension size.
|
| 2361 |
+
# Case 1: int8-mixed-fp32, input dim size is 2
|
| 2362 |
+
# Case 2: int8-mixed-fp32, input dim size exceeds 2
|
| 2363 |
+
# Case 3: int8-mixed-bf16, input dim size is 2
|
| 2364 |
+
# Case 4: int8-mixed-bf16, input dim size exceeds 2
|
| 2365 |
+
# quant
|
| 2366 |
+
# + - - - - | - - - - +
|
| 2367 |
+
# | dequant |
|
| 2368 |
+
# | | |
|
| 2369 |
+
# | OPT(to_bf16) |
|
| 2370 |
+
# | | |
|
| 2371 |
+
# | OPT(reshape) |
|
| 2372 |
+
# | / \ |
|
| 2373 |
+
# | node1 node2 |
|
| 2374 |
+
# + - - | - - - | - - +
|
| 2375 |
+
# OPT(reshape) OPT(reshape)
|
| 2376 |
+
# + - - | - - - | - - +
|
| 2377 |
+
# OPT(to_fp32) OPT(to_fp32)
|
| 2378 |
+
# + - - | - - - | - - +
|
| 2379 |
+
# quant quant
|
| 2380 |
+
_register_dequant_promotion_pass(
|
| 2381 |
+
_may_generate_pattern_with_reshape(
|
| 2382 |
+
_may_generate_pattern_with_dtype_convert(
|
| 2383 |
+
get_dequantize_per_tensor_activation_pattern(
|
| 2384 |
+
is_tensor_overload=is_tensor_overload
|
| 2385 |
+
),
|
| 2386 |
+
KeywordArg("autocast_act_dtype"),
|
| 2387 |
+
dtype == torch.bfloat16,
|
| 2388 |
+
),
|
| 2389 |
+
KeywordArg("act_reshape_size"),
|
| 2390 |
+
with_reshape=input_dim_exceeds_two,
|
| 2391 |
+
),
|
| 2392 |
+
pass_number=0,
|
| 2393 |
+
dtype=dtype,
|
| 2394 |
+
) # pass_number=0 to run before weight prepack
|
| 2395 |
+
|
| 2396 |
+
|
| 2397 |
+
def _register_qconv_weight_prepack():
|
| 2398 |
+
for dtype in [torch.float32, torch.bfloat16]:
|
| 2399 |
+
weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype)
|
| 2400 |
+
for weight_prepack_pattern in weight_prepack_patterns:
|
| 2401 |
+
# Register to pass_number 1, so we can do dequant promotion in pass_number 0.
|
| 2402 |
+
_register_qconv_weight_prepack_pass(
|
| 2403 |
+
weight_prepack_pattern, pass_number=1, dtype=dtype
|
| 2404 |
+
)
|
| 2405 |
+
|
| 2406 |
+
|
| 2407 |
+
def _register_qlinear_weight_prepack():
|
| 2408 |
+
# 6 Linear related patterns will be matched based on the dtype, input dimension size and input contiguous.
|
| 2409 |
+
# Then convert the pattern into a QLinear node with int8_fp32/bf16.
|
| 2410 |
+
# Case 1: int8-mixed-fp32, input dim size is 2
|
| 2411 |
+
# Case 2: int8-mixed-fp32, input dim size exceeds 2 and contiguous
|
| 2412 |
+
# Case 3: int8-mixed-bf16, input dim size is 2
|
| 2413 |
+
# Case 4: int8-mixed-bf16, input dim size exceeds 2 and contiguous
|
| 2414 |
+
|
| 2415 |
+
# + - - - - | - - - - - - | - - - - - +
|
| 2416 |
+
# | dq_per_tensor dq_per_channel |
|
| 2417 |
+
# | | | |
|
| 2418 |
+
# | OPT(to_bf16) OPT(to_bf16) |
|
| 2419 |
+
# | | | |
|
| 2420 |
+
# | OPT(reshape) permute |
|
| 2421 |
+
# | \ / |
|
| 2422 |
+
# | addmm/mm |
|
| 2423 |
+
# | | |
|
| 2424 |
+
# | OPT(reshape) |
|
| 2425 |
+
|
| 2426 |
+
# Case 5: int8-mixed-fp32, input dim size exceeds 2 and not contiguous
|
| 2427 |
+
# Case 6: int8-mixed-bf16, input dim size exceeds 2 and not contiguous
|
| 2428 |
+
|
| 2429 |
+
# + - - - - | - - - - - - | - - - - - +
|
| 2430 |
+
# | dq_per_tensor dq_per_channel |
|
| 2431 |
+
# | | | |
|
| 2432 |
+
# | OPT(to_bf16) OPT(to_bf16) |
|
| 2433 |
+
# | | | |
|
| 2434 |
+
# | expand permute |
|
| 2435 |
+
# | \ | |
|
| 2436 |
+
# | expand |
|
| 2437 |
+
# | / |
|
| 2438 |
+
# | bmm |
|
| 2439 |
+
# | | |
|
| 2440 |
+
# | OPT(add) |
|
| 2441 |
+
|
| 2442 |
+
linear_weight_prepack_cases = itertools.product(
|
| 2443 |
+
[torch.float32, torch.bfloat16], [True, False], [True, False]
|
| 2444 |
+
)
|
| 2445 |
+
|
| 2446 |
+
# Step 1: register patterns from mm and addmm
|
| 2447 |
+
for dtype, input_dim_exceeds_two, is_tensor_overload in linear_weight_prepack_cases:
|
| 2448 |
+
weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns(
|
| 2449 |
+
dtype,
|
| 2450 |
+
input_dim_exceeds_two,
|
| 2451 |
+
is_tensor_overload=is_tensor_overload,
|
| 2452 |
+
)
|
| 2453 |
+
for weight_prepack_pattern in weight_prepack_patterns:
|
| 2454 |
+
# Register to pass_number 1, so we can do dequant promotion in pass_number 0.
|
| 2455 |
+
_register_qlinear_weight_prepack_pass(
|
| 2456 |
+
weight_prepack_pattern,
|
| 2457 |
+
pass_number=1,
|
| 2458 |
+
dtype=dtype,
|
| 2459 |
+
input_dim_exceeds_two=input_dim_exceeds_two,
|
| 2460 |
+
)
|
| 2461 |
+
|
| 2462 |
+
# Step 2: register patterns from bmm
|
| 2463 |
+
# Linear might be decomposed into bmm when input dim exceeds 2 and not contiguous
|
| 2464 |
+
# refer to:
|
| 2465 |
+
# https://github.com/pytorch/pytorch/blob/
|
| 2466 |
+
# 80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968
|
| 2467 |
+
# in this case, we can convert it back to qlinear
|
| 2468 |
+
for dtype, with_bias, is_tensor_overload in itertools.product(
|
| 2469 |
+
[torch.float32, torch.bfloat16], [True, False], [True, False]
|
| 2470 |
+
):
|
| 2471 |
+
bmm_pattern = _generate_qlinear_weight_prepack_patterns(
|
| 2472 |
+
dtype=dtype,
|
| 2473 |
+
input_dim_exceeds_two=True,
|
| 2474 |
+
input_contiguous=False,
|
| 2475 |
+
with_bias=with_bias,
|
| 2476 |
+
is_tensor_overload=is_tensor_overload,
|
| 2477 |
+
)
|
| 2478 |
+
_register_qlinear_weight_prepack_pass(
|
| 2479 |
+
bmm_pattern,
|
| 2480 |
+
pass_number=1
|
| 2481 |
+
if with_bias
|
| 2482 |
+
else 2, # if with_bias, there is an output add, so we should try to match it firstly
|
| 2483 |
+
dtype=dtype,
|
| 2484 |
+
input_dim_exceeds_two=True,
|
| 2485 |
+
input_contiguous=False,
|
| 2486 |
+
)
|
| 2487 |
+
|
| 2488 |
+
|
| 2489 |
+
@functools.lru_cache(None)
|
| 2490 |
+
def _register_quantization_weight_pack_pass():
|
| 2491 |
+
# Step 1: Dequant promotion for int8-mixed-fp32/bf16
|
| 2492 |
+
_register_dequant_promotion()
|
| 2493 |
+
|
| 2494 |
+
# Step 2: QConv weight prepack
|
| 2495 |
+
_register_qconv_weight_prepack()
|
| 2496 |
+
|
| 2497 |
+
# Step 3: QLinear weight prepack
|
| 2498 |
+
_register_qlinear_weight_prepack()
|
| 2499 |
+
|
| 2500 |
+
|
| 2501 |
+
def quant_lift_up(graph_module: torch.fx.GraphModule):
|
| 2502 |
+
"""
|
| 2503 |
+
Lift up the quant node before view like nodes. It can benefit performance
|
| 2504 |
+
of Attention like block. For example, we have the pattern as:
|
| 2505 |
+
|
| 2506 |
+
DQ
|
| 2507 |
+
DQ LINEAR
|
| 2508 |
+
LINEAR VIEW
|
| 2509 |
+
VIEW PERMUTE
|
| 2510 |
+
PERMUTE TRANSPOSE
|
| 2511 |
+
Q Q
|
| 2512 |
+
DQ DQ
|
| 2513 |
+
Matmul
|
| 2514 |
+
DIV
|
| 2515 |
+
ADD
|
| 2516 |
+
SOFTMAX
|
| 2517 |
+
|
| 2518 |
+
We want to lift up the the quant nodes from matmul before view like nodes
|
| 2519 |
+
as the output of Linear node.
|
| 2520 |
+
|
| 2521 |
+
DQ
|
| 2522 |
+
DQ LINEAR
|
| 2523 |
+
LINEAR Q
|
| 2524 |
+
Q VIEW
|
| 2525 |
+
VIEW PERMUTE
|
| 2526 |
+
PERMUTE TRANSPOSE
|
| 2527 |
+
DQ DQ
|
| 2528 |
+
Matmul
|
| 2529 |
+
DIV
|
| 2530 |
+
ADD
|
| 2531 |
+
SOFTMAX
|
| 2532 |
+
|
| 2533 |
+
It produces a DQ->LINEAR->Q pattern which can be fused by backend.
|
| 2534 |
+
"""
|
| 2535 |
+
|
| 2536 |
+
def is_view_op(node):
|
| 2537 |
+
return node.op == "call_function" and node.target in _VIEW_OPS
|
| 2538 |
+
|
| 2539 |
+
for node in graph_module.graph.nodes:
|
| 2540 |
+
# <TODO> Leslie: Here we verify that the quant node has exactly
|
| 2541 |
+
# one input FX node, with constant scalar value for scale and zero point.
|
| 2542 |
+
# For the case input of quant node has more than one input FX nodes,
|
| 2543 |
+
# extend the implementation to lift up all the connected nodes
|
| 2544 |
+
# before the view nodes to keep the topological order.
|
| 2545 |
+
if (
|
| 2546 |
+
node.op == "call_function"
|
| 2547 |
+
and node.target in _PER_TENSOR_QUANTIZE_OPS
|
| 2548 |
+
and len(node.all_input_nodes) == 1
|
| 2549 |
+
and is_view_op(node.all_input_nodes[0])
|
| 2550 |
+
):
|
| 2551 |
+
quant_node = node
|
| 2552 |
+
input_node_of_quant = quant_node.args[0]
|
| 2553 |
+
|
| 2554 |
+
# Check the nodes along lift up path has only 1 user node
|
| 2555 |
+
# Propagate view like node to find where to insert the new quant node
|
| 2556 |
+
could_lift_up = True
|
| 2557 |
+
current_node = quant_node
|
| 2558 |
+
input_node = current_node.args[0]
|
| 2559 |
+
while is_view_op(input_node):
|
| 2560 |
+
if len(input_node.users) != 1:
|
| 2561 |
+
could_lift_up = False
|
| 2562 |
+
break
|
| 2563 |
+
current_node = input_node
|
| 2564 |
+
input_node = current_node.args[0]
|
| 2565 |
+
|
| 2566 |
+
# Further check the input node of the first view node has only 1 user node
|
| 2567 |
+
if could_lift_up and len(input_node.users) == 1:
|
| 2568 |
+
# Replace dequant's input from quant to quant's input
|
| 2569 |
+
quant_node.replace_all_uses_with(input_node_of_quant)
|
| 2570 |
+
# Insert the new quant node
|
| 2571 |
+
with graph_module.graph.inserting_before(current_node):
|
| 2572 |
+
new_quant_node = graph_module.graph.node_copy(quant_node)
|
| 2573 |
+
input_node.replace_all_uses_with(new_quant_node)
|
| 2574 |
+
|
| 2575 |
+
# Update inputs of new_quant_node
|
| 2576 |
+
def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
|
| 2577 |
+
if n == input_node_of_quant:
|
| 2578 |
+
return input_node
|
| 2579 |
+
else:
|
| 2580 |
+
return n
|
| 2581 |
+
|
| 2582 |
+
new_args = map_arg(new_quant_node.args, maybe_replace_node)
|
| 2583 |
+
new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node)
|
| 2584 |
+
new_quant_node.args = new_args # type: ignore[assignment]
|
| 2585 |
+
new_quant_node.kwargs = new_kwargs # type: ignore[assignment]
|
| 2586 |
+
graph_module.graph.erase_node(quant_node)
|
| 2587 |
+
|
| 2588 |
+
graph_module.graph.lint()
|
| 2589 |
+
graph_module.recompile()
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/reinplace.py
ADDED
|
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import itertools
|
| 3 |
+
import logging
|
| 4 |
+
import operator
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Callable, Dict, List, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch._higher_order_ops.triton_kernel_wrap import (
|
| 11 |
+
kernel_side_table,
|
| 12 |
+
triton_kernel_wrapper_functional,
|
| 13 |
+
)
|
| 14 |
+
from torch._inductor import config, inductor_prims
|
| 15 |
+
from torch._inductor.fx_utils import get_node_storage, is_node_realized
|
| 16 |
+
from torch._inductor.lowering import (
|
| 17 |
+
inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings,
|
| 18 |
+
)
|
| 19 |
+
from torch._inductor.virtualized import V
|
| 20 |
+
from torch.fx.immutable_collections import immutable_dict
|
| 21 |
+
from torch.fx.passes.reinplace import _is_view_op
|
| 22 |
+
from torch.utils import _pytree as pytree
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
log = logging.getLogger(__name__)
|
| 26 |
+
aten = torch.ops.aten
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass(frozen=True)
|
| 30 |
+
class InplaceableOp:
|
| 31 |
+
inplace_op: Callable[..., Any]
|
| 32 |
+
mutated_arg: int
|
| 33 |
+
extra_check: Callable[[torch.fx.Node], bool] = lambda node: True
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
_SCATTER_OP_TO_VIEW = {
|
| 37 |
+
torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
|
| 38 |
+
torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
|
| 39 |
+
torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor,
|
| 40 |
+
torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default,
|
| 41 |
+
}
|
| 42 |
+
_VIEW_OP_TO_SCATTER = {v: k for k, v in _SCATTER_OP_TO_VIEW.items()}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def graph_call_function(graph: torch.fx.Graph, fn, *args, **kwargs):
|
| 46 |
+
fake_args, fake_kwargs = pytree.tree_map(
|
| 47 |
+
lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node,
|
| 48 |
+
(args, kwargs),
|
| 49 |
+
)
|
| 50 |
+
with V.fake_mode:
|
| 51 |
+
fake_result = fn(*fake_args, **fake_kwargs)
|
| 52 |
+
|
| 53 |
+
node = graph.call_function(fn, args, kwargs)
|
| 54 |
+
node.meta["val"] = fake_result
|
| 55 |
+
return node
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass
|
| 59 |
+
class ViewOp:
|
| 60 |
+
target: torch._ops.OpOverload
|
| 61 |
+
args: Tuple[Any, ...]
|
| 62 |
+
kwargs: Dict[str, Any]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _inplace_generalized_scatter(
|
| 66 |
+
inp: torch.Tensor, src: torch.Tensor, view_ops: List[ViewOp]
|
| 67 |
+
) -> torch.Tensor:
|
| 68 |
+
tmp = inp
|
| 69 |
+
for view in view_ops:
|
| 70 |
+
fake_args, fake_kwargs = pytree.tree_map(
|
| 71 |
+
lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node,
|
| 72 |
+
(view.args, view.kwargs),
|
| 73 |
+
)
|
| 74 |
+
tmp = view.target(tmp, *fake_args, **fake_kwargs)
|
| 75 |
+
try:
|
| 76 |
+
tmp.copy_(src)
|
| 77 |
+
except RuntimeError as e:
|
| 78 |
+
raise RuntimeError(
|
| 79 |
+
f"shape error in scatter op, can not broadcast {src.shape} to {tmp.shape}"
|
| 80 |
+
) from e
|
| 81 |
+
return inp
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _generalized_scatter(
|
| 85 |
+
inp: torch.Tensor, src: torch.Tensor, view_ops: List[ViewOp]
|
| 86 |
+
) -> torch.Tensor:
|
| 87 |
+
out = inp.clone()
|
| 88 |
+
return _inplace_generalized_scatter(out, src, view_ops)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _decompose_scatter_functional_helper(
|
| 92 |
+
graph: torch.fx.Graph,
|
| 93 |
+
inp: torch.Tensor,
|
| 94 |
+
src: torch.Tensor,
|
| 95 |
+
view_ops: List[ViewOp],
|
| 96 |
+
) -> torch.fx.Node:
|
| 97 |
+
view_op, view_ops_tail = view_ops[0], view_ops[1:]
|
| 98 |
+
|
| 99 |
+
if view_ops_tail:
|
| 100 |
+
view = graph_call_function(
|
| 101 |
+
graph, view_op.target, inp, *view_op.args, **view_op.kwargs
|
| 102 |
+
)
|
| 103 |
+
src = _decompose_scatter_functional_helper(graph, view, src, view_ops[1:]) # type: ignore[assignment]
|
| 104 |
+
|
| 105 |
+
return graph_call_function(
|
| 106 |
+
graph,
|
| 107 |
+
_VIEW_OP_TO_SCATTER[view_op.target],
|
| 108 |
+
inp,
|
| 109 |
+
src,
|
| 110 |
+
*view_op.args,
|
| 111 |
+
**view_op.kwargs,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _decompose_scatter_functional(
|
| 116 |
+
graph: torch.fx.Graph, node: torch.fx.Node
|
| 117 |
+
) -> torch.fx.Node:
|
| 118 |
+
"""Decompose _generalized_scatter to a sequence of view_scatter operations
|
| 119 |
+
|
| 120 |
+
e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)])
|
| 121 |
+
|
| 122 |
+
will become
|
| 123 |
+
|
| 124 |
+
view = aten.slice(inp, 0, 0, 10)
|
| 125 |
+
view_updated = aten.slice_scatter(view, src, 1, 10, -10)
|
| 126 |
+
inp_updated = aten.slice_scatter(inp, view_updated, 0, 0, 10)
|
| 127 |
+
"""
|
| 128 |
+
assert node.target is _generalized_scatter
|
| 129 |
+
inp, src, view_ops = node.args
|
| 130 |
+
return _decompose_scatter_functional_helper(graph, *node.args) # type: ignore[arg-type]
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _decompose_scatter_mutating(
|
| 134 |
+
graph: torch.fx.Graph, node: torch.fx.Node
|
| 135 |
+
) -> torch.fx.Node:
|
| 136 |
+
"""Decompose _generalized_scatter using mutations
|
| 137 |
+
|
| 138 |
+
e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)])
|
| 139 |
+
|
| 140 |
+
will become
|
| 141 |
+
|
| 142 |
+
inp_updated = aten.clone(inp)
|
| 143 |
+
slice1 = aten.slice(inp_updated, 0, 0, 10)
|
| 144 |
+
slice2 = aten.slice(slice1, 1, 10, -10)
|
| 145 |
+
slice2.copy_(src)
|
| 146 |
+
|
| 147 |
+
"""
|
| 148 |
+
assert node.target in (_generalized_scatter, _inplace_generalized_scatter)
|
| 149 |
+
inp, src, view_ops = node.args
|
| 150 |
+
assert not node.kwargs
|
| 151 |
+
|
| 152 |
+
if node.target is _generalized_scatter:
|
| 153 |
+
inp = graph_call_function(graph, aten.clone, inp)
|
| 154 |
+
|
| 155 |
+
tmp = inp
|
| 156 |
+
for view in view_ops: # type: ignore[union-attr]
|
| 157 |
+
tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs) # type: ignore[union-attr]
|
| 158 |
+
|
| 159 |
+
graph_call_function(graph, aten.copy_.default, tmp, src)
|
| 160 |
+
return inp # type: ignore[return-value]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# View ops whose view_scatter op is lowered into mutations anyway,
|
| 164 |
+
# so is never a pessimisation to decompose.
|
| 165 |
+
_ALWAYS_MUTATING_SCATTER_OPS = {
|
| 166 |
+
aten.as_strided.default,
|
| 167 |
+
aten.diagonal.default,
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def scatter_always_uses_mutation(node: torch.fx.Node) -> bool:
|
| 172 |
+
_, _, view_ops = node.args
|
| 173 |
+
return any(view.target in _ALWAYS_MUTATING_SCATTER_OPS for view in view_ops) # type: ignore[union-attr]
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def should_reinplace_scatter(node: torch.fx.Node) -> bool:
|
| 177 |
+
"""Choose between mutating and functional scatter decompositions
|
| 178 |
+
|
| 179 |
+
Reinplacing view scatter ops can be pessimising as it blocks fusion with the
|
| 180 |
+
input or output tensor computations. However, it is still profitable if the
|
| 181 |
+
input and output would have been realized anyway.
|
| 182 |
+
|
| 183 |
+
"""
|
| 184 |
+
inp, src, view_ops = node.args
|
| 185 |
+
|
| 186 |
+
# Mutating scatter ops unconditionally realize input and output
|
| 187 |
+
if scatter_always_uses_mutation(node):
|
| 188 |
+
return True
|
| 189 |
+
|
| 190 |
+
if is_node_realized(inp) and is_node_realized(node): # type: ignore[arg-type]
|
| 191 |
+
return True
|
| 192 |
+
|
| 193 |
+
# If the output is copied back into the input, this forces both to be
|
| 194 |
+
# realized as the output is a user of the input
|
| 195 |
+
if inp.op in ("placeholder", "get_attr") and any( # type: ignore[union-attr]
|
| 196 |
+
user.target is aten.copy_.default and user.args[0] is inp for user in node.users
|
| 197 |
+
):
|
| 198 |
+
return True
|
| 199 |
+
|
| 200 |
+
# Otherwise, assume fusions will make functional variants profitable
|
| 201 |
+
return False
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def decompose_generalized_scatter(graph: torch.fx.Graph) -> None:
|
| 205 |
+
"""Replace _generalized_scatter with normal aten ops"""
|
| 206 |
+
for node in itertools.chain(
|
| 207 |
+
graph.find_nodes(op="call_function", target=_generalized_scatter),
|
| 208 |
+
graph.find_nodes(op="call_function", target=_inplace_generalized_scatter),
|
| 209 |
+
):
|
| 210 |
+
use_mutation = (
|
| 211 |
+
node.target is _inplace_generalized_scatter
|
| 212 |
+
or scatter_always_uses_mutation(node)
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
with graph.inserting_before(node):
|
| 216 |
+
if use_mutation:
|
| 217 |
+
new_node = _decompose_scatter_mutating(graph, node)
|
| 218 |
+
else:
|
| 219 |
+
new_node = _decompose_scatter_functional(graph, node)
|
| 220 |
+
|
| 221 |
+
node.replace_all_uses_with(new_node)
|
| 222 |
+
graph.erase_node(node)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
|
| 226 |
+
"""
|
| 227 |
+
This canonicalizes view scatter ops into a generalized form, defined as:
|
| 228 |
+
def scatter(inp, src, views):
|
| 229 |
+
tmp = inp.clone()
|
| 230 |
+
for view in views:
|
| 231 |
+
tmp = view(tmp)
|
| 232 |
+
tmp.copy_(src)
|
| 233 |
+
|
| 234 |
+
We also fuse consecutive view scatter ops of the form
|
| 235 |
+
a = scatter(view2(self), src, [view1])
|
| 236 |
+
b = scatter(self, a, [view2])
|
| 237 |
+
which can be rewritten as
|
| 238 |
+
b = scatter(self, src, [view2, view1])
|
| 239 |
+
a = view2(b)
|
| 240 |
+
|
| 241 |
+
This is both more efficient as we only do a single scatter, and also
|
| 242 |
+
easier to reinplace since there is only one use of `self`
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
node_to_view_base: Dict[torch.fx.Node, torch.fx.Node] = {}
|
| 246 |
+
node_to_view_op: Dict[torch.fx.Node, List[ViewOp]] = defaultdict(list)
|
| 247 |
+
|
| 248 |
+
def handle_views(node: torch.fx.Node):
|
| 249 |
+
inp = node.args[0]
|
| 250 |
+
node_to_view_base[node] = node_to_view_base.get(inp, inp) # type: ignore[arg-type]
|
| 251 |
+
node_to_view_op[node] = [
|
| 252 |
+
*node_to_view_op[inp], # type: ignore[index]
|
| 253 |
+
ViewOp(
|
| 254 |
+
node.target, # type: ignore[arg-type]
|
| 255 |
+
args=node.args[1:],
|
| 256 |
+
kwargs=node.kwargs,
|
| 257 |
+
),
|
| 258 |
+
]
|
| 259 |
+
|
| 260 |
+
def handle_view_scatter(node: torch.fx.Node):
|
| 261 |
+
assert len(node.args) >= 2
|
| 262 |
+
inp, src = node.args[:2]
|
| 263 |
+
|
| 264 |
+
scatter_view_op = ViewOp(
|
| 265 |
+
_SCATTER_OP_TO_VIEW[node.target],
|
| 266 |
+
args=node.args[2:],
|
| 267 |
+
kwargs=node.kwargs,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
def can_fuse():
|
| 271 |
+
if src.target is not _generalized_scatter: # type: ignore[union-attr]
|
| 272 |
+
return False
|
| 273 |
+
src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr]
|
| 274 |
+
|
| 275 |
+
inp_base = node_to_view_base.get(inp, inp) # type: ignore[arg-type]
|
| 276 |
+
src_base = node_to_view_base.get(src_inp, src_inp) # type: ignore[arg-type]
|
| 277 |
+
return inp_base is src_base and node_to_view_op[src_inp] == [ # type: ignore[index]
|
| 278 |
+
*node_to_view_op[inp], # type: ignore[index]
|
| 279 |
+
scatter_view_op,
|
| 280 |
+
]
|
| 281 |
+
|
| 282 |
+
if not can_fuse():
|
| 283 |
+
with graph.inserting_before(node):
|
| 284 |
+
new_node = graph_call_function(
|
| 285 |
+
graph,
|
| 286 |
+
_generalized_scatter,
|
| 287 |
+
inp,
|
| 288 |
+
src,
|
| 289 |
+
[scatter_view_op],
|
| 290 |
+
)
|
| 291 |
+
node.replace_all_uses_with(new_node)
|
| 292 |
+
graph.erase_node(node)
|
| 293 |
+
return
|
| 294 |
+
|
| 295 |
+
src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr]
|
| 296 |
+
with graph.inserting_before(src): # type: ignore[arg-type]
|
| 297 |
+
new_node = graph_call_function(
|
| 298 |
+
graph,
|
| 299 |
+
_generalized_scatter,
|
| 300 |
+
inp,
|
| 301 |
+
src_src,
|
| 302 |
+
[scatter_view_op, *src_scatter_view_op], # type: ignore[misc]
|
| 303 |
+
)
|
| 304 |
+
node.replace_all_uses_with(new_node)
|
| 305 |
+
graph.erase_node(node)
|
| 306 |
+
|
| 307 |
+
if src.users: # type: ignore[union-attr]
|
| 308 |
+
new_src = graph_call_function(
|
| 309 |
+
graph,
|
| 310 |
+
_SCATTER_OP_TO_VIEW[node.target],
|
| 311 |
+
new_node,
|
| 312 |
+
*node.args[2:],
|
| 313 |
+
**node.kwargs,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
handle_views(new_src)
|
| 317 |
+
src.replace_all_uses_with(new_src) # type: ignore[union-attr]
|
| 318 |
+
|
| 319 |
+
graph.erase_node(src) # type: ignore[arg-type]
|
| 320 |
+
|
| 321 |
+
for node in graph.nodes:
|
| 322 |
+
if _is_view_op(node.target):
|
| 323 |
+
handle_views(node)
|
| 324 |
+
elif node.target in _SCATTER_OP_TO_VIEW:
|
| 325 |
+
handle_view_scatter(node)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
inplaceable_ops = {
|
| 329 |
+
aten.index_put.default: InplaceableOp(aten.index_put_.default, 0),
|
| 330 |
+
aten._unsafe_index_put.default: InplaceableOp(inductor_prims._unsafe_index_put_, 0),
|
| 331 |
+
_generalized_scatter: InplaceableOp(
|
| 332 |
+
_inplace_generalized_scatter,
|
| 333 |
+
0,
|
| 334 |
+
extra_check=should_reinplace_scatter,
|
| 335 |
+
),
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
try:
|
| 339 |
+
c10d_functional = torch.ops._c10d_functional
|
| 340 |
+
inplaceable_collective_ops = {
|
| 341 |
+
c10d_functional.all_reduce.default: InplaceableOp(
|
| 342 |
+
c10d_functional.all_reduce_.default, 0
|
| 343 |
+
),
|
| 344 |
+
c10d_functional.all_reduce_coalesced.default: InplaceableOp(
|
| 345 |
+
c10d_functional.all_reduce_coalesced_.default, 0
|
| 346 |
+
),
|
| 347 |
+
}
|
| 348 |
+
inplaceable_ops.update(inplaceable_collective_ops)
|
| 349 |
+
except AttributeError:
|
| 350 |
+
# _c10d_functional ops are only available when torch
|
| 351 |
+
# is built with USE_DISTRIBUTED=1.
|
| 352 |
+
pass
|
| 353 |
+
|
| 354 |
+
inplaceable_foreach_ops: Dict[torch._ops.OpOverload, InplaceableOp] = {}
|
| 355 |
+
for outplace_op, inplace_op in inplaceable_foreach_ops_lowerings.items():
|
| 356 |
+
inplaceable_foreach_ops[outplace_op] = InplaceableOp(inplace_op, 0)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
inplaceable_triton_ops = {triton_kernel_wrapper_functional}
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
# Operators that don't depend on the tensor data
|
| 363 |
+
META_ONLY_OPS = {
|
| 364 |
+
aten.sym_size.int,
|
| 365 |
+
aten.sym_stride.int,
|
| 366 |
+
aten.sym_numel.default,
|
| 367 |
+
aten.sym_storage_offset.default,
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
|
| 372 |
+
"""
|
| 373 |
+
Reinplaces in-placeable operations.
|
| 374 |
+
If there are no uses of a view of the mutated arg after the current node,
|
| 375 |
+
it is possible to inplace the op.
|
| 376 |
+
This above algorithm could be justified by observing side effects. While
|
| 377 |
+
we traverse the graph in forwards direction, only latter nodes could view
|
| 378 |
+
side effects of the current node. If the current node is not used later as
|
| 379 |
+
well as no view of this node is used later in the graph, then it is safe to
|
| 380 |
+
inplace as there would be no way to observe the side effects.
|
| 381 |
+
This condition is slightly different for graph inputs where they can only
|
| 382 |
+
be inplaced if the above condition is true and there's a copy_ in the
|
| 383 |
+
epilogue that signals that the caller wants to observe the mutation.
|
| 384 |
+
|
| 385 |
+
Unlike JIT Inductor, AOTInductor currently unlifts weights and buffers from
|
| 386 |
+
input args, so instead of checking mutation on placeholder, AOTInductor
|
| 387 |
+
checks mutation on get_attr. This is subject to change in future.
|
| 388 |
+
"""
|
| 389 |
+
|
| 390 |
+
copy_args_to_copy_nodes = {}
|
| 391 |
+
# maps argument to the first copy_ node that mutates it.
|
| 392 |
+
copy_nodes = {}
|
| 393 |
+
mutated_inputs = set()
|
| 394 |
+
storage_to_nodes = defaultdict(list)
|
| 395 |
+
node_order: Dict[Any, int] = {}
|
| 396 |
+
for i, node in enumerate(reversed(graph.nodes)):
|
| 397 |
+
node_order[node] = len(graph.nodes) - i - 1
|
| 398 |
+
storage_to_nodes[get_node_storage(node)].append(node)
|
| 399 |
+
if node.target == aten.copy_.default and node.args[0].op in (
|
| 400 |
+
"placeholder",
|
| 401 |
+
"get_attr",
|
| 402 |
+
):
|
| 403 |
+
dst = node.args[0]
|
| 404 |
+
src = node.args[1]
|
| 405 |
+
# If the target is a getitem and it indexes a possible clone,
|
| 406 |
+
# then skip over it
|
| 407 |
+
if src.target == operator.getitem and (
|
| 408 |
+
(
|
| 409 |
+
src.args[0].target == triton_kernel_wrapper_functional
|
| 410 |
+
and src.args[0].kwargs["kwargs"][src.args[1]] == node.args[0]
|
| 411 |
+
)
|
| 412 |
+
or (src.args[0].target in inplaceable_foreach_ops)
|
| 413 |
+
or (src.args[0].target == torch.ops.higher_order.auto_functionalized)
|
| 414 |
+
):
|
| 415 |
+
src = src.args[0]
|
| 416 |
+
|
| 417 |
+
copy_args_to_copy_nodes[(dst, src)] = node
|
| 418 |
+
copy_nodes[dst] = node
|
| 419 |
+
|
| 420 |
+
mutated_inputs.add(node.args[0])
|
| 421 |
+
|
| 422 |
+
def any_use_of_views_after_node(node, shared_view_nodes, *, copy_node, mutated_arg):
|
| 423 |
+
node_loc = node_order[node]
|
| 424 |
+
copy_node_loc = node_order[copy_node] if copy_node is not None else None
|
| 425 |
+
|
| 426 |
+
def is_meta_only_user(node):
|
| 427 |
+
if _is_view_op(node.target):
|
| 428 |
+
return all(is_meta_only_user(u) for u in node.users)
|
| 429 |
+
return node.target in META_ONLY_OPS
|
| 430 |
+
|
| 431 |
+
for view in shared_view_nodes:
|
| 432 |
+
for user in view.users:
|
| 433 |
+
user_loc = node_order[user]
|
| 434 |
+
# Skip all users before node
|
| 435 |
+
if user_loc <= node_loc:
|
| 436 |
+
continue
|
| 437 |
+
# Ignore uses after the copy_ epilogue node, where the input
|
| 438 |
+
# has already been mutated anyway
|
| 439 |
+
if copy_node_loc is not None and copy_node_loc <= user_loc:
|
| 440 |
+
continue
|
| 441 |
+
# Reinplacing does not change shape metadata
|
| 442 |
+
if is_meta_only_user(user):
|
| 443 |
+
continue
|
| 444 |
+
# If our graph looks like:
|
| 445 |
+
# foo(mutated_arg)
|
| 446 |
+
# mutated_arg.copy_(other)
|
| 447 |
+
# then it's safe for us to reinplace foo because mutated_arg
|
| 448 |
+
# will get overwritten anyways.
|
| 449 |
+
if (
|
| 450 |
+
user.target is torch.ops.aten.copy_.default
|
| 451 |
+
and mutated_arg is user.args[0]
|
| 452 |
+
):
|
| 453 |
+
continue
|
| 454 |
+
return True
|
| 455 |
+
return False
|
| 456 |
+
|
| 457 |
+
def can_inplace(node, mutated_arg):
|
| 458 |
+
if isinstance(mutated_arg, (list, tuple)):
|
| 459 |
+
unique_storages = {get_node_storage(arg) for arg in mutated_arg}
|
| 460 |
+
if len(unique_storages) != len(mutated_arg):
|
| 461 |
+
# at least two Tensors in mutated_arg alias each other, so we can't reinplace it.
|
| 462 |
+
# We can probably do better (that is, reinplace one of them and clone the other)
|
| 463 |
+
# but that requires more work and mutable List[Tensor] are not that common.
|
| 464 |
+
return False
|
| 465 |
+
return all(can_inplace(node, arg) for arg in mutated_arg)
|
| 466 |
+
|
| 467 |
+
if get_node_storage(mutated_arg) is None:
|
| 468 |
+
return False
|
| 469 |
+
shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)]
|
| 470 |
+
|
| 471 |
+
if mutated_arg.op in ("placeholder", "get_attr"):
|
| 472 |
+
# Get the first copy_ node that mutates the mutated_arg.
|
| 473 |
+
copy_node = copy_nodes.get(mutated_arg, None)
|
| 474 |
+
if copy_node is None:
|
| 475 |
+
# There is no copy_ back to the candidate mutated_arg (which is a graph input).
|
| 476 |
+
# Therefore the semantics of the program are that it does not mutate
|
| 477 |
+
# mutated_arg, so we cannot re-inplace it.
|
| 478 |
+
return False
|
| 479 |
+
if any_use_of_views_after_node(
|
| 480 |
+
node, shared_view_nodes, copy_node=copy_node, mutated_arg=mutated_arg
|
| 481 |
+
):
|
| 482 |
+
return False
|
| 483 |
+
|
| 484 |
+
return True
|
| 485 |
+
elif any(view.op in ("placeholder", "get_attr") for view in shared_view_nodes):
|
| 486 |
+
# This should never happen in auto_functionalize_v2 non-inference mode,
|
| 487 |
+
# since all mutated_arg are bases.
|
| 488 |
+
|
| 489 |
+
# If mutated arg is view of any of the inputs of the graph,
|
| 490 |
+
# do not allow for inplacing.
|
| 491 |
+
# This would require more sophisticated algorithm to handle
|
| 492 |
+
return False
|
| 493 |
+
else:
|
| 494 |
+
return not any_use_of_views_after_node(
|
| 495 |
+
node, shared_view_nodes, copy_node=None, mutated_arg=mutated_arg
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
def log_inplace_results(
|
| 499 |
+
node_name,
|
| 500 |
+
old_tensors_to_clone,
|
| 501 |
+
tensors_to_clone,
|
| 502 |
+
possibly_missed_reinplacing_opportunities,
|
| 503 |
+
):
|
| 504 |
+
log.info(
|
| 505 |
+
"For node %s, attempted to reinplace %s. We were unable to reinplace %s; "
|
| 506 |
+
"%s (if non-empty) are possible missed reinplacing opportunities that may be bad for "
|
| 507 |
+
"memory usage and performance.",
|
| 508 |
+
node_name,
|
| 509 |
+
old_tensors_to_clone,
|
| 510 |
+
tensors_to_clone,
|
| 511 |
+
possibly_missed_reinplacing_opportunities,
|
| 512 |
+
)
|
| 513 |
+
torch._dynamo.utils.counters["inductor"][
|
| 514 |
+
"possibly_missed_reinplacing_opportunities"
|
| 515 |
+
] += len(possibly_missed_reinplacing_opportunities)
|
| 516 |
+
|
| 517 |
+
replace_dict: Dict[torch.fx.Node, torch.fx.Node] = {}
|
| 518 |
+
|
| 519 |
+
def reinplace_and_refine_tensors_to_clone(
|
| 520 |
+
old_tensors_to_clone, kwargs, node_name, auto_functionalize_v2=False
|
| 521 |
+
):
|
| 522 |
+
tensors_to_clone: List[str] = []
|
| 523 |
+
storage_of_reinplaced_args = set()
|
| 524 |
+
possibly_missed_reinplacing_opportunities = []
|
| 525 |
+
|
| 526 |
+
def tensor_with_same_storage_already_reinplaced(arg):
|
| 527 |
+
if isinstance(arg, (list, tuple)):
|
| 528 |
+
return any(
|
| 529 |
+
get_node_storage(a) in storage_of_reinplaced_args for a in arg
|
| 530 |
+
)
|
| 531 |
+
return get_node_storage(mutated_arg) in storage_of_reinplaced_args
|
| 532 |
+
|
| 533 |
+
for arg in old_tensors_to_clone:
|
| 534 |
+
assert arg in kwargs
|
| 535 |
+
|
| 536 |
+
mutated_arg = kwargs[arg]
|
| 537 |
+
|
| 538 |
+
# Let's say we have:
|
| 539 |
+
# - op(x, y) that mutates both x and y
|
| 540 |
+
# - new_x, new_y = functional_op(x, y) is the functional variant
|
| 541 |
+
# If we are presented with functional_op(x, x), we must not reinplace
|
| 542 |
+
# this into op(x, x), because then it would be writing to the same Tensor.
|
| 543 |
+
# Instead, it's OK to reinplace one of them and to clone the other:
|
| 544 |
+
# >>> y = x.clone()
|
| 545 |
+
# >>> op(x, y)
|
| 546 |
+
# This also applies if we have views: functional_op(x, x[0])
|
| 547 |
+
# should not reinplace into op(x, x[0]).
|
| 548 |
+
should_attempt_reinplace = not tensor_with_same_storage_already_reinplaced(
|
| 549 |
+
mutated_arg
|
| 550 |
+
)
|
| 551 |
+
if should_attempt_reinplace and can_inplace(node, mutated_arg):
|
| 552 |
+
# In general, we probably do not need those optimizations.
|
| 553 |
+
copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
|
| 554 |
+
if copy_node is not None:
|
| 555 |
+
replace_dict[copy_node] = copy_node.args[0]
|
| 556 |
+
if not auto_functionalize_v2:
|
| 557 |
+
for user in node.users:
|
| 558 |
+
# For auto_functionalize_v2, arg is the index of the base, where base at index i corresponds to
|
| 559 |
+
# output atindex size(out)+i.
|
| 560 |
+
# This used to compare string with integers before for auto_functionalize_v2. Not sure
|
| 561 |
+
# if it was needed for inplaceable_triton_ops?
|
| 562 |
+
if user.target == operator.getitem and user.args[1] == arg:
|
| 563 |
+
replace_dict[user] = mutated_arg
|
| 564 |
+
|
| 565 |
+
if isinstance(mutated_arg, (list, tuple)):
|
| 566 |
+
for a in mutated_arg:
|
| 567 |
+
storage_of_reinplaced_args.add(get_node_storage(a))
|
| 568 |
+
else:
|
| 569 |
+
storage_of_reinplaced_args.add(get_node_storage(mutated_arg))
|
| 570 |
+
else:
|
| 571 |
+
if should_attempt_reinplace:
|
| 572 |
+
possibly_missed_reinplacing_opportunities.append(arg)
|
| 573 |
+
tensors_to_clone.append(arg)
|
| 574 |
+
|
| 575 |
+
log_inplace_results(
|
| 576 |
+
node_name,
|
| 577 |
+
old_tensors_to_clone,
|
| 578 |
+
tensors_to_clone,
|
| 579 |
+
possibly_missed_reinplacing_opportunities,
|
| 580 |
+
)
|
| 581 |
+
return tensors_to_clone
|
| 582 |
+
|
| 583 |
+
for node in graph.nodes:
|
| 584 |
+
if (inplaceable_op := inplaceable_ops.get(node.target, None)) is not None:
|
| 585 |
+
mutated_arg = node.args[inplaceable_op.mutated_arg]
|
| 586 |
+
if can_inplace(node, mutated_arg) and inplaceable_op.extra_check(node):
|
| 587 |
+
# TODO(yifu): this doesn't properly remove copy epilogues for
|
| 588 |
+
# ops that mutate multiple inputs. Need to revise the copy
|
| 589 |
+
# node tracking logic to support the case.
|
| 590 |
+
copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
|
| 591 |
+
if copy_node is not None:
|
| 592 |
+
replace_dict[copy_node] = copy_node.args[0]
|
| 593 |
+
node.target = inplaceable_op.inplace_op
|
| 594 |
+
elif node.target == torch.ops.higher_order.auto_functionalized_v2:
|
| 595 |
+
_mutable_op = node.args[0]
|
| 596 |
+
kwargs = node.kwargs
|
| 597 |
+
|
| 598 |
+
all_bases = kwargs["_all_bases"]
|
| 599 |
+
bases_to_clone = range(len(all_bases))
|
| 600 |
+
base_tensors_dct = dict(enumerate(all_bases))
|
| 601 |
+
new_bases_to_clone: List[int] = reinplace_and_refine_tensors_to_clone(
|
| 602 |
+
bases_to_clone,
|
| 603 |
+
base_tensors_dct,
|
| 604 |
+
node.target,
|
| 605 |
+
auto_functionalize_v2=True,
|
| 606 |
+
)
|
| 607 |
+
# Stash the metadata. There is a pass later on where we decompose
|
| 608 |
+
# auto_functionalized into clones + a mutable op; this metadata
|
| 609 |
+
# tells the decomp to only clone the following inputs
|
| 610 |
+
node.meta["only_clone_these_tensors"] = new_bases_to_clone
|
| 611 |
+
elif node.target == torch.ops.higher_order.auto_functionalized:
|
| 612 |
+
_mutable_op = node.args[0]
|
| 613 |
+
from torch._higher_order_ops.auto_functionalize import get_mutable_args
|
| 614 |
+
|
| 615 |
+
tensors_to_clone, _ = get_mutable_args(_mutable_op)
|
| 616 |
+
# Don't try to reinplace Optional[Tensor] args that are None.
|
| 617 |
+
tensors_to_clone = [
|
| 618 |
+
t for t in tensors_to_clone if node.kwargs[t] is not None
|
| 619 |
+
]
|
| 620 |
+
tensors_to_clone = reinplace_and_refine_tensors_to_clone(
|
| 621 |
+
tensors_to_clone,
|
| 622 |
+
node.kwargs,
|
| 623 |
+
_mutable_op._name,
|
| 624 |
+
auto_functionalize_v2=False,
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
# Stash the metadata. There is a pass later on where we decompose
|
| 628 |
+
# auto_functionalized into clones + a mutable op; this metadata
|
| 629 |
+
# tells the decomp to only clone the following inputs
|
| 630 |
+
node.meta["only_clone_these_tensors"] = tensors_to_clone
|
| 631 |
+
elif node.target in inplaceable_triton_ops:
|
| 632 |
+
kernel_idx = node.kwargs["kernel_idx"]
|
| 633 |
+
kernel = kernel_side_table.get_kernel(kernel_idx)
|
| 634 |
+
from triton.runtime.autotuner import Autotuner
|
| 635 |
+
from triton.runtime.jit import JITFunction
|
| 636 |
+
|
| 637 |
+
if isinstance(kernel, JITFunction):
|
| 638 |
+
kernel_name = kernel.fn.__name__
|
| 639 |
+
elif isinstance(kernel, Autotuner):
|
| 640 |
+
if config.is_fbcode():
|
| 641 |
+
# Autotuner has different implementations for AMD and NV
|
| 642 |
+
if torch.version.hip is None:
|
| 643 |
+
kernel_name = kernel.base_fn.__name__
|
| 644 |
+
else:
|
| 645 |
+
kernel_name = kernel.fn.__name__
|
| 646 |
+
else:
|
| 647 |
+
kernel_name = kernel.base_fn.__name__
|
| 648 |
+
else:
|
| 649 |
+
raise AssertionError("Unknown triton kernel type")
|
| 650 |
+
|
| 651 |
+
# inplaceable_triton_ops take an additional argument called
|
| 652 |
+
# tensors_to_clone which contain a list of tensors to clone
|
| 653 |
+
# This pass iterates over them and sees which ones are safe
|
| 654 |
+
# to eliminate (i.e. no longer need the clones)
|
| 655 |
+
tensors_to_clone = reinplace_and_refine_tensors_to_clone(
|
| 656 |
+
node.kwargs["tensors_to_clone"], node.kwargs["kwargs"], kernel_name
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
kwargs = dict(node.kwargs)
|
| 660 |
+
kwargs["tensors_to_clone"] = tensors_to_clone
|
| 661 |
+
node.kwargs = immutable_dict(kwargs)
|
| 662 |
+
elif (
|
| 663 |
+
inplaceable_op := inplaceable_foreach_ops.get(node.target, None)
|
| 664 |
+
) is not None:
|
| 665 |
+
mutated_args = node.args[inplaceable_op.mutated_arg]
|
| 666 |
+
|
| 667 |
+
if not all((arg, node) in copy_args_to_copy_nodes for arg in mutated_args):
|
| 668 |
+
continue
|
| 669 |
+
|
| 670 |
+
if can_inplace(node, mutated_args):
|
| 671 |
+
for arg in mutated_args:
|
| 672 |
+
copy_node = copy_args_to_copy_nodes[(arg, node)]
|
| 673 |
+
replace_dict[copy_node] = copy_node.args[0]
|
| 674 |
+
|
| 675 |
+
node.target = inplaceable_op.inplace_op
|
| 676 |
+
for node, replacement in replace_dict.items():
|
| 677 |
+
while replacement in replace_dict:
|
| 678 |
+
replacement = replace_dict[replacement]
|
| 679 |
+
replace_dict[node] = replacement
|
| 680 |
+
|
| 681 |
+
node.replace_all_uses_with(replacement)
|
| 682 |
+
graph.erase_node(node)
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
def reinplace_inplaceable_ops(graph: torch.fx.Graph) -> None:
|
| 686 |
+
canonicalize_view_scatter_ops(graph)
|
| 687 |
+
reinplace_inplaceable_ops_core(graph)
|
| 688 |
+
decompose_generalized_scatter(graph)
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/replace_random.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import collections
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
| 7 |
+
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
| 8 |
+
|
| 9 |
+
from .. import config, inductor_prims
|
| 10 |
+
from ..pattern_matcher import (
|
| 11 |
+
CallFunctionVarArgs,
|
| 12 |
+
Match,
|
| 13 |
+
PatternMatcherPass,
|
| 14 |
+
register_graph_pattern,
|
| 15 |
+
)
|
| 16 |
+
from ..virtualized import V
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
log = logging.getLogger(__name__)
|
| 20 |
+
patterns = PatternMatcherPass()
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def replace_random_passes(gm: torch.fx.GraphModule):
|
| 25 |
+
"""Modify the given FX graph to use backend-native random ops"""
|
| 26 |
+
if config.fallback_random:
|
| 27 |
+
return 0
|
| 28 |
+
|
| 29 |
+
count = patterns.apply(gm)
|
| 30 |
+
with GraphTransformObserver(
|
| 31 |
+
gm, "fuse_seed_creation_pass", config.trace.log_url_for_graph_xform
|
| 32 |
+
):
|
| 33 |
+
count += fuse_seed_creation_pass(gm.graph)
|
| 34 |
+
|
| 35 |
+
return count
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def fuse_seed_creation_pass(graph: torch.fx.Graph):
|
| 39 |
+
"""
|
| 40 |
+
Horizontally fuse all the seed generation on each device
|
| 41 |
+
|
| 42 |
+
a = inductor_seed(dev)
|
| 43 |
+
b = inductor_seed(dev)
|
| 44 |
+
|
| 45 |
+
Becomes:
|
| 46 |
+
seeds = inductor_seeds(2, dev)
|
| 47 |
+
a = inductor_lookup_seed(seeds, 0)
|
| 48 |
+
b = inductor_lookup_seed(seeds, 1)
|
| 49 |
+
|
| 50 |
+
We do this because seed creation is entirely launch overhead bound.
|
| 51 |
+
"""
|
| 52 |
+
device_seeds = collections.defaultdict(list)
|
| 53 |
+
for node in graph.nodes:
|
| 54 |
+
if CallFunctionVarArgs(inductor_prims.seed).match(node):
|
| 55 |
+
device_seeds[node.args[0]].append(node)
|
| 56 |
+
|
| 57 |
+
if not device_seeds:
|
| 58 |
+
return 0
|
| 59 |
+
|
| 60 |
+
for device, seeds in device_seeds.items():
|
| 61 |
+
with graph.inserting_before(seeds[0]):
|
| 62 |
+
combined = graph.call_function(inductor_prims.seeds, (len(seeds), device))
|
| 63 |
+
with V.fake_mode:
|
| 64 |
+
combined.meta["val"] = torch.empty(
|
| 65 |
+
[len(seeds)], device=device, dtype=torch.int64
|
| 66 |
+
)
|
| 67 |
+
combined.meta["tensor_meta"] = _extract_tensor_metadata(
|
| 68 |
+
combined.meta["val"]
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
for idx, seed in enumerate(seeds):
|
| 72 |
+
with graph.inserting_before(seed):
|
| 73 |
+
new_seed = graph.call_function(
|
| 74 |
+
inductor_prims.lookup_seed, (combined, idx)
|
| 75 |
+
)
|
| 76 |
+
seed.replace_all_uses_with(new_seed)
|
| 77 |
+
new_seed.meta.update(seed.meta)
|
| 78 |
+
graph.erase_node(seed)
|
| 79 |
+
|
| 80 |
+
return len(device_seeds)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def default_kwargs(device):
|
| 84 |
+
return {}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_device(device):
|
| 88 |
+
if device is not None:
|
| 89 |
+
return device
|
| 90 |
+
return torch.empty([]).device # default device
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@register_graph_pattern(CallFunctionVarArgs(aten.rand.default), pass_dict=patterns)
|
| 94 |
+
@register_graph_pattern(CallFunctionVarArgs(aten.rand.generator), pass_dict=patterns)
|
| 95 |
+
@register_graph_pattern(CallFunctionVarArgs(aten.randn.default), pass_dict=patterns)
|
| 96 |
+
@register_graph_pattern(CallFunctionVarArgs(aten.randn.generator), pass_dict=patterns)
|
| 97 |
+
def replace_random(
|
| 98 |
+
match: Match,
|
| 99 |
+
size,
|
| 100 |
+
*,
|
| 101 |
+
generator=None,
|
| 102 |
+
dtype=None,
|
| 103 |
+
device=None,
|
| 104 |
+
layout=None,
|
| 105 |
+
pin_memory=None,
|
| 106 |
+
):
|
| 107 |
+
if generator is not None:
|
| 108 |
+
return
|
| 109 |
+
|
| 110 |
+
def replacement(size):
|
| 111 |
+
result = inductor_prims.random(
|
| 112 |
+
size, inductor_prims.seed(device), mode, **default_kwargs(device)
|
| 113 |
+
)
|
| 114 |
+
if dtype is not None:
|
| 115 |
+
result = result.to(dtype)
|
| 116 |
+
return result
|
| 117 |
+
|
| 118 |
+
mode = {
|
| 119 |
+
aten.rand: "rand",
|
| 120 |
+
aten.randn: "randn",
|
| 121 |
+
}[
|
| 122 |
+
match.output_node().target.overloadpacket # type: ignore[union-attr]
|
| 123 |
+
] # type: ignore[union-attr]
|
| 124 |
+
device = get_device(device)
|
| 125 |
+
match.replace_by_example(replacement, [size])
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@register_graph_pattern(CallFunctionVarArgs(aten.randint.low), pass_dict=patterns)
|
| 129 |
+
def replace_randint(
|
| 130 |
+
match: Match,
|
| 131 |
+
low,
|
| 132 |
+
high,
|
| 133 |
+
size,
|
| 134 |
+
*,
|
| 135 |
+
dtype=torch.int64,
|
| 136 |
+
device=None,
|
| 137 |
+
layout=None,
|
| 138 |
+
pin_memory=None,
|
| 139 |
+
):
|
| 140 |
+
def replacement(low, high, size):
|
| 141 |
+
result = inductor_prims.randint(low, high, size, inductor_prims.seed(device))
|
| 142 |
+
return result.to(dtype)
|
| 143 |
+
|
| 144 |
+
device = get_device(device)
|
| 145 |
+
match.replace_by_example(replacement, [low, high, size])
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (218 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-311.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-311.pyc
ADDED
|
Binary file (16.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-311.pyc
ADDED
|
Binary file (16.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-311.pyc
ADDED
|
Binary file (18.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-311.pyc
ADDED
|
Binary file (9.56 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-311.pyc
ADDED
|
Binary file (17 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-311.pyc
ADDED
|
Binary file (19.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-311.pyc
ADDED
|
Binary file (49.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-311.pyc
ADDED
|
Binary file (20.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_18.cpython-311.pyc
ADDED
|
Binary file (37.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_19.cpython-311.pyc
ADDED
|
Binary file (17.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-311.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-311.pyc
ADDED
|
Binary file (15.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-311.pyc
ADDED
|
Binary file (15.2 kB). View file
|
|
|