diff --git a/.gitattributes b/.gitattributes index 79fae07f586a29ba2802c1f8fe8eebc37337f771..0b10ef2aa75c0640f04a077b4f6b8df5dea6fab5 100644 --- a/.gitattributes +++ b/.gitattributes @@ -137,3 +137,8 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_ .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/lowering.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ir.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/simd.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ir.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ir.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab890966a929054aa103bc22e66f86036f63c471 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ir.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:092aa1e8b674926d96609f7d70e837e88bb2433dce56bfb3b265696082850bf7 +size 361762 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..817ddcfc6f7f1f5a006b3be628a00454bec6b2f2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e246926a23a6eb2b52e06e182bbc58f99d342ad --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:088ceca24b4ba43a80ac34a889d96ba95ca4779739aea2e41a34600e2f7fd8ae +size 103679 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/simd.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/simd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc9072415e9004576e3db5e5576affb930e10d64 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/simd.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a9dde4c92669d913e3e0a7bf7bbc82b533f3e5492ce09ea3322607c4c9cec549 +size 106873 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d431103e314de1ab3dcad05e50a5332b487cd76 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:454c0716087aee149fe6ab1aaf1a2f50703e82c70c9edf9bd21bb618627510ad +size 176338 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4e641169d851ed1248451ac701e9d2e0899aa51 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1d043ca807b2f2cb8a6fa93a9c0fe15f2091627b58bf0206323eb3d52b7c26cc +size 122305 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7e52dc8f5f46c3ca0b6f9ed858ad7b280da85ea2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp @@ -0,0 +1,354 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#define CONVERT_EXCEPTION_TO_ERROR_CODE(...) \ + try { \ + __VA_ARGS__ \ + } catch (const std::exception& e) { \ + std::cerr << "Error: " << e.what() << std::endl; \ + return AOTI_RUNTIME_FAILURE; \ + } catch (...) { \ + std::cerr << "Unknown exception occurred." << std::endl; \ + return AOTI_RUNTIME_FAILURE; \ + } \ + return AOTI_RUNTIME_SUCCESS; + +#define AOTI_VECTOR_SIZE_CHECK(actual_size, expected_size, name) \ + do { \ + AOTI_RUNTIME_CHECK( \ + actual_size == expected_size, \ + "expected " + std::string(name) + " vector size to be " + \ + std::to_string(expected_size) + ", but got " + \ + std::to_string(actual_size)); \ + } while (0) + +// AOTInductor uses at::addmm_out, which doesn't supports +// arguments that requires gradient. For this reason, we +// enforce no_grad context for run APIs. +// +// A RAII, thread local (!) guard that enables or disables grad mode upon +// construction, and sets it back to the original value upon destruction. +struct AOTINoGradGuard { + AOTINoGradGuard() : prev_mode(aoti_torch_grad_mode_is_enabled()) { + aoti_torch_grad_mode_set_enabled(false); + } + ~AOTINoGradGuard() { + aoti_torch_grad_mode_set_enabled(prev_mode); + } + bool prev_mode; +}; + +extern "C" { + +AOTIRuntimeError AOTInductorModelContainerCreate( + AOTInductorModelContainerHandle* container_handle, + size_t num_models, + bool is_cpu, + const char* cubin_dir) { + return AOTInductorModelContainerCreateWithDevice( + container_handle, + num_models, + is_cpu ? "cpu" : "cuda", + cubin_dir); +} + +AOTIRuntimeError AOTInductorModelContainerCreateWithDevice( + AOTInductorModelContainerHandle* container_handle, + size_t num_models, + const char* device_str, + const char* cubin_dir) { + if (num_models == 0) { + std::cerr << "Error: num_models must be positive, but got 0" << std::endl; + return AOTI_RUNTIME_FAILURE; + } + CONVERT_EXCEPTION_TO_ERROR_CODE({ + std::optional cubin_dir_opt; + if (cubin_dir != nullptr) { + cubin_dir_opt.emplace(cubin_dir); + } + auto* container = new torch::aot_inductor::AOTInductorModelContainer( + num_models, std::string(device_str), cubin_dir_opt); + *container_handle = + reinterpret_cast(container); + }) +} + +AOTIRuntimeError AOTInductorModelContainerDelete( + AOTInductorModelContainerHandle container_handle) { + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto* container = + reinterpret_cast( + container_handle); + delete container; + }); +} + +AOTIRuntimeError AOTInductorModelContainerRun( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle) { + auto* container = + reinterpret_cast( + container_handle); + AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs"); + AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs"); + + auto stream = + reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run( + input_handles, output_handles, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumConstants( + AOTInductorModelContainerHandle container_handle, + size_t* num_constants) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *num_constants = container->num_constants(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantName( + AOTInductorModelContainerHandle container_handle, + size_t idx, + const char** name) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *name = container->constant_name(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN( + AOTInductorModelContainerHandle container_handle, + size_t idx, + const char** original_fqn) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *original_fqn = container->constant_original_fqn(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded( + AOTInductorModelContainerHandle container_handle, + size_t idx, + bool* from_folded) { + auto* container = + reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *from_folded = container->constant_from_folded(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantDtype( + AOTInductorModelContainerHandle container_handle, + size_t idx, + int32_t* dtype) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *dtype = container->constant_dtype(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive, + bool validate_full_update) { + auto* container = + reinterpret_cast( + container_handle); + auto input_map = reinterpret_cast*>(constant_map_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->update_constant_buffer( + *input_map, use_inactive, validate_full_update); + }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle) { + return AOTInductorModelContainerUpdateConstantBuffer(container_handle, + constant_map_handle, + /*use_inactive*/ true, + /*validate_full_update*/ true); +} + +AOTIRuntimeError AOTInductorModelContainerRunConstantFolding( + AOTInductorModelContainerHandle container_handle, + bool use_inactive, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle) { + auto* container = + reinterpret_cast( + container_handle); + auto stream = + reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run_const_fold(use_inactive, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer( + AOTInductorModelContainerHandle container_handle) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->swap_constant_buffer(); + }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumInputs( + AOTInductorModelContainerHandle container_handle, + size_t* ret_num_inputs) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_num_inputs = container->num_inputs(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetInputName( + AOTInductorModelContainerHandle container_handle, + size_t input_idx, + const char** ret_input_names) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_input_names = container->input_name(input_idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumOutputs( + AOTInductorModelContainerHandle container_handle, + size_t* ret_num_outputs) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_num_outputs = container->num_outputs(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetOutputName( + AOTInductorModelContainerHandle container_handle, + size_t output_idx, + const char** ret_output_names) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_output_names = container->output_name(output_idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetCallSpec( + AOTInductorModelContainerHandle container_handle, + const char** in_spec, + const char** out_spec) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + *in_spec = container->get_in_spec(); + *out_spec = container->get_out_spec(); + }) +} + +AOTIRuntimeError AOTInductorModelCreate( + AOTInductorModelHandle* model_handle, + AOTInductorConstantMapHandle constant_map_handle){ + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto constant_map = std::make_shared(); + auto constant_array = std::make_shared>(); + auto input_map = reinterpret_cast*>(constant_map_handle); + + auto model = new torch::aot_inductor::AOTInductorModel( + constant_map, + constant_array, + "cpu", // device_str is hardcoded, as AOTInductorModelCreate is only use for CPU models + "" + ); + + if (input_map) { + for (auto const& kv : *input_map) { + constant_map->emplace(kv.first, kv.second); + } + } else { + model->load_constants(); + } + + *model_handle = reinterpret_cast(model); + })} + +AOTIRuntimeError AOTInductorModelRun( + AOTInductorModelHandle model_handle, + AtenTensorHandle* input_handles, + AtenTensorHandle* output_handles) { + auto model = + reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + model->run_impl( + input_handles, + output_handles, + (torch::aot_inductor::DeviceStreamType) nullptr, + nullptr); + }) +} + +AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle){ + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto model = reinterpret_cast( + model_handle); + delete model; + })} + +AOTIRuntimeError AOTInductorModelGetNumOutputs( + AOTInductorModelHandle model_handle, + size_t* ret_num_outputs) { + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto model = reinterpret_cast(model_handle); + *ret_num_outputs = model->num_outputs(); + }) +} + +AOTIRuntimeError AOTInductorModelUpdateConstantsMap( + AOTInductorModelHandle model_handle, + AOTInductorConstantMapHandle constant_map_handle) { + auto model = + reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto constant_map = std::make_shared(); + auto input_map = + reinterpret_cast*>( + constant_map_handle); + + for (auto const& kv : *input_map) { + constant_map->emplace(kv.first, kv.second); + } + model->update_constants_map(std::move(constant_map)); + }) +} + +} // extern "C" diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__init__.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59b36dac069c6d504a3efbacaa97e3ee546c90dd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__pycache__/device_op_overrides.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__pycache__/device_op_overrides.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc9a5b29c71e8c6f49cff5ed5f2a3ae850122861 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/__pycache__/device_op_overrides.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/device_op_overrides.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/device_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..6eec71344ae8cc7bb1c3c59cd226e22bf6339e29 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/xpu/device_op_overrides.py @@ -0,0 +1,19 @@ +# mypy: allow-untyped-defs +from ..common import DeviceOpOverrides, register_device_op_overrides + + +class XPUDeviceOpOverrides(DeviceOpOverrides): + def import_get_raw_stream_as(self, name): + return f"from torch._C import _xpu_getCurrentRawStream as {name}" + + def set_device(self, device_idx): + return f"torch.xpu.set_device({device_idx})" + + def synchronize(self): + return "torch.xpu.synchronize()" + + def device_guard(self, device_idx): + return f"torch.xpu._DeviceGuard({device_idx})" + + +register_device_op_overrides("xpu", XPUDeviceOpOverrides()) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__init__.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/ddp_fusion.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/ddp_fusion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ed43ea4c7bfc691a11833b2e70b10e40a5b4e60 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/ddp_fusion.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/b2b_gemm.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/b2b_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..5a854b5b9d99413285118f351805f77dd94f7ab3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/b2b_gemm.py @@ -0,0 +1,746 @@ +# mypy: allow-untyped-defs +import functools +from collections import deque +from typing import Dict, List, Set, Tuple + +import torch +from torch.utils._pytree import tree_map + +from ..._dynamo.utils import counters +from ..ir import ( + ComputedBuffer, + FixedLayout, + FlexibleLayout, + InputBuffer, + StorageBox, + Subgraph, + TensorBox, +) +from ..lowering import lowerings +from ..pattern_matcher import ( + Arg, + CallFunction, + Match, + PatternMatcherPass, + register_graph_pattern, +) +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + TritonTemplate, + TritonTemplateCaller, +) +from ..utils import ceildiv + + +B2B_GEMM_PASS = PatternMatcherPass( + pass_name="b2b_gemm_pass", +) + + +def b2b_gemm_grid(M, P, meta): + return (ceildiv(M, meta["BLOCK_SIZE_M"]) * ceildiv(P, meta["BLOCK_SIZE_P"]), 1, 1) + + +b2b_gemm_left_template = TritonTemplate( + name="b2b_gemm_left", + grid=b2b_gemm_grid, + debug=False, + source=r""" +{{def_kernel("A", "B", "C")}} + + + # B2B_GEMM_LEFT_TRITON_ENTRANCE + + # dynamic shapes + M = {{size("A", 0)}} + N = {{size("A", 1)}} + O = {{size("C", 0)}} + P = {{size("C", 1)}} + + # dynamic strides + stride_am = {{stride("A", 0)}} + stride_an = {{stride("A", 1)}} + stride_bn = {{stride("B", 0)}} + stride_bo = {{stride("B", 1)}} + stride_co = {{stride("C", 0)}} + stride_cp = {{stride("C", 1)}} + + # output block counts + num_m_block = tl.cdiv(M, BLOCK_SIZE_M) + num_p_block = tl.cdiv(P, BLOCK_SIZE_P) + + # internal block counts + num_n_block = tl.cdiv(N, BLOCK_SIZE_N) + num_o_block = tl.cdiv(O, BLOCK_SIZE_O) + + # output block ids + pid = tl.program_id(axis=0) + m_block_id = pid // num_p_block + p_block_id = pid % num_p_block + + # accumulator + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P), dtype=tl.float32) + + # main loop + offs_m = (m_block_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_p = (p_block_id * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P)) + # (subgraph(A @ B) @ C) + offs_o = tl.arange(0, BLOCK_SIZE_O) + for _ in range(num_o_block): + c_mask = (offs_o[:, None] < O) & (offs_p[None, :] < P) + c_ptrs = C + (offs_o[:, None] * stride_co + offs_p[None, :] * stride_cp) + c = tl.load(c_ptrs, mask=c_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_O * BLOCK_SIZE_P + acc_ab = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_O), dtype=tl.float32) + offs_n = tl.arange(0, BLOCK_SIZE_N) + for __ in range(num_n_block): + a_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + a_ptrs = A + (offs_m[:, None] * stride_am + offs_n[None, :] * stride_an) + a = tl.load(a_ptrs, mask=a_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_M * BLOCK_SIZE_N + b_mask = (offs_n[:, None] < N) & (offs_o[None, :] < O) + b_ptrs = B + (offs_n[:, None] * stride_bn + offs_o[None, :] * stride_bo) + b = tl.load(b_ptrs, mask=b_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_N * BLOCK_SIZE_O + acc_ab += tl.dot(a, b, out_dtype=tl.float32) + offs_n += BLOCK_SIZE_N + # apply the subgraph + {{ modification( + subgraph_number=0, + output_name="post_subgraph_acc_ab", + inner_mm="acc_ab" + ) | indent_except_first(2) }} + acc += tl.dot(post_subgraph_acc_ab, c, out_dtype=tl.float32) + offs_o += BLOCK_SIZE_O + + # type conversion + acc = acc.to(tl.float16) + + # store preparation + idx_m = offs_m[:, None] + idx_p = offs_p[None, :] + out_mask = (idx_m < M) & (idx_p < P) + + {{store_output(("idx_m", "idx_p"), "acc", "out_mask")}} +""", +) + + +b2b_gemm_right_template = TritonTemplate( + name="b2b_gemm_right", + grid=b2b_gemm_grid, + debug=False, + source=r""" +{{def_kernel("A", "B", "C")}} + + + # B2B_GEMM_RIGHT_TRITON_ENTRANCE + + # dynamic shapes + M = {{size("A", 0)}} + N = {{size("A", 1)}} + O = {{size("C", 0)}} + P = {{size("C", 1)}} + + # dynamic strides + stride_am = {{stride("A", 0)}} + stride_an = {{stride("A", 1)}} + stride_bn = {{stride("B", 0)}} + stride_bo = {{stride("B", 1)}} + stride_co = {{stride("C", 0)}} + stride_cp = {{stride("C", 1)}} + + # output block counts + num_m_block = tl.cdiv(M, BLOCK_SIZE_M) + num_p_block = tl.cdiv(P, BLOCK_SIZE_P) + + # internal block counts + num_n_block = tl.cdiv(N, BLOCK_SIZE_N) + num_o_block = tl.cdiv(O, BLOCK_SIZE_O) + + # output block ids + pid = tl.program_id(axis=0) + m_block_id = pid // num_p_block + p_block_id = pid % num_p_block + + # accumulator + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P), dtype=tl.float32) + + # main loop (two cases) + offs_m = (m_block_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_p = (p_block_id * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P)) + # (A @ subgraph(B @ C)) + offs_n = tl.arange(0, BLOCK_SIZE_N) + for _ in range(num_n_block): + a_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + a_ptrs = A + (offs_m[:, None] * stride_am + offs_n[None, :] * stride_an) + a = tl.load(a_ptrs, mask=a_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_M * BLOCK_SIZE_N + acc_bc = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_P), dtype=tl.float32) + offs_o = tl.arange(0, BLOCK_SIZE_O) + for __ in range(num_o_block): + b_mask = (offs_n[:, None] < N) & (offs_o[None, :] < O) + b_ptrs = B + (offs_n[:, None] * stride_bn + offs_o[None, :] * stride_bo) + b = tl.load(b_ptrs, mask=b_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_N * BLOCK_SIZE_O + c_mask = (offs_o[:, None] < O) & (offs_p[None, :] < P) + c_ptrs = C + (offs_o[:, None] * stride_co + offs_p[None, :] * stride_cp) + c = tl.load(c_ptrs, mask=c_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_O * BLOCK_SIZE_P + acc_bc += tl.dot(b, c, out_dtype=tl.float32) + offs_o += BLOCK_SIZE_O + # apply the subgraph + {{ modification( + subgraph_number=0, + output_name="post_subgraph_acc_bc", + inner_mm="acc_bc" + ) | indent_except_first(2) }} + acc += tl.dot(a, post_subgraph_acc_bc, out_dtype=tl.float32) + offs_n += BLOCK_SIZE_N + + # type conversion + acc = acc.to(tl.float16) + + # store preparation + idx_m = offs_m[:, None] + idx_p = offs_p[None, :] + out_mask = (idx_m < M) & (idx_p < P) + + {{store_output(("idx_m", "idx_p"), "acc", "out_mask")}} +""", +) + + +# Note: load_ratio_left and load_ratio_right are only calculating numbers +# in the trivial subgraph case; i.e. (A @ (B @ C)) or ((A @ B) @ C) + + +def load_ratio_left( + M: int, N: int, O: int, P: int, m: int, n: int, o: int, p: int +) -> float: + """ + compute the ratio of estimated numbers of loads in baseline and b2bgemm + M, N, O, P are matrix sizes + m, n, o, p are block sizes + | | baseline (lower bound) | b2bgemm + | load | M * N + N * O + M * O + O * P | M / m * P / p * O / o * (o * p + N / n * (m * n + n * o)) + | store | M * O + M * P | M * P + b2bgemm is always better on stores, but for loads we need to find out beneficial cases using this function + """ + base = M * N + N * O + M * O + O * P + gemm = ( + ceildiv(M, m) + * ceildiv(P, p) + * ceildiv(O, o) + * (o * p + ceildiv(N, n) * (m * n + n * o)) + ) + return base / gemm + + +def load_ratio_right( + M: int, N: int, O: int, P: int, m: int, n: int, o: int, p: int +) -> float: + """ + compute the ratio of estimated numbers of loads in baseline and b2bgemm + M, N, O, P are matrix sizes + m, n, o, p are block sizes + | | baseline (lower bound) | b2bgemm + | load | N * O + O * P + M * N + N * P | M / m * P / p * N / n * (m * n + O / o * (n * o + o * p)) + | store | N * P + M * P | M * P + b2bgemm is always better on stores, but for loads we need to find out beneficial cases using this function + """ + base = N * O + O * P + M * N + N * P + gemm = ( + ceildiv(M, m) + * ceildiv(P, p) + * ceildiv(N, n) + * (m * n + ceildiv(O, o) * (n * o + o * p)) + ) + return base / gemm + + +# the block sizes are limited by hardware (the shared memory) +# intuitively, the optimization works when the intermediate matrix is large +# and we assign large block sizes to large dimensions +b2b_gemm_configs = [ + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_O": 16, + "BLOCK_SIZE_P": 16, + "num_stages": 4, + "num_warps": 8, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_O": 32, + "BLOCK_SIZE_P": 32, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_O": 64, + "BLOCK_SIZE_P": 64, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_O": 128, + "BLOCK_SIZE_P": 16, + "num_stages": 4, + "num_warps": 8, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_O": 128, + "BLOCK_SIZE_P": 32, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_O": 128, + "BLOCK_SIZE_P": 64, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_O": 16, + "BLOCK_SIZE_P": 128, + "num_stages": 4, + "num_warps": 8, + }, + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_O": 32, + "BLOCK_SIZE_P": 128, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_O": 64, + "BLOCK_SIZE_P": 128, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_O": 16, + "BLOCK_SIZE_P": 128, + "num_stages": 4, + "num_warps": 8, + }, + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_O": 32, + "BLOCK_SIZE_P": 128, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_O": 64, + "BLOCK_SIZE_P": 128, + "num_stages": 2, + "num_warps": 4, + }, +] + + +def is_b2b_gemm_good_on( + is_left_assoc: bool, + A_node: torch.fx.Node, + B_node: torch.fx.Node, + C_node: torch.fx.Node, +) -> bool: + """ + checks whether the sizes are good for b2b_gemm + """ + # basic checks + if not all(["val" in A_node.meta, "val" in B_node.meta, "val" in C_node.meta]): + return False + A, B, C = ( + A_node.meta["val"], + B_node.meta["val"], + C_node.meta["val"], + ) # torch._subclasses.fake_tensor.FakeTensor + if not all([A.is_cuda, B.is_cuda, C.is_cuda]): + return False + if not all([len(A.shape) == 2, len(B.shape) == 2, len(C.shape) == 2]): + return False + if not ((A.shape[1] == B.shape[0]) and (B.shape[1] == C.shape[0])): + return False + # size checks: we only dispatch to B2B-GEMM when the average load ratio is > 1 + M, N = A.shape + O, P = C.shape + ratios = [] + if is_left_assoc: + for config in b2b_gemm_configs: + ratio = load_ratio_left( + M, + N, + O, + P, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_O"], + config["BLOCK_SIZE_P"], + ) + ratios.append(ratio) + else: + for config in b2b_gemm_configs: + ratio = load_ratio_right( + M, + N, + O, + P, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_O"], + config["BLOCK_SIZE_P"], + ) + ratios.append(ratio) + ratios.sort(reverse=True) + average_ratio = 1.0 + for r in ratios[:3]: # top 3 choices + average_ratio *= r + average_ratio = average_ratio ** (1 / 3) + return ( + average_ratio > 1 + ) # even if average_ratio is close to 1, the number of stores is always better + + +def unoptimized_b2b_gemm( + is_left_assoc: bool, + subgraph: Subgraph, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + *, + out: torch.Tensor, +) -> torch.Tensor: + """ + The unoptimized version is used as a fallback when the b2b_gemm kernel is not beneficial. + """ + if is_left_assoc: + torch.mm(subgraph.graph_module(torch.mm(A, B)), C, out=out) + else: + torch.mm(A, subgraph.graph_module(torch.mm(B, C)), out=out) + return out + + +unoptimized_choice = ExternKernelChoice(unoptimized_b2b_gemm) + + +def build_subgraph_buffer( + args: List[TensorBox], + subgraph: Subgraph, +): + """ + This function is adapted from ../kernel/flex_attention.py. + The goal is to take in the required args and produce the subgraph buffer + The subgraph buffer is a ComputedBuffer that will be inlined into the triton template + + Args: + args: The args that are passed into the subgraph + subgraph: The Subgraph ir for which to produce the output node + """ + cnt = 0 + env = {} + for node in subgraph.graph_module.graph.nodes: + if node.op == "placeholder": + env[node] = args[cnt] + cnt += 1 + elif node.op == "call_function": + # For call_function we use the default lowerings and pass in the + # already created TensorBoxes as args + args, kwargs = tree_map( + lambda x: env[x] if x in env else x, (node.args, node.kwargs) + ) + env[node] = lowerings[node.target](*args, **kwargs) + elif node.op == "output": + + def convert_output_node_to_buffer(output): + if output is None: + return None + output_node = output + output_buffer = env[output_node] + assert isinstance(output_buffer, TensorBox), ( + "The output node for B2B-GEMM's subgraph must be a TensorBox, but got: ", + type(output_buffer), + ) + assert isinstance(output_buffer.data, StorageBox), ( + "The output node for B2B-GEMM's subgraph must be a StorageBox, but got: ", + type(output_buffer), + ) + subgraph_buffer = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=output_buffer.data.get_device(), + dtype=output_buffer.data.get_dtype(), + size=output_buffer.data.get_size(), + ), + data=output_buffer.data.data, # type: ignore[arg-type] + ) + return subgraph_buffer + + # node.args[0] should be a single element representing the output of the subgraph + return tree_map(convert_output_node_to_buffer, node.args[0]) + + raise ValueError("B2B-GEMM was passed a subgraph with no output node!") + + +def create_placeholder( + name: str, dtype: torch.dtype, device: torch.device +) -> TensorBox: + """ + Creates a placeholder input buffers for producing subgraph_output + """ + input_buffer = InputBuffer(name, FixedLayout(device, dtype, [], [])) + return TensorBox.create(input_buffer) + + +def tuned_b2b_gemm( + is_left_assoc: bool, + subgraph: Subgraph, + A: torch._inductor.ir.TensorBox, + B: torch._inductor.ir.TensorBox, + C: torch._inductor.ir.TensorBox, + *, + layout=None, +) -> torch._inductor.ir.TensorBox: + # call .realize() to get rid of Pointwise + A.realize() + B.realize() + C.realize() + layout = FixedLayout(A.get_device(), A.get_dtype(), [A.shape[0], C.shape[1]]) + subgraph_buffer = build_subgraph_buffer( + [create_placeholder("inner_mm", A.get_dtype(), A.get_device())], + subgraph, + ) + choices: list[TritonTemplateCaller] = [] + for config in b2b_gemm_configs: + if is_left_assoc: + b2b_gemm_left_template.maybe_append_choice( + choices, + input_nodes=(A, B, C), + layout=layout, + subgraphs=[subgraph_buffer], + **config, + ) + else: + b2b_gemm_right_template.maybe_append_choice( + choices, + input_nodes=(A, B, C), + layout=layout, + subgraphs=[subgraph_buffer], + **config, + ) + # add the unoptimized choice to mitigate performance degradation + choices.append( + unoptimized_choice.bind( + (A, B, C), layout, is_left_assoc=is_left_assoc, subgraph=subgraph + ) + ) + # autotune + return autotune_select_algorithm("b2b_gemm", choices, [A, B, C], layout) + + +# match the inner mm of a potential b2b_gemm +@register_graph_pattern( + CallFunction(torch.ops.aten.mm, Arg(), Arg()), + pass_dict=B2B_GEMM_PASS, +) +def b2b_gemm_handler(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node) -> None: + # match.args: list[torch.fx.Node] + + def is_pointwise_node(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) + and (torch.Tag.pointwise in node.target.tags) + ) + + def is_mm(node: torch.fx.Node) -> bool: + return node.target == torch.ops.aten.mm.default + + # the inner MM + inner_mm = match.nodes[-1] + + # find the (candidate) outer MM, which will be re-checked below to ensure every path reaches it + # In a real (A @ f(B @ C)), every path starting from (B @ C) must reach (A @ _). + outer_mm = None + node = inner_mm + while len(node.users) > 0: + node = next(iter(node.users)) + if is_mm(node): + outer_mm = node + break + elif is_pointwise_node(node): + continue + else: + break + if not outer_mm: + return + + # find the unique input node for outer_mm representing f(B @ C) in (A @ f(B @ C)) + # we call it the "f_node" + # when the pattern is simply (A @ (B @ C)), f_node is just inner_mm + f_node = inner_mm + while next(iter(f_node.users)) is not outer_mm: + f_node = next(iter(f_node.users)) + + def all_reach_via_pointwise_with_no_other_inputs( + src: torch.fx.Node, + dst: torch.fx.Node, + ) -> Tuple[bool, Set[torch.fx.Node]]: + """ + check whether every user path from src reaches dst via pointwise nodes, + with no other input nodes for the intermediates and dst; + return + (1) the Boolean value + (2) the subgraph node set including src and dst (which only makes sense when the Boolean value is True) + """ + visited: Set[torch.fx.Node] = set() + input_counter: Dict[torch.fx.Node, int] = {} + + all_reachable = True + queue = deque([src]) + while queue: + node = queue.popleft() + if node not in visited: + if node is dst: + visited.add(node) + elif (node is src) or is_pointwise_node(node): + for user in node.users.keys(): + # for nodes other than dst, bookkeep their users' input counts + if user not in input_counter: + input_counter[user] = len(user.all_input_nodes) + input_counter[user] -= 1 + # continue BFS + queue.append(user) + visited.add(node) + else: + all_reachable = False + break + + return ( + all_reachable and all(count == 0 for count in input_counter.values()), + visited, + ) + + # check inner_mm reaches f_node on every user path via pointwise nodes with no outside input_nodes + ok, subgraph_node_set = all_reach_via_pointwise_with_no_other_inputs( + inner_mm, f_node + ) + if not ok: + return + + # check inner_mm's inputs and f_node's outputs + if not (len(inner_mm.all_input_nodes) == 2 and len(f_node.users) == 1): + return + + # at this point, the nodes between inner_mm and f_node (both included) + # are all used internally inside (A @ subgraph(B @ C)) + # i.e. they neither have other users nor have other inputs + + # original graph and module + graph, module = inner_mm.graph, inner_mm.graph.owning_module + + # construct the new (sub)graph + subgraph_node_list: List[ + torch.fx.Node + ] = [] # ordered list of nodes used for node removal later + new_graph: torch.fx.Graph = torch.fx.Graph() + node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} + new_input_anchor: torch.fx.Node # inner_mm, to be changed to an input node + new_output_anchor: torch.fx.Node # f_node, to be used to construct an output node + new_input_node: torch.fx.Node + new_output_node: torch.fx.Node + for node in graph.nodes: # preserve the order of nodes + if node in subgraph_node_set: + subgraph_node_list.append(node) + new_node = new_graph.node_copy( + node, lambda x: node_remapping[x] if x in node_remapping else x + ) + node_remapping[node] = new_node + if node is inner_mm: + new_input_anchor = new_node + if node is f_node: + new_output_anchor = new_node + if new_input_anchor is not new_output_anchor: # subgraph is non-trivial + # update the input node + with new_graph.inserting_before(new_input_anchor): + new_input_node = new_graph.placeholder(name="subgraph_input") + new_input_node.meta.update(new_input_anchor.meta) + new_input_anchor.replace_all_uses_with(new_input_node) + new_graph.erase_node(new_input_anchor) + # add the output node + new_output_node = new_graph.output(new_output_anchor) + new_output_node.meta.update(new_output_anchor.meta) + else: # subgraph is trivial, e.g. (A @ (B @ C)) + # update the input node + with new_graph.inserting_before(new_input_anchor): + new_input_node = new_graph.placeholder(name="subgraph_input") + new_input_node.meta.update(new_input_anchor.meta) + new_input_anchor.replace_all_uses_with(new_input_node) + new_graph.erase_node(new_input_anchor) + # update the output node (don't use new_output_anchor since it has been erased) + new_output_node = new_graph.output(new_input_node) + new_output_node.meta.update(new_input_node.meta) + new_graph.lint() + + # construct the subgraph + subgraph = Subgraph( + name="subgraph", graph_module=torch.fx.GraphModule(module, new_graph) + ) + + # two cases + # (1) (subgraph(A @ B) @ C), called "left_assoc" + # (2) (A @ subgraph(B @ C)), called "right_assoc" + is_left_assoc = outer_mm.args[0] is f_node + + # find the nodes A, B, C and check the sizes + A: torch.fx.Node + B: torch.fx.Node + C: torch.fx.Node + if is_left_assoc: + A = inner_mm.args[0] # type: ignore[assignment] + B = inner_mm.args[1] # type: ignore[assignment] + C = outer_mm.args[1] # type: ignore[assignment] + else: + A = outer_mm.args[0] # type: ignore[assignment] + B = inner_mm.args[0] # type: ignore[assignment] + C = inner_mm.args[1] # type: ignore[assignment] + if not is_b2b_gemm_good_on(is_left_assoc, A, B, C): + return + + # finally update the original graph + counters["inductor"]["b2b_gemm"] += 1 + graph = match.graph + with graph.inserting_before(outer_mm): + function = functools.partial(tuned_b2b_gemm, is_left_assoc, subgraph) + function.__name__ = tuned_b2b_gemm.__name__ # type: ignore[attr-defined] + function._inductor_lowering_function = True # type: ignore[attr-defined] + replacement: torch.fx.Node = graph.call_function( + function, + (A, B, C), + match.kwargs, + ) + replacement.meta.update(outer_mm.meta) + outer_mm.replace_all_uses_with(replacement) + # erase unnecessary nodes + graph.erase_node(outer_mm) + for node in reversed(subgraph_node_list): + graph.erase_node(node) + graph.lint() diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/binary_folding.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/binary_folding.py new file mode 100644 index 0000000000000000000000000000000000000000..bd1e0736d510faf10918530e0915d1c503dc09f6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/binary_folding.py @@ -0,0 +1,276 @@ +# mypy: allow-untyped-defs +import functools +import itertools + +import torch + +from ..._dynamo.utils import counters +from ..pattern_matcher import Arg, CallFunction, KeywordArg +from .freezing_patterns import register_binary_folding_pattern + + +aten = torch.ops.aten +prims = torch.ops.prims + + +def mark_mixed_dtype_conv(conv): + conv_dtype = conv.meta["val"].dtype + if conv_dtype not in (torch.float16, torch.bfloat16): + return + + if not len(conv.users) == 1: + return + + conv_user = next(iter(conv.users.keys())) + if not isinstance(conv_user.meta["val"], torch.Tensor): + return + + if not conv_user.meta["val"].dtype == torch.float32: + return + + while conv_user.target in _binary_ops: + if not len(conv_user.users) == 1: + return + + conv_user = next(iter(conv_user.users.keys())) + + if conv_user.target != prims.convert_element_type.default: + return + + conv.meta["_allow_conv_mixed_dtype_folding"] = conv_dtype + + +def mark_mixed_dtype_allowed_convs(gm): + """ + Mark convolutions which we will binary fold even with mixed precision constants. We constant fold in the higher precision + for better accuracy and then recover the original precision after. + """ + for node in gm.graph.find_nodes( + op="call_function", target=aten.convolution.default + ): + mark_mixed_dtype_conv(node) + + +def recover_original_precision_folded_convs(gm): + """ + After binary folding conv weights and biases to a higher dtype, recover the original precision they were in. + """ + graph = gm.graph + for node in graph.find_nodes(op="call_function", target=aten.convolution.default): + orig_dtype = node.meta.get("_allow_conv_mixed_dtype_folding", None) + if orig_dtype is None: + continue + + with graph.inserting_before(node): + for idx in [1, 2]: + old_input = node.args[idx] + if old_input is None: + continue + + new_input = graph.create_node( + "call_function", + prims.convert_element_type.default, + (old_input, orig_dtype), + ) + node.replace_input_with(old_input, new_input) + + +_binary_ops = [aten.add.Tensor, aten.sub.Tensor, aten.mul.Tensor, aten.div.Tensor] + + +@functools.lru_cache(None) +def binary_folding_init(): + _conv_args = [Arg() for _ in range(9)] + _computation_ops = [aten.convolution.default] + _computation_calls = [CallFunction(aten.convolution.default, *_conv_args, _users=1)] + + """ + In order to fuse add/sub/mul/div with conv, the dimensions of its + constant tensor must satisfy the following: + - with resizing, broadcast to w/ weight/bias tensor shape + - broadcast to the conv output shape + It needs to have a shape that can resize to weight/bias + tensor shape because we need to run the op with the conv + weights/bias without changing their sizes. + It needs to broadcast to the conv output shape so that we do + accidentally change the shape of op output by pre-fusing it + compared to eager. + The only dimension value shared by weight/bias/conv output + is they all contain a dim with value = channels-out. In the + conv output tensor, this is in the second dimension, + so the pointwise op tensor may have a second dimension of + value == channels-out, but all the other dimensions have to be 1 + """ + + def _op_not_broadcasting_with_conv(weight_tensor, other_tensor): + # According to opDoesNotBroadCastWithConv of frozen_conv_folding.cpp + weight_shape = weight_tensor.shape + other_shape = other_tensor.shape + if len(weight_shape) < len(other_shape): + return False + if len(weight_shape) == len(other_shape) + 1: + # weight shape is [o, i, *], other_shape is [o, 1...]. + for i in reversed(range(len(other_shape))): + if i == 0 and weight_shape[0] == other_shape[i]: + continue + if other_shape[i] != 1: + return False + else: + # weight shape is [o, i, *], other_shape is [1, i, *] + for i in reversed(range(len(other_shape))): + if i == 1 and weight_shape[0] == other_shape[i]: + continue + if other_shape[i] != 1: + return False + return True + + def _check_conv_and_broadcast_op(conv_node, other): + # According to checkConvAndBroadcastingOpPreConditions of frozen_conv_folding.cpp. + # conv.weight + if conv_node.args[1].op != "get_attr": + return False + # conv.bias + if conv_node.args[1] is not None and conv_node.args[1].op != "get_attr": + return False + if ( + not isinstance(other, int) + and not isinstance(other, float) + and other.op != "get_attr" + ): + return False + + if not len(conv_node.args[1].users) == 1: + return False + + weight_meta_value = conv_node.args[1].meta.get("val") + if weight_meta_value is None: + return False + # Avoid fusing op that causes type promotion + # restricting to float avoids int/float difficulties with scalar overload + if not weight_meta_value.is_floating_point(): + return False + if isinstance(other, torch.fx.Node) and other.op == "get_attr": + other_meta_value = other.meta.get("val") + if not other_meta_value.is_floating_point(): # type: ignore[union-attr] + return False + if ( + torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype) # type: ignore[union-attr] + != weight_meta_value.dtype + ): + if not conv_node.meta.get("_allow_conv_mixed_dtype_folding", False): + return False + + if ( + other_meta_value.dtype != torch.float # type: ignore[union-attr] + and weight_meta_value.dtype not in (torch.float16, torch.bfloat16) + ): + return False + + if not _op_not_broadcasting_with_conv(weight_meta_value, other_meta_value): + return False + else: + # TODO: support scalar case + return False + + return True + + def _is_foldable_pattern(match): + binary_node = match.output_node() + computation_node = binary_node.args[0] + other = binary_node.args[1] + if binary_node.args[0].target not in _computation_ops: + computation_node = binary_node.args[1] + other = binary_node.args[0] + if binary_node.args[0].target == aten.convolution.default: + return _check_conv_and_broadcast_op(computation_node, other) + + return False + + def resize_scalar_or_tensor_to_shape(graph, other, shape): + # TODO: support scalar case + if other.meta.get("val").numel() == 1: + # expand errors if the shape input has less # dims than the tensor input + res = graph.create_node( + "call_function", + aten.reshape.default, + (other, (1,)), + ) + res = graph.create_node( + "call_function", + aten.expand.default, + (res, shape), + ) + else: + res = graph.create_node( + "call_function", + aten.reshape.default, + (other, shape), + ) + return res + + def _create_new_conv_node(graph, conv_node, binary_node, other): + assert conv_node.target == aten.convolution.default + conv_args = list(conv_node.args) + weight_meta_value = conv_node.args[1].meta.get("val") + bias = conv_args[2] + if binary_node.target in [aten.add.Tensor, aten.sub.Tensor]: + other_reshape = resize_scalar_or_tensor_to_shape( + graph, other, (weight_meta_value.size(0),) + ) + new_bias = graph.create_node( + "call_function", + binary_node.target, + (0 if bias is None else bias, other_reshape), + ) + conv_args[2] = new_bias + else: + assert binary_node.target in [aten.mul.Tensor, aten.div.Tensor] + weight_broadcast_shape = [1 for _ in range(len(weight_meta_value.shape))] + weight_broadcast_shape[0] = weight_meta_value.size(0) + other_reshape1 = resize_scalar_or_tensor_to_shape( + graph, other, tuple(weight_broadcast_shape) + ) + new_weight = graph.create_node( + "call_function", binary_node.target, (conv_args[1], other_reshape1) + ) + new_weight.meta.update(conv_args[1].meta) + conv_args[1] = new_weight + if bias is not None: + other_reshape = resize_scalar_or_tensor_to_shape( + graph, other, (weight_meta_value.size(0),) + ) + new_bias = graph.create_node( + "call_function", binary_node.target, (bias, other_reshape) + ) + new_bias.meta.update(bias.meta) + conv_args[2] = new_bias + return graph.create_node("call_function", conv_node.target, tuple(conv_args)) + + for _computation_call, binary_op in itertools.product( + _computation_calls, _binary_ops + ): + + @register_binary_folding_pattern( + CallFunction(binary_op, _computation_call, KeywordArg("other")), + extra_check=_is_foldable_pattern, + ) + def folded_op(match, *args, **kwargs): + counters["inductor"]["binary_folding"] += 1 + other = kwargs.get("other") + binary_node = match.output_node() + computation_node = ( + binary_node.args[0] + if binary_node.args[0].target in _computation_ops + else binary_node.args[1] + ) + graph = match.graph + with graph.inserting_before(binary_node): + # TODO: support linear? + assert computation_node.target == aten.convolution.default + new_computation_node = _create_new_conv_node( + graph, computation_node, binary_node, other + ) + binary_node.replace_all_uses_with(new_computation_node) + new_computation_node.meta.update(computation_node.meta) + graph.erase_node(binary_node) + graph.erase_node(computation_node) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/ddp_fusion.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/ddp_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6c4e57a1d505f6c83ed1e1883242a9e7fcdb0e8a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/ddp_fusion.py @@ -0,0 +1,599 @@ +# Owner(s): ["oncall: distributed"] +import collections +import inspect +import logging +import math +import operator +from dataclasses import dataclass +from functools import partial +from typing import ( + Any, + Callable, + cast, + Dict, + Generator, + List, + Optional, + Set, + Tuple, + Union, +) + +import torch +import torch.fx as fx +from torch._dynamo.utils import counters +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + +from .. import config +from ..fx_utils import get_fake_args_kwargs +from ..virtualized import V + + +aten = torch.ops.aten +logger: logging.Logger = logging.getLogger("comm_fusion") + + +def move_block_after(block: List[fx.Node], target_node: fx.Node) -> None: + for node in block: + target_node.append(node) + target_node = node + + +def move_block_before(block: List[fx.Node], target_node: fx.Node) -> None: + for node in block: + target_node.prepend(node) + target_node = node + + +def call_function( + graph: fx.Graph, + target: Union[str, Callable[..., Any]], + args: Optional[Tuple[fx.node.Argument, ...]] = None, + kwargs: Optional[Dict[str, fx.node.Argument]] = None, +) -> fx.Node: + # We accept target as a str to avoid typing error as the type of + # a node.target is Union[str, Callable[..., Any]]. + # This also allows us to avoid writing check for every call. + if isinstance(target, str): + raise RuntimeError(f"Call function should not get a str target {target=}") + node = graph.call_function(target, args, kwargs) + _, args, kwargs = get_fake_args_kwargs(node) + with V.fake_mode: + node.meta["val"] = target(*args, **kwargs) + # node.meta["val"] may be a container. So we use tree_map here + # to recursively extract the tensor metadata. + node.meta["tensor_meta"] = tree_map( + _extract_tensor_metadata, (node.meta["val"],) + )[0] + return node + + +@dataclass(unsafe_hash=True) +class CommBlock: + shape: Union[torch.Size, List[torch.Size]] + node_list: List[fx.Node] + inputs: List[fx.Node] + wait_nodes: List[fx.Node] + comm_node: fx.Node + outputs: Set[fx.Node] + + +def get_comm_block(comm_node: fx.Node) -> Optional[CommBlock]: + """ + Given a collective node (e.g., allreduce), find out all the nodes belong to + this communcation. + + Args: + comm_node(fx.Node): The target communication/collective node. + Returns: + The CommBlock that encapsulates the related nodes (e.g., wait_node) of + the given comm_node. + """ + node_list = [] + wait_nodes = [] + inputs, _ = tree_flatten((comm_node.args, comm_node.kwargs)) + input_nodes = [inp for inp in inputs if isinstance(inp, fx.Node)] + wait_prefixes = "wait_tensor" + # If the users of the wait node are following items, we consinder them + # to be a part of the output. + intermediate_outputs = ("split", "reshape", "getitem", "detach", "alias") + + first_user = next(iter(comm_node.users)) + if ( + len(comm_node.users) == 1 + and first_user.target == torch.ops._c10d_functional.wait_tensor.default + ): + # Collective with only one output + node_list = [comm_node, first_user] + wait_nodes.append(first_user) + elif len(comm_node.users) > 1 and first_user.target == operator.getitem: + # Collective with only more than one output + node_list.append(comm_node) + for user in comm_node.users: + if user.target != operator.getitem: + return None + if len(user.users) != 1: + return None + wait_node = next(iter(user.users)) + if wait_node.target != torch.ops._c10d_functional.wait_tensor.default: + return None + wait_nodes.append(wait_node) + node_list.append(user) + node_list.extend(wait_nodes) + else: + return None + + # Identify all the outputs of this collective block. + outputs: Set[fx.Node] = set() + nodes = collections.deque(wait_nodes) + while nodes: + node = nodes.popleft() + for user in node.users: + if isinstance(user, fx.Node) and user.name.startswith(intermediate_outputs): + nodes.append(user) + node_list.append(user) + else: + outputs.add(node) + break + + tensor_meta = input_nodes[0].meta["tensor_meta"] + shape: Union[torch.Size, List[torch.Size]] + if isinstance(tensor_meta, TensorMetadata): + shape = tensor_meta.shape + elif isinstance(tensor_meta, (list, tuple)): + shape = [tm.shape for tm in tensor_meta] + else: + logger.warning("Unexpected type of tensor_meta %s", type(tensor_meta)) + return None + + return CommBlock( + shape=shape, + node_list=node_list, + wait_nodes=wait_nodes, + comm_node=comm_node, + inputs=input_nodes, + outputs=outputs, + ) + + +def get_all_comm_blocks( + graph: fx.Graph, + comm_ops: Tuple[torch._ops.OpOverload, ...], + comm_filter: Optional[Callable[..., bool]] = None, +) -> List[CommBlock]: + if comm_filter is None: + + def always_true(comm_block: CommBlock) -> bool: + return True + + comm_filter = always_true + + blocks = [] + for node in graph.nodes: + if node.target not in comm_ops: + continue + comm_block = get_comm_block(node) + if comm_block is not None and comm_filter(comm_block): + blocks.append(comm_block) + return blocks + + +def _fuse_allreduce_by_concat( + graph: fx.Graph, + last_input_node: fx.Node, + all_input_nodes: List[fx.Node], + last_comm_block: CommBlock, +) -> CommBlock: + """Given a list of inputs in order, create a fused allreduce using concat.""" + # Flatten all the inputs to the all_reduce nodes. + with graph.inserting_after(last_input_node): + cat_inputs = [] + for input_node in all_input_nodes: + assert isinstance(input_node.args[0], fx.Node) + input_node = input_node.args[0] + cat_inputs.append( + call_function(graph, aten.flatten.using_ints, (input_node,)) + ) + + # Concat all the flattened nodes. + with graph.inserting_after(cat_inputs[0]): + cat_node = call_function(graph, aten.cat, (cat_inputs,)) + + # Insert the fused div node and remove the input div nodes. + # This is an optimization and is not mandatory for fusion. + divisors = [div.args[1] for div in all_input_nodes] + assert all(divisor == divisors[0] for divisor in divisors) + with graph.inserting_after(cat_node): + div_node = call_function(graph, last_input_node.target, (cat_node, divisors[0])) + + # Create a new Comm/all_reduce node. + last_comm_node = last_comm_block.comm_node + last_wait_node = last_comm_block.wait_nodes[0] + with graph.inserting_after(div_node): + flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs)) + flatten_args[0] = div_node + args, kwargs = tree_unflatten(flatten_args, spec) + fused_comm_node = call_function(graph, last_comm_node.target, args, kwargs) + + # Create a new Wait node. + with graph.inserting_after(fused_comm_node): + flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs)) + flatten_args[0] = fused_comm_node + args, kwargs = tree_unflatten(flatten_args, spec) + fused_wait_node = call_function(graph, last_wait_node.target, args, kwargs) + + # Move the fused all_reduce and its args to right after the input node + nodes_to_move = cat_inputs + [cat_node, div_node, fused_comm_node, fused_wait_node] + move_block_after(nodes_to_move, last_input_node) + + return CommBlock( + shape=cast(TensorMetadata, cat_node.meta.get("tensor_meta")).shape, + node_list=[fused_comm_node, fused_wait_node], + wait_nodes=[fused_wait_node], + comm_node=fused_comm_node, + inputs=[div_node], + outputs={fused_wait_node}, + ) + + +def _fuse_with_coalesced_op( + graph: fx.Graph, + last_input_node: fx.Node, + all_input_nodes: List[fx.Node], + last_comm_block: CommBlock, +) -> CommBlock: + """Given a list of inputs in order, create a fused allreduce by coalesced.""" + last_comm_node = last_comm_block.comm_node + last_wait_node = last_comm_block.wait_nodes[0] + + # Insert the fused div node and remove the input div nodes. + # This is an optimization and is not mandatory for fusion. + dividends = [div.args[0] for div in all_input_nodes] + divisors = [div.args[1] for div in all_input_nodes] + assert all(divisor == divisors[0] for divisor in divisors) + with graph.inserting_before(last_input_node): + last_input_node = call_function( + graph, aten._foreach_div.Scalar, (dividends, divisors[0]) + ) + input_node = last_input_node + + # Create a new Comm/all_reduce_coalesced node. + with graph.inserting_after(last_comm_node): + flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs)) + flatten_args[0] = input_node + args, kwargs = tree_unflatten(flatten_args, spec) + fused_comm_node = call_function( + graph, torch.ops._c10d_functional.all_reduce_coalesced.default, args, kwargs + ) + + # Create a new wait node. + getitem_nodes = [] + wait_nodes = [] + flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs)) + for idx in range(len(all_input_nodes)): + with graph.inserting_after(fused_comm_node): + gi_node = call_function(graph, operator.getitem, (fused_comm_node, idx)) + getitem_nodes.append(gi_node) + flatten_args[0] = gi_node + args, kwargs = tree_unflatten(flatten_args, spec) + with graph.inserting_after(gi_node): + wait_nodes.append(call_function(graph, last_wait_node.target, args, kwargs)) + + # Move the new all_reduce_coalesced and its args to right after the input node + nodes_to_move = [fused_comm_node] + getitem_nodes + wait_nodes + move_block_after(nodes_to_move, last_input_node) + + return CommBlock( + shape=[ + tm.shape + for tm in cast( + List[TensorMetadata], fused_comm_node.meta.get("tensor_meta") + ) + ], + node_list=[fused_comm_node] + getitem_nodes + wait_nodes, + wait_nodes=wait_nodes, + comm_node=fused_comm_node, + inputs=[input_node], + outputs=set(wait_nodes), + ) + + +def _scatter_fused_allreduce_waits( + graph: fx.Graph, + fused_comm_block: CommBlock, + orig_comm_blocks: List[CommBlock], + node_indices: Dict[fx.Node, int], + split_and_reshape: bool = True, +) -> None: + """ + Scatters the result of the fused communication node to the original users. + If the fused method is concat splitting the output and reshape will be inserted, + before inserting getitem. Otherwise getitem will be used as the users of the + wait node. + """ + + # Before we mass up the order, we need to get the index of the last wait node + # in orig_comm_blocks. This index will be later used to determinee what users + # nodes need to be move to maintain a correct topological sort order. + last_wait_node_idx = 0 + for node in graph.nodes: + last_wait_node_idx = max( + node_indices.get(node, last_wait_node_idx), last_wait_node_idx + ) + if node == orig_comm_blocks[-1].wait_nodes[0]: + break + + if split_and_reshape: + fused_wait_node = fused_comm_block.wait_nodes[0] + with graph.inserting_after(fused_wait_node): + split_node = call_function( + graph, + aten.split, + ( + fused_wait_node, + [math.prod(cast(List[int], cb.shape)) for cb in orig_comm_blocks], + ), + ) + with graph.inserting_after(split_node): + fused_outputs = [] + for idx, comm_block in enumerate(orig_comm_blocks): + split_idx_node = call_function( + graph, operator.getitem, (split_node, idx) + ) + with graph.inserting_after(split_idx_node): + fused_outputs.append( + call_function( + graph, aten.reshape, (split_idx_node, comm_block.shape) + ) + ) + else: + fused_outputs = fused_comm_block.wait_nodes + + # Scatter the fused outputs. + incorrect_order_nodes = [] + for comm_block, fused_output in zip(orig_comm_blocks, fused_outputs): + # Some descendant users of the orig_comm_blocks may be scheduled before + # the fused all_reduce. For example, the user nodes of the very first + # all_reduce may be scheduled before the second all_reduce. Since the + # fused all_reduce is inserted right after the last all_reudce, the + # order can be wrong. + # `incorrect_order_nodes` records these nodes. + + orig_wait = comm_block.wait_nodes[0] + nodes = collections.deque(list(orig_wait.users)) + while nodes: + user_node = nodes.popleft() + if not isinstance(user_node, fx.Node): + continue + if node_indices[user_node] < last_wait_node_idx: + incorrect_order_nodes.append(user_node) + nodes.extend(list(user_node.users)) + + orig_wait.replace_all_uses_with(fused_output) + + last_fused_result = fused_outputs[0] + fused_outputs_set = set(fused_outputs) + for node in graph.nodes: + if node in fused_outputs_set: + last_fused_result = node + + # Move the incorrect_order_nodes to right after the last fused_result. + incorrect_order_nodes = sorted( + incorrect_order_nodes, key=lambda node: node_indices[node] + ) + move_block_after(incorrect_order_nodes, last_fused_result) + + +def _fuse_allreduce( + graph: fx.Graph, + comm_blocks: List[CommBlock], + node_indices: Dict[fx.Node, int], + use_concat: bool, +) -> CommBlock: + """Given a list of allreduce CommBlock, fuse the CommBlocks into one CommBlock.""" + + if len(comm_blocks) == 1: + return comm_blocks[0] + + # Find the last input node of all the CommBlocks. This node will be served + # as the inserting point of the new collective op. + last_input_node = comm_blocks[0].inputs[0] + last_input_index = -1 + all_input_nodes = [] + for comm_block in comm_blocks: + input_node = comm_block.inputs[0] + all_input_nodes.append(input_node) + index = node_indices[input_node] + if index >= last_input_index: + assert index != last_input_index + last_input_node = input_node + last_input_index = index + + if use_concat: + fused_comm_block = _fuse_allreduce_by_concat( + graph, last_input_node, all_input_nodes, comm_blocks[-1] + ) + else: + fused_comm_block = _fuse_with_coalesced_op( + graph, last_input_node, all_input_nodes, comm_blocks[-1] + ) + + _scatter_fused_allreduce_waits( + graph, fused_comm_block, comm_blocks, node_indices, split_and_reshape=use_concat + ) + + for comm_block in comm_blocks: + for wait in comm_block.wait_nodes: + graph.erase_node(wait) + graph.erase_node(comm_block.comm_node) + graph.eliminate_dead_code() + + return fused_comm_block + + +def _bucket_size_fusion( + graph: fx.Graph, comm_blocks: List[CommBlock], bucket_size_mb: int +) -> Generator[List[CommBlock], None, None]: + MB = 1024**2 + bucket_size = 1 * MB + bucket_cap_size = bucket_size_mb * MB + curr_size = 0 + curr_blocks = [] + + count = 0 + fuse_count = 0 + for i, block in enumerate(comm_blocks): + curr_blocks.append(block) + itemsize = block.comm_node.meta["tensor_meta"].dtype.itemsize + curr_size += cast(torch.Size, block.shape).numel() * itemsize + count += 1 + if curr_size < bucket_size and i != len(comm_blocks) - 1: + continue + + fuse_count += 1 + if torch.distributed.get_rank() == 0: + logger.info( + "DDP bucketing: block%d, count=%d, curr_size=%d, bucket_size=%d", + fuse_count, + count, + curr_size, + bucket_size, + ) + + # Set the debug counters + counters["inductor"]["ddp_buckets"] = fuse_count + yield curr_blocks + + bucket_size = bucket_cap_size + curr_blocks = [] + curr_size = 0 + count = 0 + + +def _fuse_ddp_communication( + graph: fx.Graph, algorithm_fn: Callable[..., Any], fusion_fn: Callable[..., Any] +) -> None: + for output in reversed(graph.nodes): + if output.op == "output": + break + + def ddp_reducer_filter(block: CommBlock) -> bool: + if ( + not isinstance(block.comm_node.args[0], fx.Node) + or block.comm_node.args[0].target != aten.div.Tensor + ): + return False + + if len(block.wait_nodes[0].users) != 1: + # gradient/wait node should only be used by one user + return False + + # Two cases: + # 1. gradient/wait node should be directly used by the output + # if gradient is None before bwd. + # 2. gradient/wait node should be directly used by copy_. + if ( + output not in block.wait_nodes[0].users + and next(iter(block.wait_nodes[0].users)).target != aten.copy_.default + ): + return False + + return True + + ops = ( + torch.ops._c10d_functional.all_reduce_.default, + torch.ops._c10d_functional.all_reduce.default, + ) + comm_blocks = get_all_comm_blocks(graph, ops, comm_filter=ddp_reducer_filter) + node_indices = {node: i for i, node in enumerate(graph.nodes)} + + for block in algorithm_fn(graph, comm_blocks): + fusion_fn(graph, block, node_indices) + + +def fuse_ddp_with_coalesced_op(graph: fx.Graph, bucket_size_mb: int) -> None: + _fuse_ddp_communication( + graph, + partial(_bucket_size_fusion, bucket_size_mb=bucket_size_mb), + partial(_fuse_allreduce, use_concat=False), + ) + + +def fuse_ddp_with_concat_op(graph: fx.Graph, bucket_size_mb: int) -> None: + _fuse_ddp_communication( + graph, + partial(_bucket_size_fusion, bucket_size_mb=bucket_size_mb), + partial(_fuse_allreduce, use_concat=True), + ) + + +def schedule_comm_wait(graph: fx.Graph) -> None: + """ + Delay the execution of wait tensors of allreduce until its first user. + + This algorithm considers the intermediate users, like split, getitem, + of the wait node and schedule those intermediate users as well. + This will result in a better overlapping result. + """ + ops = ( + torch.ops._c10d_functional.all_reduce_.default, + torch.ops._c10d_functional.all_reduce.default, + torch.ops._c10d_functional.all_reduce_coalesced.default, + torch.ops._c10d_functional.all_reduce_coalesced_.default, + ) + comm_blocks = get_all_comm_blocks(graph, ops) + if not comm_blocks: + return + + # Find all the end users. + allreduce_users: Set[fx.Node] = set() + for allreduce in comm_blocks: + for output in allreduce.outputs: + allreduce_users.update(output.users) + + node_indices = {node: i for i, node in enumerate(graph.nodes)} + for allreduce in comm_blocks: + # Find the earliest/first user -- target_node. + assert ( + len(allreduce.outputs) >= 1 + ), f"Found a allreduce that has zero outputs/users -- {allreduce}." + # Initialize the target node to avoid typing issues. + target_node = next(iter(next(iter(allreduce.outputs)).users)) + target_node_index = 2**31 + for user in (user for output in allreduce.outputs for user in output.users): + index = node_indices[user] + if index < target_node_index: + target_node = user + target_node_index = index + + # Move wait nodes and all the subsequent nodes in the comm_block to + # before the first user -- target_node. + wait_idx = -1 + for wait_idx, node in enumerate(allreduce.node_list): + if node == allreduce.wait_nodes[0]: + break + assert wait_idx >= 0 + move_block_before(allreduce.node_list[wait_idx:], target_node) + + +def fuse_ddp_communication( + graph: fx.Graph, passes: List[Union[Callable[..., None], str]], bucket_size_mb: int +) -> None: + for i, pa in enumerate(passes): + with GraphTransformObserver( + graph.owning_module, + f"fuse_ddp_communication_pass_{i}", + config.trace.log_url_for_graph_xform, + ): + if isinstance(pa, str): + func = globals()[pa] + else: + func = pa + if "bucket_size_mb" in { + v.name for v in inspect.signature(func).parameters.values() + }: + func(graph, bucket_size_mb=bucket_size_mb) + else: + func(graph) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..d6777987c0866258976d4e73205ba49084f92340 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py @@ -0,0 +1,153 @@ +# mypy: allow-untyped-defs +import logging +from typing import List + +import torch +from torch import Tensor +from torch._dynamo.utils import counters + +from .. import config +from ..pattern_matcher import Arg, CallFunction, Match, register_graph_pattern +from .split_cat import construct_pattern_matcher_pass + + +aten = torch.ops.aten +log = logging.getLogger(__name__) + +# TODO: need a better strategy for decomposing mm +MIN_FIRST_DIMENSION_DECOMPOSITION = 10240 +MAX_OTHER_DIMENSION_DECOMPOSITION = 32 + +min_first_dimension_decomposition = MIN_FIRST_DIMENSION_DECOMPOSITION +max_other_dimention_decomposition = MAX_OTHER_DIMENSION_DECOMPOSITION +if "decompose_mm_pass" in config.post_grad_fusion_options: + min_first_dimension_decomposition = config.post_grad_fusion_options[ + "decompose_mm_pass" + ].get("min_first_dimension_decomposition", MIN_FIRST_DIMENSION_DECOMPOSITION) + max_other_dimention_decomposition = config.post_grad_fusion_options[ + "decompose_mm_pass" + ].get("max_other_dimention_decomposition", MAX_OTHER_DIMENSION_DECOMPOSITION) + + +def check_device(a: Tensor, b: Tensor) -> bool: + return a.is_cuda and b.is_cuda + + +def realize_inputs(inputs: List[torch.fx.Node]): + for inp in inputs: + if isinstance(inp, torch.fx.node.Node): + inp.meta["inductor_realize_to_strides"] = True + + +def should_decompose_bmm(mat1, mat2) -> bool: + if is_node_meta_valid(mat1) and is_node_meta_valid(mat2): + mat1 = mat1.meta["val"] + mat2 = mat2.meta["val"] + else: + return False + if not check_device(mat1, mat2): + return False + else: + if len(mat1.shape) != 3 or len(mat2.shape) != 3: + return False + if mat1.shape[0] < min_first_dimension_decomposition: + return False + # 2 of m, n, k must be <= MAX_OTHER_DIMENSION_DECOMPOSITION + if (mat1.shape[1] < max_other_dimention_decomposition) + ( + mat1.shape[2] < max_other_dimention_decomposition + ) + (mat2.shape[2] < max_other_dimention_decomposition) < 2: + return False + return True + + +def should_decompose_mm(mat1, mat2) -> bool: + if is_node_meta_valid(mat1) and is_node_meta_valid(mat2): + mat1 = mat1.meta["val"] + mat2 = mat2.meta["val"] + else: + return False + return ( + check_device(mat1, mat2) + and len(mat1.shape) == 2 + and len(mat2.shape) == 2 + and mat1.shape[0] >= min_first_dimension_decomposition + and mat2.shape[0] < max_other_dimention_decomposition + and mat2.shape[1] < max_other_dimention_decomposition + ) + + +def is_node_meta_valid(node: torch.fx.Node): + return "val" in node.meta + + +def print_decompose_pattern(match: Match, inputs: List[torch.fx.Node]): + node = match.nodes[-1] + log.debug( + "Decompose %s with input shape: %s", + node.target, + ", ".join( + str(input.meta["val"].shape) if "val" in input.meta else "None" + for input in inputs + ), + ) + + +@register_graph_pattern( + CallFunction(aten.bmm, Arg(), Arg()), + pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"), +) +def decompose_bmm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node): + def repl(mat1, mat2): + return torch.sum(mat1[:, :, :, None] * mat2[:, None, :, :], dim=-2).to( + mat1.dtype + ) + + if should_decompose_bmm(mat1, mat2): + counters["inductor"]["decompose_bmm"] += 1 + match.replace_by_example(repl, [mat1, mat2]) + print_decompose_pattern(match, [mat1, mat2]) + realize_inputs([mat1, mat2]) + return + + +@register_graph_pattern( + CallFunction(aten.addmm, Arg(), Arg(), Arg()), + pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"), +) +def decompose_addmm( + match: Match, + mat1: torch.fx.Node, + mat2: torch.fx.Node, + mat3: torch.fx.Node, +): + def repl(mat1, mat2, mat3): + return ( + torch.sum(mat2[:, :, None] * mat3[None, :, :], dim=-2).to(mat2.dtype) + mat1 + ) + + if should_decompose_mm(mat2, mat3): + counters["inductor"]["decompose_addmm"] += 1 + match.replace_by_example(repl, [mat1, mat2, mat3]) + print_decompose_pattern(match, [mat1, mat2, mat3]) + realize_inputs([mat1, mat2, mat3]) + return + + +@register_graph_pattern( + CallFunction(aten.mm, Arg(), Arg()), + pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"), +) +def decompose_mm( + match: Match, + mat1: torch.fx.Node, + mat2: torch.fx.Node, +): + def repl(mat1, mat2): + return torch.sum(mat1[:, :, None] * mat2[None, :, :], dim=-2).to(mat1.dtype) + + if should_decompose_mm(mat1, mat2): + counters["inductor"]["decompose_mm"] += 1 + match.replace_by_example(repl, [mat1, mat2]) + print_decompose_pattern(match, [mat1, mat2]) + realize_inputs([mat1, mat2]) + return diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py new file mode 100644 index 0000000000000000000000000000000000000000..f24b68b6c8f68fac86ee332ea9e610fdcf30badd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py @@ -0,0 +1,80 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from typing import Union + +import torch +from torch import SymBool, SymFloat, SymInt +from torch.types import py_sym_types + + +@dataclass +class _SymExprHash: + """ + Hash for a py_sym_types that will use the underlying sympy expression + """ + + sym_obj: Union[SymInt, SymFloat, SymBool] + + def __hash__(self) -> int: + return hash((type(self.sym_obj), self.sym_obj.node.expr)) + + def __eq__(self, value) -> bool: + if not isinstance(value, _SymExprHash): + return False + return self.sym_obj.node.expr == value.sym_obj.node.expr + + +class _SymHashingDict: + """ + Wrapper around a dictionary that will convert sym types to hash with _SymExprHash and reuse + existing sym proxies. + + SymPy hash is not always reliable so optimistically hash sympy expression, and if those fail, + fallback to symnodes. + """ + + def __init__(self): + self.sym_hash_dict = {} + + def __setitem__(self, key, value): + self.sym_hash_dict.__setitem__(self._wrap_to_sym_expr_hash(key), value) + + def __getitem__(self, key): + return self.sym_hash_dict[self._wrap_to_sym_expr_hash(key)] + + def __contains__(self, key): + return self._wrap_to_sym_expr_hash(key) in self.sym_hash_dict + + def get(self, key, default=None): + return self.sym_hash_dict.get(self._wrap_to_sym_expr_hash(key), default) + + def _wrap_to_sym_expr_hash(self, key): + return _SymExprHash(key) if isinstance(key, py_sym_types) else key + + +def dedupe_symints(graph: torch.fx.Graph): + """ + Dedupes sym ints in the graph to nodes are resolvable to symint graph inputs. + + We only dedupe from graph inputs to avoid adding a potential dependency in the forward + from the backward. + + """ + + sym_dict = _SymHashingDict() + resolvable_from_input_symints = set() + + for node in graph.nodes: + val = node.meta.get("val", None) + if val is None or not isinstance(val, py_sym_types): + continue + + if node.op == "placeholder": + resolvable_from_input_symints.add(node) + sym_dict[val] = node + elif existing_node := sym_dict.get(val): + node.replace_all_uses_with(existing_node) + graph.erase_node(node) + elif all(n in resolvable_from_input_symints for n in node.all_input_nodes): + sym_dict[val] = node + resolvable_from_input_symints.add(node) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..4845142caab347ce9f67c958772fe7b23fddc599 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py @@ -0,0 +1,406 @@ +# mypy: allow-untyped-defs +import torch +import torch.nn as nn +from torch._dynamo.utils import counters +from torch._inductor import config as inductor_config +from torch.func import functional_call + +from ..pattern_matcher import ( + CallFunctionVarArgs, + CallModuleVarArgs, + Match, + register_graph_pattern, +) +from .pre_grad import efficient_conv_bn_eval_pass + + +def efficient_conv_bn_eval( + bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor +): + """ + Implementation based on https://arxiv.org/abs/2305.11624 + "Efficient ConvBN Blocks for Transfer Learning and Beyond" + It leverages the associative law between convolution and affine transform, + i.e., normalize (weight conv feature) = (normalize weight) conv feature. + It works for Eval mode of ConvBN blocks during validation, and can be used + for **training** as well, but only if one sets `bn.training=False`. It + reduces memory footprint and computation cost, at the cost of slightly + reduced numerical stability. + Args: + bn (nn.modules.batchnorm._BatchNorm): a BatchNorm module. + conv (nn.modules.conv._ConvNd): a conv module + x (torch.Tensor): Input feature map. + """ + + assert bn.running_var is not None + + # These lines of code are designed to deal with various cases + # like bn without affine transform, and conv without bias + weight_on_the_fly = conv.weight + if conv.bias is not None: + bias_on_the_fly = conv.bias + else: + bias_on_the_fly = torch.zeros_like(bn.running_var) + + if bn.weight is not None: + bn_weight = bn.weight + else: + bn_weight = torch.ones_like(bn.running_var) + + if bn.bias is not None: + bn_bias = bn.bias + else: + bn_bias = torch.zeros_like(bn.running_var) + + # shape of [C_out, 1, 1, 1] in Conv2d + target_shape = [-1] + [1] * (conv.weight.ndim - 1) + if isinstance(conv, nn.modules.conv._ConvTransposeNd): + # for transposed conv, the C_out dimension should at index 1. + target_shape[:2] = [target_shape[1], target_shape[0]] + weight_coeff = torch.rsqrt(bn.running_var + bn.eps).reshape(target_shape) + # shape of [C_out, 1, 1, 1] in Conv2d + coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff + + # shape of [C_out, C_in, k, k] in Conv2d + weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly + # shape of [C_out] in Conv2d + bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * ( + bias_on_the_fly - bn.running_mean + ) + + input = x + params = {"weight": weight_on_the_fly, "bias": bias_on_the_fly} + output = functional_call(conv, params, input) + return output + + +def efficient_conv_bn_eval_decomposed( + bn_weight, + bn_bias, + bn_running_mean, + bn_running_var, + bn_eps, + conv: torch._ops.OpOverload, + conv_weight, + conv_bias, + x, + conv_remainging_args, +): + """ + Implementation based on https://arxiv.org/abs/2305.11624 + "Efficient ConvBN Blocks for Transfer Learning and Beyond" + It leverages the associative law between convolution and affine transform, + i.e., normalize (weight conv feature) = (normalize weight) conv feature. + It works for Eval mode of ConvBN blocks during validation, and can be used + for **training** as well, but only if one sets `bn.training=False`. It + reduces memory footprint and computation cost, at the cost of slightly + reduced numerical stability. + Args: + """ + assert bn_running_var is not None + + # These lines of code are designed to deal with various cases + # like bn without affine transform, and conv without bias + weight_on_the_fly = conv_weight + if conv_bias is not None: + bias_on_the_fly = conv_bias + else: + bias_on_the_fly = torch.zeros_like(bn_running_var) + + if bn_weight is not None: + bn_weight = bn_weight + else: + bn_weight = torch.ones_like(bn_running_var) + + if bn_bias is not None: + bn_bias = bn_bias + else: + bn_bias = torch.zeros_like(bn_running_var) + + # shape of [C_out, 1, 1, 1] in Conv2d + target_shape = [-1] + [1] * (conv_weight.ndim - 1) + if "conv_transpose" in conv.__str__(): + # for transposed conv, the C_out dimension should at index 1. + target_shape[:2] = [target_shape[1], target_shape[0]] + weight_coeff = torch.rsqrt(bn_running_var + bn_eps).reshape(target_shape) + # shape of [C_out, 1, 1, 1] in Conv2d + coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff + + # shape of [C_out, C_in, k, k] in Conv2d + weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly + # shape of [C_out] in Conv2d + bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * ( + bias_on_the_fly - bn_running_mean + ) + + input = x + return conv(*((input, weight_on_the_fly, bias_on_the_fly) + conv_remainging_args)) + + +@register_graph_pattern( + CallFunctionVarArgs( + [ + torch.nn.functional.batch_norm, + ] + ), + pass_dict=efficient_conv_bn_eval_pass, + extra_check=lambda match: not inductor_config.freezing + and inductor_config.efficient_conv_bn_eval_fx_passes, +) +def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs): + bn_node = match.nodes[0] + graph = match.graph + assert len(bn_node.args) == 8 + + # We can only use efficient conv-bn for eval mode with track_running_stats + # bn_node.args is `training` + if bn_node.args[-3]: + return + + # Check if the input is Conv + input_node = bn_node.args[0] + + if input_node.op != "call_function": # type: ignore[union-attr] + return + + input_fn = input_node.target # type: ignore[arg-type, union-attr] + supported_convs = [ + torch._C._nn.linear, + torch.conv1d, + torch.conv2d, + torch.conv3d, + torch.conv_transpose1d, + torch.conv_transpose2d, + torch.conv_transpose3d, + ] + + if not any(input_fn is cls for cls in supported_convs): + return + + conv_node = input_node + # Output of conv is used by other nodes, cannot optimize + if len(conv_node.users) > 1: # type: ignore[union-attr] + return + + counters["inductor"]["efficient_conv_bn_eval"] += 1 + + with graph.inserting_before(bn_node): + # prepare args for the fused function + bn_running_mean = bn_node.args[1] + bn_running_var = bn_node.args[2] + bn_weight = bn_node.args[3] + bn_bias = bn_node.args[4] + bn_eps = bn_node.args[7] + assert len(conv_node.args) >= 2 # type: ignore[union-attr] + conv_input = conv_node.args[0] # type: ignore[union-attr] + conv_weight = conv_node.args[1] # type: ignore[union-attr] + conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None # type: ignore[union-attr] + conv_remainging_args = conv_node.args[3:] # type: ignore[union-attr] + args = ( + bn_weight, + bn_bias, + bn_running_mean, + bn_running_var, + bn_eps, + conv_node.target, # type: ignore[union-attr] + conv_weight, + conv_bias, + conv_input, + conv_remainging_args, + ) + + # create a new node + new_node = graph.create_node( + op="call_function", + target=efficient_conv_bn_eval_decomposed, + args=args, # type: ignore[arg-type] + name="efficient_conv_bn_eval", + ) + + # this node replaces the original conv + bn, and therefore + # should replace the uses of bn_node + bn_node.replace_all_uses_with(new_node) + # take care of the deletion order: + # delete bn_node first, and then conv_node + graph.erase_node(bn_node) + graph.erase_node(conv_node) # type: ignore[arg-type] + + return + + +@register_graph_pattern( + CallFunctionVarArgs( + [ + torch.ops.aten.batch_norm.default, + ] + ), + pass_dict=efficient_conv_bn_eval_pass, + extra_check=lambda match: not inductor_config.freezing + and inductor_config.efficient_conv_bn_eval_fx_passes, +) +def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwargs): + bn_node = match.nodes[0] + graph = match.graph + assert len(bn_node.args) == 9 + + # We can only use efficient conv-bn for eval mode with track_running_stats + # bn_node.args is `training` + if bn_node.args[-4]: + return + + # Check if the input is Conv + input_node = bn_node.args[0] + + if input_node.op != "call_function": # type: ignore[union-attr] + return + + input_fn = input_node.target # type: ignore[arg-type, union-attr] + supported_convs = [ + torch.ops.aten.linear.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv3d.default, + torch.ops.aten.conv_transpose1d.default, + torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv_transpose3d.input, + ] + + if not any(input_fn is cls for cls in supported_convs): + return + + conv_node = input_node + # Output of conv is used by other nodes, cannot optimize + if len(conv_node.users) > 1: # type: ignore[union-attr] + return + + counters["inductor"]["efficient_conv_bn_eval"] += 1 + + with graph.inserting_before(bn_node): + # prepare args for the fused function + bn_weight = bn_node.args[1] + bn_bias = bn_node.args[2] + bn_running_mean = bn_node.args[3] + bn_running_var = bn_node.args[4] + bn_eps = bn_node.args[7] + assert len(conv_node.args) >= 2 # type: ignore[union-attr] + conv_input = conv_node.args[0] # type: ignore[union-attr] + conv_weight = conv_node.args[1] # type: ignore[union-attr] + conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None # type: ignore[union-attr] + conv_remainging_args = conv_node.args[3:] # type: ignore[union-attr] + args = ( + bn_weight, + bn_bias, + bn_running_mean, + bn_running_var, + bn_eps, + conv_node.target, # type: ignore[union-attr] + conv_weight, + conv_bias, + conv_input, + conv_remainging_args, + ) + + # create a new node + new_node = graph.create_node( + op="call_function", + target=efficient_conv_bn_eval_decomposed, + args=args, # type: ignore[arg-type] + name="efficient_conv_bn_eval", + ) + + # this node replaces the original conv + bn, and therefore + # should replace the uses of bn_node + bn_node.replace_all_uses_with(new_node) + # take care of the deletion order: + # delete bn_node first, and then conv_node + graph.erase_node(bn_node) + graph.erase_node(conv_node) # type: ignore[arg-type] + + return + + +@register_graph_pattern( + CallModuleVarArgs( + [ + nn.modules.batchnorm._BatchNorm, + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.SyncBatchNorm, + ], + ), + pass_dict=efficient_conv_bn_eval_pass, + extra_check=lambda match: not inductor_config.freezing + and inductor_config.efficient_conv_bn_eval_fx_passes, +) +def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs): + # We matched a BN node + bn_node = match.nodes[0] + graph = match.graph + gm = graph.owning_module + bn_mod = getattr(gm, bn_node.target) # type: ignore[arg-type] + + # We can only use efficient conv-bn for eval mode with track_running_stats + if not bn_mod.track_running_stats or bn_mod.training: + return + + # Check if the input is Conv + if bn_node.args: + input_node = bn_node.args[0] + else: + input_node = bn_node.kwargs["input"] + if input_node.op != "call_module": # type: ignore[union-attr] + return + if not hasattr(gm, input_node.target): # type: ignore[arg-type, union-attr] + return + input_mod = getattr(gm, input_node.target) # type: ignore[arg-type, union-attr] + supported_convs = [ + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + ] + if not any(isinstance(input_mod, cls) for cls in supported_convs): + return + conv_node = input_node + # Output of conv is used by other nodes, cannot optimize + if len(conv_node.users) > 1: # type: ignore[union-attr] + return + + # Find a pair of conv and bn computation nodes to optimize. + counters["inductor"]["efficient_conv_bn_eval"] += 1 + + with graph.inserting_before(conv_node): # type: ignore[arg-type] + # create `get_attr` node to access modules + # note that we directly call `create_node` to fill the `name` + # argument. `graph.get_attr` and + # `graph.call_function` does not allow the `name` argument. + conv_get_node = graph.create_node( + op="get_attr", target=conv_node.target, name="get_conv" # type: ignore[union-attr] + ) + bn_get_node = graph.create_node( + op="get_attr", target=bn_node.target, name="get_bn" + ) + if conv_node.args: # type: ignore[union-attr] + conv_input = conv_node.args[0] # type: ignore[union-attr] + else: + conv_input = conv_node.kwargs["input"] # type: ignore[union-attr] + # prepare args for the fused function + args = (bn_get_node, conv_get_node, conv_input) + # create a new node + new_node = graph.create_node( + op="call_function", + target=efficient_conv_bn_eval, + args=args, + name="efficient_conv_bn_eval", + ) + # this node replaces the original conv + bn, and therefore + # should replace the uses of bn_node + bn_node.replace_all_uses_with(new_node) + # take care of the deletion order: + # delete bn_node first, and then conv_node + graph.erase_node(bn_node) + graph.erase_node(conv_node) # type: ignore[arg-type] diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/freezing_patterns.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/freezing_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..572fcc9e9299d2cf4d0b1fa901568f5e86bf4cad --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/freezing_patterns.py @@ -0,0 +1,227 @@ +# mypy: allow-untyped-defs +import functools + +import torch +from torch._inductor.compile_fx import fake_tensor_prop + +from ..._dynamo.utils import counters +from .. import config +from ..pattern_matcher import ( + _return_true, + CallFunction, + fwd_only, + Ignored, + init_once_fakemode, + KeywordArg, + Match, + PatternMatcherPass, + register_graph_pattern, + register_replacement, + stable_topological_sort, +) + + +aten = torch.ops.aten + +# First pass_patterns[0] are applied, then [1], then [2] +pass_patterns = [ + PatternMatcherPass(), + PatternMatcherPass(), + PatternMatcherPass(), +] + +binary_folding_pass = PatternMatcherPass() + + +def freezing_passes(gm: torch.fx.GraphModule, aot_example_inputs): + """ + Passes that are applied to the graph to freeze pass. + """ + + from ..freezing import constant_fold + + lazy_init() + # We need a few rounds of binary folding to get rid of all the + # unnecessary nodes, but may need a good method to chose the rounds number. + # works like: conv+binary+binary. + binary_folding = counters["inductor"]["binary_folding"] + fake_tensor_prop(gm, aot_example_inputs, True) + + torch._inductor.fx_passes.binary_folding.mark_mixed_dtype_allowed_convs(gm) + for _ in range(4): + constant_fold(gm) + # Make sure meta['val'] is properly set for all nodes + fake_tensor_prop(gm, aot_example_inputs, True) + binary_folding_pass.apply(gm.graph) # type: ignore[arg-type] + # If we don't have binary folding, we don't need to run the pass again. + # TODO: remove the need to run fake_tensor_prop on the whole model. + if counters["inductor"]["binary_folding"] == binary_folding: + break + binary_folding = counters["inductor"]["binary_folding"] + + torch._inductor.fx_passes.binary_folding.recover_original_precision_folded_convs(gm) + + constant_fold(gm) + fake_tensor_prop(gm, aot_example_inputs, True) + + for pattern in pass_patterns: + pattern.apply(gm.graph) # type: ignore[arg-type] + + # The CPU weight packing always assume the conv's weight is channels last, + # So make sure the layout_optimization is on when doing it. + if ( + torch._C._has_mkldnn + and config.cpp.weight_prepack + and config.layout_optimization + ): + from .mkldnn_fusion import _eliminate_duplicate_packed_nodes + + _eliminate_duplicate_packed_nodes(gm) + + stable_topological_sort(gm.graph) + gm.recompile() + gm.graph.lint() + + +@init_once_fakemode +def lazy_init(): + if torch._C._has_mkldnn and config.cpp.weight_prepack: + from .mkldnn_fusion import _mkldnn_weight_pack_init + + _mkldnn_weight_pack_init() + + from .binary_folding import binary_folding_init + + addmm_patterns_init() + binary_folding_init() + + +def register_freezing_graph_pattern(pattern, extra_check=_return_true, pass_number=0): + return register_graph_pattern( + pattern, + extra_check=extra_check, + pass_dict=pass_patterns[pass_number], + ) + + +def register_binary_folding_pattern(pattern, extra_check=_return_true): + return register_graph_pattern( + pattern, + extra_check=extra_check, + pass_dict=binary_folding_pass, + ) + + +@functools.lru_cache(None) +def addmm_patterns_init(): + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + val = functools.partial(torch.empty, (10, 10), device=device, requires_grad=False) + + def check_concat_weights(match): + weight_inputs = ["w1", "w2"] + if "w3" in match.kwargs: + weight_inputs.append("w3") + + equal_shape_inputs = [weight_inputs] + + if "b1" in match.kwargs: + bias_inputs = ["b1", "b2"] + if "b3" in match.kwargs: + bias_inputs.append("b3") + + equal_shape_inputs.append(bias_inputs) + + for equal_shape_group in equal_shape_inputs: + inps = [match.kwargs[name] for name in equal_shape_group] + + if not all( + inp.op == "get_attr" + and inp.meta["val"].shape == inps[0].meta["val"].shape + for inp in inps + ): + return False + + return True + + def matmul_fuse_pattern(inp, w1, w2, w3): + return (inp @ w1, inp @ w2, inp @ w3) + + def matmul_replacement(inp, w1, w2, w3): + cat_t = torch.cat((w1, w2, w3), dim=1) + mm = inp @ cat_t + return mm.chunk(3, dim=1) + + register_replacement( + matmul_fuse_pattern, + matmul_replacement, + [val(), val(), val(), val()], + fwd_only, + pass_patterns[0], + extra_check=check_concat_weights, + exclusive_arg_names=("w1", "w2", "w3"), + ) + + def matmul_fuse_pattern_two(inp, w1, w2): + return (inp @ w1, inp @ w2) + + def matmul_replacement_two(inp, w1, w2): + cat_t = torch.cat((w1, w2), dim=1) + mm = inp @ cat_t + return mm.chunk(2, dim=1) + + register_replacement( + matmul_fuse_pattern_two, + matmul_replacement_two, + [val(), val(), val()], + fwd_only, + pass_patterns[0], + extra_check=check_concat_weights, + exclusive_arg_names=("w1", "w2"), + ) + + def addmm_fuse_pattern_second(inp, w1, w2, w3, b1, b2, b3): + return ( + aten.addmm(b1, inp, w1), + aten.addmm(b2, inp, w2), + aten.addmm(b3, inp, w3), + ) + + def addmm_fuse_replacement_second(inp, w1, w2, w3, b1, b2, b3): + cat_w = torch.cat((w1, w2, w3), dim=1) + cat_b = torch.cat((b1, b2, b3)) + return aten.addmm(cat_b, inp, cat_w).chunk(3, dim=1) + + register_replacement( + addmm_fuse_pattern_second, + addmm_fuse_replacement_second, + [val() for _ in range(7)], + fwd_only, + pass_patterns[0], + extra_check=check_concat_weights, + exclusive_arg_names=("w1", "w2", "w3", "b1", "b2", "b3"), + ) + + +def same_dtype(match): + return match.output_node().args[0].meta["val"].dtype == match.kwargs["dtype"] + + +@register_graph_pattern( + CallFunction( + torch.ops.prims.convert_element_type.default, + Ignored(), + KeywordArg("dtype"), + ), + pass_dict=pass_patterns[0], + extra_check=same_dtype, +) +def unnecessary_dtype_convert(match: Match, **kwargs): + """Remove unnecessary dtype conversion op, probably left as a result of Conv-Bn folding""" + graph = match.graph + node = match.output_node() + node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type] + graph.erase_node(node) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/fuse_attention.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/fuse_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1a3b681230e06d0c0f7089b7740cddefc0e50634 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/fuse_attention.py @@ -0,0 +1,909 @@ +# mypy: allow-untyped-defs +import functools +import inspect +import logging +import math + +import torch +from torch.nn.attention import sdpa_kernel, SDPBackend + +from ..._dynamo.utils import counters +from ..pattern_matcher import ( + filter_nodes, + fwd_only, + gen_register_replacement, + joint_fwd_bwd, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten + + +if torch.version.hip: + + def _scaled_dot_product_attention(*args, **kwargs): + with sdpa_kernel(backends=[SDPBackend.MATH, SDPBackend.FLASH_ATTENTION]): + return aten.scaled_dot_product_attention(*args, **kwargs) + +else: + _scaled_dot_product_attention = aten.scaled_dot_product_attention + + +def _sfdp_pattern_1(query, key, value, inv_scale): + return ( + torch.matmul(query, key.transpose(-2, -1)) + .div(inv_scale) + .softmax(dim=-1) + .matmul(value) + ) + + +def _sfdp_replacement_1(query, key, value, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_2(query, key, value, scale_factor): + return ( + torch.matmul(query, key.transpose(-2, -1)) + .mul(scale_factor) + .softmax(dim=-1) + .matmul(value) + ) + + +def _sfdp_replacement_2(query, key, value, scale_factor): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=scale_factor, + ) + + +def _sfdp_pattern_3(query, key, value, inv_scale_factor, dropout_p): + return torch.nn.functional.dropout( + torch.matmul(query, key.transpose(-2, -1)) + .div(inv_scale_factor) + .softmax(dim=-1), + p=dropout_p, + ).matmul(value) + + +def _sfdp_replacement_3(query, key, value, inv_scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale_factor, + ) + + +def _sfdp_pattern_4(query, key, value, scale_factor, dropout_p): + return torch.nn.functional.dropout( + torch.matmul(query, key.transpose(-2, -1)).mul(scale_factor).softmax(dim=-1), + p=dropout_p, + ).matmul(value) + + +def _sfdp_replacement_4(query, key, value, scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + scale=scale_factor, + ) + + +def _sfdp_pattern_5(query, key, value, attn_mask): + attn_weight = torch.softmax( + (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1 + ) + # attn_weight = torch.dropout(attn_weight, dropout_p) + return attn_weight @ value + + +def _sfdp_replacement_5(query, key, value, attn_mask): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=0.0, + is_causal=False, + ) + + +def _sfdp_pattern_6(query, key, value, attn_mask, dropout_p): + attn_weight = torch.softmax( + (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1 + ) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + return attn_weight @ value + + +def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=dropout_p, + is_causal=False, + ) + + +def _sfdp_pattern_7(query, key, value, dropout_p): + # in real workloads inputs to matmul are permuted + # causing matmul to expand to a series of expand and clone calls + # we want the same to happen during pattern tracing + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_7(query, key, value, dropout_p): + # sdpa prefers inputs in permuted format + # it makes a copy to put them in this format + # if they aren't already + # to make replacement efficient ensure that inputs to sdpa + # are in required order + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=dropout_p, + is_causal=False, + ) + + +def _sfdp_pattern_8(query, key, value): + # no dropout version of pattern 7 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_8(query, key, value): + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=0.0, + is_causal=False, + ) + + +def _sfdp_pattern_9(query, key, value, dropout_p): + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + q = q / math.sqrt(q.size(-1)) + div = q @ k.transpose(-2, -1) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_9(query, key, value, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=dropout_p, + is_causal=False, + ) + + +def _sfdp_pattern_10(query, key, value): + # no dropout version of 9 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + q = q / math.sqrt(q.size(-1)) + div = q @ k.transpose(-2, -1) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_10(query, key, value): + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=0.0, + is_causal=False, + ) + + +def _sfdp_pattern_11(query, key, value, inv_scale): + # Mainly for huggingface models + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return torch.matmul(q, k.transpose(-2, -1)).div(inv_scale).softmax(dim=-1).matmul(v) + + +def _sfdp_replacement_11(query, key, value, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_12(query, key, value, inv_scale_factor, dropout_p): + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return torch.nn.functional.dropout( + torch.matmul(q, k.transpose(-2, -1)).div(inv_scale_factor).softmax(dim=-1), + p=dropout_p, + ).matmul(v) + + +def _sfdp_replacement_12(query, key, value, inv_scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale_factor, + ) + + +def _sfdp_pattern_13(query, key, value, dropout_p): + attn_weight = torch.bmm(query, key.transpose(1, 2)).softmax(dim=-1) + attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p) + return torch.bmm(attn_weight, value) + + +def _sfdp_replacement_13(query, key, value, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + dropout_p=dropout_p, + scale=1.0, + ).squeeze(0) + + +def _sfdp_pattern_14(query, key, value, attn_mask, inv_scale): + # for BertLarge + # Permutations are needed to create clones in graph. + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + return ( + (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask) + .softmax(dim=-1) + .matmul(v) + ) + + +def _sfdp_replacement_14(query, key, value, attn_mask, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_15(query, key, value, attn_mask, inv_scale): + # for DistilBert + # Permutations are needed to create clones in graph. + # Ref: https://github.com/pytorch/pytorch/issues/119911 + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + bs = q.size(0) + k_len = k.size(-2) + scores = q @ k.transpose(-2, -1) + scores = scores.div(inv_scale) + fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) + attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) + return torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1) @ v + + +def _sfdp_replacement_15(query, key, value, attn_mask, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + bs = query.size(0) + n_head = query.size(2) + q_len = query.size(1) + k_len = key.size(1) + # do attn_mask->logical_not() in _scaled_dot_product_attention + attn_mask = ( + (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len)) + ) + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=torch.bool), + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_16(query, key, value, attn_mask, inv_scale, dropout_p): + # for BertLarge with dropout + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + return ( + torch.nn.functional.dropout( + (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask).softmax( + dim=-1 + ), + dropout_p, + ) + .to(dtype=query.dtype) + .matmul(v) + ) + + +def _sfdp_replacement_16(query, key, value, attn_mask, inv_scale, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_17(query, key, value, attn_mask, inv_scale, dropout_p): + # for DistilBert with dropout + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + bs = q.size(0) + k_len = k.size(-2) + scores = q @ k.transpose(-2, -1) + scores = scores.div(inv_scale) + fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) + attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) + return ( + torch.nn.functional.dropout( + torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1), dropout_p + ) + @ v + ) + + +def _sfdp_replacement_17(query, key, value, attn_mask, inv_scale, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + bs = query.size(0) + n_head = query.size(2) + q_len = query.size(1) + k_len = key.size(1) + # do attn_mask->logical_not() in _scaled_dot_product_attention + attn_mask = ( + (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len)) + ) + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=torch.bool), + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_18(query, key, value, causal_mask, dropout_p): + # for hf_GPT2 with dropout (introduces clone node) for inference + # it also returns permuted key & value + query = query.permute([0, 2, 1, 3]) + key = key.permute([0, 2, 1, 3]) + value = value.permute([0, 2, 1, 3]) + attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) + inv_scale = torch.full( + [], + value.size(-1) ** 0.5, + dtype=attn_weights.dtype, + device=attn_weights.device, + ) + attn_weights = attn_weights.div(inv_scale) + causal_mask_value = torch.full( + (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device + ) + attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value) + return ( + ( + torch.nn.functional.dropout(attn_weights.softmax(dim=-1), dropout_p).matmul( + value + ) + ), + key, + value, + ) + + +def _sfdp_replacement_18(query, key, value, causal_mask, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + permuted_key = key.transpose(1, 2) + permuted_value = value.transpose(1, 2) + return ( + _scaled_dot_product_attention( + query.transpose(1, 2), + permuted_key, + permuted_value, + attn_mask=causal_mask, + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / math.sqrt(value.size(-1)), + ), + permuted_key, + permuted_value, + ) + + +def _sfdp_pattern_19(query, key, value, causal_mask, attn_mask, dropout_p): + # for token-classification+gpt2 / text-generation+gpt2 + attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) + inv_scale = torch.full( + [], + value.size(-1) ** 0.5, + dtype=attn_weights.dtype, + device=attn_weights.device, + ) + attn_weights = attn_weights.div(inv_scale) + causal_mask_value = torch.full( + (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device + ) + attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value) + attn_weights = attn_weights + attn_mask + attn_weights = attn_weights.softmax(dim=-1).type(value.dtype) + return torch.nn.functional.dropout(attn_weights, dropout_p).matmul(value) + + +def _sfdp_replacement_19(query, key, value, causal_mask, attn_mask, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) + attn_mask = torch.where(causal_mask, attn_mask, fill_value) + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / math.sqrt(value.size(-1)), + ) + + +def _sfdp_params_check(match): + assert all(k in match.kwargs for k in ("query", "key", "value")) + query = match.kwargs["query"].meta["val"] + key = match.kwargs["key"].meta["val"] + value = match.kwargs["value"].meta["val"] + if not (query.dtype == key.dtype == value.dtype) or not ( + query.device == key.device == value.device + ): + return False + add_mask_node = filter_nodes(match.nodes, aten.add.Tensor) + # Has attn_mask add. + if len(add_mask_node) > 0: + attn_mask_node = add_mask_node[0].args[1] + # attn_mask_node may be a float/int number. + if not hasattr(attn_mask_node, "meta"): + return False + attn_mask = attn_mask_node.meta["val"] # type: ignore[union-attr] + # Make sure attn_mask.dtype == query.dtype or attn_mask.dtype == torch.bool + # attn_mask.dtype == torch.float for models like albert. + if ( + not isinstance(attn_mask, torch.Tensor) + or not ( + attn_mask.dtype == query.dtype + or attn_mask.dtype == torch.bool + or attn_mask.dtype == torch.float + ) + or query.device != attn_mask.device + ): + return False + return True + + +def _sfdp_extra_check(scale_factor_op=None, disable_cuda=False): + def fn(match): + if ( + disable_cuda + and "query" in match.kwargs + and "cuda" in str(match.kwargs["query"].meta["val"].device) + ): + return False + if scale_factor_op is not None: + scale_factor_node = filter_nodes(match.nodes, scale_factor_op)[0] + # Note: args[1] of the scale_factor_node is always the scale_factor for the current patterns. + scale_factor = scale_factor_node.args[1] + # make sure the scale_factor a float/int. SymInt? + if not isinstance(scale_factor, (float, int)): + return False + return _sfdp_params_check(match) + + return fn + + +def partialize_and_update_signature(func, **kwargs): + """ + Equivalent to functools.partial but also updates the signature on returned function + """ + original_sig = inspect.signature(func) + parameters = original_sig.parameters + + new_parameters = { + key: value for key, value in parameters.items() if key not in kwargs + } + new_sig = inspect.Signature(parameters=list(new_parameters.values())) + + partial_func = functools.partial(func, **kwargs) + + def wrapper(*args, **kwargs): + return partial_func(*args, **kwargs) + + wrapper.__signature__ = new_sig # type: ignore[attr-defined] + wrapper.__name__ = func.__name__ + + return wrapper + + +def _get_sfdp_patterns(): + from .joint_graph import patterns + + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + + # sizes/values don't actually matter for initial trace + # once we get a possible match we re-trace with the actual values and verify the match still holds + g_inp = functools.partial( + torch.empty, (2, 4, 8, 16), device=device, requires_grad=True + ) + # attn_mask + b_inp = functools.partial(torch.empty, (1, 1, 8, 8), device=device) + m_inp = functools.partial(torch.empty, (2, 1, 1, 4), device=device) + # inv_scale + c_inp = functools.partial(torch.tensor, 2.0, device=device) + # workaround https://github.com/pytorch/pytorch/issues/97894 + # 0.113377 is a "magic" value that lets us recover the lost input arg relationship + d = {"dropout_p": 0.113377} + + # we could also generate all these patterns in 3d.. TODO + g_3d_inp = functools.partial( + torch.empty, (1024, 128, 128), device=device, requires_grad=True + ) + + # reshape in matmul decomposition generates a clone when batch_size>1 due to the memory layout change. + # however when batch_size=1, reshape does not change the memory layout, so clone would not be generated. + # here we need to trace with input of batch_size=1 to generate a pattern graph without clone. + g_bs1_inp = functools.partial( + torch.empty, (1, 4, 8, 16), device=device, requires_grad=True + ) + m_bs1_inp = functools.partial(torch.empty, (1, 1, 1, 4), device=device) + + # softmax will generate a dtype conversion on inputs if they are in half, + # but will not in float, so we generate a pattern for both + for dtype in [torch.float, torch.half]: + g = functools.partial(g_inp, dtype=dtype) + b = functools.partial(b_inp, dtype=dtype) + b_float = functools.partial(b_inp, dtype=torch.float) + b_bool = functools.partial(b_inp, dtype=torch.bool) + m = functools.partial(m_inp, dtype=dtype) + m_float = functools.partial(m_inp, dtype=torch.float) + m_bool = functools.partial(m_inp, dtype=torch.bool) + c = functools.partial(c_inp, dtype=dtype) + g_3d = functools.partial(g_3d_inp, dtype=dtype) + g_bs1 = functools.partial(g_bs1_inp, dtype=dtype) + m_bs1 = functools.partial(m_bs1_inp, dtype=dtype) + m_bs1_float = functools.partial(m_bs1_inp, dtype=torch.float) + m_bs1_bool = functools.partial(m_bs1_inp, dtype=torch.bool) + + candidates = [ + ( + _sfdp_pattern_1, + _sfdp_replacement_1, + [g(), g(), g(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_2, + _sfdp_replacement_2, + [g(), g(), g(), c()], + {}, + _sfdp_extra_check(aten.mul.Tensor), + ), + ( + _sfdp_pattern_3, + _sfdp_replacement_3, + [g(), g(), g(), c()], + d, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_4, + _sfdp_replacement_4, + [g(), g(), g(), c()], + d, + _sfdp_extra_check(aten.mul.Tensor), + ), + ( + _sfdp_pattern_5, + _sfdp_replacement_5, + [g(), g(), g(), b()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_6, + _sfdp_replacement_6, + [g(), g(), g(), b()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_7, + _sfdp_replacement_7, + [g(), g(), g()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_8, + _sfdp_replacement_8, + [g(), g(), g()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_9, + _sfdp_replacement_9, + [g(), g(), g()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_10, + _sfdp_replacement_10, + [g(), g(), g()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_11, + _sfdp_replacement_11, + [g(), g(), g(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_12, + _sfdp_replacement_12, + [g(), g(), g(), c()], + d, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_13, + _sfdp_replacement_13, + [g_3d(), g_3d(), g_3d()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_14, + _sfdp_replacement_14, + [g(), g(), g(), m(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_15, + _sfdp_replacement_15, + [g(), g(), g(), m(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + # TODO: Enable CUDA after solving Bert accuracy issue of calling efficient attention + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g(), g(), g(), m(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ), + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g_bs1(), g_bs1(), g_bs1(), m_bs1(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ), + ( + _sfdp_pattern_17, + _sfdp_replacement_17, + [g(), g(), g(), m(), c()], + d, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_18, + _sfdp_replacement_18, + [g(), g(), g(), m_bool()], + d, + # CUDA AOT Inductor CI job's GPT2ForSequenceClassification accuracy test failed + _sfdp_extra_check(disable_cuda=True), + ), + ( + _sfdp_pattern_18, + _sfdp_replacement_18, + [g_bs1(), g_bs1(), g_bs1(), m_bs1_bool()], + d, + # CUDA AOT Inductor CI job's GPT2ForSequenceClassification accuracy test failed + _sfdp_extra_check(disable_cuda=True), + ), + ( + _sfdp_pattern_19, + _sfdp_replacement_19, + [g(), g(), g(), b_bool(), b_float()], + d, + _sfdp_params_check, + ), + ] + mask_fp32_patterns = ["pattern_16"] + if dtype == torch.half: + # Add inputs of bf16 q/k/v and fp32 mask, for models like albert. + candidates.append( + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g(), g(), g(), m_float(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ) + ) + candidates.append( + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g_bs1(), g_bs1(), g_bs1(), m_bs1_float(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ) + ) + + for pattern, replacement, args, workaround, extra_check in candidates: + # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern + # gets serialized to a python file and does not require tracing at runtime. + assert isinstance(workaround, dict) + name = pattern.__name__ + + if dtype != torch.float: + name += "_half" + if ( + any(p in name for p in mask_fp32_patterns) + and args[3].dtype == torch.float32 + ): + name += "_mask_fp32" + if args[0].size(0) == 1: + name += "_bs1" + + training_name = name + "_training" + yield training_name, { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": joint_fwd_bwd, + "pass_dicts": patterns, + "extra_check": extra_check, + "scalar_workaround": workaround, + } + + if workaround: + assert len(workaround) == 1 and "dropout_p" in workaround + # functools.partial insufficient because we look at signature downstream + pattern = partialize_and_update_signature(pattern, dropout_p=0.0) + replacement = partialize_and_update_signature( + replacement, dropout_p=0.0 + ) + workaround = {} + + inference_name = name + "_inference" + yield inference_name, { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": fwd_only, + "pass_dicts": patterns, + "extra_check": extra_check, + "scalar_workaround": workaround, + # with dropout turned into clone, we end up with a number of + # semantically identical graphs + "skip_duplicates": True, + } + + +@functools.lru_cache(None) +def _sfdp_init(): + for key, register_replacement_kwargs in _get_sfdp_patterns(): + gen_register_replacement(key, **register_replacement_kwargs) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..5311c9789fa59ed3c2f993b14083b2b9e698ffe2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py @@ -0,0 +1,1317 @@ +# mypy: allow-untyped-defs +import collections +import logging +import operator +from collections import OrderedDict +from typing import ( + Any, + DefaultDict, + Deque, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, +) + +import torch +from torch._dynamo.utils import counters, optimus_scuba_log +from torch._utils_internal import upload_graph +from torch.fx.passes.graph_transform_observer import GraphTransformObserver + +from .. import config +from ..pattern_matcher import ( + CallFunctionVarArgs, + get_arg_value, + stable_topological_sort, +) + + +try: + # importing this will register fbgemm lowerings for inductor + import deeplearning.fbgemm.fbgemm_gpu.fb.inductor_lowerings # noqa: F401 + + has_fbgemm = True +except Exception: + has_fbgemm = False + +aten = torch.ops.aten + +log = logging.getLogger(__name__) + +MIN_FUSE_SET_SIZE = 5 +MAX_FUSE_SET_SIZE = 300 +MAX_FUSE_SEARCH_DEPTH = 5 +# The maximum tensor size that can go into the fusion group +MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR = 4096 +# Whether we only fuse nodes with same parent node +FUSE_NODES_WITH_SAME_PARENT = False +# Whether we enable the add broadcast in batch linear +SHAPE_BROADCAST_BATCH_LINEAR = False +# Whether we enable the fuse nodes with same users +Fuse_NODES_WITH_SAME_USERS = False + +# exclude these nodes from BFS +# excluding get item improves optimizer compilation time by 60s +SEARCH_EXCLUSIONS = {operator.getitem} + + +default_graph_search_options = { + "min_fuse_set_size": MIN_FUSE_SET_SIZE, + "max_fuse_set_size": MAX_FUSE_SET_SIZE, + "max_fuse_search_depth": MAX_FUSE_SEARCH_DEPTH, + "max_fuse_tensor_size_group_linear": MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR, + "fuse_nodes_with_same_parent": FUSE_NODES_WITH_SAME_PARENT, + "shape_broadcast_batch_linear": SHAPE_BROADCAST_BATCH_LINEAR, + "fuse_nodes_with_same_users": Fuse_NODES_WITH_SAME_USERS, +} + +graph_search_options = default_graph_search_options + + +def update_stack_example_value(node, metadata, dim=0, op=torch.stack): + """ + Update the example value of the node in the graph to enable followup split cat opt. + """ + if node is not None and hasattr(node, "meta"): + if op == torch.stack: + example_value = torch.stack(metadata, dim=dim) + elif op == torch.unbind: + example_value = torch.unbind(metadata, dim=dim) # type: ignore[assignment] + else: + return + node.meta["example_value"] = example_value + + +def update_pointwise_example_value(pointwise_node, input, other, op): + """ + Update the example value of the add node in the graph to enable followup split cat opt. + """ + if pointwise_node is not None and hasattr(pointwise_node, "meta"): + if op == torch.add: + example_value = torch.add(input, other) + elif op == torch.mul: + example_value = torch.mul(input, other) + else: + return + pointwise_node.meta["example_value"] = example_value + + +class GroupBatchFusionBase: + def __init__(self, **kwargs) -> None: + self.graph_search_options = kwargs.pop( + "graph_search_options", default_graph_search_options + ) + + def match(self, node): + raise NotImplementedError("match called on base") + + def fuse(self, graph, subset): + raise NotImplementedError("fuse called on base") + + +PRE_GRAD_FUSIONS: Dict[str, GroupBatchFusionBase] = {} +POST_GRAD_FUSIONS: Dict[str, GroupBatchFusionBase] = {} + + +def register_fusion(name: str, pre_grad=True): + def decorator(fusion_cls: GroupBatchFusionBase): + if pre_grad: + PRE_GRAD_FUSIONS[name] = fusion_cls + else: + POST_GRAD_FUSIONS[name] = fusion_cls + return fusion_cls + + return decorator + + +def list_group_batch_fusions(pre_grad=True) -> List[str]: + if pre_grad: + return list(PRE_GRAD_FUSIONS.keys()) + else: + return list(POST_GRAD_FUSIONS.keys()) + + +def decompose_stack(graph: torch.fx.GraphModule, input_tensors: List[Any]) -> Any: + unsqueezed_inputs = [] + unsqueezed_inputs_meta = [] + for input_tensor in input_tensors: + unsqueezed_input = graph.call_function( + aten.unsqueeze, args=(input_tensor,), kwargs={"dim": 0} + ) + unsqueezed_inputs.append(unsqueezed_input) + unsqueezed_input.meta["val"] = aten.unsqueeze(input_tensor.meta["val"], dim=0) # type: ignore[assignment] + unsqueezed_inputs_meta.append(unsqueezed_input.meta["val"]) + stacked_inputs = graph.call_function( + aten.cat, args=(unsqueezed_inputs,), kwargs={"dim": 0} + ) + stacked_inputs.meta["val"] = aten.cat(unsqueezed_inputs_meta, dim=0) # type: ignore[assignment] + return stacked_inputs + + +class GroupFusion(GroupBatchFusionBase): + """ + Fuse ops in a group way, e.g, fuse mm/addmm of arbitrary input shapes with fbgemm.gmm. + """ + + +class BatchFusion(GroupBatchFusionBase): + """ + Fuse ops in a batch way, e.g, fuse mm/addmm of same input shapes with bmm. + """ + + +class BatchPointwiseOpsFusionFactory(BatchFusion): + def __init__(self, op, **kwargs) -> None: + super().__init__(**kwargs) + self.op = op + + +@register_fusion("batch_linear_post_grad", pre_grad=False) +class PostGradBatchLinearFusion(BatchFusion): + """ + Fuse ops in a batch way in post grad (aten level). + """ + + def _addmm_node_can_be_fused(self, node: torch.fx.Node) -> bool: + # pyre-fixme[7]: Incompatible return type + return ( + node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0 # type: ignore[return-value] + ) + + def _is_input_2d(self, input: torch.fx.Node) -> bool: + input_shapes = input.meta["val"].shape + return ( + len(input_shapes) == 2 + and isinstance(input_shapes[0], int) + and isinstance(input_shapes[1], int) + ) + + def match( + self, node: torch.fx.Node + ) -> Optional[Tuple[str, int, int, int, bool, str]]: + if CallFunctionVarArgs(aten.mm).match(node): + input_m, weight_m = node.args + bias_m = None + + elif CallFunctionVarArgs(aten.addmm.default).match( + node + ) and self._addmm_node_can_be_fused(node): + bias_m, input_m, weight_m = node.args + else: + return None + # get the user of the node + if self.graph_search_options.get("fuse_nodes_with_same_users", False): + users = [user.target for user in node.users.keys()] + else: + users = "" # type: ignore[assignment] + # only handle the cases where inputs are 2D tensors + if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m): # type: ignore[arg-type] + return None + m, k = input_m.meta["val"].shape # type: ignore[union-attr] + n = weight_m.meta["val"].shape[1] # type: ignore[union-attr] + batch_key = ("batch_linear_post_grad", m, k, n, bias_m is not None, str(users)) + return batch_key + + def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): + batch_inputs = [] + batch_weights = [] + batch_biases = [] + batch_nodes = [] + batch_inputs_meta = [] + batch_weights_meta = [] + batch_biases_meta = [] + + for node in subset: + if CallFunctionVarArgs(aten.addmm.default).match(node): + bias, input, weight = node.args + elif CallFunctionVarArgs(aten.mm.default).match(node): + input, weight = node.args + bias = None + batch_nodes.append(node) + batch_inputs.append(input) # type: ignore[possibly-undefined] + batch_weights.append(weight) # type: ignore[possibly-undefined] + batch_biases.append(bias) # type: ignore[possibly-undefined] + batch_inputs_meta.append(input.meta) # type: ignore[possibly-undefined, union-attr] + batch_weights_meta.append(weight.meta) # type: ignore[possibly-undefined, union-attr] + if bias is not None: # type: ignore[possibly-undefined] + batch_biases_meta.append(bias.meta) # type: ignore[possibly-undefined, union-attr] + else: + batch_biases_meta.append(None) + + with graph.inserting_before(subset[-1]): + fused_inputs = decompose_stack(graph, batch_inputs) + fused_weights = decompose_stack(graph, batch_weights) + fused_inputs_meta_val = torch.stack( + [input["val"] for input in batch_inputs_meta] + ) + fused_weights_meta_val = torch.stack( + [weight["val"] for weight in batch_weights_meta] + ) + fused_bmm = graph.call_function( + aten.bmm, + args=(fused_inputs, fused_weights), + ) + fused_bmm.meta["val"] = aten.bmm( + fused_inputs_meta_val, fused_weights_meta_val + ) + for i, original_mm in enumerate(batch_nodes): + has_bias = False + with graph.inserting_after(fused_bmm): + new_mm = graph.call_function(aten.select, args=((fused_bmm, 0, i))) + new_mm.meta["val"] = aten.select(fused_bmm.meta["val"], 0, i) + if batch_biases[i]: + has_bias = True + # broadcast the bias to the same shape as the mm output + if self.graph_search_options.get( + "shape_broadcast_batch_linear", False + ): + broadcast_shape = torch.broadcast_shapes( + batch_biases_meta[i]["val"].shape, new_mm.meta["val"].shape + ) + broadcast_bias = graph.call_function( + aten.broadcast_to.default, + args=(batch_biases[i],), + kwargs={"size": broadcast_shape}, + ) + broadcast_bias.meta["val"] = aten.broadcast_to(batch_biases_meta[i]["val"], broadcast_shape) # type: ignore[assignment] + new_bias_add = graph.call_function( + aten.add.Tensor, args=((broadcast_bias, new_mm)) + ) + new_bias_add.meta["val"] = aten.add.Tensor( + broadcast_bias.meta["val"], new_mm.meta["val"] + ) + else: + new_bias_add = graph.call_function( + aten.add, args=((batch_biases[i], new_mm)) + ) + new_bias_add.meta["val"] = aten.add.Tensor( + batch_biases_meta[i]["val"], new_mm.meta["val"] + ) + new_mm_cont = new_bias_add if has_bias else new_mm # type: ignore[possibly-undefined] + original_mm.replace_all_uses_with(new_mm_cont) + new_mm_cont.meta.update(original_mm.meta) + graph.erase_node(original_mm) + counters["inductor"]["batch_linear_post_grad"] += 1 + + +@register_fusion("group_linear", pre_grad=False) +class GroupLinearFusion(GroupFusion): + def _addmm_node_can_be_fused(self, node: torch.fx.Node): + input_shape = node.args[1].meta["val"].shape # type: ignore[union-attr] + weight_shape = node.args[2].meta["val"].shape # type: ignore[union-attr] + return ( + node.kwargs.get("beta", 1.0) == 1.0 + and node.kwargs.get("alpha", 1.0) == 1.0 + and len(input_shape) == 2 + and len(weight_shape) == 2 + and all(x % 2 == 0 for x in input_shape + weight_shape) + and all( + shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"] + for shape in input_shape + weight_shape + ) + ) + + def _mm_node_can_be_fused(self, node: torch.fx.Node): + input_shape = node.args[0].meta["val"].shape # type: ignore[union-attr] + weight_shape = node.args[1].meta["val"].shape # type: ignore[union-attr] + return ( + len(input_shape) == 2 + and len(weight_shape) == 2 + and all(x % 2 == 0 for x in input_shape + weight_shape) + and all( + shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"] + for shape in input_shape + weight_shape + ) + ) + + def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool]]: + if CallFunctionVarArgs(aten.mm.default).match( + node + ) and self._mm_node_can_be_fused(node): + group_key = ("group_linear", True) + elif CallFunctionVarArgs(aten.addmm.default).match( + node + ) and self._addmm_node_can_be_fused(node): + bias = node.args[0] + group_key = ("group_linear", bias is None) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): + group_inputs = [] + group_weights = [] + group_biases = [] + group_nodes = [] + for node in subset: + if CallFunctionVarArgs(aten.addmm.default).match(node): + bias, input, weight = node.args + else: + assert CallFunctionVarArgs(aten.mm.default).match(node) + input, weight = node.args + bias = None + + group_nodes.append(node) + group_inputs.append(input) + group_weights.append(weight) + group_biases.append(bias) + + if all(bias is None for bias in group_biases): + group_biases = None # type: ignore[assignment] + + with graph.inserting_before(subset[0]): + fused_mm = graph.call_function( + torch.ops.fbgemm.gmm.default, + args=(group_inputs, group_weights, group_biases), + kwargs={"smart_fused": True}, + ) + + for i, original_mm in enumerate(group_nodes): + with graph.inserting_after(fused_mm): + new_mm = graph.call_function(operator.getitem, args=(fused_mm, i)) + original_mm.replace_all_uses_with(new_mm) + new_mm.meta.update(original_mm.meta) + graph.erase_node(original_mm) + counters["inductor"]["group_linear"] += 1 + + +class BatchPointwiseMathOpsPostGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch pointwise math operator (e.g., add, mul) in post grad pass. + """ + + def __init__(self, op, **kwargs) -> None: + super().__init__(op, **kwargs) + self.op = op + + def _pointwise_node_can_be_fused(self, node: torch.fx.Node): + # note: we only consider the case where the inputs are tensors + # for mixed precision training, we need to make sure the inputs + # of the aten.cat when do the stack should be the same dtype + # otherwise, the output of the aten.cat may be not the same as + # its inputs, and cause dtype not same error in mm or addmm + input, other = node.args + return ( + input.meta["val"].shape == other.meta["val"].shape # type: ignore[union-attr] + if hasattr(input, "meta") + and hasattr(other, "meta") + and "val" in input.meta # type: ignore[union-attr] + and "val" in other.meta # type: ignore[union-attr] + else False + ) + + def match(self, node: torch.fx.Node): + if CallFunctionVarArgs(self.op).match( + node + ) and self._pointwise_node_can_be_fused(node): + alpha = node.kwargs.get("alpha", 1.0) + rounding_mode = node.kwargs.get("rounding_mode", None) + input, other = node.args + shape = list(input.meta["val"].shape) # type: ignore[union-attr] + if self.graph_search_options.get("fuse_nodes_with_same_parent", False): + # only consider the linear case so far + # pyre-fixme[16] + if input.target == aten.select or other.target == aten.select: # type: ignore[union-attr] + parent = ( + # pyre-fixme[16] + input.args[0] # type: ignore[union-attr] + # pyre-fixme[16] + if input.target == aten.select # type: ignore[union-attr] + else other.args[0] # type: ignore[union-attr] + ) + else: + parent = "" + else: + parent = "" + group_key = ( + "batch_aten_" + self.op.__name__.lower().split(".")[0], + str(shape), + str(input.meta["val"].dtype), # type: ignore[union-attr] + str(other.meta["val"].dtype), # type: ignore[union-attr] + str(alpha), + str(rounding_mode), + str(parent), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): + batch_inputs, batch_others = [], [] + alpha = subset[0].kwargs.get("alpha", 1.0) + batch_inputs_meta, batch_others_meta = [], [] + + for node in subset: + input, other = node.args + batch_inputs.append(input) + batch_others.append(other) + batch_inputs_meta.append(input.meta) # type: ignore[possibly-undefined, union-attr] + batch_others_meta.append(other.meta) # type: ignore[possibly-undefined, union-attr] + + with graph.inserting_before(subset[0]): + stack_inputs = decompose_stack(graph, batch_inputs) + stack_others = decompose_stack(graph, batch_others) + stack_inputs_meta = torch.stack( + [input["val"] for input in batch_inputs_meta] + ) + stack_others_meta = torch.stack( + [other["val"] for other in batch_others_meta] + ) + + batch_op = graph.call_function( + self.op, + args=(stack_inputs, stack_others), + kwargs={"alpha": alpha} if self.op == aten.add.Tensor else {}, + ) + batch_op.meta["val"] = self.op(stack_inputs_meta, stack_others_meta) + for i, original_add in enumerate(subset): + with graph.inserting_after(batch_op): + new_add = graph.call_function( + torch.ops.aten.select, args=((batch_op, 0, i)) + ) + original_add.replace_all_uses_with(new_add) + new_add.meta.update(original_add.meta) + graph.erase_node(original_add) + counters["inductor"][ + "batch_aten_" + self.op.__name__.lower().split(".")[0] + ] += 1 + + +@register_fusion("batch_linear_lhs") +class BatchLinearLHSFusion(BatchFusion): + """ + Batch linear left-hand side fusion. This pass tries to fuse the following patterns: + + torch.nn.functional.linear(x, w1), linear(x, w2),... * linear(x, wn) + -> torch.mm(x, torch.cat([w1, w2,... * wn]).transpose(0, 1)) + + We have a separate pass to eliminate contiguous transpose in a generic way. + """ + + def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool, Any]]: + if CallFunctionVarArgs(torch.nn.functional.linear).match( + node + ) and is_linear_node_can_be_fused(node): + input = get_arg_value(node, 0, "input") + bias = get_arg_value(node, 2, "bias") + group_key = ("batch_linear_lhs", bias is None, input) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): + batch_nodes = [] + batch_input = None + batch_weights, batch_weights_meta = [], [] + batch_biases, batch_biases_meta = [], [] + split_sections = [] + for node in subset: + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 1, "weight") + bias = get_arg_value(node, 2, "bias") + batch_nodes.append(node) + if batch_input is None: + batch_input = input + else: + assert batch_input is input + batch_weights.append(weight) + batch_weights_meta.append(weight.meta["example_value"]) + if bias: + batch_biases.append(bias) + batch_biases_meta.append(bias.meta["example_value"]) + split_sections.append(weight.meta["example_value"].shape[0]) + + with graph.inserting_before(subset[0]): + cat_weights = graph.call_function( + torch.cat, args=(batch_weights,), kwargs={"dim": 0} + ) + cat_weights.meta["example_value"] = torch.cat(batch_weights_meta, dim=0) + transposed_weights = graph.call_function( + torch.transpose, args=(cat_weights, 0, 1) + ) + transposed_weights.meta["example_value"] = torch.transpose( + cat_weights.meta["example_value"], 0, 1 + ) + if len(batch_biases) > 0: + cat_biases = graph.call_function( + torch.cat, args=(batch_biases,), kwargs={"dim": 0} + ) + cat_biases.meta["example_value"] = torch.cat(batch_biases_meta, dim=0) + fused_lhs = graph.call_function( + torch.addmm, + args=(cat_biases, batch_input, transposed_weights), + ) + fused_lhs.meta["example_value"] = torch.addmm( + cat_biases.meta["example_value"], + batch_input.meta["example_value"], # type: ignore[union-attr] + transposed_weights.meta["example_value"], + ) + else: + fused_lhs = graph.call_function( + torch.mm, + args=(batch_input, transposed_weights), + ) + fused_lhs.meta["example_value"] = torch.mm( + batch_input.meta["example_value"], # type: ignore[union-attr] + transposed_weights.meta["example_value"], + ) + fused_lhs_list = graph.call_function( + torch.split, args=(fused_lhs, split_sections), kwargs={"dim": 1} + ) + + for i, node in enumerate(batch_nodes): + with graph.inserting_after(fused_lhs_list): + new_node = graph.call_function( + operator.getitem, args=(fused_lhs_list, i) + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) + counters["inductor"]["batch_linear_lhs"] += 1 + + +def is_node_meta_valid(node: Optional[torch.fx.Node]): + return node is None or "example_value" in node.meta or "val" in node.meta + + +# Poor person's check for if a node in the graph mutates its input. +# (the graph is torch IR, so we will see torch fns and python operators) +def _is_mutable_node(tgt): + if str(tgt).endswith("_"): + # e.g. torch.mul_, torch.Tensor.mul_ + return True + if ( + hasattr(tgt, "__module__") + and tgt.__module__ == "_operator" + and tgt.__name__.startswith("i") + ): + # e.g. operator.iand, operator.imul + return True + return False + + +def is_linear_node_can_be_fused(node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 1, "weight") + return ( + is_node_meta_valid(node) + and is_node_meta_valid(input) + and is_node_meta_valid(weight) + and len(input.meta["example_value"].shape) == 2 + and len(weight.meta["example_value"].shape) == 2 + # the mm -> bmm transform adds an unbind() op, + # which is not safe for autograd when the output of the mm is mutated. + # don't pattern match if any users of the mm mutate the input. + and not any(_is_mutable_node(user.target) for user in node.users) + ) + + +@register_fusion("batch_linear") +class PreGradBatchLinearFusion(BatchFusion): + """ + Batch linear fusion in pre grad pass. + Fuse linear with same size with torch.baddmm + """ + + def _getitem_args(self, getitem_node: torch.fx.Node): + if getitem_node.target != operator.__getitem__ or ( + getitem_node.op != "call_function" + ): + return None + return getitem_node.args[0] + + def match(self, node: torch.fx.Node): + if CallFunctionVarArgs(torch.nn.functional.linear).match( + node + ) and is_linear_node_can_be_fused(node): + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 1, "weight") + bias = get_arg_value(node, 2, "bias") + if self.graph_search_options.get("fuse_nodes_with_same_users", False): + users = [user.target for user in node.users.keys()] + else: + users = "" # type: ignore[assignment] + group_key = ( + "batch_linear", + self._getitem_args(input), + str(input.meta["example_value"].shape), + str(weight.meta["example_value"].shape), + bias is None, + str(users), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_weights = [] + batch_biases = [] + batch_inputs_metadata = [] + batch_weights_metadata = [] + batch_biases_metadata = [] + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["example_value"]) + weight = get_arg_value(node, 1, "weight") + batch_weights.append(weight) + batch_weights_metadata.append(weight.meta["example_value"]) + bias = get_arg_value(node, 2, "bias") + batch_biases.append(bias) + if bias is not None and hasattr(bias, "meta"): + batch_biases_metadata.append(bias.meta["example_value"]) + + with graph.inserting_before(subset[0]): + stack_inputs = graph.call_function( + torch.stack, args=(batch_inputs,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + stack_weights = graph.call_function( + torch.stack, args=(batch_weights,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_weights, batch_weights_metadata) + transpose_weight = graph.call_function( + torch.transpose, args=(stack_weights, 1, 2) + ) + transpose_weight.meta["example_value"] = torch.transpose( + stack_weights.meta["example_value"], 1, 2 + ) + if all(bias is None for bias in batch_biases): + bmm = graph.call_function( + torch.bmm, + args=(stack_inputs, transpose_weight), + ) + bmm.meta["example_value"] = torch.bmm( + stack_inputs.meta["example_value"], + transpose_weight.meta["example_value"], + ) + bmm_meta = bmm.meta["example_value"] + else: + stack_biases = graph.call_function( + torch.stack, args=(batch_biases,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_biases, batch_biases_metadata) + unsqueeze_biases = graph.call_function( + torch.unsqueeze, args=(stack_biases, 1) + ) + unsqueeze_biases.meta["example_value"] = torch.unsqueeze( + stack_biases.meta["example_value"], 1 + ) + bmm = graph.call_function( + torch.baddbmm, + args=(unsqueeze_biases, stack_inputs, transpose_weight), + ) + try: + # it will have runtime error to broadcast when it has dynamic shape included + # in the meta data, so we need to skip the update meta data + bmm.meta["example_value"] = torch.baddbmm( + unsqueeze_biases.meta["example_value"], + stack_inputs.meta["example_value"], + transpose_weight.meta["example_value"], + ) + bmm_meta = bmm.meta["example_value"] + except Exception as e: + log.debug( + f" exception when update bmm meta data with stack error tracekey {e}" # noqa: G004 + ) + bmm_meta = None + + bmm = graph.call_function(torch.unbind, args=(bmm,), kwargs={"dim": 0}) + if bmm_meta is not None: + bmm.meta["example_value"] = torch.unbind(bmm_meta, dim=0) + for i, linear in enumerate(batch_nodes): + with graph.inserting_after(bmm): + getitem = graph.call_function(operator.getitem, args=(bmm, i)) + linear.replace_all_uses_with(getitem) + getitem.meta.update(linear.meta) + graph.erase_node(linear) + counters["inductor"]["batch_linear"] += 1 + + +@register_fusion("batch_layernorm") +class BatchLayernormFusion(BatchFusion): + """ + Batch layer norm fusion in pre grad pass + """ + + def match(self, node: torch.fx.Node): + if CallFunctionVarArgs(torch.nn.functional.layer_norm).match(node): + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 2, "weight") + bias = get_arg_value(node, 3, "bias") + if self.graph_search_options.get("fuse_nodes_with_same_users", False): + users = [user.target for user in node.users.keys()] + else: + users = "" # type: ignore[assignment] + group_key = ( + ( + "batch_layernorm", + str(input.meta["example_value"].shape), + str(weight.meta["example_value"].shape) + if weight is not None + else "", + str(bias.meta["example_value"].shape) if bias is not None else "", + str(get_arg_value(node, 1, "normalized_shape")), + str(get_arg_value(node, 4, "eps")), + str(users), + ) + if "example_value" in input.meta + and is_node_meta_valid(weight) + and is_node_meta_valid(bias) + else None + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): + group_inputs = [] + group_shapes = [] + group_weights = [] + group_biases = [] + group_epss = [] + group_nodes = [] + group_inputs_metadata = [] + group_biases_metadata = [] + group_weights_metadata = [] + for node in subset: + group_nodes.append(node) + input = get_arg_value(node, 0, "input") + group_inputs.append(input) + group_inputs_metadata.append(input.meta["example_value"]) + group_shapes.append(get_arg_value(node, 1, "normalized_shape")) + weight = get_arg_value(node, 2, "weight") + group_weights.append(weight) + if weight is not None and hasattr(weight, "meta"): + group_weights_metadata.append(weight.meta["example_value"]) + bias = get_arg_value(node, 3, "bias") + group_biases.append(bias) + if bias is not None and hasattr(bias, "meta"): + group_biases_metadata.append(bias.meta["example_value"]) + eps = get_arg_value(node, 4, "eps") + if eps is None: + eps = 1e-5 + group_epss.append(eps) + stack_dim = -1 - len(group_shapes[-1]) + + if all(bias is None for bias in group_biases): + group_biases = None # type: ignore[assignment] + if all(weight is None for weight in group_weights): + group_weights = None # type: ignore[assignment] + assert all( + eps == group_epss[0] for eps in group_epss + ), "all epsilon values must be equal" + + with graph.inserting_before(subset[0]): + stack_input = graph.call_function( + torch.stack, args=(group_inputs,), kwargs={"dim": stack_dim} + ) + update_stack_example_value(stack_input, group_inputs_metadata, stack_dim) + if group_weights is not None: + stack_weight = graph.call_function( + torch.stack, args=(group_weights,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_weight, group_weights_metadata) + else: + stack_weight = None + if group_biases is not None: + stack_bias = graph.call_function( + torch.stack, args=(group_biases,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_bias, group_biases_metadata) + else: + stack_bias = None + + batch_layer_norm = graph.call_function( + torch.nn.functional.layer_norm, + args=(stack_input, group_shapes[-1]), + kwargs={"eps": group_epss[-1]}, + ) + batch_layer_norm.meta["example_value"] = stack_input.meta["example_value"] + + if group_weights is not None and group_biases is not None: + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + batch_layer_norm = graph.call_function( + torch.mul, args=(stack_weight, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + stack_weight.meta["example_value"], + previous_batch_layer_norm_meta, + torch.mul, + ) + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + batch_layer_norm = graph.call_function( + torch.add, args=(stack_bias, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + stack_bias.meta["example_value"], + previous_batch_layer_norm_meta, + torch.add, + ) + elif group_weights is not None and group_biases is None: + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + batch_layer_norm = graph.call_function( + torch.mul, args=(stack_weight, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + stack_weight.meta["example_value"], + previous_batch_layer_norm_meta, + torch.mul, + ) + elif group_weights is None and group_biases is not None: + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + batch_layer_norm = graph.call_function( + torch.add, args=(stack_bias, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + stack_bias.meta["example_value"], + previous_batch_layer_norm_meta, + torch.add, + ) + + batch_layer_norm_unbind = graph.call_function( + torch.unbind, + args=(batch_layer_norm,), + kwargs={"dim": stack_dim}, + ) + update_stack_example_value( + batch_layer_norm_unbind, + batch_layer_norm.meta["example_value"], + op=torch.unbind, + dim=stack_dim, + ) + + for i, node in enumerate(group_nodes): + with graph.inserting_after(batch_layer_norm_unbind): + new_node = graph.call_function( + operator.getitem, args=(batch_layer_norm_unbind, i) + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) + counters["inductor"]["batch_layernorm"] += 1 + + +class BatchPointwiseOpsPreGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in pre grad pass. + We fuse it in random place, and the introduced stack node may be merged in split cat. + """ + + def __init__(self, op, **kwargs) -> None: + super().__init__(op, **kwargs) + self.op = op + + def match(self, node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): + if self.graph_search_options.get("fuse_nodes_with_same_parent", False): + # pyre-fixme[16] + parent = node.args[0] + parent = parent.target if parent is not None else "" # type: ignore[union-attr] + else: + parent = "" + # for relu op, we also use the inplace to construct the key + group_key = ( + "batch_" + self.op.__name__.lower().split(".")[0], + str(input.meta["example_value"].shape), + str(node.kwargs.get("inplace", False)), + str(parent), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_inputs_metadata = [] + + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["example_value"]) + + with graph.inserting_before(subset[0]): + stack_inputs = graph.call_function( + torch.stack, args=(batch_inputs,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + if self.op == torch.nn.functional.relu: + batch_op = graph.call_function( + self.op, + args=(stack_inputs,), + kwargs={"inplace": subset[0].kwargs.get("inplace", False)}, + ) + batch_op.meta["example_value"] = self.op( + stack_inputs.meta["example_value"], + inplace=subset[0].kwargs.get("inplace", False), + ) + else: + batch_op = graph.call_function( + self.op, + args=(stack_inputs,), + ) + batch_op.meta["example_value"] = self.op( + stack_inputs.meta["example_value"] + ) + unbind_op = graph.call_function( + torch.unbind, args=(batch_op,), kwargs={"dim": 0} + ) + unbind_op.meta["example_value"] = torch.unbind( + batch_op.meta["example_value"], dim=0 + ) + for i, node in enumerate(batch_nodes): + with graph.inserting_after(unbind_op): + getitem = graph.call_function(operator.getitem, args=(unbind_op, i)) + node.replace_all_uses_with(getitem) + getitem.meta.update(node.meta) + graph.erase_node(node) + counters["inductor"]["batch_" + self.op.__name__.lower().split(".")[0]] += 1 + + +class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in post grad pass. + The introduced stack node may be merged in split cat. + """ + + def __init__(self, op, **kwargs) -> None: + super().__init__(op, **kwargs) + self.op = op + + def match(self, node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): + # for relu op, we also use the inplace to construct the key + # we batch the ops with same parent to enable followup split cat + parent = node.args[0] + parent = parent.target if self.graph_search_options.get("fuse_nodes_with_same_parent", False) else "" # type: ignore[union-attr] + group_key = ( + "batch_aten_" + self.op.__name__.lower().split(".")[0], + str(input.meta["val"].shape), + str(node.kwargs.get("inplace", False)), + # pyre-fixme[16] + str(parent), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_inputs_metadata = [] + + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["val"]) + + with graph.inserting_before(subset[0]): + stack_inputs = decompose_stack(graph, batch_inputs) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + batch_op = graph.call_function( + self.op, + args=(stack_inputs,), + ) + for i, node in enumerate(batch_nodes): + with graph.inserting_after(batch_op): + getitem = graph.call_function(aten.select, args=(batch_op, 0, i)) + node.replace_all_uses_with(getitem) + getitem.meta.update(node.meta) + graph.erase_node(node) + counters["inductor"][ + "batch_aten_" + self.op.__name__.lower().split(".")[0] + ] += 1 + + +@register_fusion("batch_tanh") +class BatchTanhPreGradFusion(BatchPointwiseOpsPreGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(torch.tanh, **kwargs) + + +@register_fusion("batch_sigmoid") +class BatchSigmoidPreGradFusion(BatchPointwiseOpsPreGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(torch.sigmoid, **kwargs) + + +@register_fusion("batch_relu") +class BatchReLuPreGradFusion(BatchPointwiseOpsPreGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(torch.nn.functional.relu, **kwargs) + + +@register_fusion("batch_aten_tanh", pre_grad=False) +class BatchTanhPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.tanh.default, **kwargs) + + +@register_fusion("batch_aten_sigmoid", pre_grad=False) +class BatchSigmoidPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.sigmoid.default, **kwargs) + + +@register_fusion("batch_aten_relu", pre_grad=False) +class BatchReLuPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.relu.default, **kwargs) + + +@register_fusion("batch_aten_add", pre_grad=False) +class BatchAddPostGradFusion(BatchPointwiseMathOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.add.Tensor, **kwargs) + + +@register_fusion("batch_aten_sub", pre_grad=False) +class BatchSubPostGradFusion(BatchPointwiseMathOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.sub.Tensor, **kwargs) + + +@register_fusion("batch_aten_div", pre_grad=False) +class BatchDivPostGradFusion(BatchPointwiseMathOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.div.Tensor, **kwargs) + + +@register_fusion("batch_aten_mul", pre_grad=False) +class BatchMulPostGradFusion(BatchPointwiseMathOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.mul.Tensor, **kwargs) + + +class _OrderedSet: + def __init__(self, param=None) -> None: + if param: + self.rep = OrderedDict(dict.fromkeys(param)) + else: + self.rep = OrderedDict() + + def __contains__(self, o) -> bool: + return o in self.rep + + def __len__(self) -> int: + return self.rep.__len__() + + def append(self, o): + self.rep[o] = None + + def __iter__(self): + return self.rep.keys().__iter__() + + +def find_independent_subset_greedy( + node_list: Iterable[torch.fx.Node], + graph_search_options: Dict[str, Any], +) -> Iterator[Iterable[torch.fx.Node]]: + """ + Yields a list of subsets of `node_list` where no element in the subset + depends on any other element in the subset. This results in a set of + independent nodes which can be fused together. + + The order of `node_list` is preserved within each subset so we can benefit + from split-cat elimination in later passes. + + During iteration it is only safe to mutate the graph by changing the nodes + that have been returned. + + graph_search_options: + - min_fuse_set_size: Minimum size of the subset to consider. Subsets below + this size will be ignored. + - max_fuse_set_size: Maximum size of the subset to consider. Subsets will + be broken to be at most this size. + """ + + # Compute all the children of `node` which are members of + # `interesting_nodes`. + def find_dependent_nodes(node, interesting_nodes): + visited_node_set: Set[torch.fx.Node] = {node} + dep_set: Set[torch.fx.Node] = set() + + work = [node] + while work: + node = work.pop() + for input_node in node.all_input_nodes: + if input_node in interesting_nodes: + dep_set.add(input_node) + + if input_node not in visited_node_set: + visited_node_set.add(input_node) + work.append(input_node) + + return dep_set + + min_fuse_set_size = graph_search_options["min_fuse_set_size"] + max_fuse_set_size = graph_search_options["max_fuse_set_size"] + + # node_list needs to be a set because we only track the nodes that are left + # in it (and we want to do the `in` on a set, not a list). But we want to + # keep the correct order. + node_list = _OrderedSet(node_list) + + cache: Dict[torch.fx.Node, Set[torch.fx.Node]] = {} + while node_list: + subset: List[torch.fx.Node] = [] + subset_deps: Set[torch.fx.Node] = set() + + next_round_node_list = _OrderedSet() + for node in node_list: + if len(subset) >= max_fuse_set_size or node in subset_deps: + next_round_node_list.append(node) + continue + + dep_set = cache.pop(node, None) + if dep_set is None: + dep_set = find_dependent_nodes(node, node_list) + + if not dep_set.intersection(subset): + subset.append(node) + subset_deps.update(dep_set) + else: + next_round_node_list.append(node) + cache[node] = dep_set + + if len(subset) >= min_fuse_set_size: + # Careful here - the caller uses the subsets to fuse nodes together + # so we need to clear any cache entry that contains one of the + # returned nodes because the dependency list could be different + # (larger) after the merge. + cache = {k: v for k, v in cache.items() if v.isdisjoint(subset)} + yield subset + + node_list = next_round_node_list + + +def get_fusion_candidates( + rule: GroupBatchFusionBase, root_node: torch.fx.Node, fused_set: Set[torch.fx.Node] +) -> DefaultDict[Any, List[torch.fx.Node]]: + """ + Search fusion candidates for a specific rule using BFS starting from the root node. + We only search the subgraph within graph_search_options["max_fuse_search_depth"]. + """ + q: Deque[Tuple[int, torch.fx.Node]] = collections.deque() + + candidate_dict: DefaultDict[Any, List[torch.fx.Node]] = collections.defaultdict( + list + ) + + if root_node.target in SEARCH_EXCLUSIONS: + return candidate_dict + + visited_set: Set[torch.fx.Node] = set() + + for next_node in root_node.all_input_nodes: + q.append((1, next_node)) + visited_set.add(next_node) + + while len(q) > 0: + depth, node = q.popleft() + + if node in fused_set: + continue + + key = rule.match(node) + if key is not None: + candidate_nodes = candidate_dict[key] + if node not in candidate_nodes: + candidate_nodes.append(node) + else: + if depth < rule.graph_search_options["max_fuse_search_depth"]: + for next_node in node.all_input_nodes: + if next_node not in visited_set: + visited_set.add(next_node) + q.append((depth + 1, next_node)) + + return candidate_dict + + +def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusionBase): + stable_topological_sort(graph) # type: ignore[arg-type] + fused_set: Set[torch.fx.Node] = set() + log_to_scuba = False + + for node in reversed(graph.nodes): + candidates = get_fusion_candidates(rule, node, fused_set) + + for key, candidate_nodes in candidates.items(): + if len(candidate_nodes) < rule.graph_search_options["min_fuse_set_size"]: + continue + + for subset in find_independent_subset_greedy( + candidate_nodes, rule.graph_search_options + ): + rule.fuse(graph, subset) + fused_set.update(subset) + log.debug( + f"{rule.__class__.__name__}: key = {key}; subset size = {len(list(subset))}" # noqa: G004 + ) + log_to_scuba = True + if log_to_scuba: + optimus_scuba_log[rule.__class__.__name__] = upload_graph(graph) + + +def generate_fusion_from_config(config_options: Dict[str, Any], pre_grad=True): + fusions: List[GroupBatchFusionBase] = [] + for name, options in config_options.items(): + # we skip all patterns from pattern_matcher passes (e.g., split_cat) + if name not in PRE_GRAD_FUSIONS and name not in POST_GRAD_FUSIONS: + continue + fusion_cls = PRE_GRAD_FUSIONS[name] if pre_grad else POST_GRAD_FUSIONS[name] + _options = graph_search_options.copy() + _options.update(options) + fusions.append(fusion_cls(graph_search_options=_options)) # type: ignore[operator] + return fusions + + +def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True): + fusions: List[GroupBatchFusionBase] = [] + # we keep all current pre grad fusions to keep + # current implementation, will remove this later + if pre_grad: + fusions += generate_fusion_from_config( + config.pre_grad_fusion_options, pre_grad=True + ) + else: + fbgemm_fusion_keys = [ + x + for x in config.post_grad_fusion_options + if config.post_grad_fusion_options[x].get("require_fbgemm", False) + ] + fbgemm_fusions = { + fusion: config.post_grad_fusion_options[fusion] + for fusion in fbgemm_fusion_keys + } + non_fbgemm_fusions = { + fusion: config.post_grad_fusion_options[fusion] + for fusion in config.post_grad_fusion_options.keys() + if fusion not in fbgemm_fusion_keys + } + fusions += generate_fusion_from_config(non_fbgemm_fusions, pre_grad=False) + if has_fbgemm: + fusions += generate_fusion_from_config(fbgemm_fusions, pre_grad=False) + + for i, rule in enumerate(fusions): + with GraphTransformObserver( + graph.owning_module, + f"group_batch_fusion_{i}", + config.trace.log_url_for_graph_xform, + ): + apply_group_batch_fusion(graph, rule) # type: ignore[arg-type] diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/joint_graph.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/joint_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..d526f1bb6a64dec65ef8e5ea012fa31ad6aa7233 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/joint_graph.py @@ -0,0 +1,694 @@ +# mypy: allow-untyped-defs +import itertools +import logging +import typing +from collections import Counter +from typing import Any, Dict, List, Set, Union + +import torch +import torch._guards +import torch.utils._pytree as pytree +from torch._inductor.constant_folding import ConstantFolder +from torch._inductor.fx_passes.dedupe_symint_uses import _SymHashingDict +from torch.fx.experimental.symbolic_shapes import statically_known_true +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.multiprocessing.reductions import StorageWeakRef + +from ...utils._ordered_set import OrderedSet +from .. import config +from ..pattern_matcher import ( + CallFunction, + init_once_fakemode, + KeywordArg, + Match, + MULTIPLE, + PatternMatcherPass, + register_graph_pattern, + stable_topological_sort, +) +from .replace_random import replace_random_passes + + +log = logging.getLogger(__name__) +patterns = PatternMatcherPass() +aten = torch.ops.aten +prims = torch.ops.prims + +pass_patterns = [ + patterns, + PatternMatcherPass(), +] + + +@init_once_fakemode +def lazy_init(): + from .fuse_attention import _sfdp_init + from .misc_patterns import _misc_patterns_init + from .pad_mm import _pad_mm_init + + _pad_mm_init() + _sfdp_init() + _misc_patterns_init() + + +def remove_no_ops( + gm: torch.fx.GraphModule, zeros: Set[torch.fx.Node], ones: Set[torch.fx.Node] +): + with torch.utils._python_dispatch._disable_current_modes(): + "Removes no-ops: (+ 0, - 0, * 1, / 1)" + graph = gm.graph + + def fake_tensors_eq(t1, t2, fields=("shape", "dtype", "device")): + if any(not isinstance(t, torch.Tensor) for t in (t1, t2)): + return False + for field in fields: + if getattr(t1, field) != getattr(t2, field): + return False + return True + + def replace_no_op(node, replace_input_index): + replacement = node.args[replace_input_index] + + # https://github.com/pytorch/pytorch/issues/86128 causes + # non-Tensor inputs even for ops with only Tensor inputs. + # TODO - decompose/type promote to avoid this + if not all(isinstance(arg, torch.fx.Node) for arg in node.args): + return + + if not fake_tensors_eq(node.meta["val"], replacement.meta["val"]): + if fake_tensors_eq( + node.meta["val"], + replacement.meta["val"], + ("shape", "device"), + ): + with graph.inserting_after(node): + replacement = graph.call_function( + torch.ops.prims.convert_element_type.default, + args=(replacement, node.meta["val"].dtype), + ) + else: + return + + node.replace_all_uses_with(replacement) + replacement.meta.update(node.meta) + graph.erase_node(node) + + for node in graph.find_nodes(op="call_function", target=aten.add.Tensor): + # TODO handle Tensor-Scalar adds, it's a different schema + if len(node.args) == 2: + if ( + not any(e in zeros for e in node.args) + or node.kwargs.get("alpha", 1) != 1 + ): + continue + + replace_index = 1 if node.args[0] in zeros else 0 + replace_no_op(node, replace_index) + + for node in graph.find_nodes(op="call_function", target=aten.sub.Tensor): + if len(node.args) == 2: + if node.args[1] not in zeros or node.kwargs.get("alpha", 1) != 1: + continue + + replace_no_op(node, 0) + + for node in graph.find_nodes(op="call_function", target=aten.mul.Tensor): + if len(node.args) == 2: + if not any(e in ones for e in node.args): + continue + + replace_input_index = 1 if node.args[0] in ones else 0 + replace_no_op(node, replace_input_index) + + for node in graph.find_nodes(op="call_function", target=aten.div.Tensor): + if len(node.args) == 2 and node.args[1] in ones: + replace_no_op(node, 0) + + # meta tensors returned from the graph have no data and can be replaced with empty_strided + for output_node in graph.find_nodes(op="output"): + had_meta_return = False + + def visit(n): + nonlocal had_meta_return + val = n.meta.get("val") + if isinstance(val, torch.Tensor) and val.device.type == "meta": + with graph.inserting_before(output_node): + n.replace_all_uses_with( + graph.call_function( + torch.ops.aten.empty_strided.default, + args=(val.size(), val.stride()), + kwargs={"dtype": val.dtype, "device": val.device}, + ) + ) + had_meta_return = True + + torch.fx.map_arg(output_node.args, visit) + if had_meta_return: + graph.eliminate_dead_code() + + +def remove_redundant_views(gm: torch.fx.GraphModule): + """ + Removes redundant views by reusing existing ones. + """ + with torch.utils._python_dispatch._disable_current_modes(): + # A dictionary mapping a tensor to all aliased views. + views: Dict[torch.fx.Node, Dict[torch.dtype, torch.fx.Node]] = {} + graph = gm.graph + + for node in graph.find_nodes( + op="call_function", target=torch.ops.aten.view.dtype + ): + src = node.args[0] + to_type = node.args[1] + existing_views = views.get(src) + is_needed = True + + if existing_views: + # Replace the view with the an existing view if available. + alias = existing_views.get(to_type) + if alias: + is_needed = False + node.replace_all_uses_with(alias) + alias.meta.update(node.meta) + graph.erase_node(node) + else: + from_type = src.meta["val"].dtype + existing_views = {from_type: src} + views[src] = existing_views + + if is_needed: + # Save the new alias but do not replace existing one. + existing_views.setdefault(to_type, node) + views[node] = existing_views + + # Clean up unused views. + while True: + unused_views = [alias for alias in views if not alias.users] + if len(unused_views) == 0: + break + for unused in unused_views: + views.pop(unused) + graph.erase_node(unused) + + +class UniformValueConstantFolder(ConstantFolder): + """ + Runs constant folding and replaces tensors that have a unifrom value + with a tensor constructor call: aten.full([shape], value, ...) + """ + + def __init__(self, gm, skip_constructors=False) -> None: + super().__init__(gm, skip_constructors) + self.node_storages_ptrs: Dict[torch.fx.Node, int] = {} + self.constant_data_ptrs: Dict[torch.fx.Node, StorageWeakRef] = {} + # we may constant fold a tensor which in the graph has a sym size + # see: [constant folding refining of symints] + self.node_replacements_shapes: Dict[torch.fx.Node, List[int]] = {} + + # initialize symint -> node mapping so that we can + # use symint nodes in full constructors + self.symint_nodes = _SymHashingDict() + for n in self.module.graph.nodes: + if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt): + self.symint_nodes[n.meta["val"]] = n + + # reference from torch/_funtorch/partitioners.py:get_default_op_list + self.view_op_packets = [ + aten.squeeze, + aten.unsqueeze, + aten.alias, + aten.view, + aten.slice, + aten.t, + prims.broadcast_in_dim, + aten.expand, + aten.as_strided, + aten.permute, + ] + + self.indexing_op_packets = { + aten.slice, + } + + def _support_dynamic_shape(self): + return True + + def insertable_tensor_check(self, t: torch.Tensor) -> bool: + return True + + def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: + self.node_replacements[node] = tensor.flatten()[0].item() + self.node_replacements_shapes[node] = node.meta["val"].shape + self.constant_data_ptrs[node] = StorageWeakRef(tensor.untyped_storage()) + + def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None: + for n in self.module.graph.find_nodes(op="placeholder"): + if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt): + env[n] = n.meta["val"] + else: + env[n] = self.unknown_value + + def _deduce_value(self, node: torch.fx.Node): + # deduce value for full-like nodes + # 1. for constructors, substitute value is a tensor of size [1] + # 2. for view ops/indexing, substitute value is the same as the input + # 3. for pointwise ops, run node to get the substitute value + # 4. deal with some special ops + # otherwise, stop deduce value and return unknown value + + # TODO: cat, more indexing + # TODO - do on cpu to avoid syncs + + # single-elem attrs + if node.op == "get_attr" or ( + node.op == "call_function" + and node.target == torch.ops.aten.lift_fresh_copy.default + ): + out = super(ConstantFolder, self).run_node(node) + if isinstance(out, torch.Tensor) and out.numel() == 1: + return out + + # handle device_put op + if node.target == prims.device_put.default: + return super(ConstantFolder, self).run_node(node) + + # constructors ops + if ( + node.op == "call_function" + and node.target == aten.full.default + and len(node.args) == 2 + ): + args, kwargs = self.fetch_args_kwargs_from_env(node) + new_args = [[1], args[1]] + return aten.full.default(*new_args, **node.kwargs) + + # handle before view ops because this changes value + if node.target == aten.view.dtype: + return super(ConstantFolder, self).run_node(node) + + # view ops, return input tensor, the first argument + if hasattr(node.target, "overloadpacket") and ( + node.target.overloadpacket in self.view_op_packets + or node.target.overloadpacket in self.indexing_op_packets + ): + assert isinstance(node.args[0], torch.fx.Node) + return self.env[node.args[0]] + + # we don't want to return unknown value for symints so that we can + # still constant fold through their use in constructors or views + # if we see them in a pointwise node (e.g., tensor * symint) + # we will bail + if "val" in node.meta and isinstance(node.meta["val"], torch.SymInt): + return node.meta["val"] + + # pointwise ops + if isinstance(node.target, torch._ops.OpOverload) and ( + torch.Tag.pointwise in node.target.tags + or node.target is torch.ops.aten.scalar_tensor.default + ): + args, kwargs = self.fetch_args_kwargs_from_env(node) + flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) + + if any(isinstance(inp, torch.SymInt) for inp in flattened_inputs): + return self.unknown_value + + # we run the ops with dim 1, so remove memory_format to avoid error + kwargs = dict(kwargs) + kwargs.pop("memory_format", None) + + return node.target(*args, **kwargs) + + return self.unknown_value + + +def constant_fold_uniform_value(gm: torch.fx.GraphModule): + with torch.utils._python_dispatch._disable_current_modes(): + "Runs constant folding and replaces constants which can be constructed with a single `full` call. Calls into remove_no_ops." + aten = torch.ops.aten + + # Constant folding can leak memory, especially with repeated compilation, so we are only going to + # remove constants which can be replaced with a constructor. + cf = UniformValueConstantFolder(gm) + cf.run() + + node_replacements = cf.node_replacements + + # note: [constant folding refining of symints] + # constant folding will partially evaluate a graph such that values which have dependencies which + # are entirely known at compile time may also become compile time constants. in some cases, + # this will include symints which we had not yet previously deduced are guaranteed a + # constant value and is then deduced in constant folding. an example is: + # unbacked_symint_eq_11 = torch.full((), 11).item() + # torch.full((unbacked_symint_eq_11,), 0) + node_replacements_shapes = cf.node_replacements_shapes + + graph = gm.graph + + zeros = set() + ones = set() + + # Got failures in `test_is_set_to_cuda` if we change aliasing on constants, + # so just constant-ify if a Tensor is unaliased + constant_data_ptr_count: typing.Counter[StorageWeakRef] = Counter() + + for node in cf.node_replacements: + constant_data_ptr_count[cf.constant_data_ptrs[node]] += 1 + + for node, value in node_replacements.items(): + # we dont have a functional way right now of instantiating a non-contiguous tensor with full/zeros/ones right now + # hasn't shown up to be important yet + if "val" not in node.meta: + # This can only happen in AOTI + continue + + fake_tensor = node.meta["val"] + if not fake_tensor.is_contiguous(memory_format=torch.contiguous_format): + continue + + # TODO - not sure about lossy uint->python value->uint conversions + if fake_tensor.dtype in ( + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + ): + continue + + if constant_data_ptr_count[cf.constant_data_ptrs[node]] > 1: + continue + + with graph.inserting_after(node): + # the conversion from tensor and back to value can be lossy, just use the original full ctor value + if ( + node.op == "call_function" + and node.target == aten.full.default + and len(node.args) == 2 + ): + value = node.args[1] + + # refines symints, see [constant folding refining of symints] above + for runtime_size, compile_time_size in zip( + node_replacements_shapes[node], fake_tensor.shape + ): + torch._check(runtime_size == compile_time_size) + + # replace SymInt as Node before creating a new full node + # e.g. (1, s0) -> (1, arg0_1) + node_shape = node_replacements_shapes[node] + if not all( + not isinstance(s, torch.SymInt) or s in cf.symint_nodes + for s in node_shape + ): + continue + + shapes = [ + cf.symint_nodes[s] if isinstance(s, torch.SymInt) else s + for s in node_replacements_shapes[node] + ] + + # zeros and ones just get traced into full, so we insert those + new_node = graph.call_function( + aten.full.default, + args=(shapes, value), + kwargs={ + "dtype": fake_tensor.dtype, + "layout": torch.strided, + "device": fake_tensor.device, + "pin_memory": False, + }, + ) + + new_node.meta.update(node.meta) + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + if value == 0: + zeros.add(new_node) + elif value == 1: + ones.add(new_node) + + remove_no_ops(gm, zeros, ones) + remove_redundant_views(gm) + + +def joint_graph_passes(graph: torch.fx.GraphModule): + """ + Run FX transformations on the joint forwards+backwards graph. + """ + lazy_init() + count = 0 + if config.joint_custom_pre_pass is not None: + with GraphTransformObserver( + graph, "joint_custom_pre_pass", config.trace.log_url_for_graph_xform + ): + config.joint_custom_pre_pass(graph.graph) + count += 1 + + from .post_grad import remove_noop_ops + + remove_noop_ops(graph.graph) + + if config.joint_graph_constant_folding: + with GraphTransformObserver( + graph, "constant_fold_uniform_value", config.trace.log_url_for_graph_xform + ): + constant_fold_uniform_value(graph) + + if config.pattern_matcher: + for patterns in pass_patterns: + count += patterns.apply(graph.graph) # type: ignore[arg-type] + + if not config.fallback_random: + count += replace_random_passes(graph) + + if config.joint_custom_post_pass is not None: + with GraphTransformObserver( + graph, "joint_custom_post_pass", config.trace.log_url_for_graph_xform + ): + config.joint_custom_post_pass(graph.graph) + count += 1 + + if count: + stable_topological_sort(graph.graph) + graph.graph.lint() + graph.recompile() + return graph + + +@register_graph_pattern( + CallFunction( + torch.ops.prims.iota.default, + KeywordArg("length"), + start=KeywordArg("start"), + step=KeywordArg("step"), + dtype=KeywordArg("dtype"), + device=KeywordArg("device"), + requires_grad=KeywordArg("requires_grad"), + ), + pass_dict=patterns, +) +def fix_iota_device(match: Match, length, start, step, dtype, device, requires_grad): + """ + Eager supports: + + aten.index(cuda_tensor, torch.arange(..., device="cpu")) + + But this results in an implicit host-device-copy and breaks cudagraphs. + Rewrite the arange to use CUDA. + """ + (node,) = match.nodes + user_devices: OrderedSet[torch.device] = OrderedSet() + for user in node.users: + if ( + user.op == "call_function" + and user.target in (aten.index.Tensor, aten.index_put.default) + and hasattr(user.meta.get("val"), "device") + ): + user_devices.add(user.meta["val"].device) # type: ignore[union-attr] + else: + return # bail out + + if len(user_devices) == 1 and "val" in node.meta: + (user_device,) = user_devices + if device.type != user_device.type: + repl = match.graph.call_function( + torch.ops.prims.iota.default, + (length,), + { + "start": start, + "step": step, + "dtype": dtype, + "device": user_device, + "requires_grad": requires_grad, + }, + ) + repl.meta.update(node.meta) + repl.meta["val"] = repl.meta["val"].to(user_device) + node.replace_all_uses_with(repl) + match.erase_nodes() + + +@register_graph_pattern( + CallFunction( + torch.ops.prims.convert_element_type.default, + CallFunction( + torch.ops.prims.convert_element_type.default, + KeywordArg("arg"), + KeywordArg("dtype1"), + ), + KeywordArg("dtype2"), + ), + pass_dict=patterns, +) +def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtype): + """Remove chain of dtype conversions often created by AMP""" + graph = match.graph + node = match.output_node() + allowed = {torch.float16, torch.bfloat16, torch.float32, torch.float64} + if dtype1 in allowed and dtype2 in allowed: + repl = graph.call_function( + torch.ops.prims.convert_element_type.default, (arg, dtype2) + ) + repl.meta.update(node.meta) + node.replace_all_uses_with(repl) + match.erase_nodes() + + +@register_graph_pattern( + CallFunction(torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")), + pass_dict=patterns, +) +def pointless_view(match: Match, arg, size): + """Remove no-op view""" + node = match.output_node() + arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr] + if size == arg_size: + node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type] + match.erase_nodes() + + +# When softmax is used with temperature or other scaling, we get the pattern +# +# scale(x) - scale(x).amax(dim, keepdim=True) +# +# which is expected to be at most zero, but we may end up with numerical +# discrepancies # between the recomputed values of scale(x) inside and out +# of the reduction, # depending on compiler optimizations, e.g. use of fma +# instructions. +# +# Here we replace it with the mathematically equivalent, +# +# scale(x - x.amax(dim, keepdim=True)) +# +# which is more stable as we only compute the scaling once. +# +# NOTE: This pattern must come after fused attention matching! + + +def _partial_softmax_pattern(linear_func, reverse=False, to_dtype=False): + # Allow matching inp * other and other * input + if reverse: + scaled = CallFunction( + linear_func, KeywordArg("other"), KeywordArg("inp"), _users=MULTIPLE + ) + else: + scaled = CallFunction( + linear_func, KeywordArg("inp"), KeywordArg("other"), _users=MULTIPLE + ) + if to_dtype: + scaled = CallFunction( + prims.convert_element_type, scaled, KeywordArg("dtype"), _users=MULTIPLE + ) + amax = CallFunction( + aten.amax.default, scaled, KeywordArg("dim"), KeywordArg("keepdim") + ) + return CallFunction(aten.sub.Tensor, scaled, amax) + + +def _other_is_broadcasted_in_dim(match): + # Check that the scaling factor is constant across the reduction dim, + # so scaling doesn't change which index corresponds to the maximum value + other = match.kwargs["other"] + if isinstance(other, (int, float)): + return True + + inp = match.kwargs["inp"] + if not all(isinstance(x, torch.fx.Node) for x in (inp, other)): + return False + + inp_example = inp.meta["val"] + other_example = other.meta["val"] + if isinstance(other_example, (torch.SymInt, torch.SymFloat)): + return True + + if not all(isinstance(x, torch.Tensor) for x in (inp_example, other_example)): + return False + + inp_ndim = inp_example.ndim + other_shape = other_example.shape + if inp_ndim < len(other_shape): + return False + + # Pad other_shape to the same ndim as inp + other_shape = [1] * (inp_ndim - len(other_shape)) + list(other_shape) + + dim = match.kwargs["dim"] + if isinstance(dim, int): + dim = (dim,) + + return all(statically_known_true(other_shape[d] == 1) for d in dim) + + +def mul_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None): + def repl(inp, other): + if dtype is not None: + inp = inp.to(dtype) + + sign: Union[int, float, torch.Tensor] + if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)): + sign = 1 if other >= 0 else -1 + else: + one = torch.scalar_tensor(1, dtype=inp.dtype, device=inp.device) + sign = torch.where(other >= 0, one, -one) + + inp = inp * sign + max_ = torch.amax(inp, dim=dim, keepdim=keepdim) + return (inp - max_) * (sign * other) + + match.replace_by_example(repl, [inp, other]) + + +for reverse, to_dtype in itertools.product((False, True), repeat=2): + register_graph_pattern( + _partial_softmax_pattern(aten.mul.Tensor, reverse=reverse, to_dtype=to_dtype), + pass_dict=pass_patterns[1], + extra_check=_other_is_broadcasted_in_dim, + )(mul_softmax_pattern) + + +def div_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None): + def repl(inp, other): + if dtype is not None: + inp = inp.to(dtype) + + sign: Union[int, float, torch.Tensor] + if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)): + sign = 1 if other >= 0 else -1 + else: + one = torch.scalar_tensor(1, dtype=inp.dtype, device=inp.device) + sign = torch.where(other >= 0, one, -one) + + inp = inp * sign + max_ = torch.amax(inp, dim=dim, keepdim=keepdim) + return (inp - max_) / (sign * other) + + match.replace_by_example(repl, [inp, other]) + + +for to_dtype in (False, True): + register_graph_pattern( + _partial_softmax_pattern(aten.div.Tensor, to_dtype=to_dtype), + pass_dict=pass_patterns[1], + extra_check=_other_is_broadcasted_in_dim, + )(div_softmax_pattern) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/micro_pipeline_tp.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/micro_pipeline_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb56590af5e3a7f43a485b9579e4631634b414f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -0,0 +1,854 @@ +# mypy: allow-untyped-defs +import operator +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, cast, Dict, List, Optional, Set + +import torch + +from .. import config, inductor_prims +from ..pattern_matcher import ( + CallFunction, + Ignored, + KeywordArg, + ListOf, + Match, + MULTIPLE, + PatternExpr, + PatternMatcherPass, +) + + +aten = torch.ops.aten +patterns = PatternMatcherPass() + + +def _is_backward(graph: torch.fx.Graph) -> bool: + placeholders = [] + for node in graph.nodes: + if node.op != "placeholder": + break + placeholders.append(node) + return not all(node.name.startswith("primal") for node in placeholders) + + +def _compute_mm_arithmetic_intensity(M: int, N: int, K: int) -> float: + return M * N * K / (M * K + N * K + M * N) + + +def _filter_nodes_by_target(nodes: List[torch.fx.Node], target) -> List[torch.fx.Node]: + return [x for x in nodes if x.target == target] + + +def _find_ancestors(node: torch.fx.Node) -> Set[torch.fx.Node]: + ancestors = set() + ancestors.add(node) + cur_nodes = [node] + while len(cur_nodes) > 0: + new_nodes = [] + for node in cur_nodes: + for inp in node.all_input_nodes: + if inp not in ancestors: + ancestors.add(inp) + new_nodes.append(inp) + cur_nodes = new_nodes + return {node for node in ancestors if node.op != "placeholder"} + + +def _get_tensor(node: torch.fx.Node) -> torch.Tensor: + val = node.meta["val"] + assert isinstance(val, torch.Tensor) + return val + + +@dataclass +class _AllGatherMatch: + match: Match + shard_node: torch.fx.Node + ag_node: torch.fx.Node + res_node: torch.fx.Node + gather_dim: int + group_name: str + + def replace_with(self, new_node: torch.fx.Node) -> None: + self.res_node.replace_all_uses_with(new_node) + + def erase(self) -> None: + for node in reversed(self.match.nodes): + if len(node.users) == 0: + node.graph.erase_node(node) + + +def find_all_gather_patterns(graph: torch.fx.Graph): + c10d = torch.ops._c10d_functional + + def make_zero_dim_all_gather_pattern(shard): + return CallFunction( + c10d.wait_tensor.default, + CallFunction( + c10d.all_gather_into_tensor.default, + shard, + Ignored(), + KeywordArg("group_name"), + ), + ) + + # Matches funcol.all_gather_tensor with gather_dim == 0 + zero_dim_all_gather_pattern = make_zero_dim_all_gather_pattern(KeywordArg("shard")) + + def make_all_gather_split_pattern(shard): + return CallFunction( + operator.getitem, + CallFunction( + aten.split.Tensor, + make_zero_dim_all_gather_pattern(shard), + Ignored(), + _users=MULTIPLE, + ), + Ignored(), + ) + + def make_cat_pattern(splits): + return CallFunction( + aten.cat.default, + ListOf(splits), + KeywordArg("gather_dim"), + ) + + # Matches funcol.all_gather_tensor with gather_dim > 0 + non_zero_dim_all_gather_pattern = make_cat_pattern( + make_all_gather_split_pattern(KeywordArg("shard")), + ) + + # Match a zero-dim all-gather in which the data is transferred as uint8 and + # viewed back as the original dtype. + zero_dim_type_erased_all_gather_pattern = CallFunction( + aten.view.dtype, + make_zero_dim_all_gather_pattern( + KeywordArg("shard"), + ), + Ignored(), + ) + + # Match a non-zero dim all-gather in which the data is transferred as uint8 + # and viewed back as the original dtype. + non_zero_dim_type_erased_all_gather_pattern = CallFunction( + aten.view.dtype, + make_cat_pattern( + CallFunction( + aten.view.dtype, + make_all_gather_split_pattern( + KeywordArg("shard"), + ), + Ignored(), + ), + ), + Ignored(), + ) + + # If two patterns with the same res_node_target have the same suffix, the + # longer pattern should appear first in the list. + # e.g. supposed we have (1) A -> B -> C -> D and (2) B -> C -> D, (1) + # should appear before (2) in the list. + res_node_target_to_patterns = { + aten.cat.default: [ + (non_zero_dim_all_gather_pattern, 0), + ], + aten.view.dtype: [ + (non_zero_dim_type_erased_all_gather_pattern, 0), + (zero_dim_type_erased_all_gather_pattern, 0), + ], + c10d.wait_tensor.default: [ + (zero_dim_all_gather_pattern, 0), + ], + } + + # Match in reverse to ensure longer patterns is prioritized + all_gathers = [] + visited_ag_nodes = set() + for node in reversed(graph.nodes): + for target, patterns in res_node_target_to_patterns.items(): + if node.target != target: + continue + for pattern, ag_node_idx in patterns: + match = pattern.match(node) + if not match: + continue + + assert isinstance(match, Match) + ag_node = match.nodes[ag_node_idx] + assert ag_node.target == c10d.all_gather_into_tensor.default + + if ag_node in visited_ag_nodes: + continue + visited_ag_nodes.add(ag_node) + + ag_match = _AllGatherMatch( + match=match, + shard_node=match.kwargs["shard"], + ag_node=ag_node, + res_node=node, + gather_dim=match.kwargs.get("gather_dim", 0), + group_name=match.kwargs["group_name"], + ) + all_gathers.append(ag_match) + + return list(reversed(all_gathers)) + + +@dataclass +class _ReduceScatterMatch: + match: Match + input_node: torch.fx.Node + rs_node: torch.fx.Node + res_node: torch.fx.Node + reduce_op: str + scatter_dim: int + group_name: str + + def replace_with(self, new_node: torch.fx.Node) -> None: + self.res_node.replace_all_uses_with(new_node) + + def erase(self) -> None: + for node in reversed(self.match.nodes): + if len(node.users) == 0: + node.graph.erase_node(node) + + +def find_reduce_scatter_patterns(graph: torch.fx.Graph): + c10d = torch.ops._c10d_functional + + def reduce_scatter_template(inp: PatternExpr): + return CallFunction( + c10d.wait_tensor.default, + CallFunction( + c10d.reduce_scatter_tensor.default, + inp, + KeywordArg("reduce_op"), + Ignored(), + KeywordArg("group_name"), + ), + ) + + # Matches funcol.reduce_scatter_tensor with scatter_dim == 0 + zero_dim_reduce_scatter_pattern = reduce_scatter_template(KeywordArg("input")) + + # Matches funcol.reduce_scatter_tensor with scatter_dim > 0 + non_zero_dim_reduce_scatter_pattern = reduce_scatter_template( + CallFunction( + aten.cat.default, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split.Tensor, + KeywordArg("input"), + Ignored(), + KeywordArg("scatter_dim"), + _users=MULTIPLE, + ), + Ignored(), + ) + ), + ), + ) + + reduce_scatters = [] + for node in reversed(graph.nodes): + if node.target == c10d.wait_tensor.default: + if match := non_zero_dim_reduce_scatter_pattern.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + rs_node=match.nodes[-2], + res_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=match.kwargs["scatter_dim"], + group_name=match.kwargs["group_name"], + ) + ) + elif match := zero_dim_reduce_scatter_pattern.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + rs_node=match.nodes[0], + res_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=0, + group_name=match.kwargs["group_name"], + ) + ) + return list(reversed(reduce_scatters)) + + +@dataclass +class _Matmul: + nodes: List[torch.fx.Node] + arg_ancestor_nodes: Set[torch.fx.Node] = field(init=False) + A_node: torch.fx.Node + B_node: torch.fx.Node + + def __post_init__(self): + assert len(self.nodes) in (1, 3) + if len(self.nodes) == 1: + assert self.nodes[0].target in (aten.mm.default, aten._scaled_mm.default) + else: + assert self.nodes[0].target == aten.reshape.default + assert self.nodes[1].target in (aten.mm.default, aten._scaled_mm.default) + assert self.nodes[2].target == aten.reshape.default + self.arg_ancestor_nodes = _find_ancestors(self.B_node) + + def replace_with(self, new_node: torch.fx.Node) -> None: + """ + Replace the matmul with the new node. + """ + graph = new_node.graph + + # For 2D-matmuls, we simply replace the mm node with `new_node`. + if len(self.nodes) == 1: + mm_node = self.nodes[0] + assert mm_node.target in (aten.mm.default, aten._scaled_mm.default) + mm_node.replace_all_uses_with(new_node) + graph.erase_node(mm_node) + return + + # An ND-matmul is reshape -> mm -> reshape sequence. We first replace + # the second reshape node with `new_node`. Then, we ensure that the + # original mm node in the sequence ends up with zero users by replacing + # it with a reverse reshape of `new_node`. + graph = new_node.graph + assert len(self.nodes) == 3 + mm_node = self.nodes[1] + output_reshape_node = self.nodes[2] + + assert mm_node.target in (aten.mm.default, aten._scaled_mm.default) + assert output_reshape_node.target == aten.reshape.default + + output_reshape_node.replace_all_uses_with(new_node) + if len(mm_node.users) > 1: + with graph.inserting_after(new_node): + new_mm_node = graph.call_function( + aten.reshape.default, + args=(new_node, list(_get_tensor(mm_node).shape)), + ) + mm_node.replace_all_uses_with(new_mm_node) + + def erase(self) -> None: + for node in reversed(self.nodes): + if len(node.users) == 0: + node.graph.erase_node(node) + + @classmethod + def from_match(cls, match: List[torch.fx.Node]) -> "_Matmul": + assert len(match) in (1, 3) + assert match[0].target in ( + aten.mm.default, + aten.reshape.default, + ) + mm_node = match[0] if len(match) == 1 else match[1] + return _Matmul( + nodes=match, + A_node=cast(torch.fx.Node, match[0].args[0]), + B_node=cast(torch.fx.Node, mm_node.args[1]), + ) + + +@dataclass +class _ScaledMatmul(_Matmul): + A_scale_node: torch.fx.Node + B_scale_node: torch.fx.Node + bias_node: Optional[torch.fx.Node] + result_scale_node: Optional[torch.fx.Node] + out_dtype: Optional[torch.dtype] + use_fast_accum: bool + + def __post_init__(self): + super().__post_init__() + self.arg_ancestor_nodes |= _find_ancestors(self.A_scale_node) + self.arg_ancestor_nodes |= _find_ancestors(self.B_scale_node) + + @classmethod + def from_match(cls, match: List[torch.fx.Node]) -> "_ScaledMatmul": + assert len(match) in (1, 3) + assert match[0].target in ( + aten._scaled_mm.default, + aten.reshape.default, + ) + mm_node = match[0] if len(match) == 1 else match[1] + + def get_arg(node: torch.fx.Node, idx: int, default: Any) -> Any: + if idx >= len(node.args): + return default + return node.args[idx] + + return _ScaledMatmul( + nodes=match, + A_node=cast(torch.fx.Node, match[0].args[0]), + B_node=cast(torch.fx.Node, mm_node.args[1]), + A_scale_node=cast(torch.fx.Node, mm_node.args[2]), + B_scale_node=cast(torch.fx.Node, mm_node.args[3]), + bias_node=get_arg(mm_node, 4, None), + result_scale_node=get_arg(mm_node, 5, None), + out_dtype=get_arg(mm_node, 6, None), + use_fast_accum=get_arg(mm_node, 7, False), + ) + + +def _find_reshape_mm_reshape(node: torch.fx.Node) -> List[_Matmul]: + if node.target != aten.reshape.default: + return [] + + matches = [] + for mm_node in node.users: + if mm_node.target not in (aten.mm.default, aten._scaled_mm.default): + continue + for reshape_node in mm_node.users: + if reshape_node.target != aten.reshape.default: + continue + + # Since the reshape -> mm -> reshape pattern would be subsumed into + # the fused op, we only match the patterns where the shape of the + # second reshape is matches the mm result produced by the fused op. + matmul_input_node = cast(torch.fx.Node, node.args[0]) + B_node = cast(torch.fx.Node, mm_node.args[1]) + matmul_out_shape = torch.Size( + [ + *_get_tensor(matmul_input_node).shape[:-1], + _get_tensor(B_node).shape[-1], + ] + ) + if _get_tensor(reshape_node).shape != matmul_out_shape: + continue + matches.append([node, mm_node, reshape_node]) + # If for some rare reason mm_node is being reshaped by two + # different reshape nodes, we only include mm_node once in the + # parsing result. + break + + matmuls = [] + for match in matches: + mm_node = match[1] + if mm_node.target == aten.mm.default: + matmul = _Matmul.from_match(match) + matmuls.append(matmul) + elif mm_node.target == aten._scaled_mm.default: + matmul = _ScaledMatmul.from_match(match) + matmuls.append(matmul) + else: + raise AssertionError( + "Expect the node's target to be either aten.mm.default or " + f"aten._scaled_mm.default. Got {mm_node.target}." + ) + return matmuls + + +def _find_consumer_matmuls(node: torch.fx.Node) -> List[_Matmul]: + """ + Find the matmuls that use `node` as the lhs argument. + """ + matmuls = [] + for user in node.users: + # ND matmuls + if user.target == aten.reshape.default: + matmuls.extend(_find_reshape_mm_reshape(user)) + # 2D matmuls + elif user.target == aten.mm.default: + matmul = _Matmul.from_match(match=[user]) + matmuls.append(matmul) + elif user.target == aten._scaled_mm.default: + matmul = _ScaledMatmul.from_match([user]) + matmuls.append(matmul) + return matmuls + + +def _insert_fused_all_gather_matmul( + graph: torch.fx.Graph, + matmuls: List[_Matmul], + shard_node: torch.fx.Node, + gather_dim: int, + group_name: str, +) -> torch.fx.Node: + mm_types = set(map(type, matmuls)) + assert len(mm_types) == 1 + mm_type = next(iter(mm_types)) + if mm_type == _Matmul: + B_nodes = [matmul.B_node for matmul in matmuls] + return graph.call_function( + torch.ops.symm_mem.fused_all_gather_matmul.default, + args=(shard_node, B_nodes, gather_dim, group_name), + ) + elif mm_type == _ScaledMatmul: + scaled_matmuls = cast(List[_ScaledMatmul], matmuls) + return graph.call_function( + torch.ops.symm_mem.fused_all_gather_scaled_matmul.default, + args=( + shard_node, + [matmul.B_node for matmul in scaled_matmuls], + scaled_matmuls[0].A_scale_node, + [matmul.B_scale_node for matmul in scaled_matmuls], + gather_dim, + group_name, + [matmul.bias_node for matmul in scaled_matmuls], + [matmul.result_scale_node for matmul in scaled_matmuls], + [matmul.out_dtype for matmul in scaled_matmuls], + [matmul.use_fast_accum for matmul in scaled_matmuls], + ), + ) + else: + raise AssertionError(f"Unexpected matmul match type: {mm_type}") + + +def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None: + """ + Fused the pattern + + A = all_gather_tensor(A_shard, gather_dim, group_name) + C_0 = torch.matmul(A, B_0) + C_1 = torch.matmul(A, B_1) + C_2 = torch.matmul(A, B_2) + ... + + into + + A, Cs = torch.ops.symm_mem.fused_all_gather_matmul( + A_shard, [B_0, B_1, B_2, ...], gather_dim, group_name, + ) + """ + if ( + not torch.distributed.is_available() + or not torch.distributed.is_nccl_available() + ): + return + + c10d = torch.ops._c10d_functional + from torch.distributed._symmetric_memory import ( + is_symm_mem_enabled_for_group, + restride_A_shard_for_fused_all_gather_matmul, + ) + + shard_node, ag_node, ag_res_node, gather_dim, group_name = ( + all_gather.shard_node, + all_gather.ag_node, + all_gather.res_node, + all_gather.gather_dim, + all_gather.group_name, + ) + + if not is_symm_mem_enabled_for_group(group_name): + return + + if gather_dim >= len(_get_tensor(shard_node).shape) - 1: + # Decomposing the matmul on the K dimension is not supported + return + + # Find consumer matmuls + matmuls = _find_consumer_matmuls(ag_res_node) + + # The matmuls are only fusible if non-A args don't depend on the all-gather + # result node + matmuls = [ + matmul + for matmul in matmuls + if all_gather.res_node not in matmul.arg_ancestor_nodes + ] + + if len(matmuls) == 0 or len(set(map(type, matmuls))) != 1: + return + + # Fuse the all_gather_tensor with the eligible matmuls + graph = ag_node.graph + with graph.inserting_before(ag_node): + if "val" in shard_node.meta: + restrided = restride_A_shard_for_fused_all_gather_matmul( + _get_tensor(shard_node), + gather_dim, + ) + shard_node = graph.call_function( + inductor_prims.force_stride_order, + args=(shard_node, restrided.stride()), + ) + + fused_node = _insert_fused_all_gather_matmul( + graph, matmuls, shard_node, gather_dim, group_name + ) + new_ag_node = graph.call_function( + operator.getitem, + args=(fused_node, 0), + ) + new_out_nodes = graph.call_function( + operator.getitem, + args=(fused_node, 1), + ) + for idx, matmul in enumerate(matmuls): + new_out_node = graph.call_function( + operator.getitem, + args=(new_out_nodes, idx), + ) + matmul.replace_with(new_out_node) + matmul.erase() + all_gather.replace_with(new_ag_node) + all_gather.erase() + + # Raise ancestors of non-A args that are topologically ordered between + # ag_res_node and the matmul above fused_node. + order = {node: idx for idx, node in enumerate(graph.nodes)} + nodes_to_raise = sorted( + {x for matmul in matmuls for x in matmul.arg_ancestor_nodes}, + key=lambda x: order[x], + ) + for node in nodes_to_raise: + if order[node] > order[fused_node]: + fused_node.prepend(node) + + +def _find_producer_matmul(node: torch.fx.Node) -> Optional[_Matmul]: + if node.target == aten.mm.default: + return _Matmul.from_match(match=[node]) + elif node.target == aten._scaled_mm.default: + return _ScaledMatmul.from_match(match=[node]) + elif node.target == aten.reshape.default: + reshape_node_1 = node + + mm_node = reshape_node_1.args[0] + assert isinstance(mm_node, torch.fx.Node) + if mm_node.target not in (aten.mm.default, aten._scaled_mm.default): + return None + + reshape_node_0 = mm_node.args[0] + assert isinstance(reshape_node_0, torch.fx.Node) + if reshape_node_0.target != aten.reshape.default: + return None + + if mm_node.target == aten.mm.default: + return _Matmul.from_match(match=[reshape_node_0, mm_node, reshape_node_1]) + elif mm_node.target == aten._scaled_mm.default: + return _ScaledMatmul.from_match( + match=[reshape_node_0, mm_node, reshape_node_1] + ) + return None + + +def _insert_fused_matmul_reduce_scatter( + graph: torch.fx.Graph, + matmul: _Matmul, + reduce_op: str, + scatter_dim: int, + group_name: str, +) -> torch.fx.Node: + if type(matmul) == _Matmul: + return graph.call_function( + torch.ops.symm_mem.fused_matmul_reduce_scatter.default, + args=(matmul.A_node, matmul.B_node, reduce_op, scatter_dim, group_name), + ) + elif type(matmul) == _ScaledMatmul: + return graph.call_function( + torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default, + args=( + matmul.A_node, + matmul.B_node, + matmul.A_scale_node, + matmul.B_scale_node, + reduce_op, + scatter_dim, + group_name, + matmul.bias_node, + matmul.result_scale_node, + matmul.out_dtype, + matmul.use_fast_accum, + ), + ) + else: + raise AssertionError(f"Unexpected matmul match type: {type(matmul)}") + + +def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: + """ + Fused the pattern + + reduce_scatter_tensor(A @ B, scatter_dim, group_name) + + into + + torch.ops.symm_mem.fused_matmul_reduce_scatter( + A, B, scatter_dim, group_name, + ) + """ + if ( + not torch.distributed.is_available() + or not torch.distributed.is_nccl_available() + ): + return + + c10d = torch.ops._c10d_functional + from torch.distributed._symmetric_memory import ( + is_symm_mem_enabled_for_group, + restride_A_for_fused_matmul_reduce_scatter, + ) + + input_node, rs_node, rs_res_node, reduce_op, scatter_dim, group_name = ( + reduce_scatter.input_node, + reduce_scatter.rs_node, + reduce_scatter.res_node, + reduce_scatter.reduce_op, + reduce_scatter.scatter_dim, + reduce_scatter.group_name, + ) + + if not is_symm_mem_enabled_for_group(group_name): + return + + # Currently fused_matmul_reduce_scatter doesn't return the matmul result, + # so we can't apply the fusion if the matmul result is used by multiple + # users. This is not a fundamental limitation of the fused op and can be + # addressed if needed. + if len(input_node.users) != 1: + return + + matmul = _find_producer_matmul(input_node) + if matmul is None: + return + + if rs_res_node in matmul.arg_ancestor_nodes: + return + + graph = rs_res_node.graph + with graph.inserting_before(rs_res_node): + if "val" in matmul.A_node.meta: + restrided = restride_A_for_fused_matmul_reduce_scatter( + _get_tensor(matmul.A_node), + scatter_dim, + ) + matmul.A_node = graph.call_function( + inductor_prims.force_stride_order, + args=(matmul.A_node, restrided.stride()), + ) + + fused_node = _insert_fused_matmul_reduce_scatter( + graph, + matmul, + reduce_op, + scatter_dim, + group_name, + ) + reduce_scatter.replace_with(fused_node) + reduce_scatter.erase() + matmul.erase() + + order = {node: idx for idx, node in enumerate(graph.nodes)} + nodes_to_raise = sorted( + matmul.arg_ancestor_nodes, + key=lambda x: order[x], + ) + for node in nodes_to_raise: + if order[node] > order[fused_node]: + fused_node.prepend(node) + + +def _get_node_to_ancestors( + graph: torch.fx.Graph, +) -> Dict[torch.fx.Node, Set[torch.fx.Node]]: + """ + Compute the ancestors for all nodes in a graph. + """ + node_to_ancestors = defaultdict(set) + for node in graph.nodes: + node_to_ancestors[node] = set(node.all_input_nodes) + for dep in node.all_input_nodes: + node_to_ancestors[node] |= node_to_ancestors[dep] + + return node_to_ancestors + + +def _get_collective_to_overlappable_nodes( + graph: torch.fx.Graph, +) -> Dict[torch.fx.Node, List[torch.fx.Node]]: + """ + For each collective in the graph, find nodes that are neither ancestors nor + descendants of the collective. + """ + + def is_collective(node) -> bool: + # Only consider all-gather and reduce-scatter in the context of + # micro-pipeline TP. + return node.target in [ + torch.ops._c10d_functional.all_gather_into_tensor.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + ] + + node_to_ancestors = _get_node_to_ancestors(graph) + collective_to_overlappable_nodes = defaultdict(list) + for node in graph.nodes: + if not is_collective(node): + continue + for x in graph.nodes: + if ( + node not in node_to_ancestors[x] + and x not in node_to_ancestors[node] + and x.op == "call_function" + ): + collective_to_overlappable_nodes[node].append(x) + + return collective_to_overlappable_nodes + + +def _get_unexposed_collectives(graph: torch.fx.Graph) -> List[torch.fx.Node]: + """ + Find all unexposed collectives in the graph. + + Because we don't have the runtime estimate, this function is a rough + estimation using the following strong/hand-wavy assumptions: + + - Only a predefined set of "compute intensive" operation can hide a collective. + - Any "compute intensive" operation can hide exactly one collective. + """ + + def _is_compute_intensive(node: torch.fx.Node) -> bool: + return node.target in [torch.ops.aten.mm.default] + + collective_to_overlapping_candidates = defaultdict(list) + available_nodes = set() + collective_to_overlappable_nodes = _get_collective_to_overlappable_nodes(graph) + for collective, overlappable_nodes in collective_to_overlappable_nodes.items(): + candidates = [x for x in overlappable_nodes if _is_compute_intensive(x)] + collective_to_overlapping_candidates[collective] = candidates + available_nodes |= set(candidates) + + unexposed_collectives = [] + for ( + collective, + overlapping_candidates, + ) in collective_to_overlapping_candidates.items(): + # Each collective consumes exactly one overlapping candidate + for x in overlapping_candidates: + if x in available_nodes: + unexposed_collectives.append(collective) + available_nodes.remove(x) + break + return unexposed_collectives + + +def micro_pipeline_tp_pass(graph: torch.fx.Graph): + all_gathers = find_all_gather_patterns(graph) + reduce_scatters = find_reduce_scatter_patterns(graph) + + # When a collective can be hidden through either simple overlapping or + # micro-pipeline TP, we prefer simple overlapping to avoid the overhead + # associated with decomposition. If reorder_for_compute_comm_overlap is + # enabled, we identify collectives that can be hidden through simple + # overlapping and exclude them from micro-pipeline TP candidates. + if config.reorder_for_compute_comm_overlap: + unexposed_collectives = _get_unexposed_collectives(graph) + all_gathers = [x for x in all_gathers if x.ag_node not in unexposed_collectives] + reduce_scatters = [ + x for x in reduce_scatters if x.rs_node not in unexposed_collectives + ] + + for all_gather in all_gathers: + fuse_all_gather_matmul(all_gather) + + for reduce_scatter in reduce_scatters: + fuse_matmul_reduce_scatter(reduce_scatter) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/misc_patterns.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/misc_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..0f608952a2fb2722934f822edf88281152b1bdcc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/misc_patterns.py @@ -0,0 +1,131 @@ +# mypy: allow-untyped-defs +import functools +from typing import Dict, Set, Tuple + +import torch +from torch._dynamo.utils import counters +from torch._ops import OpOverload, OpOverloadPacket + +from ..pattern_matcher import fwd_only, register_replacement + + +aten = torch.ops.aten + + +@functools.lru_cache(None) +def _misc_patterns_init(): + from .joint_graph import patterns as joint_graph_patterns + from .post_grad import pass_patterns as post_grad_patterns_all + + post_grad_patterns = post_grad_patterns_all[1] # medium priority + + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + + # These patterns do 2 things + # 1. Since we know that index is completely unique, we can codegen it using + # stores instead of atomic adds, which is quite a bit faster. + # 2. Also, since we are guaranteed that they are completely within bounds, + # we can use unsafe indexing and skip debug asserts + def randperm_index_add_pattern(x, y): + index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]] + return torch.index_add(x, dim=0, source=y, index=index), index + + def randperm_index_add_replacement(x, y): + index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]] + return ( + torch.ops.aten._unsafe_index_put( + x, (index,), aten._unsafe_index(x, (index,)) + y, accumulate=False + ), + index, + ) + + register_replacement( + randperm_index_add_pattern, + randperm_index_add_replacement, + [torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)], + fwd_only, + [post_grad_patterns, joint_graph_patterns], + ) + + def randperm_index_pattern(x, slice_shape): + index = torch.randperm(x.shape[0], device=x.device)[:slice_shape] + return torch.ops.aten.index(x, (index,)), index + + def randperm_index_replacement(x, slice_shape): + index = torch.randperm(x.shape[0], device=x.device)[:slice_shape] + return torch.ops.aten._unsafe_index(x, (index,)), index + + register_replacement( + randperm_index_pattern, + randperm_index_replacement, + [torch.empty(4, 8, device=device)], + fwd_only, + [post_grad_patterns, joint_graph_patterns], + scalar_workaround={"slice_shape": 42}, + ) + + +class NumpyCompatNormalization: + numpy_compat: Dict[str, Tuple[str, ...]] = { + "dim": ("axis",), + "keepdim": ("keepdims",), + "input": ("x", "a", "x1"), + "other": ("x2",), + } + inverse_mapping: Dict[str, str] + cache: Dict["torch.fx.graph.Target", Set[str]] + + def __init__(self) -> None: + self.cache = {} # callable -> tuple of replaceable args e.g. ["axis"] + self.inverse_mapping = {} + for actual_kwarg, numpy_kwargs in self.numpy_compat.items(): + for numpy_kwarg in numpy_kwargs: + assert numpy_kwarg not in self.inverse_mapping + self.inverse_mapping[numpy_kwarg] = actual_kwarg + + def __call__(self, graph: torch.fx.Graph): + for node in graph.nodes: + if node.op != "call_function": + continue + if isinstance(node.target, (OpOverload, OpOverloadPacket)): + # only applies to torch ops; e.g. torch.stack(axis=1) works, torch.ops.aten.stack(axis=1) doesn't. + continue + kwargs = node.kwargs + + if node.target in self.cache: + replaceable_kwargs = self.cache[node.target] + else: + signatures = torch.fx.operator_schemas.get_signature_for_torch_op( + node.target + ) + signatures = () if signatures is None else signatures + replaceable_kwargs = set() + for sig in signatures: + for param_name in sig.parameters.keys(): + if param_name in self.numpy_compat: + replaceable_kwargs.update(self.numpy_compat[param_name]) + + self.cache[node.target] = replaceable_kwargs + + if not replaceable_kwargs: + continue + + new_kwargs = {} + kwargs_changed = False + for k, v in kwargs.items(): + if k in replaceable_kwargs: + kwargs_changed = True + new_kwargs[self.inverse_mapping[k]] = v + else: + new_kwargs[k] = v + + if kwargs_changed: + node.kwargs = torch.fx.immutable_collections.immutable_dict(new_kwargs) + counters["inductor"]["numpy_compat_normalization"] += 1 + + +numpy_compat_normalization = NumpyCompatNormalization() diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..156760a68e7e6978af1e9d67c89a37dfe963129a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -0,0 +1,1266 @@ +# mypy: allow-untyped-defs +import functools +import operator +from functools import reduce +from typing import Any, Tuple + +import torch +from torch.fx.experimental.symbolic_shapes import has_free_symbols + +from .. import ir +from ..lowering import lowerings as L +from ..pattern_matcher import ( + Arg, + CallFunction, + filter_nodes, + get_arg_value, + KeywordArg, + MULTIPLE, +) +from ..virtualized import ops, V +from .freezing_patterns import register_freezing_graph_pattern +from .post_grad import register_lowering_pattern +from .quantization import ( + _register_quantization_lowerings, + _register_quantization_weight_pack_pass, + _register_woq_lowerings, +) + + +if torch._C._has_mkldnn: + aten = torch.ops.aten + mkldnn = torch.ops.mkldnn + prims = torch.ops.prims + + _conv_args = [Arg() for _ in range(10)] + _linear_args = [Arg() for _ in range(6)] + _conv_transpose_args = [Arg() for _ in range(11)] + + def _conv_call(users=1): + return CallFunction( + mkldnn._convolution_pointwise.default, *_conv_args, _users=users + ) + + def _linear_call(users=1): + return CallFunction( + mkldnn._linear_pointwise.default, *_linear_args, _users=users + ) + + def _conv_transpose_call(users=1): + return CallFunction( + mkldnn._convolution_transpose_pointwise.default, + *_conv_transpose_args, + _users=users, + ) + + def _to_float(input_call, users=1): + return CallFunction( + prims.convert_element_type.default, + input_call, + KeywordArg("to_float"), + _users=users, + ) + + def _to_bf16(input_call): + return CallFunction( + prims.convert_element_type.default, + input_call, + KeywordArg("to_bf16"), + _users=1, + ) + + def _to_fp16(input_call): + return CallFunction( + prims.convert_element_type.default, + input_call, + KeywordArg("to_fp16"), + _users=1, + ) + + def _unary_fusion_pattern(unary_fusion, call_fn, users, lowp_dtype): + # only insert to_dtype if lowp_dtype is True + computation_call = ( + _to_float(call_fn(), users=users) if lowp_dtype else call_fn(users=users) + ) + out = unary_fusion(computation_call) + if lowp_dtype == torch.bfloat16: + return _to_bf16(out) + elif lowp_dtype == torch.float16: + return _to_fp16(out) + else: + return out + + def _gelu_fusion_1(computation_call): + return CallFunction( + aten.mul, + CallFunction(aten.mul, computation_call, 0.5), + CallFunction( + aten.add, + CallFunction( + aten.erf, + CallFunction(aten.mul, computation_call, 0.7071067811865476), + ), + 1, + ), + ) + + def _gelu_fusion_2(computation_call): + return CallFunction( + aten.mul, + CallFunction(aten.mul, computation_call, 0.5), + CallFunction( + aten.add, + CallFunction( + aten.tanh, + CallFunction( + aten.mul, + CallFunction( + aten.add, + computation_call, + CallFunction( + aten.mul, + CallFunction( + aten.mul, + CallFunction( + aten.mul, computation_call, computation_call + ), + computation_call, + ), + 0.044715, + ), + ), + 0.7978845608028654, + ), + ), + 1, + ), + ) + + def _hardswish_fusion(computation_call): + return CallFunction( + aten.div, + CallFunction( + aten.mul, + computation_call, + CallFunction( + aten.clamp_max, + CallFunction( + aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0 + ), + 6, + ), + ), + 6, + ) + + def _silu_fusion(computation_call): + return CallFunction( + aten.mul, computation_call, CallFunction(aten.sigmoid, computation_call) + ) + + def _hardsigmoid_fusion(computation_call): + return CallFunction( + aten.div, + CallFunction( + aten.clamp_max, + CallFunction( + aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0 + ), + 6, + ), + 6, + ) + + def _leaky_relu_fusion(computation_call): + return CallFunction( + aten.where, + CallFunction(aten.gt, computation_call, 0), + computation_call, + CallFunction(aten.mul, computation_call, KeywordArg("negative_slope")), + ) + + def _hardtanh_fusion(computation_call): + return CallFunction( + aten.clamp_max, + CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")), + KeywordArg("max_value"), + ) + + def _combined_fusion(computation_call, elementwise_op): + return CallFunction(elementwise_op, computation_call) + + # binary_op(other, computation_op) + def _binary_fusion_v1(computation_call, binary_fn): + return CallFunction(binary_fn, KeywordArg("other"), computation_call) + + # binary_op(computation_op, other) + def _binary_fusion_v2(computation_call, binary_fn): + return CallFunction(binary_fn, computation_call, KeywordArg("other")) + + def _is_single_computation_op(computation_op, lowp_dtype=None): + def fn(match): + computation_nodes = filter_nodes(match.nodes, computation_op) + + if lowp_dtype: + output_node_meta = match.output_node().meta.get("val") + if output_node_meta.dtype != lowp_dtype: + return False + + if len(computation_nodes) < 1: + return False + if any(n.args[-3] != "none" for n in computation_nodes): + return False + return True + + return fn + + def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None): + def fn(match): + matched = _is_single_computation_op(computation_op, lowp_dtype)(match) + computation_node = filter_nodes(match.nodes, computation_op)[0] + if lowp_dtype: + conversion_dtype_nodes = filter_nodes( + match.nodes, prims.convert_element_type.default + ) + if len(conversion_dtype_nodes) != 2: + return False + # fusion pattern is always in the form of computation_op + to_float32 + unary_op + to_bfloat16 + if computation_node == conversion_dtype_nodes[0].args[0]: + to_float = conversion_dtype_nodes[0].args[1] + to_lp = conversion_dtype_nodes[1].args[1] + else: + to_float = conversion_dtype_nodes[1].args[1] + to_lp = conversion_dtype_nodes[0].args[1] + matched = matched and to_float == torch.float and to_lp == lowp_dtype + return matched + + return fn + + def _register_unary_fusion_lowering( + pattern, unary_attr, computation_op, lowp_dtype=None + ): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_computation_unary_fusion(computation_op, lowp_dtype), + ) + def fn(match, *args, **kwargs): + computation_args = list(args)[:-3] + [ + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ] + return L[computation_op](*computation_args) + + return fn + + def _register_leaky_relu_fusion_lowering(pattern, computation_op, lowp_dtype=None): + @register_lowering_pattern( + pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype) + ) + def fn(match, *args, **kwargs): + negative_slope = kwargs.get("negative_slope") + if isinstance(negative_slope, ir.TensorBox): + matched = False + else: # inp is a Number + matched = True + if lowp_dtype: + dtype1 = kwargs.get("to_float") + dtype2 = ( + kwargs.get("to_bf16") + if lowp_dtype == torch.bfloat16 + else kwargs.get("to_fp16") + ) + matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype + computation_args = list(args) + if matched: + computation_args = computation_args[:-3] + [ + "leaky_relu", + [negative_slope], + "", + ] + return L[computation_op](*computation_args) + else: + # computation_args += ["none", [], ""] + out = L[computation_op](*computation_args) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=torch.float) + out = L[aten.where]( + L[aten.gt](out, 0), + out, + L[aten.mul](out, negative_slope), + ) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined] + return out + + return fn + + def _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype=None): + @register_lowering_pattern( + pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype) + ) + def fn(match, *args, **kwargs): + min_value = kwargs.get("min_value") + max_value = kwargs.get("max_value") + if isinstance(min_value, ir.TensorBox) or isinstance( + max_value, ir.TensorBox + ): + matched = False + else: # inp is a Number + assert max_value is not None + matched = min_value <= max_value + if lowp_dtype: + dtype1 = kwargs.get("to_float") + dtype2 = ( + kwargs.get("to_bf16") + if lowp_dtype == torch.bfloat16 + else kwargs.get("to_fp16") + ) + matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype + computation_args = list(args) + if matched: + computation_args = computation_args[:-3] + [ + "hardtanh", + [min_value, max_value], + "", + ] + return L[computation_op](*computation_args) + else: + out = L[computation_op](*computation_args) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=torch.float) + out = L[aten.clamp_max](L[aten.clamp_min](out, min_value), max_value) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined] + return out + + return fn + + _binary_attr = { + aten.add: "add", + ops.add: "add", + aten.sub: "sub", + ops.sub: "sub", + } + + def _is_valid_binary(match, fn): + binary_nodes = filter_nodes(match.nodes, fn) + if len(binary_nodes) < 1: + return False + + def get_meta_value(argument: torch.fx.node.Argument): + # Only torch.fx.Node is expected to have meta. + if isinstance(argument, torch.fx.Node): + return argument.meta.get("val", None) + return None + + if any( + not isinstance(get_meta_value(n.args[0]), torch.Tensor) + or not isinstance(get_meta_value(n.args[1]), torch.Tensor) + for n in binary_nodes + ): + return False + # check alpha is one. + if any( + get_arg_value(n, 2, kwarg_name="alpha") != 1.0 + and get_arg_value(n, 2, kwarg_name="alpha") is not None + for n in binary_nodes + ): + return False + if any( + get_meta_value(n.args[0]).size() != get_meta_value(n.args[1]).size() + or get_meta_value(n.args[0]).device != get_meta_value(n.args[1]).device + or get_meta_value(n.args[0]).dtype != get_meta_value(n.args[1]).dtype + for n in binary_nodes + ): + return False + # check args[0] and args[1] is not same + if any(n.args[0] == n.args[1] for n in binary_nodes): + return False + return True + + def _is_valid_computation_binary(computation_op, binary_op, other_index=None): + def fn(match): + if not _is_single_computation_op(computation_op)(match): + return False + if not _is_valid_binary(match, binary_op): + return False + return True + + return fn + + def _get_remaining_users(extra_input_node, compute_node): + # Think about this pattern: + # ReLU + # / \ + # Conv1 + # / \ + # Conv2 + # \ / + # Add + # Although, the extra input node (ReLU) has more than 1 users: Conv1 and Add. + # The Conv1 is the ancestor node of the current compute node (Conv2). + # This indicates that the buffer of ReLU has completed all its usage, + # So we can safely make changes to it now by doing Conv2->Add inplace fusion. + # Take above case as example: + # * extra_input_node: ReLU + # * compute_node: Conv2 + # _get_remaining_users will return the users of extra_input_node which are not + # ancestor node of compute_node. + def _is_ancestor_node(_current_node, _ancestor_node): + # Check whether _ancestor_node is the ancestor node of _current_node + _node_list = [_current_node] + _visited_nodes = set() + while len(_node_list) != 0: + _current_node = _node_list.pop(0) + if _current_node not in _visited_nodes: + _visited_nodes.add(_current_node) + if _current_node == _ancestor_node: + return True + elif isinstance( + _current_node, torch.fx.Node + ) and _current_node.op not in ["placeholder", "output", "get_attr"]: + for input in _current_node.all_input_nodes: + _node_list.append(input) # noqa: PERF402 + return False + + return [ + user + for user in list(extra_input_node.users) + if not _is_ancestor_node(compute_node, user) + ] + + def _is_valid_computation_binary_inplace(computation_op, binary_op, other_index): + def fn(match): + if not _is_valid_computation_binary(computation_op, binary_op)(match): + return False + binary_nodes = filter_nodes(match.nodes, binary_op) + + def _get_compute_node(_binary_node, _other_index): + assert ( + len(_binary_node.all_input_nodes) == 2 + ), "Binary node should have 2 input nodes." + _compute_index = 1 if (_other_index == 0) else 0 + return _binary_node.args[_compute_index] + + def _other_input_not_inplaceable(_binary_node, _other_index): + _compute_node = _get_compute_node(_binary_node, _other_index) + return ( + len( + _get_remaining_users( + _binary_node.args[_other_index], _compute_node + ) + ) + > 1 + or _binary_node.args[_other_index] == _compute_node.args[0] + ) + + if any(_other_input_not_inplaceable(n, other_index) for n in binary_nodes): + return False + if any( + n.args[other_index].op in ["placeholder", "output"] + for n in binary_nodes + ): + return False + return True + + return fn + + def _register_binary_unary_fusion_lowering( + pattern, + computation_op, + binary_op, + fusion_op, + unary_attr=None, + ): + @register_lowering_pattern( + pattern, extra_check=_is_valid_computation_binary(computation_op, binary_op) + ) + def fn(match, *args, **kwargs): + other = kwargs.get("other") + assert isinstance(other, ir.TensorBox) + binary_attr = _binary_attr[binary_op] + args_list = list(args) + computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr] + if len(args_list) > 6: + if unary_attr is not None: + computation_args += [ + 1.0, + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ] + else: + computation_args += [1.0, None, [], None] + return L[fusion_op](*computation_args) + + return fn + + def _can_be_inplace(_other): + if isinstance(_other.data, ir.View): + return _can_be_inplace(_other.data) + else: + return not ( + isinstance(_other.data, ir.ReinterpretView) + or len(_other.get_inputs_that_alias_output()) > 0 + ) + + def _register_binary_unary_maybe_inplace_fusion_lowering( + pattern, + computation_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + unary_attr=None, + other_index=None, + ): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_computation_binary_inplace( + computation_op, binary_op, other_index + ), + ) + def fn(match, *args, **kwargs): + other = kwargs.get("other") + assert isinstance(other, ir.TensorBox) + binary_attr = _binary_attr[binary_op] + args_list = list(args) + computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr] + if len(args_list) > 6: + if unary_attr is not None: + computation_args += [ + 1.0, + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ] + else: + computation_args += [1.0, None, [], None] + # Make sure the other is not an alias or mutation(fx side doesn't has such info). + other.realize() + if not _can_be_inplace(other): + return L[outplace_fusion_op](*computation_args) + return L[inplace_fusion_op](*computation_args) + + return fn + + computation_ops = [ + mkldnn._convolution_pointwise.default, + mkldnn._linear_pointwise.default, + mkldnn._convolution_transpose_pointwise.default, + ] + + class UnaryAttr: + def __init__( + self, op_name: str, scalars_attr=None, algorithm_attr=None + ) -> None: + self.op_name = op_name + self.scalars_attr = scalars_attr if scalars_attr else [] + self.algorithm_attr = algorithm_attr if algorithm_attr else "" + + def _register_unary_fusion(): + computation_call_fns = [_conv_call, _linear_call, _conv_transpose_call] + + def _unary_fusion_patterns(lowp_dtype): + replacement_unary_fusion_patterns = { + UnaryAttr("gelu", algorithm_attr="tanh"): [ + _unary_fusion_pattern(_gelu_fusion_2, call_fn, 4, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("gelu", algorithm_attr="none"): [ + _unary_fusion_pattern(_gelu_fusion_1, call_fn, 2, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("hardswish"): [ + _unary_fusion_pattern(_hardswish_fusion, call_fn, 2, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("hardsigmoid"): [ + _unary_fusion_pattern(_hardsigmoid_fusion, call_fn, 1, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("swish"): [ + _unary_fusion_pattern(_silu_fusion, call_fn, 2, lowp_dtype) + for call_fn in computation_call_fns + ], + } + if not lowp_dtype: + call_user1 = [call_fn(users=1) for call_fn in computation_call_fns] + replacement_unary_fusion_patterns.update( + { + UnaryAttr("relu"): [ + _combined_fusion(u, aten.relu) for u in call_user1 + ], + UnaryAttr("sigmoid"): [ + _combined_fusion(u, aten.sigmoid) for u in call_user1 + ], + UnaryAttr("tanh"): [ + _combined_fusion(u, aten.tanh) for u in call_user1 + ], + } + ) + + return replacement_unary_fusion_patterns + + for lowp_dtype in [torch.bfloat16, torch.float16, None]: + replace_patterns = _unary_fusion_patterns(lowp_dtype) + for unary_attr, patterns in replace_patterns.items(): + _register_unary_fusion_lowering( + patterns[0], unary_attr, computation_ops[0], lowp_dtype + ) + _register_unary_fusion_lowering( + patterns[1], unary_attr, computation_ops[1], lowp_dtype + ) + _register_unary_fusion_lowering( + patterns[2], unary_attr, computation_ops[2], lowp_dtype + ) + _leaky_relu_patterns = [ + _unary_fusion_pattern(_leaky_relu_fusion, call_fn, 3, lowp_dtype) + for call_fn in computation_call_fns + ] + for pattern, computation_op in zip(_leaky_relu_patterns, computation_ops): + _register_leaky_relu_fusion_lowering( + pattern, computation_op, lowp_dtype + ) + hardtanh_patterns = [ + _unary_fusion_pattern(_hardtanh_fusion, call_fn, 1, lowp_dtype) + for call_fn in computation_call_fns + ] + for pattern, computation_op in zip(hardtanh_patterns, computation_ops): + _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype) + + def _register_inplace_fusion(): + binary_ops = [aten.add, ops.add] + inplace_fusion_op = mkldnn._convolution_pointwise_.binary + outplace_fusion_op = mkldnn._convolution_pointwise.binary + conv_call = _conv_call(users=1) + conv_op = computation_ops[0] + for binary_op in binary_ops: + binary_v1 = _binary_fusion_v1(conv_call, binary_op) + binary_unary_v1 = _combined_fusion(binary_v1, aten.relu) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_unary_v1, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=0, + unary_attr=UnaryAttr("relu"), + ) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_v1, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=0, + ) + binary_v2 = _binary_fusion_v2(conv_call, binary_op) + binary_unary_v2 = _combined_fusion(binary_v2, aten.relu) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_unary_v2, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=1, + unary_attr=UnaryAttr("relu"), + ) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_v2, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=1, + ) + + def _register_binary_fusion(): + binary_ops = [aten.add, ops.add, aten.sub, ops.sub] + fusion_ops = [ + mkldnn._convolution_pointwise.binary, + mkldnn._linear_pointwise.binary, + ] + _computation_user_1 = [_conv_call(users=1), _linear_call(users=1)] + for computation_call, computation_op, fusion_op in zip( + _computation_user_1, computation_ops[:-1], fusion_ops + ): + for binary_op in binary_ops: + pattern = _binary_fusion_v2(computation_call, binary_op) + _register_binary_unary_fusion_lowering( + pattern, computation_op, binary_op, fusion_op + ) + + for binary_op in [aten.add, ops.add]: + pattern = _binary_fusion_v1(computation_call, binary_op) + _register_binary_unary_fusion_lowering( + pattern, computation_op, binary_op, fusion_op + ) + + def _register_binary_unary_fusion(): + binary_ops = [aten.add, ops.add, aten.sub, ops.sub] + fusion_ops = [mkldnn._convolution_pointwise.binary] + _computation_user_1 = [_conv_call(users=1)] + for computation_call, computation_op, fusion_op in zip( + _computation_user_1, computation_ops[:-1], fusion_ops + ): + for binary_op in binary_ops: + pattern_v1 = _combined_fusion( + _binary_fusion_v2(computation_call, binary_op), aten.relu + ) + _register_binary_unary_fusion_lowering( + pattern_v1, + computation_op, + binary_op, + fusion_op, + unary_attr=UnaryAttr("relu"), + ) + for binary_op in [aten.add, ops.add]: + pattern_v2 = _combined_fusion( + _binary_fusion_v1(computation_call, binary_op), aten.relu + ) + _register_binary_unary_fusion_lowering( + pattern_v2, + computation_op, + binary_op, + fusion_op, + unary_attr=UnaryAttr("relu"), + ) + + def _recover_linear(): + # convert reshape+linear+reshape to a single linear for applying fusion path. + @register_freezing_graph_pattern( + CallFunction( + aten.reshape.default, + CallFunction( + mkldnn._linear_pointwise.default, + CallFunction( + aten.reshape.default, + Arg(), + KeywordArg("reshape_1"), + _users=MULTIPLE, + ), + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + ), + KeywordArg("reshape_2"), + ), + pass_number=1, + ) + def reshape_linear_reshape_pattern(match, *args, **kwargs): + def get_val(val): + return val if isinstance(val, int) else val.meta.get("val") + + reshape_1 = kwargs.get("reshape_1") + reshape_2 = kwargs.get("reshape_2") + assert isinstance(reshape_1, list) + assert isinstance(reshape_2, list) + assert len(reshape_1) == 2 + + graph = match.graph + reshape_2_node = match.output_node() + linear_input_node = reshape_2_node.args[0].args[0].args[0] + # check linear's input's shape[:-1] == reshape_2[:-1] + # and check product(reshape_2[:-1]) == reshape_1[0] + can_remove_reshape = linear_input_node.meta.get("val").shape[ + :-1 + ] == torch.Size([get_val(val) for val in reshape_2[:-1]]) + can_remove_reshape = can_remove_reshape and ( + reduce( + operator.mul, + [get_val(val) for val in reshape_2[:-1]], + ) + == get_val(reshape_1[0]) + ) + + if can_remove_reshape: + repl = graph.call_function(mkldnn._linear_pointwise.default, args) + repl.meta.update(reshape_2_node.meta) + reshape_2_node.replace_all_uses_with(repl) + old_linear_node = reshape_2_node.args[0] + reshape_1_node = old_linear_node.args[0] + graph.erase_node(reshape_2_node) + graph.erase_node(old_linear_node) + if len(reshape_1_node.users) == 0: + graph.erase_node(reshape_1_node) + + def is_linear_add_bias(match): + add_node = match.output_node() + linear_node = add_node.args[0] + packed_weight_node = linear_node.args[1] + assert packed_weight_node.target == mkldnn._reorder_linear_weight + transpose_weight_node = packed_weight_node.args[0] + assert transpose_weight_node.target == aten.permute.default + weight_meta = transpose_weight_node.args[0].meta.get("val") + bias_node = add_node.args[1] + if isinstance(bias_node, int): + # we only folding bias if it is a constant + return False + bias_meta = add_node.args[1].meta.get("val") + if weight_meta is None or bias_meta is None: + return False + assert weight_meta.dtype in ( + torch.bfloat16, + torch.float16, + ) + if bias_meta.dtype != weight_meta.dtype: + return False + return ( + linear_node.args[2] is None + and bias_meta.dim() == 1 + and bias_meta.size(0) == weight_meta.size(1) + ) + + # convert linear+bias to a single linear for applying fusion path. + @register_freezing_graph_pattern( + CallFunction( + aten.add.Tensor, + CallFunction(mkldnn._linear_pointwise.default, *_linear_args), + Arg(), + ), + pass_number=1, + extra_check=is_linear_add_bias, + ) + def linear_bias_pattern(match, *args): + graph = match.graph + add_node = match.output_node() + linear_node = add_node.args[0] + new_args = list(linear_node.args) + new_args[2] = add_node.args[1] + repl = graph.call_function( + mkldnn._linear_pointwise.default, tuple(new_args) + ) + repl.meta.update(add_node.meta) + add_node.replace_all_uses_with(repl) + match.erase_nodes() + + def _is_packable_mkldnn_rnn_layer(match): + lstm_node = match.output_node() + POS_WEIGHTS = [1, 2] + POS_INPUTS = [0, 5, 6] + POS_ARGS = POS_WEIGHTS + POS_INPUTS + # Weights should be Constant + if any( + lstm_node.args[POS_WEIGHT].op != "get_attr" for POS_WEIGHT in POS_WEIGHTS + ): + return False + + # Meta info for weights and inputs should be available + if any(lstm_node.args[POS_ARG].meta.get("val") is None for POS_ARG in POS_ARGS): + return False + + # Check device + if any( + lstm_node.args[POS_ARG].meta.get("val").device.type != "cpu" + for POS_ARG in POS_ARGS + ): + return False + + # Check dtype + if any( + lstm_node.args[POS_ARG].meta.get("val").dtype == torch.bfloat16 + and not mkldnn._is_mkldnn_bf16_supported() + for POS_ARG in POS_ARGS + ): + return False + if any( + lstm_node.args[POS_ARG].meta.get("val").dtype == torch.float16 + and not mkldnn._is_mkldnn_fp16_supported() + for POS_ARG in POS_ARGS + ): + return False + + return True + + def _is_packable_convolution(match): + """ + Check if the node is supported for MKLDNN convolution. + """ + conv_node = match.output_node() + input_meta_value = conv_node.args[0].meta.get("val") + weight_meta_value = conv_node.args[1].meta.get("val") + if input_meta_value is None or weight_meta_value is None: + return False + input_size = input_meta_value.shape + if conv_node.args[1].op != "get_attr": + return False + for meta_value in [input_meta_value, weight_meta_value]: + if ( + meta_value is None + or meta_value.device.type != "cpu" + or (meta_value.dim() != 4 and meta_value.dim() != 5) + ): + return False + if ( + input_meta_value.dtype == torch.bfloat16 + or weight_meta_value.dtype == torch.bfloat16 + ): + if not mkldnn._is_mkldnn_bf16_supported(): + return False + if ( + input_meta_value.dtype == torch.float16 + or weight_meta_value.dtype == torch.float16 + ): + if not mkldnn._is_mkldnn_fp16_supported(): + return False + is_transposed = conv_node.args[-3] + if is_transposed: + # TODO: Support dynamic shape case for MKLDNN conv transpose. + if has_free_symbols(input_size): + return False + groups = conv_node.args[-1] + in_channels = weight_meta_value.size(0) + # doesn't support group_depthwise_conv_transpose. + if groups > 1 and groups == in_channels: + return False + # Port from: aten/src/ATen/native/Convolution.cpp:is_output_padding_big + output_paddings = conv_node.args[-2] + strides = conv_node.args[3] + if any( + output_padding >= stride + for output_padding, stride in zip(output_paddings, strides) + ): + return False + return True + + def _is_packable_linear(match): + """ + Check if the node is supported for MKLDNN linear. + """ + linear_node = match.output_node() + # mkldnn linear only supports beta=1or0 and alpha=1 + if linear_node.target == aten.addmm.default: + alpha = linear_node.kwargs.get("alpha", 1.0) + beta = linear_node.kwargs.get("beta", 1.0) + if (beta != 0.0 and beta != 1.0) or alpha != 1.0: + return False + # weight_idx is 1 for aten.mm and is 2 for aten.addmm + weight_idx = 2 if linear_node.target == aten.addmm.default else 1 + if linear_node.args[weight_idx].op != "get_attr": + return False + input_meta_value = linear_node.args[weight_idx - 1].meta.get("val") + weight_meta_value = linear_node.args[weight_idx].meta.get("val") + if input_meta_value is None or weight_meta_value is None: + return False + batch_size = input_meta_value.shape[0] + if ( + input_meta_value.dtype == torch.float64 + or weight_meta_value.dtype == torch.float64 + ): + return False + is_lp_weight = weight_meta_value.dtype in ( + torch.bfloat16, + torch.float16, + ) + # on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol. + # on aarch64, use mkldnn op for fp32 as well if acl is enabled + if ( + not is_lp_weight + and not mkldnn._is_mkldnn_acl_supported() + and ((not torch._C.has_mkl) or has_free_symbols(batch_size)) + ): + return False + for meta_value in [input_meta_value, weight_meta_value]: + if ( + meta_value is None + or meta_value.device.type != "cpu" + or meta_value.dim() != 2 + ): + return False + if weight_idx == 2: + bias_meta_value = linear_node.args[0].meta.get("val") + if ( + bias_meta_value is None + or meta_value.device.type != "cpu" + or bias_meta_value.dim() != 1 + or bias_meta_value.size(0) != weight_meta_value.size(1) + ): + return False + + if ( + input_meta_value.dtype == torch.bfloat16 + or weight_meta_value.dtype == torch.bfloat16 + ): + if not mkldnn._is_mkldnn_bf16_supported(): + return False + if ( + input_meta_value.dtype == torch.float16 + or weight_meta_value.dtype == torch.float16 + ): + if not mkldnn._is_mkldnn_fp16_supported(): + return False + return True + + _aten_conv_args = ( + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + KeywordArg("is_transposed"), + Arg(), + Arg(), + ) + + _aten_mkldnn_rnn_layer_args = ( + Arg(), # input + Arg(), # weight0 + Arg(), # weight1 + Arg(), # weight2 + Arg(), # weight3 + Arg(), # hx_ + Arg(), # cx_ + KeywordArg("reverse"), # reverse + Arg(), # batch_sizes + Arg(), # mode + Arg(), # hidden_size + Arg(), # num_layers + Arg(), # has_biases + Arg(), # bidirectional + Arg(), # batch_first + Arg(), # train + ) + + def _register_weight_pack_pass(): + @register_freezing_graph_pattern( + CallFunction(aten.convolution.default, *_aten_conv_args), + extra_check=_is_packable_convolution, + ) + def convolution(match, *args, **kwargs): + is_transposed = kwargs.get("is_transposed") + assert isinstance(is_transposed, bool) + graph = match.graph + conv_node = match.output_node() + input_size = conv_node.args[0].meta.get("val").shape + with graph.inserting_before(conv_node): + constant_args = [args[4], args[3], args[5], args[-1]] + packed_weight_op = mkldnn._reorder_convolution_weight + packed_conv_op = mkldnn._convolution_pointwise.default + if is_transposed: + constant_args.insert(1, args[-2]) # output_padding + packed_weight_op = mkldnn._reorder_convolution_transpose_weight + packed_conv_op = mkldnn._convolution_transpose_pointwise.default + if not has_free_symbols(input_size): + packed_weight_inputs = ( + (args[1],) + tuple(constant_args) + (input_size,) + ) + packed_weight_node = graph.create_node( + "call_function", packed_weight_op, args=packed_weight_inputs + ) + else: + assert not is_transposed + # For dynamic shape case, we need to pack weight in runtime. + packed_weight_node = args[1] + packed_conv_inputs = ( + (args[0], packed_weight_node, args[2]) + + tuple(constant_args) + + ("none", [], "") + ) + packed_conv_node = graph.create_node( + "call_function", packed_conv_op, tuple(packed_conv_inputs) + ) + conv_node.replace_all_uses_with(packed_conv_node) + packed_conv_node.meta.update(conv_node.meta) + graph.erase_node(conv_node) + + @register_freezing_graph_pattern( + CallFunction(aten.mkldnn_rnn_layer.default, *_aten_mkldnn_rnn_layer_args), + extra_check=_is_packable_mkldnn_rnn_layer, + ) + def mkldnn_rnn_layer(match, *args, **kwargs): + def get_item(graph, node, index): + return graph.call_function(operator.getitem, (node, index)) + + graph = match.graph + lstm_node = match.output_node() + input = args[0] + weight0, weight1 = args[1:3] + reverse = kwargs.get("reverse") + packed_lstm_op = aten.mkldnn_rnn_layer.default + hidden_size = args[9] + has_biases = args[11] + batch_first = args[13] + with graph.inserting_before(lstm_node): + packed_weight_op = mkldnn._reorder_mkldnn_rnn_layer_weight.default + packed_weight_inputs = ( + weight0, + weight1, + hidden_size, + reverse, + has_biases, + batch_first, + ) + packed_weight_node = graph.create_node( + "call_function", packed_weight_op, packed_weight_inputs, {}, "name" + ) + packed_weight_items = [ + get_item(graph, packed_weight_node, i) for i in range(2) + ] + pack_lstm_inputs = ( + args[0], + *packed_weight_items, + args[3], + args[4], + args[5], + args[6], + reverse, + *args[7:], + ) + + packed_lstm_node = graph.create_node( + "call_function", packed_lstm_op, args=pack_lstm_inputs + ) + lstm_node.replace_all_uses_with(packed_lstm_node) + packed_lstm_node.meta.update(lstm_node.meta) + graph.erase_node(lstm_node) + + @register_freezing_graph_pattern( + CallFunction( + aten.addmm.default, + Arg(), + Arg(), + Arg(), + beta=KeywordArg("beta"), + alpha=KeywordArg("alpha"), + ), + extra_check=_is_packable_linear, + ) + @register_freezing_graph_pattern( + CallFunction(aten.mm.default, Arg(), Arg()), + extra_check=_is_packable_linear, + ) + def linear(match, *args, **kwargs): + graph = match.graph + linear_node = match.output_node() + input = args[0] if linear_node.target == aten.mm.default else args[1] + bias = ( + None + if linear_node.target == aten.mm.default + or ( + linear_node.target == aten.addmm.default + and linear_node.kwargs.get("beta", 1.0) == 0.0 + ) + else args[0] + ) + weight = args[1] if linear_node.target == aten.mm.default else args[2] + with graph.inserting_before(linear_node): + transpose_weight_node = graph.create_node( + "call_function", aten.permute.default, (weight, (1, 0)) + ) + weight_dtype = weight.meta.get("val").dtype + is_lp_weight = weight_dtype in ( + torch.bfloat16, + torch.float16, + ) + batch_size = input.meta.get("val").shape[0] + if has_free_symbols(batch_size): + assert ( + is_lp_weight or mkldnn._is_mkldnn_acl_supported() + ), f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}" + # For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance. + packed_weight_inputs = ( + transpose_weight_node, + batch_size.node.shape_env.size_hint(batch_size.node.expr) + if has_free_symbols(batch_size) + else batch_size, + ) + # MKL packed matrix can't be copied to a different address because the internal implementation + # depends on the alignment of internally-stored metadata. + # In aot mode, we need to firstly save the packed weight, when loading it, + # it will be in a different address which doesn't work. + # Disable MKL prepack linear in AOT mode + packed_weight_op = ( + mkldnn._reorder_linear_weight + if ( + is_lp_weight + or mkldnn._is_mkldnn_acl_supported() + or V.aot_compilation is True + ) + else torch.ops.mkl._mkl_reorder_linear_weight + ) + packed_weight_node = graph.create_node( + "call_function", packed_weight_op, args=packed_weight_inputs + ) + + packed_linear_inputs: Tuple[Any, ...] = (input, packed_weight_node) + if ( + is_lp_weight + or mkldnn._is_mkldnn_acl_supported() + or V.aot_compilation is True + ): + packed_linear_inputs += (bias, "none", [], "") + packed_linear_op = mkldnn._linear_pointwise.default + else: + packed_linear_inputs += (transpose_weight_node, bias, batch_size) + packed_linear_op = torch.ops.mkl._mkl_linear + packed_linear_node = graph.create_node( + "call_function", packed_linear_op, packed_linear_inputs + ) + linear_node.replace_all_uses_with(packed_linear_node) + packed_linear_node.meta.update(linear_node.meta) + graph.erase_node(linear_node) + + def _eliminate_duplicate_packed_nodes(gm): + """ + Combine packed weight nodes with the same inputs to reduce memory usage. + for example: + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(32, 32, bias=True) + + def forward(self, x): + return self.linear(self.linear(x)) + + the above's packed weight nodes are duplicate if two linear calls have same input size. + """ + if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()): + return gm + + packed_weight_ops = [ + torch._C._nn.mkldnn_reorder_conv2d_weight, + torch._C._nn.mkldnn_reorder_conv3d_weight, + mkldnn._reorder_convolution_transpose_weight, + mkldnn._reorder_linear_weight, + mkldnn._reorder_mkldnn_rnn_layer_weight, + ] + if torch._C.has_mkl: + packed_weight_ops.append(torch.ops.mkl._mkl_reorder_linear_weight) + + for node in gm.graph.nodes: + if node.target in packed_weight_ops and len(node.args[0].users) > 1: + for user_node in list(node.args[0].users.keys()): + if ( + user_node.target == node.target + and user_node != node + and user_node.args == node.args + ): + user_node.replace_all_uses_with(node) + gm.graph.erase_node(user_node) + + @functools.lru_cache(None) + def _mkldnn_fusion_init(): + # TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now. + # Otherwise even the matmul or innerproduct can not be accelerated with acl + if ( + torch.backends.mkldnn.enabled + and torch.backends.mkldnn.is_available() + and not torch.ops.mkldnn._is_mkldnn_acl_supported() + ): + _register_unary_fusion() + _register_inplace_fusion() + _register_binary_unary_fusion() + _register_binary_fusion() + _register_quantization_lowerings() + _register_woq_lowerings() + + @functools.lru_cache(None) + def _mkldnn_weight_pack_init(): + if torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available(): + _register_weight_pack_pass() + _recover_linear() + _register_quantization_weight_pack_pass() diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/numeric_utils.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/numeric_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7069b100a38bd77d010d15268c5018082b56173b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/numeric_utils.py @@ -0,0 +1,212 @@ +# mypy: allow-untyped-defs +import gc +import logging +import os +import random +import traceback + +import numpy + +import torch +import torch.optim as optim + +from .. import config + + +logger: logging.Logger = logging.getLogger(__name__) + +MAIN_RANDOM_SEED = 1337 + +# Set the CUBLAS_WORKSPACE_CONFIG environment variable +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + +# If the two forward functions involve any non-deterministic operations, +# such as certain types of parallelism or asynchronous execution, +# this can also lead to different outputs. +def set_deterministic() -> None: + """Make torch manual seed deterministic.""" + + torch.manual_seed(MAIN_RANDOM_SEED) + random.seed(MAIN_RANDOM_SEED) + numpy.random.seed(MAIN_RANDOM_SEED) + torch.use_deterministic_algorithms(True) + + +def clean_memory() -> None: + """Clean memory to avoid OOM.""" + gc.collect() + torch.cuda.empty_cache() + + +# We compare the numerical results before and after pre/post grad fx passes +# transformation to make sure the numerical results are the same. +def compare_dict_tensors(dict_base, dict_control, precision): + if len(set(dict_base.keys())) != len(set(dict_control.keys())): + logger.warning("Mismatch keys found before and after pre/post grad fx passes.") + logger.debug("keys before pre/post grad fx passes %s", dict_base.keys()) + logger.debug("keys after pre/post grad fx passes %s", dict_control.keys()) + return False + is_allclose = True + for key in dict_base.keys(): + if key not in dict_control: + logger.warning( + "Mismatch parameter name %s does not exist after pre/post grad fx passes", + key, + ) + # Some parameters have `None`, and not every param has a valid .grad field, we skip them + if dict_base[key] is None or dict_control[key] is None: + continue + if not torch.allclose( + dict_base[key], + dict_control[key], + rtol=precision, + atol=precision, + equal_nan=True, + ): + logger.warning( + "Mismatch parameter values found before and after pre/post grad fx passes." + ) + logger.debug("value before pre/post grad fx passes %s", dict_base[key]) + logger.debug("value after pre/post grad fx passes %s", dict_control[key]) + is_allclose = False + return is_allclose + + +def compare_tuple_tensors(tuple_base, tuple_control, precision): + if len(tuple_base) != len(tuple_control): + logger.warning( + "Mismatch fw output length. before transformation: %s, after transformation: %s", + len(tuple_base), + len(tuple_control), + ) + return False + is_allclose = True + for i in range(len(tuple_base)): + # Some parameters have `None`, we skip them + if tuple_base[i] is None or tuple_control[i] is None: + continue + if not torch.allclose( + tuple_base[i], + tuple_control[i], + rtol=precision, + atol=precision, + equal_nan=True, + ): + logger.debug( + "forward output before pre/post grad fx passes %s", tuple_base[i] + ) + logger.debug( + "forward output after pre/post grad fx passes %s", tuple_control[i] + ) + is_allclose = False + return is_allclose + + +def compare_parameters(model_base, model_control, precision): + return compare_dict_tensors( + dict(model_base.named_parameters()), + dict(model_control.named_parameters()), + precision, + ) + + +def compare_forward_output(pred_base, pred_control, precision): + return compare_tuple_tensors( + pred_base, + pred_control, + precision, + ) + + +def compare_gradients(model_base, model_control, precision): + grad_base = {key: param.grad for key, param in model_base.named_parameters()} + grad_pt2 = {key: param.grad for key, param in model_control.named_parameters()} + return compare_dict_tensors( + grad_base, + grad_pt2, + precision, + ) + + +def run_model( + model_base, model_control, model_input, num_iterations=10, precision=1e-4 +): + clean_memory() + for i in range(num_iterations): + logger.info("start %s iteration", i) + set_deterministic() + pred_base = model_base(*model_input) + set_deterministic() + pred_control = model_control(*model_input) + + res = compare_parameters(model_base, model_control, precision) + logger.info("compare parameters. Numerical result : %s", res) + + res = compare_forward_output(pred_base, pred_control, precision) + logger.info("compare loss/predict. Numerical result : %s", res) + # tensor may not have a grad_fn + try: + _ = pred_base[0].sum().backward(retain_graph=True) + _ = pred_control[0].sum().backward(retain_graph=True) + res = compare_gradients(model_base, model_control, precision) + logger.info("compare param grad. Numerical result : %s", res) + except Exception: + logger.exception("Exception when comparing gradients") + traceback.print_exc() + + if config.fx_passes_numeric_check["requires_optimizer"]: + try: + optimizer_base = optim.SGD( + [param for name, param in model_base.named_parameters()], lr=0.01 + ) + optimizer_base.step() + + optimizer_control = optim.SGD( + [param for name, param in model_control.named_parameters()], lr=0.01 + ) + optimizer_control.step() + + res = compare_parameters(model_base, model_control, precision) + logger.info( + "compare parameters with optimizer added. Numerical result : %s", + res, + ) + except Exception as e: + logger.exception( + "Exception when optimizer is added to check parameter names" + ) + traceback.print_exc() + else: + logger.warning( + "no parameter with optimizer to compare with length %s before transformation" + " and the length %s after transformation", + len(dict(model_base.named_parameters())), + len(dict(model_control.named_parameters())), + ) + + +def numeric_check_if_enabled( + gm_before_fx_passes, + gm_after_fx_passes, + example_inputs, + num_iterations, + precision, +): + # need to topo-sort graphmodule before we run the model, + # otherwise it may fail as refer before def + # fail silently in order not to block the model run + try: + with torch.autograd.set_detect_anomaly(True): + run_model( + gm_before_fx_passes, + gm_after_fx_passes, + example_inputs, + num_iterations=num_iterations, + precision=precision, + ) + except Exception as e: + logger.warning( + "Runtime numeric check failed in pre grad fx passes with error: %s", e + ) + traceback.print_exc() diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pad_mm.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pad_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..87450b34e7ee2734f22f29fde2e1b995aa1128e8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pad_mm.py @@ -0,0 +1,881 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import operator +import typing +from typing import Callable, List, Optional, Union + +import torch +import torch._inductor.runtime.runtime_utils +from torch import Tensor +from torch._dynamo.utils import counters +from torch._inductor import utils +from torch._inductor.autoheuristic.autoheuristic import ( + AHContext, + AutoHeuristic, + LocalFeedback, +) +from torch._inductor.autoheuristic.autoheuristic_utils import ( + context_add_strides, + context_add_using_tf32, + pad_mm_operations, + pad_mm_precondition, +) +from torch._subclasses.fake_tensor import FakeTensor +from torch.utils._mode_utils import no_dispatch + +from ...utils._triton import has_triton +from ..pattern_matcher import ( + fwd_only, + gen_register_replacement, + joint_fwd_bwd, + Match, + ReplaceFn, + SearchFn, +) + + +aten = torch.ops.aten + + +# This flag is only used for testing purpose. +# Changing it to True will ignore comparing do_bench times +# between original pattern and padded one. +_skip_do_bench_times = False + + +def fetch_fake_tensors(match, kwarg_names) -> List[Tensor]: + kwargs = match.kwargs + return [kwargs[name].meta["val"] for name in kwarg_names] + + +def unwrap_fake_args(*arg_names): + def decorator(func): + def wrapper(match): + fake_tensors = fetch_fake_tensors(match, arg_names) + return func(*fake_tensors) + + return wrapper + + return decorator + + +def get_alignment_size(x: Tensor) -> int: + return get_alignment_size_dtype(x.dtype) + + +def get_alignment_size_dtype(dtype: torch.dtype) -> int: + if dtype == torch.float16 or dtype == torch.half or dtype == torch.bfloat16: + return 8 + elif dtype == torch.float32 or dtype == torch.float: + return 4 + else: + return 0 + + +def check_device(a: Tensor, b: Tensor) -> bool: + return a.is_cuda and b.is_cuda + + +def check_dtype(a: Tensor, b: Tensor) -> bool: + return a.is_floating_point() and b.is_floating_point() + + +def should_pad_common( + mat1: Tensor, mat2: Tensor, input: Optional[Tensor] = None +) -> bool: + # It's fine we have symbolic shapes or strides as long as they + # have hints. Later, we will make sure we only pad non-symbolic dimensions. + def valid_shape_and_stride(t: Optional[Tensor]) -> bool: + if t is None: + return True + + symbolic_cnt = 0 + for x in t.size(): + if isinstance(x, int): + continue + elif utils.is_symbolic(x): + if not x.node.has_hint(): + return False + symbolic_cnt += 1 + else: + return False + # filter out cases where all dimentions are symbolic + if symbolic_cnt == len(t.size()): + return False + return all( + isinstance(x, int) or (utils.is_symbolic(x) and x.node.has_hint()) + for x in t.stride() + ) + + return ( + torch._inductor.config.shape_padding + and check_device(mat1, mat2) + and check_dtype(mat1, mat2) + and all(valid_shape_and_stride(t) for t in (mat1, mat2, input)) + ) + + +def get_padded_length(x: Union[int, torch.SymInt], alignment_size) -> int: + # we don't pad x if it is symbolic + if isinstance(x, torch.SymInt) or alignment_size == 0 or x % alignment_size == 0: + return 0 + + # ignore dim that can be squeezed away + if x == 1: + return 0 + + return int((x // alignment_size + 1) * alignment_size) - x + + +def pad_dim(x: Tensor, padded_length: int, dim: int) -> Tensor: + if padded_length == 0: + return x + pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :]) + return torch.cat([x, pad], dim=dim) + + +def addmm_pattern( + input: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float +) -> Tensor: + return aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha) + + +def should_pad_addmm(match: Match) -> bool: + mat1, mat2, input = fetch_fake_tensors(match, ("mat1", "mat2", "input")) + return should_pad_common(mat1, mat2, input) and should_pad_bench( + match, mat1, mat2, torch.ops.aten.addmm, input=input + ) + + +def pad_addmm( + input: Optional[Tensor], + mat1: Tensor, + mat2: Tensor, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, + beta=1.0, + alpha=1.0, + mat1_pre_padded: bool = False, + mat2_pre_padded: bool = False, +): + # for paddings, dim order is reversed for some reasons + # and for every dim, we need to specify left and right padding + if not mat1_pre_padded: + mat1 = pad_mat1( + mat1, m_padded_length=m_padded_length, k_padded_length=k_padded_length + ) + if not mat2_pre_padded: + mat2 = pad_mat2( + mat2, k_padded_length=k_padded_length, n_padded_length=n_padded_length + ) + + # the add broadcasts, so we only pad if the dimension != 1 + if input is not None: + if n_padded_length != 0: + if input.dim() == 2 and input.shape[1] != 1: + input = pad_dim(input, n_padded_length, 1) + elif input.dim() == 1 and input.shape[0] != 1: + input = pad_dim(input, n_padded_length, 0) + if m_padded_length != 0 and input.dim() == 2 and input.shape[0] != 1: + input = pad_dim(input, m_padded_length, 0) + + res = aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha) + + if m_padded_length != 0: + res = res[:-m_padded_length, :] + if n_padded_length != 0: + res = res[:, :-n_padded_length] + return res + + +def addmm_replace( + input: Optional[Tensor], mat1: Tensor, mat2: Tensor, beta=1.0, alpha=1.0 +) -> Tensor: + k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) + n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2)) + m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1)) + return pad_addmm( + input, + mat1, + mat2, + m_padded_length, + k_padded_length, + n_padded_length, + beta, + alpha, + ) + + +def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool: + denominator = M * K + N * K + M * N + if denominator == 0: + return False + arithmetic_intensity = (M * N * K) / denominator + + # we have experienced some large perf hits in this case, even in bandwidth bound regimes + if ( + dtype is torch.bfloat16 + and K > M + and K > N + and torch.cuda.get_device_capability() < (9, 0) + ): # doesnt repro on h100s: + return True + + # Fails with AMD + try: + machine_balance = ( + 1000 * utils.get_device_tflops(dtype) + ) / utils.get_gpu_dram_gbps() + except Exception: + return True + + # dram_gbps might be underestimating bandwidth because of cache. + # if we estimate machine balance too low we might miss some speedups, + # if we extimate too high there will be unnecessary compilation time increase. + # TODO - finetune coefficient here. As a reference point, Triton mm model assumes + # 80% of reads are in cache and cache is 4x faster than dram_gbps + machine_balance = machine_balance * 0.5 + + return arithmetic_intensity > machine_balance + + +@functools.lru_cache(None) +def get_pad_cache(): + return torch._inductor.codecache.LocalCache() + + +def get_cached_should_pad(key: str) -> bool: + return get_pad_cache().lookup(key) + + +def set_cached_should_pad(key: str, value: bool): + return get_pad_cache().set_value(key, value=value) + + +def get_cached_base_mm_benchmark_time(key: str) -> float: + return get_pad_cache().lookup(key) + + +def set_cached_base_mm_benchmark_time(key: str, value: float): + return get_pad_cache().set_value(key, value=value) + + +def should_pad_bench_key( + match, + mat1: Tensor, + mat2: Tensor, + op, + input: Optional[Tensor] = None, + is_base_time_key=False, +) -> str: + def tensor_key(t): + return (t.shape, t.stride(), t.dtype) + + tf32_key = ( + None if mat1.dtype != torch.float32 else torch.backends.cuda.matmul.allow_tf32 + ) + + def fmt_pad(name): + if is_base_time_key: + return None + return f"exclude_pad:{should_exclude_padding_time(match, name)}" + + key = ( + tensor_key(mat1), + tensor_key(mat2), + fmt_pad("mat1"), + fmt_pad("mat2"), + op, + input if input is None else tensor_key(input), + tf32_key, + ) + + key = str(key) + if is_base_time_key: + key = f"base mm time: {key}" + return key + + +def get_non_view_def(node): + if node.op == operator.getitem: + return get_non_view_def(node.args[0]) + + if ( + node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) + and utils.is_view(node.target) + ): + return get_non_view_def(node.all_input_nodes[0]) + + return node + + +def should_exclude_padding_time(match, arg_name): + node_def = get_non_view_def(match.kwargs[arg_name]) + + # constant padding converts tensors to contiguous so even if the input tensor + # can be planned layout transform is not free. TODO - way to pad and preserve layout ? + if not fetch_fake_tensors(match, (arg_name,))[0].is_contiguous(): + return False + + # TODO - see issue https://githpub.com/pytorch/pytorch/issues/128889 + # We would only able to completely plan these out if we were only doing + # first dimension padding. non-first we would still need a copy + # because these outputs are fixed dense. + cannot_plan_output = [ + aten.mm.default, + aten.convolution.default, + aten.convolution_backward.default, + aten.bmm.default, + aten.addmm.default, + aten._scaled_dot_product_flash_attention.default, + aten._scaled_dot_product_efficient_attention.default, + ] + + if node_def.target in cannot_plan_output: + return False + + if ( + node_def.target == aten.cat.default + and len(node_def.all_input_nodes) + > torch._inductor.config.max_pointwise_cat_inputs + ): + return False + + # optimistically assume we should be able to memory plan away + # all non inputs + return node_def.op != "placeholder" + + +def should_pad(key: str, ori_time, pad_time) -> bool: + multiplier = 1.1 + # Shape padding introduces additional memory ops. Based on microbenchmarks, 1.1x represents a reasonable + # tradeoff between performance improvement from shape padding and overhead from additional memory ops + # TODO: Build a learned model which would be better than this heuristic + if "shape_padding_multiplier" in torch._inductor.config.post_grad_fusion_options: + multiplier = torch._inductor.config.post_grad_fusion_options[ + "shape_padding_multiplier" + ].get("value", 1.1) + counters["inductor"]["shape_padding_multiplier"] += 1 + should_pad = _skip_do_bench_times or ori_time > pad_time * multiplier + set_cached_should_pad(key, should_pad) + return should_pad + + +def should_pad_bench( + match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None +) -> bool: + do_bench = functools.partial( + torch._inductor.runtime.benchmarking.benchmarker.benchmark_gpu, + warmup=5, + ) + m_padded_length = 0 + n_padded_length = 0 + batchsize = 1 + with no_dispatch(): + if op is torch.ops.aten.mm or op is torch.ops.aten.addmm: + m = mat1.shape[0] + k = mat1.shape[1] + n = mat2.shape[1] + k_padded_length = get_padded_length(k, get_alignment_size(mat1)) + n_padded_length = get_padded_length(n, get_alignment_size(mat2)) + m_padded_length = get_padded_length(m, get_alignment_size(mat1)) + elif op is torch.ops.aten.bmm: + batchsize = mat1.shape[0] + m = mat1.shape[1] + k = mat1.shape[2] + n = mat2.shape[2] + k_padded_length = get_padded_length(k, get_alignment_size(mat1)) + m_padded_length = get_padded_length(m, get_alignment_size(mat1)) + n_padded_length = get_padded_length(n, get_alignment_size(mat2)) + else: + return False + + if m_padded_length == k_padded_length == n_padded_length == 0: + return False + + def realize_symbols(ds): + return [d if isinstance(d, int) else d.node.hint for d in ds] + + if any( + dim == 0 + for dim in itertools.chain( + realize_symbols(mat1.shape), realize_symbols(mat2.shape) + ) + ): + return False + + if torch._inductor.config.force_shape_pad: + return True + + if not has_triton(): + return False + + if not is_mm_compute_bound(m, k, n, mat1.dtype): + return False + + # We don't want to look up the cache for cases that are trivially false + # since it does file io + key = should_pad_bench_key(match, mat1, mat2, op, input) + + cached_pad = get_cached_should_pad(key) + if cached_pad is not None: + return cached_pad + + def realize_tensor(t): + if isinstance(t, FakeTensor): + size_hints = realize_symbols(t.size()) + stride_hint = realize_symbols(t.stride()) + real_size = ( + sum((d - 1) * s for d, s in zip(size_hints, stride_hint)) + 1 + ) + real_t = torch.randn(real_size, dtype=t.dtype, device=t.device) + return torch.as_strided(real_t, size_hints, stride_hint) + else: + return torch.randn_like(t) + + mat1 = realize_tensor(mat1) + mat2 = realize_tensor(mat2) + + # since we key on whether or not the inputs can be memory planned, set cache for the + # original time which is unaffected by whether or not the input can be planned + ori_time_key = should_pad_bench_key( + match, mat1, mat2, op, input, is_base_time_key=True + ) + ori_time = get_cached_base_mm_benchmark_time(ori_time_key) + if ori_time is None and op is torch.ops.aten.addmm and input is not None: + # realize bias for addmm + input = realize_tensor(input) + + mat1_pad = mat1 + mat2_pad = mat2 + + is_bmm = op is torch.ops.aten.bmm + + mat1_pre_padded = should_exclude_padding_time(match, "mat1") + fns = [] + if mat1_pre_padded and (m_padded_length or k_padded_length): + mat1_pad = pad_mat1( + mat1_pad, + m_padded_length=m_padded_length, + k_padded_length=k_padded_length, + is_bmm=is_bmm, + ) + + def write_pad(): + if is_bmm: + mat1_pad[:, -m_padded_length:, -k_padded_length:].fill_(0) + else: + mat1_pad[-m_padded_length:, -k_padded_length:].fill_(0) + + fns.append(write_pad) + + mat2_pre_padded = should_exclude_padding_time(match, "mat2") + if mat2_pre_padded and (k_padded_length or n_padded_length): + mat2_pad = pad_mat2( + mat2_pad, + k_padded_length=k_padded_length, + n_padded_length=n_padded_length, + is_bmm=is_bmm, + ) + + def write_pad(): + if is_bmm: + mat2_pad[:, -k_padded_length:, -n_padded_length:].fill_(0) + else: + mat2_pad[-k_padded_length:, -n_padded_length:].fill_(0) + + fns.append(write_pad) + + if op is torch.ops.aten.addmm: + input_pad = None + if input is not None and input.is_cuda: + input_pad = torch.randn_like(input) + fns.append( + lambda: pad_addmm( + input_pad, + mat1_pad, + mat2_pad, + m_padded_length, + k_padded_length, + n_padded_length, + mat1_pre_padded=mat1_pre_padded, + mat2_pre_padded=mat2_pre_padded, + ) + ) + elif op is torch.ops.aten.mm: + fns.append( + lambda: pad_mm( + mat1_pad, + mat2_pad, + m_padded_length, + k_padded_length, + n_padded_length, + mat1_pre_padded=mat1_pre_padded, + mat2_pre_padded=mat2_pre_padded, + ) + ) + else: + fns.append( + lambda: pad_bmm( + mat1_pad, + mat2_pad, + m_padded_length, + k_padded_length, + n_padded_length, + mat1_pre_padded=mat1_pre_padded, + mat2_pre_padded=mat2_pre_padded, + ) + ) + + def orig_bench_fn(): + if op is torch.ops.aten.bmm or op is torch.ops.aten.mm: + op(mat1, mat2) + else: + op(input, mat1, mat2) + + def pad_bench_fn(): + for fn in fns: + fn() + + if ( + torch._inductor.config.run_autoheuristic("pad_mm") + and op is torch.ops.aten.mm + ): + ah_should_pad = run_autoheuristic( + mat1, + mat2, + orig_bench_fn, + pad_bench_fn, + m_padded_length, + k_padded_length, + n_padded_length, + do_bench, + mat1_pre_padded, + mat2_pre_padded, + ori_time, + ori_time_key, + key, + ) + if ah_should_pad is not None: + return ah_should_pad + + if ori_time is None: + ori_time = do_bench(orig_bench_fn) + set_cached_base_mm_benchmark_time(ori_time_key, ori_time) + + pad_time = do_bench(pad_bench_fn) + return should_pad(key, ori_time, pad_time) + + +def get_context( + mat1: Tensor, + mat2: Tensor, + mat1_pre_padded: bool, + mat2_pre_padded: bool, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, +): + context = AHContext() + + context.add_feature("m", mat1.shape[0]) + context.add_feature("k", mat1.shape[1]) + context.add_feature("n", mat2.shape[1]) + + context_add_strides(context, "mat1", mat1.stride()) + context_add_strides(context, "mat2", mat2.stride()) + + context.add_feature("m_padded_length", m_padded_length) + context.add_feature("k_padded_length", k_padded_length) + context.add_feature("n_padded_length", n_padded_length) + + context.add_feature("mat1_align_size", get_alignment_size(mat1)) + context.add_feature("mat2_align_size", get_alignment_size(mat2)) + + context.add_feature("mat1_dtype", mat1.dtype, is_categorical=True) + context.add_feature("mat2_dtype", mat2.dtype, is_categorical=True) + + context.add_feature("prepadded_mat1", mat1_pre_padded, is_categorical=True) + context.add_feature("prepadded_mat2", mat2_pre_padded, is_categorical=True) + + context_add_using_tf32(context, mat1.dtype) + return context + + +def run_autoheuristic( + mat1: Tensor, + mat2: Tensor, + orig_bench_fn: Callable[[], None], + pad_bench_fn: Callable[[], None], + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, + do_bench, + mat1_pre_padded: bool, + mat2_pre_padded: bool, + ori_time, + ori_time_key: str, + key: str, +) -> Optional[bool]: + def feedback_fn(choice: str): + if choice == orig_choice: + return do_bench(orig_bench_fn) + elif choice == pad_choice: + return do_bench(pad_bench_fn) + return None + + def fallback() -> str: + return "autotune" + + orig_choice = "orig" + pad_choice = "pad" + choices = [orig_choice, pad_choice] + feedback = LocalFeedback(feedback_fn) + context = get_context( + mat1, + mat2, + mat1_pre_padded, + mat2_pre_padded, + m_padded_length, + k_padded_length, + n_padded_length, + ) + name = "pad_mm" + autoheuristic = AutoHeuristic( + fallback=fallback, + choices=choices, + feedback=feedback, + context=context, + name=name, + augment_context=pad_mm_operations(), + precondition=pad_mm_precondition, + ) + choice = autoheuristic.get_choice() + choice2should_pad = {orig_choice: False, pad_choice: True, "autotune": None} + ah_should_pad = choice2should_pad.get(choice, None) + + if torch._inductor.config.collect_autoheuristic(name): + ah_ori_time = autoheuristic.get_collected_feedback(orig_choice) + ah_pad_time = autoheuristic.get_collected_feedback(pad_choice) + + # if precondition is not satisifed, autoheuristic does not collect data + if ah_ori_time is not None and ah_pad_time is not None: + if ori_time is None: + set_cached_base_mm_benchmark_time(ori_time_key, ah_ori_time) + return should_pad(key, ah_ori_time, ah_pad_time) + if ah_should_pad is not None: + set_cached_should_pad(key, ah_should_pad) + return ah_should_pad + + +def mm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor: + return aten.mm(mat1, mat2) + + +def should_pad_mm(match: Match) -> bool: + mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2")) + return should_pad_common(mat1, mat2) and should_pad_bench( + match, mat1, mat2, torch.ops.aten.mm + ) + + +def pad_mat1(mat1, *, m_padded_length, k_padded_length, is_bmm=False): + if m_padded_length == 0 and k_padded_length == 0: + return mat1 + elif k_padded_length != 0 and m_padded_length != 0: + # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding + pad_arg = [0, k_padded_length, 0, m_padded_length] + if is_bmm: + pad_arg.extend((0, 0)) + return aten.constant_pad_nd(mat1, pad_arg) + elif m_padded_length != 0: + return pad_dim(mat1, m_padded_length, 0 if not is_bmm else 1) + else: + assert k_padded_length != 0 + return pad_dim(mat1, k_padded_length, 1 if not is_bmm else 2) + + +def pad_mat2(mat2, *, k_padded_length, n_padded_length, is_bmm=False): + if k_padded_length == 0 and n_padded_length == 0: + return mat2 + elif k_padded_length != 0 and n_padded_length != 0: + # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding + pad_arg = [0, n_padded_length, 0, k_padded_length] + if is_bmm: + pad_arg.extend((0, 0)) + return aten.constant_pad_nd(mat2, pad_arg) + elif k_padded_length != 0: + return pad_dim(mat2, k_padded_length, 0 if not is_bmm else 1) + else: + assert n_padded_length != 0 + return pad_dim(mat2, n_padded_length, 1 if not is_bmm else 2) + + +def pad_mm( + mat1: Tensor, + mat2: Tensor, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, + mat1_pre_padded: bool = False, + mat2_pre_padded: bool = False, +) -> Tensor: + if not mat1_pre_padded: + mat1 = pad_mat1( + mat1, m_padded_length=m_padded_length, k_padded_length=k_padded_length + ) + if not mat2_pre_padded: + mat2 = pad_mat2( + mat2, k_padded_length=k_padded_length, n_padded_length=n_padded_length + ) + res = aten.mm(mat1, mat2) + if m_padded_length != 0: + res = res[:-m_padded_length, :] + if n_padded_length != 0: + res = res[:, :-n_padded_length] + return res + + +def mm_replace(mat1: Tensor, mat2: Tensor) -> Tensor: + k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) + m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1)) + n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2)) + return pad_mm( + mat1, + mat2, + m_padded_length, + k_padded_length, + n_padded_length, + ) + + +def bmm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor: + return aten.bmm(mat1, mat2) + + +def should_pad_bmm(match: Match) -> bool: + mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2")) + return should_pad_common(mat1, mat2) and should_pad_bench( + match, mat1, mat2, torch.ops.aten.bmm + ) + + +def pad_bmm( + mat1: Tensor, + mat2: Tensor, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, + mat1_pre_padded: bool = False, + mat2_pre_padded: bool = False, +) -> Tensor: + if not mat1_pre_padded: + mat1 = pad_mat1( + mat1, + m_padded_length=m_padded_length, + k_padded_length=k_padded_length, + is_bmm=True, + ) + if not mat2_pre_padded: + mat2 = pad_mat2( + mat2, + k_padded_length=k_padded_length, + n_padded_length=n_padded_length, + is_bmm=True, + ) + res = aten.bmm(mat1, mat2) + if m_padded_length != 0: + res = res[:, :-m_padded_length, :] + if n_padded_length != 0: + res = res[:, :, :-n_padded_length] + return res + + +def bmm_replace(mat1: Tensor, mat2: Tensor) -> Tensor: + k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1)) + n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2)) + m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) + return pad_bmm( + mat1, + mat2, + m_padded_length, + k_padded_length, + n_padded_length, + ) + + +@functools.lru_cache(None) +def _pad_mm_init(): + from .joint_graph import patterns + + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + + # sizes/values dont actually matter for initial trace + # once we get a possible match we re-trace with the actual values and verify the match still holds + + dim2a = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True) + dim2b = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True) + + dim3a = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True) + dim3b = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True) + + dim1a = functools.partial(torch.empty, (4), device=device, requires_grad=True) + + # workaround https://github.com/pytorch/pytorch/issues/97894 + # 0.113377 is a "magic" value that lets us recover the lost input arg relationship + rep = {"beta": 0.213377, "alpha": 0.113377} + + for pattern, replacement, args, workaround, extra_check in [ + ( + typing.cast(SearchFn, mm_pattern), + typing.cast(ReplaceFn, mm_replace), + [dim2a(), dim2b()], + {}, + should_pad_mm, + ), + ( + typing.cast(SearchFn, bmm_pattern), + typing.cast(ReplaceFn, bmm_replace), + [dim3a(), dim3b()], + {}, + should_pad_bmm, + ), + ( + typing.cast(SearchFn, addmm_pattern), + typing.cast(ReplaceFn, addmm_replace), + [dim1a(), dim2a(), dim2b()], + rep, + should_pad_addmm, + ), + ]: + assert isinstance(workaround, dict) # mypy is unable to infer the type properly + name = pattern.__name__ + + gen_register_replacement( + f"{name}_training", + pattern, + replacement, + args, + joint_fwd_bwd, + patterns, + extra_check=extra_check, + scalar_workaround=workaround, + ) + + gen_register_replacement( + f"{name}_inference", + pattern, + replacement, + args, + fwd_only, + patterns, + extra_check=extra_check, + scalar_workaround=workaround, + ) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/post_grad.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/post_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..4e73931fcef23b606a56c977a36747ca1eefc3ad --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/post_grad.py @@ -0,0 +1,1318 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +import itertools +import logging +import operator +from collections import Counter, defaultdict +from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union + +import torch +import torch._inductor as inductor +import torch.utils._pytree as pytree +from torch import fx +from torch._decomp import register_decomposition +from torch._dynamo.utils import counters, optimus_scuba_log +from torch._inductor import comms +from torch._inductor.virtualized import ops +from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype +from torch._utils_internal import upload_graph +from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq +from torch.fx.passes.graph_transform_observer import GraphTransformObserver + +from .. import config, ir, pattern_matcher +from ..codegen.common import BackendFeature, has_backend_feature +from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage +from ..lowering import lowerings as L +from ..pattern_matcher import ( + _return_true, + Arg, + CallFunction, + CallFunctionVarArgs, + filter_nodes, + get_arg_value, + get_mutation_region_id, + Ignored, + init_once_fakemode, + KeywordArg, + ListOf, + Match, + MULTIPLE, + PatternMatcherPass, + register_graph_pattern, + stable_topological_sort, +) +from ..utils import decode_device, get_gpu_type, is_pointwise_use +from ..virtualized import V +from .b2b_gemm import B2B_GEMM_PASS +from .ddp_fusion import fuse_ddp_communication +from .group_batch_fusion import group_batch_fusion_passes, POST_GRAD_FUSIONS +from .micro_pipeline_tp import micro_pipeline_tp_pass +from .pre_grad import is_same_dict, save_inductor_dict +from .reinplace import reinplace_inplaceable_ops +from .split_cat import POST_GRAD_PATTERNS + + +if TYPE_CHECKING: + from sympy import Expr + + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims + +# First pass_patterns[0] are applied, then [1], then [2] +pass_patterns = [ + PatternMatcherPass(), + PatternMatcherPass(), + PatternMatcherPass(), +] + + +def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): + """ + Passes that run on after grad. This is called once on the forwards + graph and once on the backwards graph. + + The IR here has been normalized and functionalized. + """ + if config.dce: + # has some issues with mutation in inference mode + gm.graph.eliminate_dead_code() + + if is_inference and config.reorder_for_locality: + reorder_for_locality(gm.graph) + + fake_tensor_updater = FakeTensorUpdater(gm.graph) + + if config.post_grad_custom_pre_pass is not None: + with GraphTransformObserver( + gm, "post_grad_custom_pre_pass", config.trace.log_url_for_graph_xform + ): + config.post_grad_custom_pre_pass(gm.graph) + + if config.pattern_matcher: + lazy_init() + optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph) + group_batch_fusion_passes(gm.graph, pre_grad=False) + remove_noop_ops(gm.graph) + for patterns in pass_patterns: + patterns.apply(gm.graph) # type: ignore[arg-type] + for pass_name in config.post_grad_fusion_options: + # skip all patterns for group batch fusions + if pass_name in POST_GRAD_FUSIONS: + continue + pattern_matcher_pass = POST_GRAD_PATTERNS[pass_name] + inductor_before_change = save_inductor_dict( + [pattern_matcher_pass.pass_name] + ) + pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] + if not is_same_dict(counters["inductor"], inductor_before_change): + optimus_scuba_log[ + f"{pattern_matcher_pass.pass_name}_post_grad" + ] = upload_graph(gm.graph) + if config.b2b_gemm_pass: + B2B_GEMM_PASS.apply(gm.graph) # type: ignore[arg-type] + + if config._micro_pipeline_tp: + micro_pipeline_tp_pass(gm.graph) + + if config._fuse_ddp_communication: + fuse_ddp_communication( + gm.graph, + config._fuse_ddp_communication_passes, + config._fuse_ddp_bucket_size, + ) + + if config.post_grad_custom_post_pass is not None: + with GraphTransformObserver( + gm, "post_grad_custom_post_pass", config.trace.log_url_for_graph_xform + ): + config.post_grad_custom_post_pass(gm.graph) + + stable_topological_sort(gm.graph) + + move_constructors_to_gpu(gm.graph) + + fake_tensor_updater.incremental_update() + + # Keep these last, since they introduces mutation. Look at + # ./fx_passes/README.md for a discussion of mutation invariants. + reinplace_inplaceable_ops(gm.graph) + decompose_auto_functionalized(gm.graph) + + comms.reinplace_fsdp_all_gather(gm.graph) + + gm.recompile() + optimus_scuba_log["after_recompile_post_grad"] = upload_graph(gm.graph) + gm.graph.lint() + + +@init_once_fakemode +def lazy_init(): + if torch._C._has_mkldnn: + from . import decompose_mem_bound_mm # noqa: F401 + from .mkldnn_fusion import _mkldnn_fusion_init + + _mkldnn_fusion_init() + + +def reorder_for_locality(graph: torch.fx.Graph): + def visit(other_node): + if ( + other_node.op == "call_function" + and other_node.target != operator.getitem + and all((n in seen_nodes) for n in other_node.users) + and get_mutation_region_id(graph, node) + == get_mutation_region_id(graph, other_node) + ): + # move node's producers right before it + node.prepend(other_node) + + seen_nodes = set() + + # only reorder nodes before the first copy_ in the graph. + # copy_ will appear at the end of functionalized graphs when there is mutation on inputs, + # and this reordering doesnt work well with mutation + first_copy = next( + iter(graph.find_nodes(op="call_function", target=torch.ops.aten.copy_.default)), + None, + ) + past_mutating_epilogue = True if first_copy is None else False + + for node in reversed(graph.nodes): + seen_nodes.add(node) + if not past_mutating_epilogue: + past_mutating_epilogue = node is first_copy + continue + + torch.fx.map_arg((node.args, node.kwargs), visit) + + +def register_lowering_pattern(pattern, extra_check=_return_true, pass_number=1): + """ + Register an aten to inductor IR replacement pattern + """ + return pattern_matcher.register_lowering_pattern( + pattern, extra_check, pass_dict=pass_patterns[pass_number] + ) + + +################################################################################ +# Actual patterns below this point. +# Priority of patterns is: +# - later output nodes first +# - order patterns are defined in +################################################################################ + + +def is_valid_mm_plus_mm(match: Match): + *b1, m1, k1 = match.kwargs["mat1"].meta.get("tensor_meta").shape + *b2, k2, n1 = match.kwargs["mat2"].meta.get("tensor_meta").shape + if k1 != k2: + return False + + *b1, m2, k3 = match.kwargs["mat3"].meta.get("tensor_meta").shape + *b2, k4, n2 = match.kwargs["mat4"].meta.get("tensor_meta").shape + if k3 != k4: + return False + + if m1 != m2 or n1 != n2: + return False + + return True + + +def scatter_upon_const_tensor_extra_check(m): + if not config.optimize_scatter_upon_const_tensor: + return False + full_shape = m.kwargs["shape"] + selector = m.kwargs["selector"] + dim = m.kwargs["dim"] + if dim < 0: + dim += len(full_shape) + + selector_ft = selector.meta["val"] + assert selector_ft.dim() == len(full_shape) + + for idx, select_sz, full_sz in zip( + itertools.count(), selector_ft.shape, full_shape + ): + if idx == dim: + continue + + # TODO: the pattern can be updated to support the case that index tensor + # is shorter. But that will need a more complex condition expression + # especially for multi-dimensional tensors. + # Skip it for now. + if isinstance(full_sz, fx.Node): + full_sz = full_sz.meta["val"] + if select_sz < full_sz: + return False + + # Actually we can support small size larger than 1. It would be a bit + # tedius. E.g., we load all the index values (not many) and compare + # them with the position in tensor to decide what value to return. + return selector_ft.size(dim) == 1 + + +@register_lowering_pattern( + CallFunction( + aten.scatter.value, + CallFunction( + aten.full, + KeywordArg("shape"), + KeywordArg("background_val"), + dtype=KeywordArg("dtype"), + ), + KeywordArg("dim"), + KeywordArg("selector"), + KeywordArg("val"), # scalar value + ), + extra_check=scatter_upon_const_tensor_extra_check, +) +def scatter_upon_const_tensor( + match: Match, shape, background_val, dtype, dim, selector, val +): + """ + Match the pattern of full+scatter into a pointwise. + + TODO: Right now the scatter value must be a scalar. But we could support it + when it is a tensor as well. + """ + from torch._inductor import metrics + + metrics.num_matches_for_scatter_upon_const_tensor += 1 + + selector_loader = selector.make_loader() + + def inner_fn(idx): + selector_idx = list(idx) + selector_idx[dim] = 0 + + selector = selector_loader(selector_idx) + return ops.where( + selector == ops.index_expr(idx[dim], torch.int64), + ops.constant(val, dtype), + ops.constant(background_val, dtype), + ) + + return ir.Pointwise.create( + device=selector.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=shape, + ) + + +@register_lowering_pattern( + CallFunction( + aten.add, + CallFunction(aten.mm, KeywordArg("mat1"), KeywordArg("mat2")), + CallFunction(aten.mm, KeywordArg("mat3"), KeywordArg("mat4")), + ), + extra_check=is_valid_mm_plus_mm, +) +def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4): + return inductor.kernel.mm_plus_mm.tuned_mm_plus_mm(mat1, mat2, mat3, mat4) + + +def cuda_and_enabled_mixed_mm(match): + return ( + (config.use_mixed_mm or config.mixed_mm_choice != "default") + and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False) + and ( + match.kwargs["mat2_dtype"].itemsize + > match.kwargs["mat2"].meta.get("val").dtype.itemsize + ) + and has_backend_feature("cuda", BackendFeature.TRITON_TEMPLATES) + ) + + +def cuda_and_enabled_mixed_mm_and_not_int8(match): + return ( + cuda_and_enabled_mixed_mm(match) + and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False) + and getattr(match.kwargs["mat2"].meta.get("val"), "dtype", torch.int8) + != torch.int8 + ) # bitshift numerics in triton and pytorch don't match for torch.int8 + + +""" + this is intended to be used to unpack a [K,N] int4 tensor from a [K/2, N] uint4x2 tensor + (where the int4 and uint4x2 are represented with int8 and uint8 respectively) + where every other row of the int4 is packed with the row above it as: + uint4x2[k,n] = (8+int4[2*k,n])+(8+int4[2*k+1,n])<<4 + + unpack formulas: + int4[2*k,n]=(uint4x2[k,n] & 0xF) - 8 + int4[2*k+1,n]=(uint4x2[k,n] >> 4) - 8 + + thus matching on unpack formula: + torch.mm(mat1, torch.cat((mat2 & 0xF, mat2>>4),1).reshape(mat2_mm_shape).to(mat2_dtype).sub(8)) + + note: although the unpack formula in pytorch and the triton kernel is designed for a uint8 mat2, the behavior + of the kernel matches the pytorch formula for all dtypes except torch.int8 + where the bitwise numerics in triton do not match those in pytorch. +""" + + +@register_lowering_pattern( + CallFunction( + aten.mm.default, + KeywordArg("mat1"), + CallFunction( + aten.sub.Tensor, + CallFunction( + prims.convert_element_type.default, + CallFunction( + aten.reshape.default, + CallFunction( + aten.cat.default, + ListOf( + CallFunction( + aten.bitwise_and.Scalar, + KeywordArg("mat2"), + 0xF, + ), + # CallFunction( + # aten.__rshift__.Scalar, + # KeywordArg("mat2"), + # 4, + # ), + True, + ), + 1, + ), + KeywordArg("mat2_mm_shape"), + ), + KeywordArg("mat2_dtype"), + ), + 8, + ), + ), + extra_check=cuda_and_enabled_mixed_mm_and_not_int8, +) +def uint4x2_mixed_mm(match: Match, mat1, mat2, mat2_mm_shape, mat2_dtype): + return inductor.kernel.unpack_mixed_mm.tuned_uint4x2_mixed_mm( + mat1, mat2, mat2_mm_shape, mat2_dtype + ) + + +""" + torch.mm(mat1, mat2.to(mat2_dtype)) +""" + + +@register_lowering_pattern( + CallFunction( + aten.mm, + KeywordArg("mat1"), + CallFunction( + prims.convert_element_type.default, + KeywordArg("mat2"), + KeywordArg("mat2_dtype"), + ), + ), + extra_check=cuda_and_enabled_mixed_mm, +) +def mixed_mm(match: Match, mat1, mat2, mat2_dtype): + return inductor.kernel.mm.tuned_mixed_mm(mat1, mat2, mat2_dtype) + + +@register_graph_pattern( + CallFunction( + aten.cumsum.default, + CallFunction( + torch.ops.aten.full.default, + KeywordArg("shape"), + KeywordArg("fill_value"), + dtype=KeywordArg("dtype"), + layout=Ignored(), + device=KeywordArg("device"), + pin_memory=False, + _users=MULTIPLE, + ), + KeywordArg("dim"), + _users=MULTIPLE, + ), + pass_dict=pass_patterns[1], +) +def pointless_cumsum_replacement(match: Match, shape, fill_value, device, dtype, dim): + """Based on a pattern in OPTForCausalLM""" + + if is_integer_dtype(dtype) or is_boolean_dtype(dtype): + # cumsum promotes all integral types to int64 + dtype = torch.int64 + + def repl(*shape): + dim_size = shape[dim] + idx = torch.arange(1, dim_size + 1, device=device, dtype=dtype) + + inter_shape = [1] * len(shape) + inter_shape[dim] = dim_size + return (idx * fill_value).view(inter_shape).expand(shape) + + # only replace the output node, not all nodes + match.nodes = [match.output_node()] + match.replace_by_example(repl, list(shape)) + + +def shape_of_mm(a, b): + m, _ = a.get_size() + _, n = b.get_size() + return [m, n] + + +@register_lowering_pattern( + CallFunction(aten.cat, ListOf(CallFunction(aten.mm, Arg(), Arg())), Arg()), +) +def cat_mm(match, inputs, dim): + return cat_tuned_op(match, inputs, dim, op=L[aten.mm], shape_of=shape_of_mm) + + +@register_lowering_pattern( + CallFunction( + aten.cat, ListOf(CallFunction(aten.addmm, Arg(), Arg(), Arg())), Arg() + ), +) +def cat_addmm(match, inputs, dim): + def shape_of(bias, a, b): + m, _ = a.get_size() + _, n = b.get_size() + return [m, n] + + return cat_tuned_op(match, inputs, dim, op=L[aten.addmm], shape_of=shape_of) + + +def cat_tuned_op(match, inputs, dim, *, op, shape_of): + """ + Memory planning to remove cat. We can't use the stock memory + planner since autotuning matmuls needs to know the output layout. + """ + if len(inputs) == 1: + return op(*inputs[0]) + + # TODO(jansel): rewrite this as a bmm? + if dim < 0: + dim += len(shape_of(*inputs[0])) + assert dim in (0, 1) + notdim = 1 - dim + + new_size: Optional[Union[List[Expr], List[int]]] = None + offsets_start = [] + offsets_end = [] + + # compute output sizes + for i in range(len(inputs)): + shape = shape_of(*inputs[i]) + if new_size is None: + new_size = shape + else: + new_size[notdim] = V.graph.sizevars.guard_equals( # type: ignore[call-overload] + shape[notdim], new_size[notdim] + ) + new_size[dim] += shape[dim] + offsets_start.append(new_size[dim] - shape[dim]) + offsets_end.append(new_size[dim]) + + assert new_size is not None + dtype = functools.reduce( + torch.promote_types, + [x.get_dtype() for x in itertools.chain.from_iterable(inputs)], + ) + device = inputs[0][0].get_device() + kernel = ir.ConcatKernel( + name=None, + layout=ir.FixedLayout(device, dtype, new_size), + inputs=[], + ) + kernel_tensor = ir.TensorBox.create(kernel) + + for i in range(len(inputs)): + dst = ir.SliceView.create(kernel_tensor, dim, offsets_start[i], offsets_end[i]) + src = op(*inputs[i], layout=dst.get_layout()).data.data + assert isinstance(src, (ir.ExternKernelOut, ir.TemplateBuffer)) + src.layout = ir.NonOwningLayout(dst) + kernel.inputs.append(src) + + kernel.name = V.graph.register_buffer(kernel) + kernel.inputs = ir.ConcatKernel.unwrap_storage(kernel.inputs) + V.graph.register_operation(kernel) + return kernel_tensor + + +_cat_1 = CallFunction(aten.cat, Arg(), 1, _users=2) + + +@register_lowering_pattern( + CallFunction( + aten.cat, + [ + _cat_1, + CallFunction( + aten.slice, + _cat_1, + 1, + 0, + KeywordArg("size"), + ), + ], + 1, + ) +) +def cat_slice_cat(match, cat_input, size, dim=1): + """ + This is an example of a more complex pattern where cat_1 is used + multiple times inside the pattern. We fold 2 calls to cat into one. + + Matches: + cat_1: f32[1024, 4077] = torch.ops.aten.cat.default([add_26, primals_217], 1) + slice_1: f32[1024, 4077] = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807) + slice_2: f32[1024, 19] = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19) + cat_2: f32[1024, 4096] = torch.ops.aten.cat.default([cat_1, slice_2], 1) + + + Rewrite to: + slice_2 = torch.ops.aten.slice.Tensor(add_26, 1, 0, 19) + cat_2 = torch.ops.aten.cat.default([add_26, primals_217, slice2], 1) + """ + first, *rest = cat_input + # Optimization is optional, because we can just not fold the cat + # size should be within first.get_size()[dim] such that the optimization is valid. + # For negative `end`, we currently fallback to not optimizing. + if size >= 0 and V.graph.sizevars.statically_known_leq(size, first.get_size()[dim]): + # fold 2 cats into 1 cat + return L[aten.cat]( + [ + first, + *rest, + L[aten.slice](first, dim, 0, size), + ], + dim, + ) + else: + # don't expect to hit this case, just fall back + tmp = L[aten.cat](cat_input, dim) + return L[aten.cat]( + [ + tmp, + L[aten.slice](tmp, dim, 0, size), + ], + dim, + ) + + +def is_valid_splitwithsizes_cat(match): + split_nodes = filter_nodes(match.nodes, aten.split_with_sizes) + cat_nodes = filter_nodes(match.nodes, aten.cat) + get_item_nodes = filter_nodes(match.nodes, operator.getitem) + if len(split_nodes) != 1 or len(cat_nodes) != 1: + return False + split_node, cat_node = split_nodes[0], cat_nodes[0] + # The dim of split and cat should match for passthrough + if get_arg_value(split_node, 2, "dim") != get_arg_value(cat_node, 1, "dim"): + return False + get_item_args = { + get_arg_value(get_item_node, 1) for get_item_node in get_item_nodes + } + assert None not in get_item_args + split_sizes = get_arg_value(split_node, 1, "split_sizes") + # All parts of split should be included in the cat + if get_item_args != set(range(len(split_sizes))): + return False + # The order of get_item_args should same with cat_node used. + # For example, if the split_node like split_with_sizes(input, [2, 2, 3], 1), + # the cat node should be like cat([get_item(0), get_item(1), get_item(2)], 1). + cat_items_args_order = [ + get_arg_value(item_node, 1) for item_node in get_arg_value(cat_node, 0) + ] + if cat_items_args_order != list(range(len(split_sizes))): + return False + + return True + + +def same_meta(node1: torch.fx.Node, node2: torch.fx.Node): + """True if two nodes have the same metadata""" + val1 = node1.meta.get("val") + val2 = node2.meta.get("val") + return ( + val1 is not None + and val2 is not None + and statically_known_true(sym_eq(val1.size(), val2.size())) + and val1.layout == val2.layout + and val1.dtype == val2.dtype + and val1.device == val2.device + and ( + val1.layout != torch.strided + or statically_known_true(sym_eq(val1.stride(), val2.stride())) + ) + ) + + +noop_registry: Dict[Any, Any] = {} + + +def register_noop_decomp(targets, nop_arg=0): + def register_fun(cond): + register_decomposition(targets, registry=noop_registry, unsafe=True)( + (cond, nop_arg) # type: ignore[arg-type] + ) + return cond + + return register_fun + + +@register_noop_decomp(aten.slice) +def slice_noop(self, dim=0, start=None, end=None, step=1): + if start is None or end is None: + return False + if ( + statically_known_true(sym_eq(start, 0)) + and statically_known_true(end >= 2**63 - 1) + and statically_known_true(sym_eq(step, 1)) + ): + return True + return False + + +@register_noop_decomp(aten.slice_scatter, 1) +def slice_scatter_noop(self, src, dim=0, start=None, end=None, step=1): + if start is None: + start = 0 + if end is None: + end = 2**63 - 1 + if start == 0 and end >= 2**63 - 1 and step == 1: + return True + return False + + +@register_noop_decomp(aten.repeat) +def repeat_noop(self, repeats): + return all(r == 1 for r in repeats) + + +@register_noop_decomp(aten.constant_pad_nd) +def constant_pad_nd(x, padding, fill_value=0): + return all(p == 0 for p in padding) + + +@register_noop_decomp(torch.ops.prims.convert_element_type) +def convert_element_type_noop(x, dtype: torch.dtype): + return x.dtype == dtype + + +@register_noop_decomp(torch.ops.prims.device_put) +def device_put_noop(x, device): + return x.device == decode_device(device) + + +@register_noop_decomp([aten.ceil, aten.floor, aten.round, aten.trunc]) +def int_noop(x): + return is_integer_dtype(x.dtype) + + +@register_noop_decomp([aten.pow]) +def pow_noop(a, b): + return isinstance(b, int) and b == 1 + + +@register_noop_decomp([aten.cat], lambda args: args[0][0]) +def cat_noop(inputs, dim=0): + return len(inputs) == 1 + + +@register_noop_decomp(aten.view) +def view_noop(arg, size): + return arg.shape == size + + +# Note, we also always have a check for identical metadata, which is why these +# are safe +@register_noop_decomp([aten.copy], nop_arg=1) +@register_noop_decomp([aten.alias, aten.clone]) +def true_noop(*args, **kwargs): + return True + + +def remove_noop_ops(graph: torch.fx.Graph): + """ + Removes both operations that are essentially aten.clone and operations that are essentially aten.alias from the graph. + """ + inputs = set() + input_storages = set() + output_storages = set() + + for node in graph.find_nodes(op="placeholder"): + inputs.add(node) + input_storages.add(get_node_storage(node)) + + output_node = next(iter(reversed(graph.nodes))) + assert output_node.op == "output" + outputs = output_node.args[0] + if not isinstance(outputs, (list, tuple)): + # nested subgraphs can have singleton outputs + outputs = (outputs,) + for out in outputs: + if isinstance(out, torch.fx.Node): + output_storages.add(get_node_storage(out)) + + for node in graph.nodes: + if node.target in noop_registry: + cond, src_index = noop_registry[node.target] + if isinstance(src_index, int): + src = node.args[src_index] + else: + src = src_index(node.args) + if not isinstance(src, torch.fx.Node): + continue + # Don't introduce new aliasing between inputs and outputs. + # See fx_passes/README.md for a discussion of why this is + # necessary. + node_storage = get_node_storage(node) + src_storage = get_node_storage(src) + node_is_view = node_storage == src_storage + if ( + not node_is_view + and node_storage in output_storages + and (src_storage in input_storages or src_storage in output_storages) + ): + continue + + # Even if input and outputs are expected to alias, + # don't make "node is src" True + if ( + node_is_view + and node in output_node.args + and (src in inputs or src in output_node.args) + ): + continue + + is_valid, args, kwargs = get_fake_args_kwargs(node) + if not is_valid: + continue + if same_meta(node, src) and cond(*args, **kwargs): + node.replace_all_uses_with(src) + graph.erase_node(node) + + +def decompose_auto_functionalized(graph): + """Decomposes auto_functionalized and triton_kernel_wrapper_functional + nodes into clones and the underlying mutation node. + + We assume that the reinplacing pass runs before this; the reinplacing pass + tells us (via rewriting the arguments or .meta to those nodes) which + Tensors we should clone and which Tensors are safe to reinplace. + """ + graph_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized), + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + from torch._higher_order_ops.auto_functionalize import auto_functionalized_dense + + only_clone_these_tensors = tuple( + match.nodes[0].meta.get("only_clone_these_tensors", []) + ) + + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + # NB: we combine (args, kwargs) into flat args for replacing. + # This is replace_by_example uses make_fx which does not support + # tracing a function with kwargs. + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + return auto_functionalized_dense(*args, only_clone_these_tensors, **kwargs) + + match.replace_by_example(decomp, flat_args, run_functional_passes=False) + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.triton_kernel_wrapper_functional), + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + from torch._higher_order_ops.triton_kernel_wrap import ( + triton_kernel_wrapper_functional_dense, + ) + + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + # NB: we combine (args, kwargs) into flat args for replacing. + # This is replace_by_example uses make_fx which does not support + # tracing a function with kwargs. + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + return (triton_kernel_wrapper_functional_dense(*args, **kwargs),) + + match.replace_by_example(decomp, flat_args, run_functional_passes=False) + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized_v2), + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + from torch._higher_order_ops.auto_functionalize import ( + auto_functionalized_v2_dense, + ) + + only_clone_these_bases = tuple( + match.nodes[0].meta.get("only_clone_these_tensors", []) + ) + + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + # NB: we combine (args, kwargs) into flat args for replacing. + # This is replace_by_example uses make_fx which does not support + # tracing a function with kwargs. + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + return auto_functionalized_v2_dense(*args, only_clone_these_bases, **kwargs) + + match.replace_by_example(decomp, flat_args, run_functional_passes=False) + + graph_pass.apply(graph) + + for node in graph.find_nodes( + op="call_function", target=torch.ops.higher_order.auto_functionalized + ): + raise AssertionError("auto_functionalized was not removed") + + for node in graph.find_nodes( + op="call_function", target=torch.ops.higher_order.auto_functionalized_v2 + ): + raise AssertionError("auto_functionalized_v2 was not removed") + + for node in graph.find_nodes( + op="call_function", + target=torch.ops.higher_order.triton_kernel_wrapper_functional, + ): + raise AssertionError("triton_kernel_wrapper_functional was not removed") + + +@register_lowering_pattern( + CallFunction( + aten.cat, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split_with_sizes, + KeywordArg("input_"), + Ignored(), + Ignored(), + _users=MULTIPLE, + ), + Ignored(), + ), + ), + Ignored(), + ), + pass_number=2, + extra_check=is_valid_splitwithsizes_cat, +) +def splitwithsizes_cat_replace(match, input_): + return input_ + + +def is_valid_cat_splitwithsizes(match): + cat_nodes = filter_nodes(match.nodes, aten.cat) + split_nodes = filter_nodes(match.nodes, aten.split_with_sizes) + if len(split_nodes) != 1 or len(cat_nodes) != 1: + return False + split_node, cat_node = split_nodes[0], cat_nodes[0] + + # the cat node has other users: can't eliminate + if len(cat_node.users) > 1: + return False + + # the dim of the cat and split should match + dim = get_arg_value(split_node, 2, "dim") + if dim != get_arg_value(cat_node, 1, "dim"): + return False + + cat_inputs = list(get_arg_value(cat_node, 0)) + split_sizes = get_arg_value(split_node, 1, "split_sizes") + # the number of input tensors in cat and the + # length of the split sizes should match + if len(cat_inputs) != len(split_sizes): + return False + + for cat_input, split_size in zip(cat_inputs, split_sizes): + # each cat input tensor's size along dim + # should match the corresponding split size + if "val" not in cat_input.meta: + return False + cat_input_size = cat_input.meta["val"].size(dim) + if cat_input_size != split_size: + return False + + return True + + +@register_lowering_pattern( + CallFunction( + aten.split_with_sizes, + CallFunction( + aten.cat, + KeywordArg("input_"), + Ignored(), + _users=MULTIPLE, + ), + Ignored(), + Ignored(), + ), + pass_number=2, + extra_check=is_valid_cat_splitwithsizes, +) +def cat_splitwithsizes_replace(match, input_): + return input_ + + +def view_to_reshape(gm): + """ + Replace view ops in the GraphModule to reshape ops. + """ + for nd in gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.view.default + ): + nd.target = torch.ops.aten.reshape.default + + +def should_prefer_unfused_addmm(match): + inp = match.kwargs["inp"] + if not inp.meta["val"].is_cuda: + return False + + output = match.output_node() + return all(is_pointwise_use(use) for use in output.users) + + +@register_graph_pattern( + CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()), + pass_dict=pass_patterns[2], + extra_check=should_prefer_unfused_addmm, +) +def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp): + def repl(inp, x1, x2): + return x1 @ x2 + inp + + match.replace_by_example(repl, [inp, mat1, mat2]) + + +def is_valid_addmm_fusion(match): + mat1, mat2 = match.args + inp = match.kwargs["inp"] + + if not ( + isinstance(inp, torch.fx.Node) and isinstance(inp.meta["val"], torch.Tensor) + ): + return False # Input is a number + + in_shape = inp.meta["val"].shape + mm_shape = mat1.meta["val"].shape[0], mat2.meta["val"].shape[1] + matched = is_expandable_to(in_shape, mm_shape) + if not matched: + return False # Shape mismatch + + return not should_prefer_unfused_addmm(match) + + +@register_graph_pattern( + CallFunction( + aten.add, + CallFunction(aten.mm, Arg(), Arg()), + KeywordArg("inp"), + ), + pass_dict=pass_patterns[2], + extra_check=is_valid_addmm_fusion, +) +@register_graph_pattern( + CallFunction( + aten.add, + KeywordArg("inp"), + CallFunction(aten.mm, Arg(), Arg()), + ), + pass_dict=pass_patterns[2], + extra_check=is_valid_addmm_fusion, +) +def addmm(match, mat1, mat2, *, inp): + def repl(inp, mat1, mat2): + return aten.addmm(inp, mat1, mat2) + + match.replace_by_example(repl, [inp, mat1, mat2]) + + +def check_shape_cuda_and_fused_int_mm_mul_enabled(match): + return ( + config.force_fuse_int_mm_with_mul + and len(getattr(match.args[2].meta.get("val"), "shape", [])) == 2 + and getattr(match.args[2].meta.get("val"), "is_cuda", False) + ) + + +@register_lowering_pattern( + CallFunction( + prims.convert_element_type.default, + CallFunction( + aten.mul, + CallFunction( + aten._int_mm, + Arg(), + Arg(), + ), + Arg(), + ), + Arg(), + ), + check_shape_cuda_and_fused_int_mm_mul_enabled, +) +@register_lowering_pattern( + CallFunction( + aten.mul, + CallFunction( + aten._int_mm, + Arg(), + Arg(), + ), + Arg(), + ), + check_shape_cuda_and_fused_int_mm_mul_enabled, +) +def fused_int_mm_mul(match: Match, mat1, mat2, mat3, out_dtype=None): + return inductor.kernel.mm.tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype) + + +def is_index_put_and_requires_h2d_sync_for_gpu_value(node): + from torch.fx.operator_schemas import normalize_function + + if node.target not in [ + torch.ops.aten.index_put.default, + torch.ops.aten.index_put_.default, + ]: + return False + # Inductor falls back to aten.index_put_. + # index_put_ will will call nonzero() and perform a H2D sync if + # any of its indices are bool/byte tensors + # However, it will short-circuit this H2D sync and run mask_fill_ + # if the value we are putting is a cpu scalar. + # Therefore, when inductor sees an index_put_ with byte tensor indices, + # it should *not* convert the cpu scalar value into a gpu tensor. + args_, kwargs_ = normalize_function(node.target, node.args, node.kwargs) # type: ignore[misc] + any_byte_bool_indices = False + indices = args_[1] + for i in indices: + if i is not None and i.meta["val"].dtype in [torch.bool, torch.int8]: + any_byte_bool_indices = True + + val = args_[2].meta["val"] + val_is_cpu_scalar = val.device.type == "cpu" and val.numel() == 1 + # If both these conditions hold, then converting the val + # to a gpu tensor will incur a H2D sync when inductor calls aten.index_put_ + return any_byte_bool_indices and val_is_cpu_scalar + + +class ConstructorMoverPass: + def __init__(self, target: str, allow_outputs: bool = False) -> None: + """ + Move constructors from cpu to the target_device. + + Sweeps through the module, looking for constructor nodes that can be moved + to the target_device. + + A constructor node can be moved to the target_device iff all of its users + can also be moved (tested by cannot_be_moved). Otherwise, all dependent + constructor nodes won't be moved. + + - target: target device type + - allow_outputs: allow outputs to be moved + """ + + self.target = target + self.allow_outputs = allow_outputs + + assert isinstance(target, str), ( + "target should be a string representing the device type. " + f"Got: {type(target).__name__}" + ) + + def allow_cpu_device(self, node: fx.Node) -> bool: + """ + Returns whether a node that returns a tensor on the target device may have + cpu tensors as input. + """ + return node.target in ( + torch.ops.aten.index.Tensor, + torch.ops.aten.index_put.default, + torch.ops.aten.index_put_.default, + torch.ops.aten.copy.default, + torch.ops.aten.copy_.default, + torch.ops.aten.slice_scatter.default, + ) + + def cannot_be_moved(self, node: fx.Node) -> bool: + """ + Returns whether a node can be moved to the target device. + + If this function returns False, it means that this node and all of its users + won't be moved into the target device. + """ + if node.target == "output": + return not self.allow_outputs + + if not ( + isinstance(node.target, torch._ops.OpOverload) + and node.target.namespace in ("prims", "aten") + ): + return True + if is_index_put_and_requires_h2d_sync_for_gpu_value(node): + return True + + return False + + def get_node_device(self, node: fx.Node) -> Optional[torch.device]: + """ + Get the device of a node. + """ + ten = node.meta.get("val") + return None if not isinstance(ten, torch.Tensor) else ten.device + + def get_cpu_indeg_count(self, graph: fx.Graph) -> Dict[fx.Node, int]: + """ + Get the number of cpu inputs to a node + """ + cpu_indeg: Dict[fx.Node, int] = Counter() + + for node in graph.nodes: + cpu_count = 0 + + def add_cpu_inp(node): + nonlocal cpu_count + device = self.get_node_device(node) + cpu_count += device is not None and device.type == "cpu" + + pytree.tree_map_only(fx.Node, add_cpu_inp, (node.args, node.kwargs)) + + if cpu_count: + cpu_indeg[node] = cpu_count + + return cpu_indeg + + def __call__(self, graph: fx.Graph) -> None: + target_devices = set() + constructors = [] + + for node in graph.nodes: + device = self.get_node_device(node) + if device and device.type == self.target: + target_devices.add(device) + + if not ( + isinstance(node.target, torch._ops.OpOverload) + and node.target.namespace in ("prims", "aten") + ): + continue + + if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target): + continue + + if not node.kwargs.get("device") == torch.device("cpu"): + continue + + constructors.append(node) + + # not handling multiple target devices initially + if not constructors or len(target_devices) != 1: + return + + movable_constructors = self.find_movable_constructors(graph, constructors) + + for node in movable_constructors: + kwargs = node.kwargs.copy() + kwargs["device"] = next(iter(target_devices)) + node.kwargs = kwargs + + def find_movable_constructors( + self, graph: fx.Graph, constructors: List[fx.Node] + ) -> Set[fx.Node]: + """ + Starting from the cpu constructors, iterate through the graph and test that all of their + downstream uses can safely be moved to cpu. + """ + cpu_indeg: Dict[fx.Node, int] = self.get_cpu_indeg_count(graph) + + # which constructors cannot be moved to gpu + cannot_move_to_gpu: Set[fx.Node] = set() + + # For any node in the graph, which constructors does it have a dependency on + constructor_dependencies: Dict[fx.Node, Set[fx.Node]] = defaultdict(set) + + # if a cpu node has a dependency on two different cpu constructors, + # then if either constructor cannot be moved to gpu, the other cannot as well. + # In this case any node with a dependency on one will have a dependency on the other + equal_constructor_sets: Dict[fx.Node, Set[fx.Node]] = { + c: {c} for c in constructors + } + + def make_dependencies_equivalent( + set1: Set[fx.Node], set2: Set[fx.Node] + ) -> Set[fx.Node]: + # could use union find but not worth complexity here + set1.update(set2) + for obj in set1: + equal_constructor_sets[obj] = set1 + return set1 + + queue: List[fx.Node] = list(constructors) + + for c in queue: + constructor_dependencies[c].add(c) + + while queue: + node = queue.pop() + dependencies = constructor_dependencies[node] + + for user in node.users: + if self.cannot_be_moved(user): + cannot_move_to_gpu.update(dependencies) + break + + # this node was used on a op which takes in multiple devices and output a gpu + # tensor. we can convert its cpu input to gpu without making further changes + node_device = self.get_node_device(user) + if ( + self.allow_cpu_device(user) + and node_device + and node_device.type == self.target + ): + del cpu_indeg[user] + else: + # otherwise, we should continue look at its downstream uses + cpu_indeg[user] -= 1 + if cpu_indeg[user] == 0: + del cpu_indeg[user] + queue.append(user) + + unioned_set = make_dependencies_equivalent( + dependencies, constructor_dependencies[user] + ) + constructor_dependencies[user] = unioned_set + + for node in cpu_indeg: + if constructor_dependencies[node]: + cannot_move_to_gpu.update(constructor_dependencies[node]) + + all_cannot_move_to_gpu = cannot_move_to_gpu.copy() + for constructor in cannot_move_to_gpu: + all_cannot_move_to_gpu.update(equal_constructor_sets[constructor]) + + return set(constructors) - all_cannot_move_to_gpu + + +def move_constructors_to_gpu(graph: fx.Graph) -> None: + """ + Moves intermediary tensors which are constructed on the cpu to gpu when safe + """ + ConstructorMoverPass(get_gpu_type())(graph) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pre_grad.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pre_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..eef058ab0b5aa22dfd74c6bda72aefba5b61b443 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pre_grad.py @@ -0,0 +1,800 @@ +# mypy: allow-untyped-defs +import copy +import itertools +import logging +from typing import Dict, Optional + +import torch +import torch.nn as nn +from torch._dynamo.utils import counters, detect_fake_mode, optimus_scuba_log +from torch._utils_internal import upload_graph +from torch.fx.experimental.optimization import ( + matches_module_pattern, + replace_node_module, +) +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.fx.passes.shape_prop import ShapeProp +from torch.nn import functional as F +from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights + +from .. import config +from ..fx_utils import matches_module_function_pattern +from ..pattern_matcher import ( + init_once_fakemode, + PatternMatcherPass, + stable_topological_sort, +) +from ..utils import is_cpu_device, pass_execution_and_save +from .group_batch_fusion import group_batch_fusion_passes, PRE_GRAD_FUSIONS +from .misc_patterns import numpy_compat_normalization +from .split_cat import PRE_GRAD_PATTERNS + + +log = logging.getLogger(__name__) + +efficient_conv_bn_eval_pass = PatternMatcherPass( + pass_name="efficient_conv_bn_eval_pass" +) + +fuse_split_linear_add_pass = PatternMatcherPass( + pass_name="fuse_split_linear_add_pass", +) +fuse_chunk_squeeze_cat_pass = PatternMatcherPass( + pass_name="fuse_chunk_squeeze_cat_pass", +) +remove_reshape_pass = PatternMatcherPass( + pass_name="remove_reshape_pass", +) + +# based on predispatch aten IR +normalization_pass_aten = PatternMatcherPass() +merge_splits_pass_aten = PatternMatcherPass() +split_cat_pass_aten = PatternMatcherPass() +unbind_stack_pass_aten = PatternMatcherPass() +merge_getitem_cat_pass_aten = PatternMatcherPass() +merge_stack_tahn_unbind_pass_aten = PatternMatcherPass() +mutate_cat_pass_aten = PatternMatcherPass() +remove_split_with_size_one_pass_aten = PatternMatcherPass() + + +def save_inductor_dict(pass_to_compare=None): + if not pass_to_compare: + pass_to_compare = list(config.pre_grad_fusion_options.keys()) + list( + config.post_grad_fusion_options.keys() + ) + return {p: dict(counters["inductor"]).get(p, 0) for p in pass_to_compare} + + +def is_same_dict(inductor_dict, optimus_dict): + for pass_name, count in optimus_dict.items(): + if count != dict(inductor_dict).get(pass_name, 0): + return False + return True + + +def normalize_node_kwargs_pass(graph): + return None + + +def fuse_parallel_linear_pass(graph): + return None + + +def remove_split_ops(graph, shape_prop): + return None + + +def fuse_chunk_reshape_unsqueeze_concat_pass(graph): + return None + + +def fuse_chunk_reshape_concat_pass(graph): + return None + + +def remove_noop_pass(graph): + return None + + +def stack_to_unsqueeze_pass(graph): + return None + + +@init_once_fakemode +def lazy_init(): + from . import efficient_conv_bn_eval, split_cat # noqa: F401 # noqa: F401 + + if config.is_fbcode(): + from . import fb # type: ignore[attr-defined] # noqa: F401 + + +def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs=None): + """ + Apply passes on the input FX graph using Torch IR. + + WARNING: + The IR before grad is not functional or normalized, so it is harder + to write passes on this IR. Passes must be safe with respect to + aliasing and mutation and need to handle all possible arg schemas. + + Consider adding a new pass to post_grad.py or joint_graph.py which + are after functionalization and normalization. + """ + if config.pattern_matcher: + lazy_init() + if hasattr( + config, "fx_passes_numeric_check" + ) and config.fx_passes_numeric_check.get("pre_grad", False): + gm_before_fx_passes = gm.__copy__() + # explicitly run with predispatch atenIR based passes + if config.is_predispatch: + + def shape_prop(mod) -> None: + ShapeProp( + gm=mod, + # pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode` + fake_mode=detect_fake_mode(example_inputs), + ).propagate(*example_inputs) + + # normalization pass + pass_execution_and_save( + normalization_pass_aten.apply, + gm, + example_inputs, + "[Pre grad(predispatch IR)]Apply normalization pass", + ) + # normalize kwargs, must be called as the first pass + pass_execution_and_save( + normalize_node_kwargs_pass, + gm, + example_inputs, + "[Pre grad(predispatch IR)]Apply normalize_node_kwargs_pass", + ) + pass_execution_and_save( + remove_noop_pass, + gm, + example_inputs, + "[Pre grad(predispatch IR)]Apply remove_noop pass", + ) + pass_execution_and_save( + fuse_chunk_reshape_concat_pass, + gm, + example_inputs, + "[Pre grad(predispatch IR)] Apply fuse_chunk_reshape_concat_pass", + ) + pass_execution_and_save( + group_batch_fusion_passes, + gm, + example_inputs, + "[Pre grad(predispatch IR)] Apply group_batch_fusion", + ) + pass_execution_and_save( + normalize_node_kwargs_pass, + gm, + example_inputs, + "[Pre grad(predispatch IR)]Apply normalize_node_kwargs_pass", + ) + pass_execution_and_save( + fuse_chunk_squeeze_cat_pass.apply, + gm, + example_inputs, + "[Pre grad(predispatch IR)] Apply fuse_chunk_squeeze_cat_pass", + ) + pass_execution_and_save( + fuse_split_linear_add_pass.apply, + gm, + example_inputs, + "[Pre grad(predispatch IR)] Apply fuse_split_linear_add_pass", + ) + pass_execution_and_save( + remove_reshape_pass.apply, + gm, + example_inputs, + "[Pre grad(predispatch IR)] Apply remove_reshape_pass", + ) + pass_execution_and_save( + fuse_parallel_linear_pass, + gm, + example_inputs, + "[Pre grad(predispatch IR)] Apply fuse_parallel_linear_pass", + ) + pass_execution_and_save( + lambda graph: remove_split_ops(graph.owning_module, shape_prop), + gm, + example_inputs, + "[Pre grad(predispatch IR)] Apply remove_split_ops", + ) + # run before fuse_chunk_reshape_unsqueeze_concat_pass + pass_execution_and_save( + stack_to_unsqueeze_pass, + gm, + example_inputs, + "[Pre grad(predispatch IR)] Apply stack_to_unsqueeze_pass", + ) + pass_execution_and_save( + fuse_chunk_reshape_unsqueeze_concat_pass, + gm, + example_inputs, + "[Pre grad(predispatch IR)] Apply fuse_chunk_reshape_unsqueeze_concat_pass", + ) + # Remove noops at the end, which may be generated other passes. + pass_execution_and_save( + remove_noop_pass, + gm, + example_inputs, + "[Pre grad(predispatch IR)]Apply remove_noop pass", + ) + shape_prop(gm) + + else: + # We only log the graph with changes to avoid the excessive compilation time + # https://fb.workplace.com/groups/257735836456307/permalink/633533465543207/ + if example_inputs is not None: + gm = fuse_fx(gm, example_inputs) + numpy_compat_normalization(gm.graph) + optimus_scuba_log["before_recompile_pre_grad"] = upload_graph(gm.graph) + group_batch_fusion_passes(gm.graph, pre_grad=True) + for pass_name in config.pre_grad_fusion_options: + # skip all patterns for group batch fusions + if pass_name in PRE_GRAD_FUSIONS: + continue + pattern_matcher_pass = PRE_GRAD_PATTERNS[pass_name] + inductor_before_change = save_inductor_dict( + [pattern_matcher_pass.pass_name] + ) + # we support run same pattern multiple times, the default is to run only once + counter = config.pre_grad_fusion_options[pass_name].get("counter", 1) + for _ in range(counter): + pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] + if not is_same_dict(counters["inductor"], inductor_before_change): + optimus_scuba_log[ + f"{pattern_matcher_pass.pass_name}_pre_grad" + ] = upload_graph(gm.graph) + # TODO: move efficient_conv_bn_eval_pass to the fusions dict too. + efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type] + + if config.pre_grad_custom_pass is not None: + with GraphTransformObserver( + gm, "pre_grad_custom_pass", config.trace.log_url_for_graph_xform + ): + config.pre_grad_custom_pass(gm.graph) + stable_topological_sort(gm.graph) + + from .quantization import quant_lift_up + + quant_lift_up(gm) + + gm.graph.lint() + gm.recompile() + optimus_scuba_log["after_recompile_pre_grad"] = upload_graph(gm.graph) + + if ( + config.pattern_matcher + and hasattr(config, "fx_passes_numeric_check") + and config.fx_passes_numeric_check.get("pre_grad", False) + and example_inputs is not None + ): + from .numeric_utils import numeric_check_if_enabled + + gm_after_fx_passes = gm.__copy__() + numeric_check_if_enabled( + gm_before_fx_passes, # type: ignore[possibly-undefined] + gm_after_fx_passes, + example_inputs, + config.fx_passes_numeric_check.get("num_iterations", 1), + config.fx_passes_numeric_check.get("precision", 1e-4), + ) + + return gm + + +def fuse_fx(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule: + is_cpu = is_cpu_device(example_inputs) + # pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode` + fake_mode = detect_fake_mode(example_inputs) + + gm = sink_cat_after_pointwise(gm) + if config.permute_fusion and not is_cpu: + # For linear permute fusion, we need to check input info to identify + # and perform proper permutation/transpose + ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs) + with GraphTransformObserver( + gm, "linear_permute_fusion", config.trace.log_url_for_graph_xform + ): + gm = linear_permute_fusion(gm) + with GraphTransformObserver( + gm, "permute_linear_fusion", config.trace.log_url_for_graph_xform + ): + gm = permute_linear_fusion(gm) + with GraphTransformObserver( + gm, "permute_matmul_fusion", config.trace.log_url_for_graph_xform + ): + gm = permute_matmul_fusion(gm) + + # make sure the autograd is disabled. + if torch.is_grad_enabled() or not is_cpu: + return gm + if config.freezing: + with GraphTransformObserver( + gm, "remove_identity", config.trace.log_url_for_graph_xform + ): + gm = remove_identity(gm) + with GraphTransformObserver( + gm, "fuse_conv_bn", config.trace.log_url_for_graph_xform + ): + gm = fuse_conv_bn(gm) + return gm + + +def fetch_attr(target: str, mod): + target_atoms = target.split(".") + attr_itr = mod + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +def remove_identity(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Removes all identity layers from the module. + """ + + class IdentityRemover(torch.fx.Transformer): + def call_module(self, target, args, kwargs): + if isinstance(self.submodules[target], nn.Identity): + assert len(args) == 1 + return args[0] + else: + return super().call_module(target, args, kwargs) + + return IdentityRemover(gm).transform() + + +def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False) -> torch.fx.GraphModule: + """ + Fuses Convolution/BN layers for inference purposes. + """ + modules_patterns = [ + (torch.nn.Conv1d, torch.nn.BatchNorm1d), + (torch.nn.Conv2d, torch.nn.BatchNorm2d), + (torch.nn.Conv3d, torch.nn.BatchNorm3d), + ] + module_function_patterns = [ + (torch.nn.Conv1d, F.batch_norm), + (torch.nn.Conv2d, F.batch_norm), + (torch.nn.Conv3d, F.batch_norm), + ] + modules = dict(gm.named_modules()) + + class ConvBNFusion: + def __init__( + self, + bn_node, + conv_module, + bn_module=None, # For BN Module + bn_running_mean=None, # For Functional BN + bn_running_var=None, + bn_eps=None, + bn_weight=None, + bn_bias=None, + ) -> None: + self.bn_nodes = [ + bn_node, + ] + self.conv_module = conv_module + self.bn_module = bn_module + self.bn_running_mean = bn_running_mean + self.bn_running_var = bn_running_var + self.bn_eps = bn_eps + self.bn_weight = bn_weight + self.bn_bias = bn_bias + self.fusion_enabled = True + + def add_bn_node(self, bn_node): + self.bn_nodes.append(bn_node) + + def disable_fusion(self): + self.fusion_enabled = False + + def is_fusion_enabled(self): + return self.fusion_enabled + + conv_bn_to_fuse: Dict[int, ConvBNFusion] = {} + for pattern in modules_patterns: + conv_bn_to_fuse.clear() + for node in gm.graph.nodes: + if matches_module_pattern(pattern, node, modules): + if len(node.args[0].users) > 1: # Output of conv is used by other nodes + continue + conv = modules[node.args[0].target] + bn = modules[node.target] + eval_mode = all(not n.training for n in [conv, bn]) + if not eval_mode: + continue + if not bn.track_running_stats: + continue + + # Do hash based on the module name of conv + hash_id = hash(node.args[0].target) + if hash_id not in conv_bn_to_fuse: + conv_bn_to_fuse[hash_id] = ConvBNFusion(node, conv, bn) + else: + if bn == conv_bn_to_fuse[hash_id].bn_module: + # Do fusion if same bn module + conv_bn_to_fuse[hash_id].add_bn_node(node) + else: + # Disable the conv bn folding if conv shared by different bn + conv_bn_to_fuse[hash_id].disable_fusion() + + for conv_bn_fusion in conv_bn_to_fuse.values(): + if conv_bn_fusion.is_fusion_enabled(): + bn_nodes = conv_bn_fusion.bn_nodes + conv = conv_bn_fusion.conv_module + bn = conv_bn_fusion.bn_module + + fused_conv = fuse_conv_bn_eval(conv, bn) + for bn_node in bn_nodes: + replace_node_module(bn_node.args[0], modules, fused_conv) + bn_node.replace_all_uses_with(bn_node.args[0]) + gm.graph.erase_node(bn_node) + + gm.graph.lint() + for pattern in module_function_patterns: + conv_bn_to_fuse.clear() + for node in gm.graph.nodes: + if matches_module_function_pattern(pattern, node, modules): + # TODO: support kwargs. + if len(node.args) != 8: + continue + conv = modules[node.args[0].target] + bn_training = node.args[5] + bn_eps = node.args[7] + if conv.training or bn_training: + continue + if type(bn_eps) is not float: + continue + + def _used_by_same_conv_module(users): + conv_module_name = users[0].args[0].target + return all( + conv_module_name == user.args[0].target for user in users + ) + + bn_args_is_constant = all( + n.op == "get_attr" + and (len(n.users) == 1 or _used_by_same_conv_module(list(n.users))) + for n in node.args[1:5] + ) + if not bn_args_is_constant: + continue + bn_running_mean = fetch_attr(node.args[1].target, gm) + bn_running_var = fetch_attr(node.args[2].target, gm) + bn_weight = fetch_attr(node.args[3].target, gm) + bn_bias = fetch_attr(node.args[4].target, gm) + if bn_running_mean is None or bn_running_var is None: + continue + + # Do hash based on the module name of conv + hash_id = hash(node.args[0].target) + if hash_id not in conv_bn_to_fuse: + conv_bn_to_fuse[hash_id] = ConvBNFusion( + node, + conv, + bn_running_mean=bn_running_mean, + bn_running_var=bn_running_var, + bn_eps=bn_eps, + bn_weight=bn_weight, + bn_bias=bn_bias, + ) + else: + if ( + hash(bn_running_mean) + == hash(conv_bn_to_fuse[hash_id].bn_running_mean) + and hash(bn_running_var) + == hash(conv_bn_to_fuse[hash_id].bn_running_var) + and torch.allclose( + torch.tensor(bn_eps), + torch.tensor(conv_bn_to_fuse[hash_id].bn_eps), + ) + and hash(bn_weight) == hash(conv_bn_to_fuse[hash_id].bn_weight) + and hash(bn_bias) == hash(conv_bn_to_fuse[hash_id].bn_bias) + ): + # Do fusion if same functional bn + conv_bn_to_fuse[hash_id].add_bn_node(node) + else: + # Disable the conv bn folding if conv shared by different bn + conv_bn_to_fuse[hash_id].disable_fusion() + + for conv_bn_fusion in conv_bn_to_fuse.values(): + if conv_bn_fusion.is_fusion_enabled(): + bn_nodes = conv_bn_fusion.bn_nodes + conv = conv_bn_fusion.conv_module + bn_running_mean = conv_bn_fusion.bn_running_mean + bn_running_var = conv_bn_fusion.bn_running_var + bn_eps = conv_bn_fusion.bn_eps + bn_weight = conv_bn_fusion.bn_weight + bn_bias = conv_bn_fusion.bn_bias + + fused_conv = copy.deepcopy(conv) + fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights( + fused_conv.weight, + fused_conv.bias, + bn_running_mean, + bn_running_var, + bn_eps, + bn_weight, + bn_bias, + ) + for bn_node in bn_nodes: + replace_node_module(bn_node.args[0], modules, fused_conv) + bn_node.replace_all_uses_with(bn_node.args[0]) + gm.graph.erase_node(bn_node) + gm.graph.lint() + gm.recompile() + + return gm + + +class NormalizedLinearNode: + def __init__(self, node: torch.fx.Node) -> None: + assert node.op == "call_function" + assert node.target in [torch.nn.functional.linear] + self.node: torch.fx.Node = node + + def get_input(self) -> torch.fx.Node: + if len(self.node.args) > 0: + return self.node.args[0] # type: ignore[return-value] + else: + return self.node.kwargs["input"] # type: ignore[return-value] + + def get_weight(self) -> torch.fx.Node: + if len(self.node.args) > 1: + return self.node.args[1] # type: ignore[return-value] + else: + return self.node.kwargs["weight"] # type: ignore[return-value] + + def get_bias(self) -> torch.fx.Node: + if len(self.node.args) > 2: + return self.node.args[2] # type: ignore[return-value] + else: + return self.node.kwargs["bias"] if "bias" in self.node.kwargs else None # type: ignore[return-value] + + +class NormalizedMatmulNode: + def __init__(self, node: torch.fx.Node) -> None: + assert node.op == "call_function" + assert node.target in [torch.bmm, torch.matmul] + self.node: torch.fx.Node = node + + def get_input(self) -> torch.fx.Node: + if len(self.node.args) > 0: + return self.node.args[0] # type: ignore[return-value] + else: + return self.node.kwargs["input"] # type: ignore[return-value] + + def get_other(self) -> torch.fx.Node: + if len(self.node.args) > 1: + return self.node.args[1] # type: ignore[return-value] + else: + return self.node.kwargs["other"] # type: ignore[return-value] + + +def check_permute(node: torch.fx.Node) -> bool: + ranks = len(node.meta["tensor_meta"].shape) + if len(node.args) > 3: + permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] # type: ignore[operator] + elif ( + "permutation" in node.kwargs + and node.kwargs["permutation"] is not None + and len(node.kwargs["permutation"]) > 2 # type: ignore[arg-type] + ): + permutation = [i % ranks for i in node.kwargs["permutation"]] # type: ignore[union-attr] + else: + return False + allowed_permutation = list(range(ranks)) + allowed_permutation[-1] = ranks - 2 + allowed_permutation[-2] = ranks - 1 + return permutation == allowed_permutation + + +def sink_cat_after_pointwise(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + def one_user(node): + users = list(node.users) + return users[0] if len(users) == 1 else None + + def is_view(node): + view = {"view"} + return node.op == "call_method" and node.target in view + + def is_pointwise_unary(node): + pointwise = {torch.relu, torch.tanh, "relu", "tanh"} + return node.op in {"call_function", "call_method"} and node.target in pointwise + + g = module.graph + for node in g.nodes: + if node.op != "call_function" or node.target != torch.cat: + continue + + cat_or_view = node + while True: + user = one_user(cat_or_view) + if not user or not is_view(user): + break + cat_or_view = user + + if user and is_pointwise_unary(user): + with g.inserting_before(node): + + def cat_args(tensors, dim=0): + return tensors, dim + + tensors, dim = cat_args(*node.args, **node.kwargs) + new_kwargs = { + name: val for name, val in user.kwargs.items() if name != "input" + } + new_tensors = [ + g.create_node(user.op, user.target, args=(arg,), kwargs=new_kwargs) + for arg in tensors + ] + new_cat = g.create_node( + "call_function", torch.cat, args=(new_tensors, dim) + ) + user.replace_all_uses_with(cat_or_view) + node.replace_all_uses_with(new_cat) + g.erase_node(user) + g.erase_node(node) + g.lint() + module.recompile() + return module + + +def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.find_nodes(op="call_method", target="permute"): + if check_permute(node): + if len(node.args) > 0: + input_node = node.args[0] + else: + input_node = node.kwargs["input"] + if ( + input_node.op == "call_function" + and input_node.target == torch.nn.functional.linear + ): + normalized = NormalizedLinearNode(input_node) + input = normalized.get_input() + weight = normalized.get_weight() + bias = normalized.get_bias() + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + linear_transpose, args=(input, weight, bias) + ) + node.replace_all_uses_with(fused_node) + module.graph.erase_node(node) + if len(input_node.users) == 0: + module.graph.erase_node(input_node) + + module.graph.lint() + module.recompile() + return module + + +# Y1 = X * W^T + bias +# Y2 = Y1.permute(0, 2, 1) +# ----> +# Y2 = (W * X^T + bias.unsqueeze(-1))^T +def linear_transpose( + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] +) -> torch.Tensor: + if bias is None: + return torch.matmul(weight, input.transpose(-1, -2)) + return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1) + + +def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.find_nodes( + op="call_function", target=torch.nn.functional.linear + ): + if len(node.args) > 0: + input_node = node.args[0] + else: + input_node = node.kwargs["input"] + if ( + input_node.op == "call_method" + and input_node.target == "permute" + and check_permute(input_node) + ): + normalized = NormalizedLinearNode(node) + if len(input_node.args) > 0: + input = input_node.args[0] + else: + input = input_node.kwargs["input"] + weight = normalized.get_weight() + bias = normalized.get_bias() + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + transpose_linear, args=(input, weight, bias) + ) + node.replace_all_uses_with(fused_node) + module.graph.erase_node(node) + if len(input_node.users) == 0: + module.graph.erase_node(input_node) + + module.graph.lint() + module.recompile() + return module + + +def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in itertools.chain( + module.graph.find_nodes(op="call_function", target=torch.bmm), + module.graph.find_nodes(op="call_function", target=torch.matmul), + ): + normalized = NormalizedMatmulNode(node) + input_A_node = normalized.get_input() + input_B_node = normalized.get_other() + input_A = input_A_node + input_B = input_B_node + Atrans = Btrans = False + if ( + input_A_node.op == "call_method" + and input_A_node.target == "permute" + and check_permute(input_A_node) + ): + Atrans = True + if len(input_A_node.args) > 0: + input_A = input_A_node.args[0] # type: ignore[assignment] + else: + input_A = input_A_node.kwargs["input"] # type: ignore[assignment] + + if ( + input_B_node.op == "call_method" + and input_B_node.target == "permute" + and check_permute(input_B_node) + ): + Btrans = True + if len(input_B_node.args) > 0: + input_B = input_B_node.args[0] # type: ignore[assignment] + else: + input_B = input_B_node.kwargs["input"] # type: ignore[assignment] + + if Atrans or Btrans: + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + transpose_matmul, + args=(input_A, input_B, Atrans, Btrans), + ) + node.replace_all_uses_with(fused_node) + module.graph.erase_node(node) + if Atrans and len(input_A_node.users) == 0: + module.graph.erase_node(input_A_node) + if Btrans and len(input_B_node.users) == 0: + module.graph.erase_node(input_B_node) + + module.graph.lint() + module.recompile() + return module + + +# X1 = X.permute(0, 2, 1) +# Y1 = X1 * W1^T + bias1 +# ----> +# Y2 = X1.transpose(-1, -2) * W1^T + bias1 +def transpose_linear( + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] +) -> torch.Tensor: + if bias is None: + return torch.matmul(input.transpose(-1, -2), weight.t()) + return torch.matmul(input.transpose(-1, -2), weight.t()) + bias + + +def transpose_matmul( + A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool +) -> torch.Tensor: + if Atrans: + A = A.transpose(-1, -2) + if Btrans: + B = B.transpose(-1, -2) + return torch.matmul(A, B) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/quantization.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..6e17e2ff456fc626ea0ff4b1765fe20783242abb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/quantization.py @@ -0,0 +1,2589 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import copy +import functools +import itertools +import math +import operator +from typing import Any, Tuple + +import torch +from torch._dynamo.utils import counters +from torch.fx.experimental.symbolic_shapes import has_free_symbols +from torch.fx.node import map_arg + +from ..lowering import lowerings as L, require_channels_last +from ..pattern_matcher import Arg, CallFunction, filter_nodes, KeywordArg, ListOf, Match +from ..utils import pad_listlike +from .freezing_patterns import register_freezing_graph_pattern +from .post_grad import register_lowering_pattern + + +aten = torch.ops.aten +prims = torch.ops.prims +quantized_decomposed = torch.ops.quantized_decomposed +quantized = torch.ops.quantized + +# Only for per tensor quant since permute may changes the channel idx +_PER_TENSOR_QUANTIZE_OPS = [ + quantized_decomposed.quantize_per_tensor.default, + quantized_decomposed.quantize_per_tensor.tensor, +] + +_VIEW_OPS = [ + aten.transpose.int, + aten.permute.default, + aten.view.default, +] + +""" +The quantization.py file primarily incorporates passes related to quantization fusion +in inductor, includes: +1. Dequant Promotion; +2. Conv/GEMM weight prepack with oneDNN Library; +3. Conv/GEMM quantization fusion with output quant node (if have); +4. Other pointwise operators' quantization fusion like: qmaxpool2d, qcat and more; + +It also involves int8-mixed-fp32 and int8-mixed-bf16 quantization. The main difference +of patterns for int8-mixed-bf16, comparing with int8-mixed-fp32, is +1. There is to(dtype=torch.bfloat16) node at the inputs of activation and weight for Conv/GEMM. +2. There is to(dtype=torch.float32) node at the outputs of Conv/GEMM before inputs to next quant node. +Refer to: https://github.com/pytorch/pytorch/issues/111640 for detail design of int8-mixed-bf16 +quantization. +""" + + +def _get_pattern_output_dtype(match: Match): + """ + Get the pattern's output dtype from node's meta + Assume only 1 output node in this matched pattern. + """ + pattern_output_nodes = match.output_nodes() + assert len(pattern_output_nodes) == 1 + output_node = pattern_output_nodes[0] + assert isinstance(output_node, torch.fx.Node) + output_dtype = output_node.meta["val"].dtype + assert output_dtype in [torch.uint8, torch.float32, torch.bfloat16] + return output_dtype + + +def _may_generate_pattern_with_dtype_convert( + pattern, dtype=Arg(), with_dtype_convert=True, users=1 +): + if with_dtype_convert: + return CallFunction( + prims.convert_element_type.default, + pattern, + dtype, + _users=users, + ) + else: + return pattern + + +def _may_generate_pattern_with_reshape(pattern, reshape_size=Arg(), with_reshape=True): + if with_reshape: + return CallFunction( + torch.ops.aten.reshape.default, + pattern, + reshape_size, + ) + else: + return pattern + + +def _generate_linear_t_pattern( + _dequant_per_channel_pattern, + dtype, +): + assert dtype in [torch.float32, torch.bfloat16] + t_pattern = CallFunction( + aten.permute.default, + _may_generate_pattern_with_dtype_convert( + _dequant_per_channel_pattern, + KeywordArg("autocast_wgt_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("permute_axes"), + ) + return t_pattern + + +def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16): + # only insert to_dtype if is_bf16 is True + computation_call = _may_generate_pattern_with_dtype_convert( + call_fn, dtype=KeywordArg("to_float"), with_dtype_convert=is_bf16, users=users + ) + return unary_fusion(computation_call) + + +def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False): + dequantize_per_tensor_activation_pattern = CallFunction( + quantized_decomposed.dequantize_per_tensor.tensor + if is_tensor_overload + else quantized_decomposed.dequantize_per_tensor.default, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("x_quant_min"), + KeywordArg("x_quant_max"), + KeywordArg("x_dq_dtype"), + ) + return dequantize_per_tensor_activation_pattern + + +dequantize_per_channel_weight_pattern = CallFunction( + quantized_decomposed.dequantize_per_channel.default, + KeywordArg("q_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("w_axis"), + KeywordArg("w_quant_min"), + KeywordArg("w_quant_max"), + KeywordArg("w_dtype"), +) + +dequantize_per_channel_to_bf16_weight_pattern = ( + _may_generate_pattern_with_dtype_convert( + dequantize_per_channel_weight_pattern, + KeywordArg("autocast_wgt_dtype"), + ) +) + +dequantize_per_channel_clone_weight_pattern = CallFunction( + aten.clone.default, + dequantize_per_channel_weight_pattern, + memory_format=KeywordArg("memory_format"), +) + +dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction( + aten.clone.default, + dequantize_per_channel_to_bf16_weight_pattern, + memory_format=KeywordArg("memory_format"), +) + + +def get_dequantize_qconv_pt2e_pattern(users=1): + return CallFunction( + torch.ops.onednn.qconv2d_pointwise.default, + KeywordArg("x"), + KeywordArg("x_scale"), # x_scale + KeywordArg("x_zp"), # x_zp + KeywordArg("packed_weight"), # packed_weight + KeywordArg("w_scale"), # w_scale + KeywordArg("w_zp"), # w_zp + KeywordArg("b"), # bias + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("groups"), + KeywordArg("output_scale"), # output_scale = 1.0 + KeywordArg("output_zero_point"), # output_zero_point = 0 + KeywordArg("output_dtype"), # output_dtype = None + KeywordArg("attr"), # attr = "none" + Arg(), # scalars + Arg(), # algorithm + _users=users, + ) + + +def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors, users=1): + qlinear_op = ( + torch.ops.onednn.qlinear_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.default + ) + return CallFunction( + qlinear_op, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("b"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("postop_name"), + KeywordArg("postop_args"), + KeywordArg("postop_algorithm"), + _users=users, + ) + + +dequantize_accum_pattern = CallFunction( + quantized_decomposed.dequantize_per_tensor.default, + KeywordArg("accum"), + KeywordArg("accum_scale"), + KeywordArg("accum_zp"), + Arg(), + Arg(), + KeywordArg("accum_dq_dtype"), +) + + +def generate_pattern_with_binary( + binary_post_op, + computation_call, + extra_input_pattern, + dtype_convert=False, + swap_inputs=False, +): + binary_pattern = ( + CallFunction( + binary_post_op, + extra_input_pattern, + computation_call, + ) + if swap_inputs + else CallFunction( + binary_post_op, + computation_call, + extra_input_pattern, + ) + ) + return _may_generate_pattern_with_dtype_convert( + binary_pattern, + KeywordArg("convert_dtype_after_inplace_add"), + dtype_convert, + ) + + +def generate_pattern_with_unary(computation_call, unary_post_op): + if unary_post_op is not None: + return CallFunction( + unary_post_op, + computation_call, + ) + return computation_call + + +def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False): + quantized_op_output_pattern_pt2e = CallFunction( + quantized_decomposed.quantize_per_tensor.default, + _may_generate_pattern_with_dtype_convert( + computation_call, + Arg(), + with_dtype_convert, + ), + KeywordArg("o_inv_scale"), + KeywordArg("o_zp"), + KeywordArg("o_qmin"), + KeywordArg("o_qmax"), + KeywordArg("o_dtype"), + ) + return quantized_op_output_pattern_pt2e + + +def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_value): + if kwarg_name in check_node.kwargs: + actual_value = check_node.kwargs[kwarg_name] + return actual_value == expected_value + else: + assert len(check_node.args) >= (args_index + 1) + actual_value = check_node.args[args_index] + return actual_value == expected_value + + +def _is_valid_quantized_conv2d_optimization_pattern(): + def fn(match): + output_dtype = _get_pattern_output_dtype(match) + if output_dtype in [torch.float32, torch.bfloat16]: + # Only keep matched pattern with same output_dtype + qconv_node_after_weight_prepack = filter_nodes( + match.nodes, torch.ops.onednn.qconv2d_pointwise + )[0] + return _check_node_kwarg_arg_value( + qconv_node_after_weight_prepack, "output_dtype", 13, output_dtype + ) + return True + + return fn + + +def _register_quantized_conv_lowering( + pattern, + pass_number, + computation_op, + unary_attr, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_quantized_conv2d_optimization_pattern(), + pass_number=pass_number, + ) + def qconv(match: Match, *args, **kwargs): + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + # Conv Params + b, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + output_dtype = _get_pattern_output_dtype(match) + assert output_dtype in [torch.uint8, torch.float32, torch.bfloat16] + # Output QParams + o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0 + o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0 + assert ( + kwargs["attr"] == "none" + ) # Expected no post op fused in weight prepack phase + if unary_attr.op_name == "hardtanh": + min_value = kwargs.get("min_value") + max_value = kwargs.get("max_value") + unary_attr.scalars_attr = [min_value, max_value] + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ) + counters["inductor"]["qconv2d_unary_matcher_count"] += 1 + counters["inductor"]["qconv2d_unary_matcher_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qconv + + +def _is_valid_quantized_linear_optimization_pattern(): + def fn(match): + output_dtype = _get_pattern_output_dtype(match) + if output_dtype in [torch.float32, torch.bfloat16]: + # Only keep matched pattern with same output_dtype + qlinear_node_after_weight_prepack = filter_nodes( + match.nodes, torch.ops.onednn.qlinear_pointwise + )[0] + return _check_node_kwarg_arg_value( + qlinear_node_after_weight_prepack, "output_dtype", 9, output_dtype + ) + return True + + return fn + + +def _register_quantized_linear_lowering( + pattern, + pass_number, + computation_op, + unary_attr, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_quantized_linear_optimization_pattern(), + pass_number=pass_number, + ) + def qlinear(match: Match, *args, **kwargs): + output_dtype = _get_pattern_output_dtype(match) + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # bias + b = kwargs["b"] if "b" in kwargs else None + + # Output QParams + o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0 + o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0 + assert ( + kwargs["postop_name"] == "none" + ) # Expected no post op fused in weight prepack phase + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + o_inv_scale, + o_zero_point, + output_dtype, + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ) + counters["inductor"]["qlinear_unary_matcher_count"] += 1 + counters["inductor"]["qlinear_unary_matcher_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qlinear + + +def _register_quantized_linear_binary_lowering( + pattern, + pass_number, + computation_op, + binary_unary_attr, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_qlinear_binary_optimization_pattern(), + pass_number=pass_number, + ) + def qlinear_binary(match: Match, *args, **kwargs): + output_dtype = _get_pattern_output_dtype(match) + assert output_dtype is not None + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + x2 = ( + kwargs["accum"] + if binary_unary_attr.binary_op_name == "sum" + else kwargs["other"] + ) + x2_scale = 1.0 + x2_zp = 0 + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + # bias + b = kwargs["b"] if "b" in kwargs else None + # Output QParams + o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0 + o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0 + + x2.realize() + from .mkldnn_fusion import _can_be_inplace + + binary_op_name = binary_unary_attr.binary_op_name + + if binary_op_name == "sum" and not _can_be_inplace(x2): + # When we enable the GEMM Template, the output of QLinear + # will be reshaped from 2D back to 3D if the input is 3D. + # This causes _can_be_inplace(x2) to return False if x2 happens + # to be the output of QLinear in this scenario. + # Change the post op from sum to binary add for this case. + # Refer to test case: + # test_mkldnn_pattern_matcher.py::test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2 + binary_op_name = "add" + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + x2, + b, + o_inv_scale, + o_zero_point, + output_dtype, + x2_scale, + x2_zp, + binary_op_name, + binary_unary_attr.alpha, + binary_unary_attr.unary_op_name, + binary_unary_attr.scalars_attr, + binary_unary_attr.algorithm_attr, + ) + counters["inductor"]["qlinear_binary_matcher_count"] += 1 + counters["inductor"]["qlinear_binary_matcher_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qlinear_binary + + +def _is_valid_qconv_binary_optimization_pattern(): + return _is_valid_quantized_op_binary_optimization_pattern( + torch.ops.onednn.qconv2d_pointwise + ) + + +def _is_valid_qlinear_binary_optimization_pattern(): + return _is_valid_quantized_op_binary_optimization_pattern( + torch.ops.onednn.qlinear_pointwise, + # we don't insert q-dq for extra input due to accuracy issues + extra_input_from_dequant=False, + ) + + +def _is_valid_quantized_op_binary_optimization_pattern( + qop, extra_input_from_dequant=True +): + # Check if it's a valid Binary Pattern for qconv2d and qlinear: + # * qop_pointwise should only has one users + # * If extra_input_from_dequant is True, extra input of binary node should come from dequant pattern + # * the two inputs of binary node should have attribute "meta" and should be tensors + # * the two inputs of binary node should have the same shape + # * All users of the extra input in this pattern should be + # ancestor nodes of the compute node, except for the binary node + # connected to the compute node. + def fn(match): + output_dtype = _get_pattern_output_dtype(match) + compute_node = filter_nodes(match.nodes, qop)[0] + # qop_pointwise should only have one user + if len(compute_node.users) != 1: + return False + binary_node_inputs = next(iter(compute_node.users)).args + assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs" + if output_dtype in [torch.float32, torch.bfloat16]: + extra_input_of_binary_node = None + for arg in binary_node_inputs: + if arg != compute_node: + extra_input_of_binary_node = arg + break + assert extra_input_of_binary_node is not None + # Extra input of binary node comes from dequant pattern + if extra_input_from_dequant and ( + (not isinstance(extra_input_of_binary_node, torch.fx.Node)) + or ( + extra_input_of_binary_node.target + != quantized_decomposed.dequantize_per_tensor.default + ) + ): + return False + + # the two inputs of binary node should have attribute "meta" and should be tensors + if not ( + hasattr(binary_node_inputs[0], "meta") + and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor) # type: ignore[union-attr] + ) or not ( + hasattr(binary_node_inputs[1], "meta") + and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor) # type: ignore[union-attr] + ): + return False + # the two inputs of binary node should have the same shape + if ( + binary_node_inputs[0].meta["val"].size() # type: ignore[union-attr] + != binary_node_inputs[1].meta["val"].size() # type: ignore[union-attr] + ): + return False + + # All users of the extra input in this pattern should be + # ancestor nodes of the compute node, except for the binary node + # connected to the compute node. + + from .mkldnn_fusion import _get_remaining_users + + extra_input_of_pattern = ( + match.kwargs["other"] + if "other" in match.kwargs + else ( + match.kwargs["accum"] + if output_dtype == torch.uint8 or (not extra_input_from_dequant) + else match.kwargs["accum_after_dequant"] + ) + ) + if ( + len(_get_remaining_users(extra_input_of_pattern, compute_node)) > 1 + or extra_input_of_pattern == compute_node.args[0] + ): + return False + return True + + return fn + + +def _register_quantized_conv_binary_lowering( + pattern, + pass_number, + computation_op, + binary_unary_attr, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_qconv_binary_optimization_pattern(), + pass_number=pass_number, + ) + def qconv_binary(match: Match, *args, **kwargs): + output_dtype = _get_pattern_output_dtype(match) + assert output_dtype is not None + x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"] + accum = ( + kwargs["accum"] + if output_dtype == torch.uint8 + else kwargs["accum_after_dequant"] + ) + accum_scale = kwargs["accum_scale"] if output_dtype == torch.uint8 else 1.0 + accum_zp = kwargs["accum_zp"] if output_dtype == torch.uint8 else 0 + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + b, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + # Output QParams + o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0 + o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0 + + accum.realize() + from .mkldnn_fusion import _can_be_inplace + + assert _can_be_inplace( + accum + ), "QConv Binary Inplace Fusion requires accum is not an alias or mutation." + + computation_args = ( + x, + x_scale, + x_zp, + accum, + accum_scale, + accum_zp, + packed_weight, + w_scale, + w_zp, + b, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + binary_unary_attr.binary_op_name, + binary_unary_attr.alpha, + binary_unary_attr.unary_op_name, + binary_unary_attr.scalars_attr, + binary_unary_attr.algorithm_attr, + ) + counters["inductor"]["qconv2d_binary_matcher_count"] += 1 + counters["inductor"]["qconv2d_binary_matcher_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qconv_binary + + +def _register_quantization_unary_fusion(): + from .mkldnn_fusion import ( + _gelu_fusion_1 as _gelu_fusion_erf, + _gelu_fusion_2 as _gelu_fusion_tanh, + _hardswish_fusion, + _hardtanh_fusion, + _silu_fusion, + ) + + class UnaryAttr: + def __init__( + self, op_name: str, scalars_attr=None, algorithm_attr=None + ) -> None: + self.op_name = op_name + self.scalars_attr = scalars_attr if scalars_attr else [] + self.algorithm_attr = algorithm_attr if algorithm_attr else "" + + for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: + # QConv2d + # Priority 1 to match: QConv2d Unary pattern with int8 output + # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly. + # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant + is_bf16 = original_pattern_output_dtype == torch.bfloat16 + conv_unary_replace_patterns = { + UnaryAttr("none", [], ""): generate_pattern_with_output_quant( + get_dequantize_qconv_pt2e_pattern(1), + ), + UnaryAttr("relu", [], ""): generate_pattern_with_output_quant( + generate_pattern_with_unary( + get_dequantize_qconv_pt2e_pattern(1), aten.relu.default + ), + ), + UnaryAttr("hardtanh", [], ""): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _hardtanh_fusion, + get_dequantize_qconv_pt2e_pattern(1), + 1, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + UnaryAttr("hardswish", [], ""): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _hardswish_fusion, + get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2), + 2, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + UnaryAttr("swish", [], ""): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _silu_fusion, + get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2), + 2, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + } + + for unary_attr, patterns in conv_unary_replace_patterns.items(): + # Register qconv2d pattern for ExternKernel Lowering + _register_quantized_conv_lowering( + patterns, + 1, # pass_number + torch.ops.onednn.qconv2d_pointwise, # computation_op + unary_attr, # unary_attr + ) + + # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output + conv_unary_replace_float_out_patterns = { + UnaryAttr("relu", [], ""): generate_pattern_with_unary( + get_dequantize_qconv_pt2e_pattern(1), aten.relu.default + ), + UnaryAttr("hardtanh", [], ""): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _hardtanh_fusion, + get_dequantize_qconv_pt2e_pattern(1), + 1, + is_bf16, + ), + Arg(), + is_bf16, + ), + UnaryAttr("hardswish", [], ""): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _hardswish_fusion, + get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2), + 2, + is_bf16, + ), + Arg(), + is_bf16, + ), + UnaryAttr("swish", [], ""): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _silu_fusion, + get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2), + 2, + is_bf16, + ), + Arg(), + is_bf16, + ), + } + + for unary_attr, patterns in conv_unary_replace_float_out_patterns.items(): + # Register qconv2d pattern for ExternKernel Lowering + _register_quantized_conv_lowering( + patterns, + 2, # pass_number + torch.ops.onednn.qconv2d_pointwise, # computation_op + unary_attr, # unary_attr + ) + + # QLinear + for x_scale_zp_are_tensors in (False, True): + qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) + # Priority 1 to match: QLinear Unary pattern with int8 output + linear_unary_replace_patterns = { + UnaryAttr("none", [], ""): generate_pattern_with_output_quant( + qlinear_pattern, + ), + UnaryAttr("relu", [], ""): generate_pattern_with_output_quant( + generate_pattern_with_unary(qlinear_pattern, aten.relu.default), + ), + UnaryAttr("gelu", [], "none"): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _gelu_fusion_erf, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 2 + ), + 2, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + UnaryAttr("gelu", [], "tanh"): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _gelu_fusion_tanh, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 4 + ), + 4, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + } + + for unary_attr, patterns in linear_unary_replace_patterns.items(): + _register_quantized_linear_lowering( + patterns, + 1, # pass_number + torch.ops.onednn.qlinear_pointwise, # computation_op + unary_attr, # unary_attr + ) + + # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output + linear_unary_replace_float_out_patterns = { + UnaryAttr("relu", [], ""): generate_pattern_with_unary( + qlinear_pattern, aten.relu.default + ), + UnaryAttr("gelu", [], "none"): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _gelu_fusion_erf, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 2 + ), + 2, + is_bf16, + ), + Arg(), + is_bf16, + ), + UnaryAttr("gelu", [], "tanh"): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _gelu_fusion_tanh, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 4 + ), + 4, + is_bf16, + ), + Arg(), + is_bf16, + ), + } + + for unary_attr, patterns in linear_unary_replace_float_out_patterns.items(): + _register_quantized_linear_lowering( + patterns, + 2, # pass_number + torch.ops.onednn.qlinear_pointwise, # computation_op + unary_attr, # unary_attr + ) + + +def _register_quantization_binary_fusion(): + class BinaryUnaryAttr: + def __init__( + self, + binary_op_name: str, + alpha=None, + unary_op_name: str = "none", + scalars_attr=None, + algorithm_attr=None, + ) -> None: + self.binary_op_name = binary_op_name + self.alpha = alpha if alpha else 1.0 + self.unary_op_name = unary_op_name + self.scalars_attr = scalars_attr if scalars_attr else [] + self.algorithm_attr = algorithm_attr if algorithm_attr else "" + + for int8_mixed_bf16_with_inplace_add in [False, True]: + # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output + binary_replace_patterns = { + BinaryUnaryAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_binary( + aten.add.Tensor, + get_dequantize_qconv_pt2e_pattern(1), + dequantize_accum_pattern, + int8_mixed_bf16_with_inplace_add, + ), + ), + BinaryUnaryAttr( + "sum", 1.0, "relu", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_dequantize_qconv_pt2e_pattern(1), + dequantize_accum_pattern, + int8_mixed_bf16_with_inplace_add, + ), + aten.relu.default, + ), + ), + } + + for binary_unary_attr, patterns in binary_replace_patterns.items(): + _register_quantized_conv_binary_lowering( + patterns, + 0, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + + # Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output + binary_replace_float_out_patterns = { + BinaryUnaryAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_dequantize_qconv_pt2e_pattern(1), + KeywordArg("accum_after_dequant"), + int8_mixed_bf16_with_inplace_add, + ), + aten.relu.default, + ), + } + + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + if int8_mixed_bf16_with_inplace_add: + _register_quantized_conv_binary_lowering( + patterns, + 0, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + else: + _register_quantized_conv_binary_lowering( + patterns, + 1, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + + # Priority 3: QConv2d Binary pattern with fp32/bfloat16 output + binary_replace_float_out_patterns = { + BinaryUnaryAttr("sum", 1.0, "none", [], ""): generate_pattern_with_binary( + aten.add.Tensor, + get_dequantize_qconv_pt2e_pattern(1), + KeywordArg("accum_after_dequant"), + int8_mixed_bf16_with_inplace_add, + ), + } + + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_quantized_conv_binary_lowering( + patterns, + 1 if int8_mixed_bf16_with_inplace_add else 2, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + + # QLinear + r""" + Supported linear-binary(-unary) patterns + + linear(X) extra input + \ / + Add + | + Optional(relu) + | + Y + + 1. int8-mixed-fp32 + +---+---------------+-----------+------------------------------+---------+ + | # | Add type | Quant out | Pattern | Post op | + +---+---------------+-----------+------------------------------+---------+ + | 1 | In-/out-place | Yes | linear + fp32 -> (relu) -> q | add | + +---+---------------+-----------+------------------------------+---------+ + | 2 | In-/out-place | No | linear + fp32 -> (relu) | sum | + +---+---------------+-----------+------------------------------+---------+ + + 2. int8-mixed-bf16 + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | # | X2 dtype | Add type | Quant out | Pattern | Post op | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 1 | BF16 | In-/out-place | Yes | linear + bf16 -> (relu) -> q | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 2 | BF16 | In-/out-place | No | linear + bf16 -> (relu) | sum | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 3 | FP32 | Out-place | Yes | linear + fp32 -> (relu) -> q | add | + | | | In-place right| | | | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 4 | FP32 | Out-place | No | linear + fp32 -> (relu) | sum | + | | | In-place right| | | | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 5 | FP32 | In-place left | Yes | linear + fp32 -> to_bf16 -> (relu) -> q | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 6 | FP32 | In-place left | No | linear + fp32 -> to_bf16 -> (relu) | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + + Note + (1) The positions of linear and the extra input can be swapped. + (2) we don't insert q-dq before the extra input of linear-add by recipe. But if q-dq is found at the + extra input, we don't match that pattern because we cannot match all these patterns in 3 passes. + """ + for x_scale_zp_are_tensors in (False, True): + qlinear_binary_op = ( + torch.ops.onednn.qlinear_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.binary + ) + unary_postop_list = ["none", "relu"] + unary_postop_dict = { + "none": None, + "relu": aten.relu.default, + } + convert_dtype_after_binary_list = [False, True] + + # Priority 1 to match: QLinear Binary or Binary-Unary pattern with int8 output + # Covers case (1) of int8-mixed-fp32 and case (1)(3)(5) of int8-mixed-bf16, + # totally 3 patterns (2 are identical) + swap_binary_inputs_list = [False, True] + int8_mixed_bf16_list = [False, True] + combinations = itertools.product( + unary_postop_list, + int8_mixed_bf16_list, + swap_binary_inputs_list, + convert_dtype_after_binary_list, + ) + qlinear_binary_replace_patterns = {} + for unary_op, int8_mixed_bf16, swap_inputs, cvt_dtype_binary in combinations: + if not int8_mixed_bf16 and cvt_dtype_binary: + # No convert node after binary node if dtypes are all fp32 + continue + qlinear_binary_replace_patterns.update( + { + BinaryUnaryAttr( + "add", 1.0, unary_op, [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + # If fp32 extra input is inplace added to bf16 linear output, + # a to_bf16 node is inserted after binary + dtype_convert=cvt_dtype_binary, + swap_inputs=swap_inputs, + ), + unary_postop_dict[unary_op], + ), + ) + } + ) + for binary_unary_attr, patterns in qlinear_binary_replace_patterns.items(): + _register_quantized_linear_binary_lowering( + patterns, + 0, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, # binary_unary_attr + ) + + # Priority 2.1 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output + # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16, + # totally 2 patterns (2 are identical) + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + BinaryUnaryAttr( + "sum", 1.0, "relu", [], "" + ): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("accum"), + dtype_convert=False, + swap_inputs=swap_binary_inputs, + ), + aten.relu.default, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_quantized_linear_binary_lowering( + patterns, + 1, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + # Priority 2.2 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output + # Covers case (6) of int8-mixed-bf16 + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + BinaryUnaryAttr( + "add", 1.0, "relu", [], "" + ): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + dtype_convert=True, + swap_inputs=swap_binary_inputs, + ), + aten.relu.default, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_quantized_linear_binary_lowering( + patterns, + 1, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + + # Priority 3.1: QLinear Binary pattern with fp32/bfloat16 output + # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16, + # totally 2 patterns (2 are identical) + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + BinaryUnaryAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("accum"), + dtype_convert=False, + swap_inputs=swap_binary_inputs, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_quantized_linear_binary_lowering( + patterns, + 2, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + # Priority 3.2: QLinear Binary pattern with fp32/bfloat16 output + # Covers (6) of int8-mixed-bf16 + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + BinaryUnaryAttr( + "add", 1.0, "none", [], "" + ): generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + dtype_convert=True, + swap_inputs=swap_binary_inputs, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_quantized_linear_binary_lowering( + patterns, + 2, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + + +def _is_valid_quantized_maxpool2d_optimization_pattern(): + def fn(match): + # Only match the pattern which max_pool2d_with_indices returns value + # instead of indices. + get_item_node = filter_nodes(match.nodes, operator.getitem)[0] + return get_item_node.args[1] == 0 + + return fn + + +def _register_quantized_maxpool2d_lowering( + pattern, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_quantized_maxpool2d_optimization_pattern(), + ) + def qmaxpool2d(match: Match, *args, **kwargs): + x = kwargs["x"] + kernel_size = kwargs["kernel_size"] + stride = kwargs["stride"] if ("stride" in kwargs) else None + padding = kwargs["padding"] if ("padding" in kwargs) else 0 + dilation = kwargs["dilation"] if ("dilation" in kwargs) else 1 + ceil_mode = kwargs["ceil_mode"] if ("ceil_mode" in kwargs) else False + + if padding == 0: + padding = [0, 0] + if dilation == 1: + dilation = [1, 1] + if not stride: + stride = kernel_size + kernel_size = pad_listlike(kernel_size, 2) + stride = pad_listlike(stride, 2) + padding = pad_listlike(padding, 2) + dilation = pad_listlike(dilation, 2) + + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(dilation) == 2 + + computation_args = ( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + ) + computation_args, _ = require_channels_last(computation_op, *computation_args) + counters["inductor"]["qmaxpool2d_matcher_count"] += 1 + counters["inductor"]["qmaxpool2d_matcher_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qmaxpool2d + + +def _register_quantization_maxpool2d(): + # Currently, the default parameters are not in FX Graph generated by Dynamo export. + # So, if user defines nn.MaxPool2d with different assignment of default parameter, + # it will generate graph with different number of input nodes and hence + # different pattern to be matched. + # Refer to the issue: https://github.com/pytorch/pytorch/issues/105901 + max_pool2d_args_list = [ + [ + KeywordArg("stride"), + ], + [ + KeywordArg("stride"), + KeywordArg("padding"), + ], + [ + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + ], + [ + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("ceil_mode"), + ], + ] + for max_pool2d_args in max_pool2d_args_list: + dequantize_maxpool2d_pattern = CallFunction( + aten.max_pool2d_with_indices.default, + get_dequantize_per_tensor_activation_pattern(), + KeywordArg("kernel_size"), + *max_pool2d_args, + ) + dequantize_lowmem_maxpool2d_pattern = CallFunction( + prims._low_memory_max_pool2d_with_offsets.default, + get_dequantize_per_tensor_activation_pattern(), + KeywordArg("kernel_size"), + *max_pool2d_args, + KeywordArg("offset_dtype"), + ) + dequantize_maxpool2d_get_item_pattern = CallFunction( + operator.getitem, + dequantize_maxpool2d_pattern, + Arg(), + ) + dequantize_lowmem_maxpool2d_get_item_pattern = CallFunction( + operator.getitem, + dequantize_lowmem_maxpool2d_pattern, + Arg(), + ) + _register_quantized_maxpool2d_lowering( + generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern), + quantized.max_pool2d.default, + ) + _register_quantized_maxpool2d_lowering( + generate_pattern_with_output_quant( + dequantize_lowmem_maxpool2d_get_item_pattern + ), + quantized.max_pool2d.default, + ) + + +def _is_input_output_same_scale_zp(check_node): + def fn(match): + # Ensure all the inputs and output has same scale and zero point + # Step 1: Check inputs/output zero point + # Get dequant nodes at input + dequant_nodes = filter_nodes( + match.nodes, quantized_decomposed.dequantize_per_tensor.default + ) + zero_points = [node.args[2] for node in dequant_nodes] + # Get quant nodes at output + quant_nodes = filter_nodes( + match.nodes, quantized_decomposed.quantize_per_tensor.default + ) + assert len(quant_nodes) == 1, "expect only 1 add node at output quant pattern" + zero_points.append(quant_nodes[0].args[2]) + if not all(zero_point == zero_points[0] for zero_point in zero_points): + return False + + # Step 2: Check inputs/output scale + scales = [node.args[1] for node in dequant_nodes] + scales.append(quant_nodes[0].args[1]) + if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales): # type: ignore[arg-type] + return False + + return True + + return fn + + +def _register_quantized_cat_lowering( + pattern, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_input_output_same_scale_zp(aten.cat.default), + ) + def qcat(match: Match, inputs, dim, **kwargs): + # inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...] + uint8_inputs = [input[0] for input in inputs] + counters["inductor"]["qcat_matcher_count"] += 1 + counters["inductor"]["qcat_matcher_nodes"] += len(match.nodes) + return L[computation_op](uint8_inputs, dim) + + return qcat + + +_raw_dequantize_per_tensor_activation_pattern = CallFunction( + quantized_decomposed.dequantize_per_tensor.default, + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), +) + + +def _register_quantization_cat(): + dequantize_cat_pattern = CallFunction( + aten.cat.default, + ListOf(_raw_dequantize_per_tensor_activation_pattern), + KeywordArg("dim"), + ) + _register_quantized_cat_lowering( + generate_pattern_with_output_quant(dequantize_cat_pattern), + aten.cat, + ) + + +def _register_quantized_reshape_lowering( + pattern, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_input_output_same_scale_zp(aten.reshape.default), + ) + def qreshape(match: Match, *args, **kwargs): + qx = kwargs["x"] + shape = kwargs["shape"] + counters["inductor"]["qreshape_matcher_count"] += 1 + counters["inductor"]["qreshape_matcher_nodes"] += len(match.nodes) + return L[computation_op](qx, shape) + + return qreshape + + +def _register_quantization_reshape(): + dequantize_reshape_pattern = CallFunction( + torch.ops.aten.reshape.default, + get_dequantize_per_tensor_activation_pattern(), + KeywordArg("shape"), + ) + _register_quantized_reshape_lowering( + generate_pattern_with_output_quant(dequantize_reshape_pattern), + aten.reshape, + ) + + +def _is_valid_woq_optimization_pattern(): + def fn(match): + assert all(k in match.kwargs for k in ("x", "weight", "scales")) + x = match.kwargs["x"].meta["val"] + weight = match.kwargs["weight"].meta["val"] + scales = match.kwargs["scales"].meta["val"] + return ( + # For now, we only support woq mm kernels + # with x.type=bfloat16 and w.type=int8 + x.dtype == torch.bfloat16 + and weight.dtype == torch.int8 + and scales.dtype == torch.bfloat16 + # _weight_int8pack_mm kernel only supports cpu now + # TODO: add cuda kernel support instead of calling mul+sum + and x.device.type == "cpu" + and x.device == weight.device + and x.device == scales.device + ) + + return fn + + +def _register_woq_lowering(pattern, computation_woq, computation_reshape): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_woq_optimization_pattern(), + ) + def woq(match: Match, *args, **kwargs): + x = kwargs["x"] + weight = kwargs["weight"] + scales = kwargs["scales"] + counters["inductor"]["woq_matcher_count"] += 1 + counters["inductor"]["woq_matcher_nodes"] += len(match.nodes) + out_features = weight.get_size()[0] + origin_x_size = x.get_size() + x_shape = [-1, origin_x_size[-1]] + out_shape = origin_x_size[:-1] + [ + out_features, + ] + func1 = L[computation_reshape](x, x_shape) + func2 = L[computation_woq](func1, weight, scales) + return L[computation_reshape](func2, out_shape) + + return woq + + +def _register_woq_mm_int8_pattern1(): + # F.linear(x, weight.to(dtype=x.dtype)) * scales + # case of dispatching to mm, with x reshape + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.reshape.default, + CallFunction( + aten.mm.default, + CallFunction(aten.reshape.default, KeywordArg("x"), Arg()), + CallFunction( + aten.permute.default, + CallFunction( + prims.convert_element_type.default, KeywordArg("weight"), Arg() + ), + Arg(), + ), + ), + Arg(), + ), + KeywordArg("scales"), + ) + _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) + + +def _register_woq_mm_int8_pattern2(): + # F.linear(x, weight.to(dtype=x.dtype)) * scales + # case of dispatching to mm, w/o x reshape + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.reshape.default, + CallFunction( + aten.mm.default, + KeywordArg("x"), + CallFunction( + aten.permute.default, + CallFunction( + prims.convert_element_type.default, KeywordArg("weight"), Arg() + ), + Arg(), + ), + ), + Arg(), + ), + KeywordArg("scales"), + ) + _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) + + +def _register_woq_mm_int8_pattern3(): + # F.linear(x, weight.to(dtype=x.dtype)) * scales + # case of dispatching to bmm + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.bmm.default, + CallFunction(aten.expand.default, KeywordArg("x"), Arg()), + CallFunction( + aten.expand.default, + CallFunction( + aten.permute.default, + CallFunction( + prims.convert_element_type.default, KeywordArg("weight"), Arg() + ), + Arg(), + ), + Arg(), + ), + ), + KeywordArg("scales"), + ) + _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) + + +def _register_quantization_lowerings(): + _register_quantization_unary_fusion() + _register_quantization_binary_fusion() + _register_quantization_maxpool2d() + _register_quantization_cat() + _register_quantization_reshape() + + +def _register_woq_lowerings(): + _register_woq_mm_int8_pattern1() + _register_woq_mm_int8_pattern2() + _register_woq_mm_int8_pattern3() + + +def _is_valid_dequant_promotion_pattern(dtype=torch.float32): + def _inner(match): + assert dtype in [torch.float32, torch.bfloat16] + dequant_pattern_end_node = match.output_node() + if dequant_pattern_end_node.target not in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + prims.convert_element_type.default, + aten.reshape.default, + ]: + return False + + if dequant_pattern_end_node.target is aten.reshape.default: + dequant_node = ( + dequant_pattern_end_node.args[ + 0 + ] # pattern: linear <- reshape <- dequant + if dtype == torch.float32 + else dequant_pattern_end_node.args[0].args[ + 0 + ] # pattern: linear <- reshape <- to_bf16 <- dequant + ) + else: + dequant_node = ( + dequant_pattern_end_node # pattern: linear <- dequant + if dtype == torch.float32 + else dequant_pattern_end_node.args[ + 0 + ] # pattern: linear <- to_bf16 <- dequant + ) + + if ( + dequant_node.target + in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ] + and len(list(dequant_pattern_end_node.users)) > 1 + ): + # If dequant pattern has more than 1 users, then do dequant promoted + return True + return False + + return _inner + + +def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_promotion_pattern(dtype), + pass_number=pass_number, + ) + def dequant_promotion(match: Match, *args, **kwargs): + # Dequant_promotion will transform + # graph 1: + # quant + # + - - - | - - - + + # | dequant | + # | / \ | + # | node1 node2 | + # + - | - - - | - + + # quant quant + # into: + # graph 2: + # quant + # + - - / - \ - - + + # |dequant dequant| + # | | | | + # | node1 node2 | + # + - | - - - | - + + # quant quant + # In graph 1, the dequant node is shared by node1 and node2, + # as a result, neither node1 nor node2 could form an int8 + # fusion pattern. + # After this transformation, the graph 2 could hit the int8 + # fusion pattern: dequant-node-quant, respectively for + # node1 and node2. + assert dtype in [torch.float32, torch.bfloat16] + + def clone_to_new_node(graph, source_node, user_node): + # Clone the source_node to a new node + # Replace user_node's input from source_node to new_node + assert ( + source_node.op == "call_function" + ), "clone_to_new_node only support node.op call_function" + with graph.inserting_before(user_node): + new_node = graph.call_function( + source_node.target, + args=source_node.args, + kwargs=source_node.kwargs, + ) + new_node.meta = copy.copy(source_node.meta) + user_node.replace_input_with(source_node, new_node) + return new_node + + # Find the start node and end node of a dequant pattern + # * End node should be the match.output_node() + # * Start node should be the node of dequantize_per_tensor + dequant_pattern_end_node = match.output_node() + assert dequant_pattern_end_node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + prims.convert_element_type.default, + aten.reshape.default, + ] + + # For a dequant pattern, we should expect see the node list as: + # * OPT(aten.reshape.default) + # * OPT(prims.convert_element_type.default) (to_bf16) + # * dequantize_per_tensor + def _find_first_node_in_dequant_pattern(_node): + if _node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ]: + # For a dequant pattern, we expect the start node is a dequantize_per_tensor node + return _node + else: + assert ( + len(_node.args) >= 1 + ), "In in dequant pattern, each node should have more than 1 arg." + return _find_first_node_in_dequant_pattern(_node.args[0]) + + dequant_pattern_start_node = _find_first_node_in_dequant_pattern( + dequant_pattern_end_node + ) + + assert dequant_pattern_start_node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ] + + # Clone the dequant pattern for each user node + graph = match.graph + user_node_list = list(dequant_pattern_end_node.users) + for user_node in user_node_list[1:]: + _source_node = dequant_pattern_end_node + _user_node = user_node + while _source_node != dequant_pattern_start_node.args[0]: + _user_node = clone_to_new_node(graph, _source_node, _user_node) + _source_node = _source_node.args[0] # type: ignore[assignment] + + counters["inductor"]["dequant_promotion_matcher_count"] += 1 + counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes) + + +def _is_valid_dequant_conv2d_pattern(dtype): + def _inner(match): + # Here we do some further check to ensure: + # 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now. + # 2. The dequant pattern has only 1 user of conv2d node. + # If these conditions don't meet, we will not + # insert weight prepack node into the matched pattern. + conv_node = match.output_node() + assert conv_node.target is aten.convolution.default + input_meta_value = conv_node.args[0].meta.get("val") + weight_meta_value = conv_node.args[1].meta.get("val") + for meta_value in [input_meta_value, weight_meta_value]: + if ( + meta_value is None + or meta_value.device.type != "cpu" + or meta_value.dim() != 4 + ): + # Only support conv2d now + return False + + assert dtype in [torch.float32, torch.bfloat16] + + if dtype == torch.float32: + dequant_node = conv_node.args[0] + else: + convert_to_bf16 = conv_node.args[0] + dequant_node = convert_to_bf16.args[0] + + if len(list(dequant_node.users)) != 1: + # Ensure the dequant pattern only has 1 user + # since we will delete the dequant pattern here + return False + return True + + return _inner + + +def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_conv2d_pattern(dtype), + pass_number=pass_number, + ) + def qconv_weight_prepack(match: Match, *args, **kwargs): + """ + Match the pattern: + int8 activation + | + dequant_per_tensor + | + Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight + + Insert weight prepack node and change the pattern to: + int8 activation + | + onednn.qconv2d_pointwise <- onednn.qconv_prepack <- int8_weight + """ + assert dtype in [torch.float32, torch.bfloat16] + conv_node = match.output_node() + assert conv_node.target is aten.convolution.default + if dtype == torch.float32: + dequant_node = conv_node.args[0] + else: + convert_to_bf16 = conv_node.args[0] + dequant_node = convert_to_bf16.args[0] # type: ignore[union-attr] + has_clone_to_channel_last_node_in_pattern = ( + conv_node.args[1].target is aten.clone.default # type: ignore[union-attr] + ) + clone_node = ( + conv_node.args[1] if has_clone_to_channel_last_node_in_pattern else None + ) + + if dtype == torch.float32: + dequant_per_channel = ( + clone_node.args[0] # type: ignore[union-attr] + if has_clone_to_channel_last_node_in_pattern + else conv_node.args[1] + ) + else: + weight_to_bf16_node = ( + clone_node.args[0] # type: ignore[union-attr] + if has_clone_to_channel_last_node_in_pattern + else conv_node.args[1] + ) + dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr] + + assert ( + dequant_per_channel.target # type: ignore[union-attr] + is quantized_decomposed.dequantize_per_channel.default + ) + + # Activation QParams + qx, x_zp, x_scale = ( + kwargs["x"], + kwargs["x_zp"], + kwargs["x_scale"], + ) + + # Weight QParams + qw, w_scale, w_zp = ( + kwargs["q_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # Conv Params + bias, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + + x_shape = qx.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(conv_node): + # Insert weight prepack node and the QConv node + packed_weight_inputs = ( + qw, + w_scale, + x_scale, + x_zp, + stride, + padding, + dilation, + groups, + x_shape, + ) + packed_weight_op = torch.ops.onednn.qconv_prepack + prepack_weight_node = graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + new_args: Tuple[Any, ...] = ( + qx, + x_scale, + x_zp, + prepack_weight_node, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # attr + [], # scalars + "", # algorithm + ) + new_conv_node = graph.call_function( + torch.ops.onednn.qconv2d_pointwise.default, args=new_args + ) + conv_node.replace_all_uses_with(new_conv_node) + new_conv_node.meta.update(conv_node.meta) + + # Erase the original conv node + graph.erase_node(conv_node) + # Erase the dequant pattern + if dtype == torch.bfloat16: + graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined, arg-type] + graph.erase_node(dequant_node) # type: ignore[arg-type] + # Erase the dequant per channel pattern + if clone_node is not None: + graph.erase_node(clone_node) # type: ignore[arg-type] + if dtype == torch.bfloat16: + graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type] + graph.erase_node(dequant_per_channel) # type: ignore[arg-type] + counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +def _generate_dequant_convolution_node_pattern( + _dequant_per_channel_pattern, dtype=torch.float32 +): + assert dtype in [torch.float32, torch.bfloat16] + dequant_convolution_node_pattern = CallFunction( + aten.convolution.default, + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + _dequant_per_channel_pattern, + KeywordArg("b"), + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("is_transposed"), + KeywordArg("out_padding"), + KeywordArg("groups"), + ) + return dequant_convolution_node_pattern + + +def _generate_qconv_weight_prepack_patterns(dtype=torch.float32): + assert dtype in [torch.float32, torch.bfloat16] + return ( + _generate_dequant_convolution_node_pattern( + dequantize_per_channel_weight_pattern + if dtype == torch.float32 + else dequantize_per_channel_to_bf16_weight_pattern, + dtype, + ), + # There is another pattern due to the pass of convert_conv_weights_to_channels_last + # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362. + # Depend on some heuristics, it may or may not insert to(channel_last) node + # between convolution and dequant_per_channel node + _generate_dequant_convolution_node_pattern( + dequantize_per_channel_clone_weight_pattern + if dtype == torch.float32 + else dequantize_per_channel_to_bf16_clone_weight_pattern, + dtype, + ), + ) + + +def _get_linear_node(match, input_dim_exceeds_two, input_contiguous): + output_reshape_node = None + if input_dim_exceeds_two: + if input_contiguous: + output_reshape_node = match.output_node() + assert output_reshape_node.target is aten.reshape.default + linear_node = output_reshape_node.args[0] + else: + linear_nodes = filter_nodes(match.nodes, aten.bmm.default) + assert len(linear_nodes) == 1 + linear_node = linear_nodes[0] + else: + linear_node = match.output_node() + + assert linear_node.target in ( + aten.addmm.default, + aten.mm.default, + aten.bmm.default, + ) + return linear_node, output_reshape_node + + +def _get_linear_dq_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous +): + act_reshape_node = None + activation_to_bf16_node = None + act_expand_node = None + if input_dim_exceeds_two: + if input_contiguous: + act_reshape_node = linear_node.args[input_index] + assert act_reshape_node.target is aten.reshape.default + if dtype == torch.float32: + # pattern: linear -> reshape -> dequant + dequant_node = act_reshape_node.args[0] + else: + # pattern: linear -> reshape -> to_bf16 -> dequant + activation_to_bf16_node = act_reshape_node.args[0] + dequant_node = activation_to_bf16_node.args[0] + else: + # bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous + act_expand_node = linear_node.args[input_index] + assert act_expand_node.target is aten.expand.default + if dtype == torch.float32: + dequant_node = act_expand_node.args[0] + else: + activation_to_bf16_node = act_expand_node.args[0] + dequant_node = activation_to_bf16_node.args[0] + else: + if dtype == torch.float32: + # pattern: linear -> dequant + dequant_node = linear_node.args[input_index] + else: + # pattern: linear -> to_bf16 -> dequant + activation_to_bf16_node = linear_node.args[input_index] + dequant_node = activation_to_bf16_node.args[0] + return dequant_node, act_reshape_node, activation_to_bf16_node, act_expand_node + + +def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contiguous): + def _inner(match): + # Check dequant pattern has only 1 user. + ( + linear_node, + _, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + + input_index = 1 if linear_node.target is aten.addmm.default else 0 + assert dtype in [torch.float32, torch.bfloat16] + ( + dequant_node, + _, + _, + _, + ) = _get_linear_dq_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + ) + + assert dequant_node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ] + + if len(list(dequant_node.users)) != 1: + # Ensure the dequant pattern only has 1 user + # since we will delete the dequant pattern here + return False + + # Extra check for bmm pattern + if input_dim_exceeds_two and not input_contiguous: + # Check for act + # Act expand size should be exactly same as act size + act_expand_size = match.kwargs["act_expand_size"] + act_node = match.kwargs["x"] + if not ( + hasattr(act_node, "meta") + and isinstance(act_node.meta.get("val", None), torch.Tensor) + and (act_node.meta["val"].size() == torch.Size(act_expand_size)) + ): + return False + + # Check for wgt + # wgt permute dims should be [1, 0] + wgt_permute_dims = match.kwargs["permute_axes"] + if wgt_permute_dims != [1, 0]: + return False + + # Check below wgt size items: + # wgt before expand should with dim 2 + # Expand size should with dim 3 + # Expand size[0] should same as act size[0] + # Expand size[1] should same as wgt size[1] + # Expand size[2] should same as wgt size[0] + qweight_node = match.kwargs["q_weight"] + wgt_expand_size = match.kwargs["wgt_expand_size"] + if not ( + hasattr(qweight_node, "meta") + and isinstance(qweight_node.meta.get("val", None), torch.Tensor) + and len(qweight_node.meta["val"].size()) == 2 + and len(wgt_expand_size) == 3 + and wgt_expand_size[0] == act_node.meta["val"].size()[0] + and wgt_expand_size[1] == qweight_node.meta["val"].size()[1] + and wgt_expand_size[2] == qweight_node.meta["val"].size()[0] + ): + return False + + return True + + return _inner + + +def _register_qlinear_weight_prepack_pass( + pattern, + pass_number, + dtype=torch.float32, + input_dim_exceeds_two=False, + input_contiguous=True, +): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_linear_pattern( + dtype, input_dim_exceeds_two, input_contiguous + ), + pass_number=pass_number, + ) + def qlinear_weight_prepack(match: Match, *args, **kwargs): + """ + Match the pattern: + int8 activation + | + dequant_per_tensor + | + mm/addmm <- t <- dequant_per_channel <- int8_weight + + Insert weight prepack node and change the pattern to: + int8 activation + | + onednn.qlinear_pointwise <- onednn.qlinear_prepack <- int8_weight + """ + assert dtype in [torch.float32, torch.bfloat16] + ( + linear_node, + output_reshape_node, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + input_index = 1 if linear_node.target is aten.addmm.default else 0 + weight_index = input_index + 1 + + ( + dequant_node, + act_reshape_node, + activation_to_bf16_node, + act_expand_node, + ) = _get_linear_dq_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + ) + + if input_dim_exceeds_two and not input_contiguous: + wgt_expand_node = linear_node.args[weight_index] + assert wgt_expand_node.target is aten.expand.default + t_node = wgt_expand_node.args[0] + else: + t_node = linear_node.args[weight_index] + + if dtype == torch.float32: + dequant_per_channel = t_node.args[0] + else: + weight_to_bf16_node = t_node.args[0] + dequant_per_channel = weight_to_bf16_node.args[0] + assert ( + dequant_per_channel.target + is quantized_decomposed.dequantize_per_channel.default + ) + + # Activation QParams + qx, x_zp, x_scale = ( + kwargs["x"], + kwargs["x_zp"], + kwargs["x_scale"], + ) + + # Weight QParams + qw, w_scale, w_zp = ( + kwargs["q_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # Params + bias = kwargs["b"] if "b" in kwargs else None + + x_shape = qx.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(linear_node): + # Insert weight prepack node and the qlinear node + packed_weight_inputs = ( + qw, + x_shape, + ) + packed_weight_op = torch.ops.onednn.qlinear_prepack + prepack_weight_node = graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + new_args: Tuple[Any, ...] = ( + qx, + x_scale, + x_zp, + prepack_weight_node, + w_scale, + w_zp, + bias, + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # post op name + [], # post op args + "", # post op algorithm + ) + Node = torch.fx.node.Node + if isinstance(x_scale, Node) and isinstance(x_zp, Node): + new_linear_node = graph.call_function( + torch.ops.onednn.qlinear_pointwise.tensor, args=new_args + ) + else: + new_linear_node = graph.call_function( + torch.ops.onednn.qlinear_pointwise.default, args=new_args + ) + if input_dim_exceeds_two: + if input_contiguous: + output_reshape_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(output_reshape_node.meta) + else: + if bias: + output_add_node_for_bias = match.output_node() + assert output_add_node_for_bias.target is aten.add.Tensor + output_add_node_for_bias.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(output_add_node_for_bias.meta) + else: + linear_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(linear_node.meta) + else: + linear_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(linear_node.meta) + + # Erase the original linear node + if input_dim_exceeds_two: + if input_contiguous: + graph.erase_node(output_reshape_node) + elif not input_contiguous and bias: + graph.erase_node(output_add_node_for_bias) # type: ignore[possibly-undefined] + graph.erase_node(linear_node) + if input_dim_exceeds_two: + if input_contiguous: + graph.erase_node(act_reshape_node) + else: + graph.erase_node(act_expand_node) + graph.erase_node(wgt_expand_node) # type: ignore[possibly-undefined] + if dtype == torch.bfloat16: + graph.erase_node(activation_to_bf16_node) + # Erase the dequant pattern + graph.erase_node(dequant_node) + # Erase the dequant per channel pattern + graph.erase_node(t_node) + if dtype == torch.bfloat16: + graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] + graph.erase_node(dequant_per_channel) + + counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +def _generate_dequant_linear_node_pattern( + _dequant_per_channel_pattern, + dtype=torch.float32, + input_dim_exceeds_two=False, + is_tensor_overload=False, +): + assert dtype in [torch.float32, torch.bfloat16] + t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) + dequant_linear_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.addmm.default, + KeywordArg("b"), + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + dequant_linear_no_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.mm.default, + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern + + +def _generate_dequant_bmm_node_pattern( + _dequant_per_channel_pattern, + dtype=torch.float32, + with_bias=False, + is_tensor_overload=False, +): + # When activation of linear dim exceed 2 and not contiguous + t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) + + assert dtype in [torch.float32, torch.bfloat16] + dequant_bmm_pattern = CallFunction( + aten.bmm.default, + CallFunction( + aten.expand.default, + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_expand_size"), + ), + CallFunction( + aten.expand.default, + t_pattern, + KeywordArg("wgt_expand_size"), + ), + ) + + def _generate_pattern_with_output_add(_dequant_bmm_pattern, _with_bias): + if _with_bias: + return CallFunction( + aten.add.Tensor, + _dequant_bmm_pattern, + KeywordArg("b"), + ) + else: + return _dequant_bmm_pattern + + return _generate_pattern_with_output_add(dequant_bmm_pattern, with_bias) + + +def _generate_qlinear_weight_prepack_patterns( + dtype=torch.float32, + input_dim_exceeds_two=False, + input_contiguous=True, + with_bias=False, + is_tensor_overload=False, +): + if input_dim_exceeds_two and not input_contiguous: + return _generate_dequant_bmm_node_pattern( + dequantize_per_channel_weight_pattern, + dtype, + with_bias, + is_tensor_overload, + ) + else: + return _generate_dequant_linear_node_pattern( + dequantize_per_channel_weight_pattern, + dtype, + input_dim_exceeds_two, + is_tensor_overload, + ) + + +def _register_dequant_promotion(): + dequant_pattern_cases = itertools.product( + [torch.float32, torch.bfloat16], [True, False], [True, False] + ) + for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases: + # 4 dequantization patterns will be matched based on the dtype and input dimension size. + # Case 1: int8-mixed-fp32, input dim size is 2 + # Case 2: int8-mixed-fp32, input dim size exceeds 2 + # Case 3: int8-mixed-bf16, input dim size is 2 + # Case 4: int8-mixed-bf16, input dim size exceeds 2 + # quant + # + - - - - | - - - - + + # | dequant | + # | | | + # | OPT(to_bf16) | + # | | | + # | OPT(reshape) | + # | / \ | + # | node1 node2 | + # + - - | - - - | - - + + # OPT(reshape) OPT(reshape) + # + - - | - - - | - - + + # OPT(to_fp32) OPT(to_fp32) + # + - - | - - - | - - + + # quant quant + _register_dequant_promotion_pass( + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern( + is_tensor_overload=is_tensor_overload + ), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + with_reshape=input_dim_exceeds_two, + ), + pass_number=0, + dtype=dtype, + ) # pass_number=0 to run before weight prepack + + +def _register_qconv_weight_prepack(): + for dtype in [torch.float32, torch.bfloat16]: + weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype) + for weight_prepack_pattern in weight_prepack_patterns: + # Register to pass_number 1, so we can do dequant promotion in pass_number 0. + _register_qconv_weight_prepack_pass( + weight_prepack_pattern, pass_number=1, dtype=dtype + ) + + +def _register_qlinear_weight_prepack(): + # 6 Linear related patterns will be matched based on the dtype, input dimension size and input contiguous. + # Then convert the pattern into a QLinear node with int8_fp32/bf16. + # Case 1: int8-mixed-fp32, input dim size is 2 + # Case 2: int8-mixed-fp32, input dim size exceeds 2 and contiguous + # Case 3: int8-mixed-bf16, input dim size is 2 + # Case 4: int8-mixed-bf16, input dim size exceeds 2 and contiguous + + # + - - - - | - - - - - - | - - - - - + + # | dq_per_tensor dq_per_channel | + # | | | | + # | OPT(to_bf16) OPT(to_bf16) | + # | | | | + # | OPT(reshape) permute | + # | \ / | + # | addmm/mm | + # | | | + # | OPT(reshape) | + + # Case 5: int8-mixed-fp32, input dim size exceeds 2 and not contiguous + # Case 6: int8-mixed-bf16, input dim size exceeds 2 and not contiguous + + # + - - - - | - - - - - - | - - - - - + + # | dq_per_tensor dq_per_channel | + # | | | | + # | OPT(to_bf16) OPT(to_bf16) | + # | | | | + # | expand permute | + # | \ | | + # | expand | + # | / | + # | bmm | + # | | | + # | OPT(add) | + + linear_weight_prepack_cases = itertools.product( + [torch.float32, torch.bfloat16], [True, False], [True, False] + ) + + # Step 1: register patterns from mm and addmm + for dtype, input_dim_exceeds_two, is_tensor_overload in linear_weight_prepack_cases: + weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns( + dtype, + input_dim_exceeds_two, + is_tensor_overload=is_tensor_overload, + ) + for weight_prepack_pattern in weight_prepack_patterns: + # Register to pass_number 1, so we can do dequant promotion in pass_number 0. + _register_qlinear_weight_prepack_pass( + weight_prepack_pattern, + pass_number=1, + dtype=dtype, + input_dim_exceeds_two=input_dim_exceeds_two, + ) + + # Step 2: register patterns from bmm + # Linear might be decomposed into bmm when input dim exceeds 2 and not contiguous + # refer to: + # https://github.com/pytorch/pytorch/blob/ + # 80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968 + # in this case, we can convert it back to qlinear + for dtype, with_bias, is_tensor_overload in itertools.product( + [torch.float32, torch.bfloat16], [True, False], [True, False] + ): + bmm_pattern = _generate_qlinear_weight_prepack_patterns( + dtype=dtype, + input_dim_exceeds_two=True, + input_contiguous=False, + with_bias=with_bias, + is_tensor_overload=is_tensor_overload, + ) + _register_qlinear_weight_prepack_pass( + bmm_pattern, + pass_number=1 + if with_bias + else 2, # if with_bias, there is an output add, so we should try to match it firstly + dtype=dtype, + input_dim_exceeds_two=True, + input_contiguous=False, + ) + + +@functools.lru_cache(None) +def _register_quantization_weight_pack_pass(): + # Step 1: Dequant promotion for int8-mixed-fp32/bf16 + _register_dequant_promotion() + + # Step 2: QConv weight prepack + _register_qconv_weight_prepack() + + # Step 3: QLinear weight prepack + _register_qlinear_weight_prepack() + + +def quant_lift_up(graph_module: torch.fx.GraphModule): + """ + Lift up the quant node before view like nodes. It can benefit performance + of Attention like block. For example, we have the pattern as: + + DQ + DQ LINEAR + LINEAR VIEW + VIEW PERMUTE + PERMUTE TRANSPOSE + Q Q + DQ DQ + Matmul + DIV + ADD + SOFTMAX + + We want to lift up the the quant nodes from matmul before view like nodes + as the output of Linear node. + + DQ + DQ LINEAR + LINEAR Q + Q VIEW + VIEW PERMUTE + PERMUTE TRANSPOSE + DQ DQ + Matmul + DIV + ADD + SOFTMAX + + It produces a DQ->LINEAR->Q pattern which can be fused by backend. + """ + + def is_view_op(node): + return node.op == "call_function" and node.target in _VIEW_OPS + + for node in graph_module.graph.nodes: + # Leslie: Here we verify that the quant node has exactly + # one input FX node, with constant scalar value for scale and zero point. + # For the case input of quant node has more than one input FX nodes, + # extend the implementation to lift up all the connected nodes + # before the view nodes to keep the topological order. + if ( + node.op == "call_function" + and node.target in _PER_TENSOR_QUANTIZE_OPS + and len(node.all_input_nodes) == 1 + and is_view_op(node.all_input_nodes[0]) + ): + quant_node = node + input_node_of_quant = quant_node.args[0] + + # Check the nodes along lift up path has only 1 user node + # Propagate view like node to find where to insert the new quant node + could_lift_up = True + current_node = quant_node + input_node = current_node.args[0] + while is_view_op(input_node): + if len(input_node.users) != 1: + could_lift_up = False + break + current_node = input_node + input_node = current_node.args[0] + + # Further check the input node of the first view node has only 1 user node + if could_lift_up and len(input_node.users) == 1: + # Replace dequant's input from quant to quant's input + quant_node.replace_all_uses_with(input_node_of_quant) + # Insert the new quant node + with graph_module.graph.inserting_before(current_node): + new_quant_node = graph_module.graph.node_copy(quant_node) + input_node.replace_all_uses_with(new_quant_node) + + # Update inputs of new_quant_node + def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node: + if n == input_node_of_quant: + return input_node + else: + return n + + new_args = map_arg(new_quant_node.args, maybe_replace_node) + new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node) + new_quant_node.args = new_args # type: ignore[assignment] + new_quant_node.kwargs = new_kwargs # type: ignore[assignment] + graph_module.graph.erase_node(quant_node) + + graph_module.graph.lint() + graph_module.recompile() diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/reinplace.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/reinplace.py new file mode 100644 index 0000000000000000000000000000000000000000..59706134f85feaef835f9be5e0ac11ce8e96c69b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/reinplace.py @@ -0,0 +1,688 @@ +# mypy: allow-untyped-defs +import itertools +import logging +import operator +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Tuple + +import torch +from torch._higher_order_ops.triton_kernel_wrap import ( + kernel_side_table, + triton_kernel_wrapper_functional, +) +from torch._inductor import config, inductor_prims +from torch._inductor.fx_utils import get_node_storage, is_node_realized +from torch._inductor.lowering import ( + inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings, +) +from torch._inductor.virtualized import V +from torch.fx.immutable_collections import immutable_dict +from torch.fx.passes.reinplace import _is_view_op +from torch.utils import _pytree as pytree + + +log = logging.getLogger(__name__) +aten = torch.ops.aten + + +@dataclass(frozen=True) +class InplaceableOp: + inplace_op: Callable[..., Any] + mutated_arg: int + extra_check: Callable[[torch.fx.Node], bool] = lambda node: True + + +_SCATTER_OP_TO_VIEW = { + torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, + torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, + torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor, + torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default, +} +_VIEW_OP_TO_SCATTER = {v: k for k, v in _SCATTER_OP_TO_VIEW.items()} + + +def graph_call_function(graph: torch.fx.Graph, fn, *args, **kwargs): + fake_args, fake_kwargs = pytree.tree_map( + lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node, + (args, kwargs), + ) + with V.fake_mode: + fake_result = fn(*fake_args, **fake_kwargs) + + node = graph.call_function(fn, args, kwargs) + node.meta["val"] = fake_result + return node + + +@dataclass +class ViewOp: + target: torch._ops.OpOverload + args: Tuple[Any, ...] + kwargs: Dict[str, Any] + + +def _inplace_generalized_scatter( + inp: torch.Tensor, src: torch.Tensor, view_ops: List[ViewOp] +) -> torch.Tensor: + tmp = inp + for view in view_ops: + fake_args, fake_kwargs = pytree.tree_map( + lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node, + (view.args, view.kwargs), + ) + tmp = view.target(tmp, *fake_args, **fake_kwargs) + try: + tmp.copy_(src) + except RuntimeError as e: + raise RuntimeError( + f"shape error in scatter op, can not broadcast {src.shape} to {tmp.shape}" + ) from e + return inp + + +def _generalized_scatter( + inp: torch.Tensor, src: torch.Tensor, view_ops: List[ViewOp] +) -> torch.Tensor: + out = inp.clone() + return _inplace_generalized_scatter(out, src, view_ops) + + +def _decompose_scatter_functional_helper( + graph: torch.fx.Graph, + inp: torch.Tensor, + src: torch.Tensor, + view_ops: List[ViewOp], +) -> torch.fx.Node: + view_op, view_ops_tail = view_ops[0], view_ops[1:] + + if view_ops_tail: + view = graph_call_function( + graph, view_op.target, inp, *view_op.args, **view_op.kwargs + ) + src = _decompose_scatter_functional_helper(graph, view, src, view_ops[1:]) # type: ignore[assignment] + + return graph_call_function( + graph, + _VIEW_OP_TO_SCATTER[view_op.target], + inp, + src, + *view_op.args, + **view_op.kwargs, + ) + + +def _decompose_scatter_functional( + graph: torch.fx.Graph, node: torch.fx.Node +) -> torch.fx.Node: + """Decompose _generalized_scatter to a sequence of view_scatter operations + + e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)]) + + will become + + view = aten.slice(inp, 0, 0, 10) + view_updated = aten.slice_scatter(view, src, 1, 10, -10) + inp_updated = aten.slice_scatter(inp, view_updated, 0, 0, 10) + """ + assert node.target is _generalized_scatter + inp, src, view_ops = node.args + return _decompose_scatter_functional_helper(graph, *node.args) # type: ignore[arg-type] + + +def _decompose_scatter_mutating( + graph: torch.fx.Graph, node: torch.fx.Node +) -> torch.fx.Node: + """Decompose _generalized_scatter using mutations + + e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)]) + + will become + + inp_updated = aten.clone(inp) + slice1 = aten.slice(inp_updated, 0, 0, 10) + slice2 = aten.slice(slice1, 1, 10, -10) + slice2.copy_(src) + + """ + assert node.target in (_generalized_scatter, _inplace_generalized_scatter) + inp, src, view_ops = node.args + assert not node.kwargs + + if node.target is _generalized_scatter: + inp = graph_call_function(graph, aten.clone, inp) + + tmp = inp + for view in view_ops: # type: ignore[union-attr] + tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs) # type: ignore[union-attr] + + graph_call_function(graph, aten.copy_.default, tmp, src) + return inp # type: ignore[return-value] + + +# View ops whose view_scatter op is lowered into mutations anyway, +# so is never a pessimisation to decompose. +_ALWAYS_MUTATING_SCATTER_OPS = { + aten.as_strided.default, + aten.diagonal.default, +} + + +def scatter_always_uses_mutation(node: torch.fx.Node) -> bool: + _, _, view_ops = node.args + return any(view.target in _ALWAYS_MUTATING_SCATTER_OPS for view in view_ops) # type: ignore[union-attr] + + +def should_reinplace_scatter(node: torch.fx.Node) -> bool: + """Choose between mutating and functional scatter decompositions + + Reinplacing view scatter ops can be pessimising as it blocks fusion with the + input or output tensor computations. However, it is still profitable if the + input and output would have been realized anyway. + + """ + inp, src, view_ops = node.args + + # Mutating scatter ops unconditionally realize input and output + if scatter_always_uses_mutation(node): + return True + + if is_node_realized(inp) and is_node_realized(node): # type: ignore[arg-type] + return True + + # If the output is copied back into the input, this forces both to be + # realized as the output is a user of the input + if inp.op in ("placeholder", "get_attr") and any( # type: ignore[union-attr] + user.target is aten.copy_.default and user.args[0] is inp for user in node.users + ): + return True + + # Otherwise, assume fusions will make functional variants profitable + return False + + +def decompose_generalized_scatter(graph: torch.fx.Graph) -> None: + """Replace _generalized_scatter with normal aten ops""" + for node in itertools.chain( + graph.find_nodes(op="call_function", target=_generalized_scatter), + graph.find_nodes(op="call_function", target=_inplace_generalized_scatter), + ): + use_mutation = ( + node.target is _inplace_generalized_scatter + or scatter_always_uses_mutation(node) + ) + + with graph.inserting_before(node): + if use_mutation: + new_node = _decompose_scatter_mutating(graph, node) + else: + new_node = _decompose_scatter_functional(graph, node) + + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + +def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None: + """ + This canonicalizes view scatter ops into a generalized form, defined as: + def scatter(inp, src, views): + tmp = inp.clone() + for view in views: + tmp = view(tmp) + tmp.copy_(src) + + We also fuse consecutive view scatter ops of the form + a = scatter(view2(self), src, [view1]) + b = scatter(self, a, [view2]) + which can be rewritten as + b = scatter(self, src, [view2, view1]) + a = view2(b) + + This is both more efficient as we only do a single scatter, and also + easier to reinplace since there is only one use of `self` + """ + + node_to_view_base: Dict[torch.fx.Node, torch.fx.Node] = {} + node_to_view_op: Dict[torch.fx.Node, List[ViewOp]] = defaultdict(list) + + def handle_views(node: torch.fx.Node): + inp = node.args[0] + node_to_view_base[node] = node_to_view_base.get(inp, inp) # type: ignore[arg-type] + node_to_view_op[node] = [ + *node_to_view_op[inp], # type: ignore[index] + ViewOp( + node.target, # type: ignore[arg-type] + args=node.args[1:], + kwargs=node.kwargs, + ), + ] + + def handle_view_scatter(node: torch.fx.Node): + assert len(node.args) >= 2 + inp, src = node.args[:2] + + scatter_view_op = ViewOp( + _SCATTER_OP_TO_VIEW[node.target], + args=node.args[2:], + kwargs=node.kwargs, + ) + + def can_fuse(): + if src.target is not _generalized_scatter: # type: ignore[union-attr] + return False + src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr] + + inp_base = node_to_view_base.get(inp, inp) # type: ignore[arg-type] + src_base = node_to_view_base.get(src_inp, src_inp) # type: ignore[arg-type] + return inp_base is src_base and node_to_view_op[src_inp] == [ # type: ignore[index] + *node_to_view_op[inp], # type: ignore[index] + scatter_view_op, + ] + + if not can_fuse(): + with graph.inserting_before(node): + new_node = graph_call_function( + graph, + _generalized_scatter, + inp, + src, + [scatter_view_op], + ) + node.replace_all_uses_with(new_node) + graph.erase_node(node) + return + + src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr] + with graph.inserting_before(src): # type: ignore[arg-type] + new_node = graph_call_function( + graph, + _generalized_scatter, + inp, + src_src, + [scatter_view_op, *src_scatter_view_op], # type: ignore[misc] + ) + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + if src.users: # type: ignore[union-attr] + new_src = graph_call_function( + graph, + _SCATTER_OP_TO_VIEW[node.target], + new_node, + *node.args[2:], + **node.kwargs, + ) + + handle_views(new_src) + src.replace_all_uses_with(new_src) # type: ignore[union-attr] + + graph.erase_node(src) # type: ignore[arg-type] + + for node in graph.nodes: + if _is_view_op(node.target): + handle_views(node) + elif node.target in _SCATTER_OP_TO_VIEW: + handle_view_scatter(node) + + +inplaceable_ops = { + aten.index_put.default: InplaceableOp(aten.index_put_.default, 0), + aten._unsafe_index_put.default: InplaceableOp(inductor_prims._unsafe_index_put_, 0), + _generalized_scatter: InplaceableOp( + _inplace_generalized_scatter, + 0, + extra_check=should_reinplace_scatter, + ), +} + +try: + c10d_functional = torch.ops._c10d_functional + inplaceable_collective_ops = { + c10d_functional.all_reduce.default: InplaceableOp( + c10d_functional.all_reduce_.default, 0 + ), + c10d_functional.all_reduce_coalesced.default: InplaceableOp( + c10d_functional.all_reduce_coalesced_.default, 0 + ), + } + inplaceable_ops.update(inplaceable_collective_ops) +except AttributeError: + # _c10d_functional ops are only available when torch + # is built with USE_DISTRIBUTED=1. + pass + +inplaceable_foreach_ops: Dict[torch._ops.OpOverload, InplaceableOp] = {} +for outplace_op, inplace_op in inplaceable_foreach_ops_lowerings.items(): + inplaceable_foreach_ops[outplace_op] = InplaceableOp(inplace_op, 0) + + +inplaceable_triton_ops = {triton_kernel_wrapper_functional} + + +# Operators that don't depend on the tensor data +META_ONLY_OPS = { + aten.sym_size.int, + aten.sym_stride.int, + aten.sym_numel.default, + aten.sym_storage_offset.default, +} + + +def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None: + """ + Reinplaces in-placeable operations. + If there are no uses of a view of the mutated arg after the current node, + it is possible to inplace the op. + This above algorithm could be justified by observing side effects. While + we traverse the graph in forwards direction, only latter nodes could view + side effects of the current node. If the current node is not used later as + well as no view of this node is used later in the graph, then it is safe to + inplace as there would be no way to observe the side effects. + This condition is slightly different for graph inputs where they can only + be inplaced if the above condition is true and there's a copy_ in the + epilogue that signals that the caller wants to observe the mutation. + + Unlike JIT Inductor, AOTInductor currently unlifts weights and buffers from + input args, so instead of checking mutation on placeholder, AOTInductor + checks mutation on get_attr. This is subject to change in future. + """ + + copy_args_to_copy_nodes = {} + # maps argument to the first copy_ node that mutates it. + copy_nodes = {} + mutated_inputs = set() + storage_to_nodes = defaultdict(list) + node_order: Dict[Any, int] = {} + for i, node in enumerate(reversed(graph.nodes)): + node_order[node] = len(graph.nodes) - i - 1 + storage_to_nodes[get_node_storage(node)].append(node) + if node.target == aten.copy_.default and node.args[0].op in ( + "placeholder", + "get_attr", + ): + dst = node.args[0] + src = node.args[1] + # If the target is a getitem and it indexes a possible clone, + # then skip over it + if src.target == operator.getitem and ( + ( + src.args[0].target == triton_kernel_wrapper_functional + and src.args[0].kwargs["kwargs"][src.args[1]] == node.args[0] + ) + or (src.args[0].target in inplaceable_foreach_ops) + or (src.args[0].target == torch.ops.higher_order.auto_functionalized) + ): + src = src.args[0] + + copy_args_to_copy_nodes[(dst, src)] = node + copy_nodes[dst] = node + + mutated_inputs.add(node.args[0]) + + def any_use_of_views_after_node(node, shared_view_nodes, *, copy_node, mutated_arg): + node_loc = node_order[node] + copy_node_loc = node_order[copy_node] if copy_node is not None else None + + def is_meta_only_user(node): + if _is_view_op(node.target): + return all(is_meta_only_user(u) for u in node.users) + return node.target in META_ONLY_OPS + + for view in shared_view_nodes: + for user in view.users: + user_loc = node_order[user] + # Skip all users before node + if user_loc <= node_loc: + continue + # Ignore uses after the copy_ epilogue node, where the input + # has already been mutated anyway + if copy_node_loc is not None and copy_node_loc <= user_loc: + continue + # Reinplacing does not change shape metadata + if is_meta_only_user(user): + continue + # If our graph looks like: + # foo(mutated_arg) + # mutated_arg.copy_(other) + # then it's safe for us to reinplace foo because mutated_arg + # will get overwritten anyways. + if ( + user.target is torch.ops.aten.copy_.default + and mutated_arg is user.args[0] + ): + continue + return True + return False + + def can_inplace(node, mutated_arg): + if isinstance(mutated_arg, (list, tuple)): + unique_storages = {get_node_storage(arg) for arg in mutated_arg} + if len(unique_storages) != len(mutated_arg): + # at least two Tensors in mutated_arg alias each other, so we can't reinplace it. + # We can probably do better (that is, reinplace one of them and clone the other) + # but that requires more work and mutable List[Tensor] are not that common. + return False + return all(can_inplace(node, arg) for arg in mutated_arg) + + if get_node_storage(mutated_arg) is None: + return False + shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)] + + if mutated_arg.op in ("placeholder", "get_attr"): + # Get the first copy_ node that mutates the mutated_arg. + copy_node = copy_nodes.get(mutated_arg, None) + if copy_node is None: + # There is no copy_ back to the candidate mutated_arg (which is a graph input). + # Therefore the semantics of the program are that it does not mutate + # mutated_arg, so we cannot re-inplace it. + return False + if any_use_of_views_after_node( + node, shared_view_nodes, copy_node=copy_node, mutated_arg=mutated_arg + ): + return False + + return True + elif any(view.op in ("placeholder", "get_attr") for view in shared_view_nodes): + # This should never happen in auto_functionalize_v2 non-inference mode, + # since all mutated_arg are bases. + + # If mutated arg is view of any of the inputs of the graph, + # do not allow for inplacing. + # This would require more sophisticated algorithm to handle + return False + else: + return not any_use_of_views_after_node( + node, shared_view_nodes, copy_node=None, mutated_arg=mutated_arg + ) + + def log_inplace_results( + node_name, + old_tensors_to_clone, + tensors_to_clone, + possibly_missed_reinplacing_opportunities, + ): + log.info( + "For node %s, attempted to reinplace %s. We were unable to reinplace %s; " + "%s (if non-empty) are possible missed reinplacing opportunities that may be bad for " + "memory usage and performance.", + node_name, + old_tensors_to_clone, + tensors_to_clone, + possibly_missed_reinplacing_opportunities, + ) + torch._dynamo.utils.counters["inductor"][ + "possibly_missed_reinplacing_opportunities" + ] += len(possibly_missed_reinplacing_opportunities) + + replace_dict: Dict[torch.fx.Node, torch.fx.Node] = {} + + def reinplace_and_refine_tensors_to_clone( + old_tensors_to_clone, kwargs, node_name, auto_functionalize_v2=False + ): + tensors_to_clone: List[str] = [] + storage_of_reinplaced_args = set() + possibly_missed_reinplacing_opportunities = [] + + def tensor_with_same_storage_already_reinplaced(arg): + if isinstance(arg, (list, tuple)): + return any( + get_node_storage(a) in storage_of_reinplaced_args for a in arg + ) + return get_node_storage(mutated_arg) in storage_of_reinplaced_args + + for arg in old_tensors_to_clone: + assert arg in kwargs + + mutated_arg = kwargs[arg] + + # Let's say we have: + # - op(x, y) that mutates both x and y + # - new_x, new_y = functional_op(x, y) is the functional variant + # If we are presented with functional_op(x, x), we must not reinplace + # this into op(x, x), because then it would be writing to the same Tensor. + # Instead, it's OK to reinplace one of them and to clone the other: + # >>> y = x.clone() + # >>> op(x, y) + # This also applies if we have views: functional_op(x, x[0]) + # should not reinplace into op(x, x[0]). + should_attempt_reinplace = not tensor_with_same_storage_already_reinplaced( + mutated_arg + ) + if should_attempt_reinplace and can_inplace(node, mutated_arg): + # In general, we probably do not need those optimizations. + copy_node = copy_args_to_copy_nodes.get((mutated_arg, node)) + if copy_node is not None: + replace_dict[copy_node] = copy_node.args[0] + if not auto_functionalize_v2: + for user in node.users: + # For auto_functionalize_v2, arg is the index of the base, where base at index i corresponds to + # output atindex size(out)+i. + # This used to compare string with integers before for auto_functionalize_v2. Not sure + # if it was needed for inplaceable_triton_ops? + if user.target == operator.getitem and user.args[1] == arg: + replace_dict[user] = mutated_arg + + if isinstance(mutated_arg, (list, tuple)): + for a in mutated_arg: + storage_of_reinplaced_args.add(get_node_storage(a)) + else: + storage_of_reinplaced_args.add(get_node_storage(mutated_arg)) + else: + if should_attempt_reinplace: + possibly_missed_reinplacing_opportunities.append(arg) + tensors_to_clone.append(arg) + + log_inplace_results( + node_name, + old_tensors_to_clone, + tensors_to_clone, + possibly_missed_reinplacing_opportunities, + ) + return tensors_to_clone + + for node in graph.nodes: + if (inplaceable_op := inplaceable_ops.get(node.target, None)) is not None: + mutated_arg = node.args[inplaceable_op.mutated_arg] + if can_inplace(node, mutated_arg) and inplaceable_op.extra_check(node): + # TODO(yifu): this doesn't properly remove copy epilogues for + # ops that mutate multiple inputs. Need to revise the copy + # node tracking logic to support the case. + copy_node = copy_args_to_copy_nodes.get((mutated_arg, node)) + if copy_node is not None: + replace_dict[copy_node] = copy_node.args[0] + node.target = inplaceable_op.inplace_op + elif node.target == torch.ops.higher_order.auto_functionalized_v2: + _mutable_op = node.args[0] + kwargs = node.kwargs + + all_bases = kwargs["_all_bases"] + bases_to_clone = range(len(all_bases)) + base_tensors_dct = dict(enumerate(all_bases)) + new_bases_to_clone: List[int] = reinplace_and_refine_tensors_to_clone( + bases_to_clone, + base_tensors_dct, + node.target, + auto_functionalize_v2=True, + ) + # Stash the metadata. There is a pass later on where we decompose + # auto_functionalized into clones + a mutable op; this metadata + # tells the decomp to only clone the following inputs + node.meta["only_clone_these_tensors"] = new_bases_to_clone + elif node.target == torch.ops.higher_order.auto_functionalized: + _mutable_op = node.args[0] + from torch._higher_order_ops.auto_functionalize import get_mutable_args + + tensors_to_clone, _ = get_mutable_args(_mutable_op) + # Don't try to reinplace Optional[Tensor] args that are None. + tensors_to_clone = [ + t for t in tensors_to_clone if node.kwargs[t] is not None + ] + tensors_to_clone = reinplace_and_refine_tensors_to_clone( + tensors_to_clone, + node.kwargs, + _mutable_op._name, + auto_functionalize_v2=False, + ) + + # Stash the metadata. There is a pass later on where we decompose + # auto_functionalized into clones + a mutable op; this metadata + # tells the decomp to only clone the following inputs + node.meta["only_clone_these_tensors"] = tensors_to_clone + elif node.target in inplaceable_triton_ops: + kernel_idx = node.kwargs["kernel_idx"] + kernel = kernel_side_table.get_kernel(kernel_idx) + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + + if isinstance(kernel, JITFunction): + kernel_name = kernel.fn.__name__ + elif isinstance(kernel, Autotuner): + if config.is_fbcode(): + # Autotuner has different implementations for AMD and NV + if torch.version.hip is None: + kernel_name = kernel.base_fn.__name__ + else: + kernel_name = kernel.fn.__name__ + else: + kernel_name = kernel.base_fn.__name__ + else: + raise AssertionError("Unknown triton kernel type") + + # inplaceable_triton_ops take an additional argument called + # tensors_to_clone which contain a list of tensors to clone + # This pass iterates over them and sees which ones are safe + # to eliminate (i.e. no longer need the clones) + tensors_to_clone = reinplace_and_refine_tensors_to_clone( + node.kwargs["tensors_to_clone"], node.kwargs["kwargs"], kernel_name + ) + + kwargs = dict(node.kwargs) + kwargs["tensors_to_clone"] = tensors_to_clone + node.kwargs = immutable_dict(kwargs) + elif ( + inplaceable_op := inplaceable_foreach_ops.get(node.target, None) + ) is not None: + mutated_args = node.args[inplaceable_op.mutated_arg] + + if not all((arg, node) in copy_args_to_copy_nodes for arg in mutated_args): + continue + + if can_inplace(node, mutated_args): + for arg in mutated_args: + copy_node = copy_args_to_copy_nodes[(arg, node)] + replace_dict[copy_node] = copy_node.args[0] + + node.target = inplaceable_op.inplace_op + for node, replacement in replace_dict.items(): + while replacement in replace_dict: + replacement = replace_dict[replacement] + replace_dict[node] = replacement + + node.replace_all_uses_with(replacement) + graph.erase_node(node) + + +def reinplace_inplaceable_ops(graph: torch.fx.Graph) -> None: + canonicalize_view_scatter_ops(graph) + reinplace_inplaceable_ops_core(graph) + decompose_generalized_scatter(graph) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/replace_random.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/replace_random.py new file mode 100644 index 0000000000000000000000000000000000000000..f56a90b86dd5c1c35ebba7c97fbff1f4497313c6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/replace_random.py @@ -0,0 +1,145 @@ +# mypy: allow-untyped-defs +import collections +import logging + +import torch +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.fx.passes.shape_prop import _extract_tensor_metadata + +from .. import config, inductor_prims +from ..pattern_matcher import ( + CallFunctionVarArgs, + Match, + PatternMatcherPass, + register_graph_pattern, +) +from ..virtualized import V + + +log = logging.getLogger(__name__) +patterns = PatternMatcherPass() +aten = torch.ops.aten + + +def replace_random_passes(gm: torch.fx.GraphModule): + """Modify the given FX graph to use backend-native random ops""" + if config.fallback_random: + return 0 + + count = patterns.apply(gm) + with GraphTransformObserver( + gm, "fuse_seed_creation_pass", config.trace.log_url_for_graph_xform + ): + count += fuse_seed_creation_pass(gm.graph) + + return count + + +def fuse_seed_creation_pass(graph: torch.fx.Graph): + """ + Horizontally fuse all the seed generation on each device + + a = inductor_seed(dev) + b = inductor_seed(dev) + + Becomes: + seeds = inductor_seeds(2, dev) + a = inductor_lookup_seed(seeds, 0) + b = inductor_lookup_seed(seeds, 1) + + We do this because seed creation is entirely launch overhead bound. + """ + device_seeds = collections.defaultdict(list) + for node in graph.nodes: + if CallFunctionVarArgs(inductor_prims.seed).match(node): + device_seeds[node.args[0]].append(node) + + if not device_seeds: + return 0 + + for device, seeds in device_seeds.items(): + with graph.inserting_before(seeds[0]): + combined = graph.call_function(inductor_prims.seeds, (len(seeds), device)) + with V.fake_mode: + combined.meta["val"] = torch.empty( + [len(seeds)], device=device, dtype=torch.int64 + ) + combined.meta["tensor_meta"] = _extract_tensor_metadata( + combined.meta["val"] + ) + + for idx, seed in enumerate(seeds): + with graph.inserting_before(seed): + new_seed = graph.call_function( + inductor_prims.lookup_seed, (combined, idx) + ) + seed.replace_all_uses_with(new_seed) + new_seed.meta.update(seed.meta) + graph.erase_node(seed) + + return len(device_seeds) + + +def default_kwargs(device): + return {} + + +def get_device(device): + if device is not None: + return device + return torch.empty([]).device # default device + + +@register_graph_pattern(CallFunctionVarArgs(aten.rand.default), pass_dict=patterns) +@register_graph_pattern(CallFunctionVarArgs(aten.rand.generator), pass_dict=patterns) +@register_graph_pattern(CallFunctionVarArgs(aten.randn.default), pass_dict=patterns) +@register_graph_pattern(CallFunctionVarArgs(aten.randn.generator), pass_dict=patterns) +def replace_random( + match: Match, + size, + *, + generator=None, + dtype=None, + device=None, + layout=None, + pin_memory=None, +): + if generator is not None: + return + + def replacement(size): + result = inductor_prims.random( + size, inductor_prims.seed(device), mode, **default_kwargs(device) + ) + if dtype is not None: + result = result.to(dtype) + return result + + mode = { + aten.rand: "rand", + aten.randn: "randn", + }[ + match.output_node().target.overloadpacket # type: ignore[union-attr] + ] # type: ignore[union-attr] + device = get_device(device) + match.replace_by_example(replacement, [size]) + + +@register_graph_pattern(CallFunctionVarArgs(aten.randint.low), pass_dict=patterns) +def replace_randint( + match: Match, + low, + high, + size, + *, + dtype=torch.int64, + device=None, + layout=None, + pin_memory=None, +): + def replacement(low, high, size): + result = inductor_prims.randint(low, high, size, inductor_prims.seed(device)) + return result.to(dtype) + + device = get_device(device) + match.replace_by_example(replacement, [low, high, size]) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..025b66ce5366375b60d6d22e0ae9e305114b0e1c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a7c63050a2cbadabbad72b24d965c29befa8780 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f6fbee3aae3784806fe1d4474e726e8d8c4bb64 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dc755f1c930e897af49889fa46c8166028a6c3e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1dd1f89507a1db19b59e21b493c7fde074c7d273 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49b3a23620a889ec386fdd5fa55b53d92997d03c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b8f7f54a2f2f14599c020ea28c6cb476704b541 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a87eaefe79e8c7bf1c83a81f07d1d76b3bef408 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c7f1850a7578dc081b52dab8cab931040439c27 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..291e69726eaffc7c509f6e6595ea362c0bc9beab Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_18.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_18.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e7778a4d3fe2c53672fe59f2f8bfab50230a76c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_18.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_19.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_19.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78573f629bb711b881db7367b0f4f559a8d8b405 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_19.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7c37f16465841a4f40f37deadf1df6f774dfeb9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7c0a2cb127de80163b9bb7fedf3b1d19279a21b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c63c9afcc5d4f002e307dbf89046f598fbc7e4f6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c8bffdb3d3a3b1f41dfc3e735a95a84817a3c0b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e2abd79669a6774df43e75673eef2b49006a94f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..868107f5e3cb2f7a69feebc68d22a4f3b541eff3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2751caf03fff1d4cce115253ef5e1c8c2c7fd3a9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2b0b81f166da74e7b4e3249c8e56a29109afd1e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/addmm_pattern.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/addmm_pattern.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..013b6b4f29934e7a27fd20a27edff86b4cea9528 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/addmm_pattern.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/bmm_pattern.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/bmm_pattern.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cea07086d4f7f1d79dd234818a59dfe9264d7987 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/bmm_pattern.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/mm_pattern.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/mm_pattern.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..409ca3249f1976e23e1836ed1e31f96f974d2283 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/mm_pattern.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py new file mode 100644 index 0000000000000000000000000000000000000000..55d2216b4e1fb29e8c5c8f25fd181bb9f8b41fc3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py @@ -0,0 +1,173 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_1_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_1_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_1_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py new file mode 100644 index 0000000000000000000000000000000000000000..860ef1c8551fb520f615a0da612ee8fc4088b88c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py @@ -0,0 +1,204 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_10_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_10_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, convert_element_type_default_3, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_10_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_10_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py new file mode 100644 index 0000000000000000000000000000000000000000..d8119c33ed93b8534270e8ea85410e94a549d027 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py @@ -0,0 +1,203 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_11_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_11_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_11_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_11_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py new file mode 100644 index 0000000000000000000000000000000000000000..40834960904aa7d8aa7becc4fe99cf25a37d97ae --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py @@ -0,0 +1,219 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_12_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_12_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_12_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_12_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py new file mode 100644 index 0000000000000000000000000000000000000000..bef5eab2bee969ecb5541a32f515f3ed10187e01 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py @@ -0,0 +1,129 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2) +amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) +neg_default = CallFunction(aten.neg.default, div_Tensor) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4, _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, fma_default, permute_default_2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, fma_default) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1')) +_sfdp_pattern_13_training = MultiOutputPattern([bmm_default_1, + bmm_default_3, + permute_default_4, + bmm_default_5, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2) +amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +_sfdp_pattern_13_inference = CallFunction(aten.bmm.default, div_Tensor, KeywordArg('value'), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default) +convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, convert_element_type_default_5, permute_default_2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, convert_element_type_default_5) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1')) +_sfdp_pattern_13_half_training = MultiOutputPattern([bmm_default_1, + bmm_default_3, + permute_default_4, + bmm_default_5, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default) +convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +_sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, convert_element_type_default_1, KeywordArg('value'), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py new file mode 100644 index 0000000000000000000000000000000000000000..a1e87c009fcca38bcbeee044d19308958aeb9585 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py @@ -0,0 +1,209 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_14_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_14_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_14_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_14_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py new file mode 100644 index 0000000000000000000000000000000000000000..289585111a54bf37d1f017adbcf73afd7ae03f94 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py @@ -0,0 +1,227 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_15_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_15_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_15_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_15_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py new file mode 100644 index 0000000000000000000000000000000000000000..e3c1b5c6023539c1caa0474561bff2e18cd63803 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py @@ -0,0 +1,598 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_mask_fp32_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_mask_fp32_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_mask_fp32_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_mask_fp32_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py new file mode 100644 index 0000000000000000000000000000000000000000..f741b23c0dd3bd166ec9687cf91a78cf41bf839e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py @@ -0,0 +1,243 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_17_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_17_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_5) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_17_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_17_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py new file mode 100644 index 0000000000000000000000000000000000000000..25c482876a998cd70d7d08302ee191d3fc2666b4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py @@ -0,0 +1,452 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_bs1_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_bs1_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_5, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_half_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_half_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_5, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_half_bs1_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_half_bs1_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py new file mode 100644 index 0000000000000000000000000000000000000000..3cba2215bc7619e367660e47f0f3c27889063fa8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py @@ -0,0 +1,208 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_19_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_19_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_3, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_19_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_19_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py new file mode 100644 index 0000000000000000000000000000000000000000..f573cb373491300e370e2dfaa541a4f4756d0fc3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py @@ -0,0 +1,173 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_2_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_2_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_2_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_2_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py new file mode 100644 index 0000000000000000000000000000000000000000..d7eb251ba52d497f698d4a087ebd059ba16dbf4c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py @@ -0,0 +1,189 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_3_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_3_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_3_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_3_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py new file mode 100644 index 0000000000000000000000000000000000000000..773b2be31bde5a9afe6df90dbffc5e264e386037 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py @@ -0,0 +1,189 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor_4, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5) +mul_Tensor_6 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_6, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_4_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_4_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_4, Ignored()) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +mul_Tensor_6 = CallFunction(aten.mul.Tensor, convert_element_type_default_5, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_6, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_4_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_4_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py new file mode 100644 index 0000000000000000000000000000000000000000..fe481c8293be744e46c1287758043d4be82637ab --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py @@ -0,0 +1,177 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_5_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_5_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_5_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_5_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py new file mode 100644 index 0000000000000000000000000000000000000000..7de8b8229ea8b66455ffb0a2b86c397217048767 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py @@ -0,0 +1,193 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_6_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_6_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_6_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_6_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py new file mode 100644 index 0000000000000000000000000000000000000000..ff198232b5e6d84511087a483cd51ff2ec8f43d8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py @@ -0,0 +1,220 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_7_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_7_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_7_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_7_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py new file mode 100644 index 0000000000000000000000000000000000000000..8c4b27c8a6fb6d3f6da8c178d46308bf61584333 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py @@ -0,0 +1,204 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_8_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_8_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_8_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_8_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py new file mode 100644 index 0000000000000000000000000000000000000000..78380c1bb341e1a8de1fe269dd05e34e8604d7ba --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py @@ -0,0 +1,220 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_9_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_9_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_9_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_9_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py new file mode 100644 index 0000000000000000000000000000000000000000..58864035cb936de9418d0b87bda304c0a84d2b25 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py @@ -0,0 +1,52 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +addmm_default = CallFunction(aten.addmm.default, KeywordArg('input'), KeywordArg('mat1'), KeywordArg('mat2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha')) +mul_Scalar = CallFunction(aten.mul.Scalar, KeywordArg('tangents_1'), KeywordArg('beta')) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, mul_Scalar, Ignored(), True) +view_default = CallFunction(aten.view.default, sum_dim_IntList, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored()) +mm_default = CallFunction(aten.mm.default, KeywordArg('tangents_1'), permute_default) +mul_Scalar_1 = CallFunction(aten.mul.Scalar, mm_default, KeywordArg('alpha')) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored()) +mm_default_1 = CallFunction(aten.mm.default, permute_default_1, KeywordArg('tangents_1')) +mul_Scalar_2 = CallFunction(aten.mul.Scalar, mm_default_1, KeywordArg('alpha')) +addmm_pattern_training = MultiOutputPattern([addmm_default, + view_default, + mul_Scalar_1, + mul_Scalar_2, + None, + None +]) + + +addmm_pattern_inference = CallFunction(aten.addmm.default, KeywordArg('input'), KeywordArg('mat1'), KeywordArg('mat2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha'), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py new file mode 100644 index 0000000000000000000000000000000000000000..b1363546bd11198e3a3488e8b71899b87c431cee --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py @@ -0,0 +1,44 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('mat1'), KeywordArg('mat2')) +permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, permute_default_1, KeywordArg('tangents_1')) +bmm_pattern_training = MultiOutputPattern([bmm_default, + bmm_default_1, + bmm_default_2 +]) + + +bmm_pattern_inference = CallFunction(aten.bmm.default, KeywordArg('mat1'), KeywordArg('mat2'), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py new file mode 100644 index 0000000000000000000000000000000000000000..5380feb34d56f3edb3c68ced3f1ff70e452eb0c5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py @@ -0,0 +1,44 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2')) +permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored()) +mm_default_1 = CallFunction(aten.mm.default, KeywordArg('tangents_1'), permute_default) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored()) +mm_default_2 = CallFunction(aten.mm.default, permute_default_1, KeywordArg('tangents_1')) +mm_pattern_training = MultiOutputPattern([mm_default, + mm_default_1, + mm_default_2 +]) + + +mm_pattern_inference = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'), _users=0) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/split_cat.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/split_cat.py new file mode 100644 index 0000000000000000000000000000000000000000..d0098a535ee295fe5c5e4cd97600b08a638a6f02 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/split_cat.py @@ -0,0 +1,2514 @@ +# mypy: allow-untyped-defs +import itertools +import logging +import operator +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing_extensions import TypeAlias + +import torch +from torch._dynamo.utils import counters + +from ..pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethodVarArgs, + FailedMatch, + get_arg_value, + Ignored, + KeywordArg, + ListOf, + Match, + MatchContext, + MULTIPLE, + PatternExpr, + PatternMatcherPass, + register_graph_pattern, + RepeatedExpr, +) +from .group_batch_fusion import is_node_meta_valid, POST_GRAD_FUSIONS, PRE_GRAD_FUSIONS + + +log = logging.getLogger(__name__) + +_Arguments: TypeAlias = Tuple[torch.fx.node.Argument, ...] +_TransformParam: TypeAlias = Tuple[ + Optional[_Arguments], + Optional[_Arguments], + Optional[_Arguments], + Optional[_Arguments], +] +_Range: TypeAlias = Tuple[int, int] + + +PRE_GRAD_PATTERNS: Dict[str, PatternMatcherPass] = {} +POST_GRAD_PATTERNS: Dict[str, PatternMatcherPass] = {} + +pre_grad_pass_names = [ + "normalization_pass", + "remove_split_with_size_one_pass", + "merge_getitem_cat_pass", + "merge_stack_tahn_unbind_pass", + "merge_splits_pass", + "mutate_cat_pass", + "split_cat_pass", + "unbind_stack_pass", + "split_cat_to_slices_pass", + "unbind_cat_to_view_pass", + "split_stack_to_cats_pass", + "unbind_stack_to_slices_pass", + "move_reshape_out_of_split_stack_pass", +] + +post_grad_pass_names = [ + "normalization_aten_pass", + "decompose_mm_pass", + "unbind_stack_aten_pass", + "shape_padding_multiplier", +] + +for pass_name in pre_grad_pass_names: + # exclude all passes from the group batch fusion + # they do not use pattern matcher + if pass_name in PRE_GRAD_FUSIONS: + continue + PRE_GRAD_PATTERNS[pass_name] = PatternMatcherPass( + pass_name=pass_name, + ) + +for pass_name in post_grad_pass_names: + # exclude all passes from the group batch fusion + # they do not use pattern matcher + if pass_name in POST_GRAD_FUSIONS: + continue + POST_GRAD_PATTERNS[pass_name] = PatternMatcherPass( + pass_name=pass_name, + ) + + +def construct_pattern_matcher_pass(pass_name: str): + """ + Return the specific pattern_matcher_pass given the pass name. + """ + if pass_name in PRE_GRAD_PATTERNS: + return PRE_GRAD_PATTERNS[pass_name] + else: + return POST_GRAD_PATTERNS[pass_name] + + +def _get_split_args_default(split_node): + input_kwarg = "tensor" + split_size_kwarg = "split_size_or_sections" + dim_kwarg = "dim" + default_dim_value = 0 + if split_node.op == "call_method": + split_size_kwarg = "split_size" + return ( + get_arg_value(split_node, 0, input_kwarg), + get_arg_value(split_node, 1, split_size_kwarg), + get_arg_value(split_node, 2, dim_kwarg) or default_dim_value, + ) + + +def _get_dim(node: Any): + assert isinstance(node, torch.fx.Node) + if "dim" in node.kwargs: + assert isinstance(node.kwargs["dim"], int) + return node.kwargs["dim"] + if node.target == torch.unbind: + if len(node.args) == 2: + assert isinstance(node.args[-1], int) + return node.args[-1] + return 0 # defaults to dim=0 + if node.target == torch.split: + if len(node.args) == 3: + assert isinstance(node.args[-1], int) + return node.args[-1] + return 0 # defaults to dim=0 + raise AssertionError( + f"Can't extract `dim` from {node.target} {node.args} {node.kwargs}" + ) + + +# noqa: W605 +# ############The pattern to be optimized is######### +# unbind (dim=0) +# / ... \ +# getitem getitem -> user=1 +# | | +# split split -> dim=1, user=1, split_section_size=1 +# | | +# getitem getitem -> user=1 +# \ / +# cat (dim=1) -> user=1 +# | + +# ################After transformation############# +# unbind (dim=0) +# / ... \ +# getitem getitem -> user=1 +# \ / +# cat (dim=1) -> user=1 +# | + + +def normalize_split_base( + match: Match, + _get_split_args: Callable[ + [torch.fx.Node], Tuple[Optional[torch.fx.Node], Optional[Any], Optional[int]] + ], +): + """ + Normalize split with split_size into split_with_sizes, so that we only deal with one type of split in + subsequent optimizations + """ + split_node = match.nodes[0] + graph = match.graph + split_input, split_size, split_dim = _get_split_args(split_node) + if split_input is None or split_dim is None or split_size is None: + log.debug("couldn't find split args") + return + if not is_node_meta_valid(split_node): + log.debug("example value absent for node: %s", split_node) + return + assert isinstance(split_node.meta["example_value"], (list, tuple)) + split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]] + + if any(isinstance(section, torch.SymInt) for section in split_sections): + # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. + return + if split_dim < 0: # Normalize split dim + split_dim += split_input.meta["example_value"].dim() + + new_args = (split_input, split_sections) + new_kwargs = {"dim": split_dim} + if ( + split_node.args == new_args + and split_node.kwargs == new_kwargs + and split_node.op == "call_function" + ): + return + + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.split, + args=new_args, + kwargs=new_kwargs, # type: ignore[arg-type] + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + graph.erase_node(split_node) + counters["inductor"]["normalization_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.split, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +@register_graph_pattern( + CallMethodVarArgs("split", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_split_default(match: Match, *args, **kwargs): + return normalize_split_base(match, _get_split_args_default) + + +@register_graph_pattern( + CallFunctionVarArgs(torch.split, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"), +) +@register_graph_pattern( + CallMethodVarArgs("split", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"), +) +def remove_split_with_size_one(match: Match, *args, **kwargs): + graph = match.graph + split_node = match.nodes[0] + split_input, split_size, split_dim = _get_split_args_default(split_node) + if split_input is None or split_dim is None or split_size is None: + log.debug("couldn't find split args") + return + if not is_node_meta_valid(split_node): + log.debug("example value absent for node: %s", split_node) + return + assert isinstance(split_node.meta["example_value"], (list, tuple)) + split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]] + + if any(isinstance(section, torch.SymInt) for section in split_sections): + # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. + return + # remove the dummy split whose split sections size is one + if len(split_sections) == 1: + # find the grand children of the split_node + next_users = find_next_users(split_node) + user = next(iter(split_node.users.keys())) + # replace the users of grand child node with the input node + for next_user in next_users: + next_user.replace_input_with(user, split_input) + # erase the split node and its child + graph.erase_node(user) + graph.erase_node(split_node) + counters["inductor"]["remove_split_with_size_one_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.unbind, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +@register_graph_pattern( + CallMethodVarArgs("unbind", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_unbind_default(match: Match, *args, **kwargs): + node = match.nodes[0] + graph = match.graph + input = get_arg_value(node, 0, "input") + dim = get_arg_value(node, 1, "dim") + if dim is None: + axis = node.kwargs.get("axis") + if axis is not None: + dim = axis + else: + dim = 0 + if input is None: + log.debug("couldn't find unbind args") + return + if not is_node_meta_valid(input): + log.debug("example value absent for node: %s", input) + return + ndim = input.meta["example_value"].ndim + if dim < 0: # Normalize unbind dim + dim += ndim + with graph.inserting_after(node): + new_node = graph.call_function( + torch.unbind, + args=(input,), + kwargs={"dim": dim}, + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) + counters["inductor"]["normalization_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.cat, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_cat_default(match: Match, *args, **kwargs): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + cat_node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(cat_node, 0, "tensors") + cat_dim = get_arg_value(cat_node, 1, "dim") + if cat_dim is None: + cat_axis = cat_node.kwargs.get("axis") + if cat_axis is not None: + cat_dim = cat_axis + else: + cat_dim = 0 + if tensors is None or cat_dim is None: + log.debug("couldn't find cat args") + return + assert isinstance(tensors, (list, tuple)) + for tensor in itertools.chain([cat_node], tensors): + if not is_node_meta_valid(tensor): + log.debug("example value absent for node: %s", tensor) + return + + ndim = cat_node.meta["example_value"].dim() + + def is_empty_tensor(x): + # special case where torch.cat supports cat'ing with an empty tensor + x_shape = x.meta["example_value"].shape + return len(x_shape) == 1 and guard_size_oblivious(x_shape[0] == 0) + + assert all( + ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors + ) + + if cat_dim < 0: # Normalize cat dim + cat_dim += ndim + + new_args = (tensors,) + new_kwargs = {"dim": cat_dim} + if ( + cat_node.args == new_args + and cat_node.kwargs == new_kwargs + and cat_node.op == "call_function" + ): + return + + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.cat, + args=new_args, + kwargs=new_kwargs, + ) + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + graph.erase_node(cat_node) + counters["inductor"]["normalization_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.stack, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_stack_default(match: Match, *args, **kwargs): + node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(node, 0, "tensors") + dim = get_arg_value(node, 1, "dim") or 0 + if tensors is None or dim is None: + log.debug("couldn't find stack args") + return + assert isinstance(tensors, (list, tuple)) + + # A bug in pytorch, some nodes miss the example_value metadata + for tensor in itertools.chain([node], tensors): + if not is_node_meta_valid(tensor): + log.debug("example value absent for node: %s", tensor) + return + + ndim = node.meta["example_value"].dim() + if dim < 0: # Normalize dim + dim += ndim + + with graph.inserting_after(node): + new_node = graph.call_function( + node.target, # type: ignore[arg-type] + args=(tensors,), + kwargs={"dim": dim}, + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) + counters["inductor"]["normalization_pass"] += 1 + + +def find_next_users(split_node: torch.fx.Node) -> List[torch.fx.Node]: + next_users = [] + for getitem_node in split_node.users.keys(): + for getitem_user in getitem_node.users.keys(): + if getitem_user not in next_users: + next_users.append(getitem_user) + return next_users + + +@register_graph_pattern( + CallMethodVarArgs("squeeze", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_squeeze_default(match: Match, *args, **kwargs): + squeeze_node = match.nodes[0] + squeeze_input = get_arg_value(squeeze_node, 0) + + if "dim" in squeeze_node.kwargs: + assert len(squeeze_node.args) == 1 + dim = squeeze_node.kwargs["dim"] + elif len(squeeze_node.args) == 1: + # squeeze(Tensor) + dim = None + elif len(squeeze_node.args) == 2: + # squeeze(Tensor self, int dim) + # squeeze(Tensor self, int[] dim) + dim = squeeze_node.args[1] + else: + # squeeze(Tensor self, int[] dim) (called with varargs) + dim = squeeze_node.args[1:] + + if isinstance(dim, Sequence) and len(dim) == 1: + dim = dim[0] + + with match.graph.inserting_after(squeeze_node): + if dim is None: + new_squeeze_node = match.graph.call_function( + torch.squeeze, args=(squeeze_input,) + ) + else: + new_squeeze_node = match.graph.call_function( + torch.squeeze, args=(squeeze_input,), kwargs={"dim": dim} + ) + squeeze_node.replace_all_uses_with(new_squeeze_node) + new_squeeze_node.meta.update(squeeze_node.meta) + match.graph.erase_node(squeeze_node) + + +@register_graph_pattern( + CallMethodVarArgs("reshape", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_reshape_default(match: Match, *args, **kwargs): + reshape_node = match.nodes[0] + if not is_node_meta_valid(reshape_node): + log.debug("example value absent for node: %s", reshape_node) + return + reshape_input = get_arg_value(reshape_node, 0) + + with match.graph.inserting_after(reshape_node): + new_reshape_node = match.graph.call_function( + torch.reshape, + args=(reshape_input, tuple(reshape_node.meta["example_value"].shape)), + ) + reshape_node.replace_all_uses_with(new_reshape_node) + new_reshape_node.meta.update(reshape_node.meta) + match.graph.erase_node(reshape_node) + + +class TorchSplit(CallFunction): + """ + Matches a call to torch.split if it is in a normalized form. Ensures that all users of + splits are unique getitems. + """ + + def __init__(self, arg, sizes, func=torch.split) -> None: + # using KeywordArg("dim") for `dim` checks they all match + super().__init__(func, arg, sizes, _users=MULTIPLE, dim=KeywordArg("dim")) + + def _match(self, node: torch.fx.Node, ctx: MatchContext): + m = super()._match(node, ctx) + if not m: + return m + split_sections = node.args[1] + if not isinstance(split_sections, (list, tuple)): + return FailedMatch("split not normalized") + # check users are all unique getitems + seen_idxs = set() + for user in node.users: + if not CallFunction(operator.getitem, Arg(), Arg()).match(user): + # This should ideally never happen. Split user should always be a getitem + return FailedMatch(f"user of split not a getitem: {user}") + if not isinstance(user.args[1], int): + return FailedMatch("only integer getitems are handled") + if user.args[1] in seen_idxs: + return FailedMatch(f"duplicate getitem {user.args[1]}") + if user.args[-1] < 0: # type: ignore[operator] + # This shouldn't ideally happen as dynamo normalizes indexes to positive + return FailedMatch("negative index") + seen_idxs.add(user.args[1]) + return m + + +@register_graph_pattern( + TorchSplit( + CallFunction( + operator.getitem, + TorchSplit( + KeywordArg("first_split_input"), + KeywordArg("first_split_sections"), + ), + Ignored(), + ), + KeywordArg("next_split_sections"), + ), + pass_dict=construct_pattern_matcher_pass("merge_splits_pass"), +) +def merge_splits( + match: Match, + first_split_input: torch.fx.Node, + first_split_sections: List[int], + next_split_sections: List[int], + # Note: dim is implicitly passed by TorchSplit, as it internally uses a pattern with dim + dim: int, +): + node = match.output_node() + # it is possible that the split has no users, + # we check the corner case and skip the pattern + if len(node.users.keys()) == 0: + return + graph = match.graph + first_split = node.args[0].args[0] # type: ignore[union-attr] + next_split_index = node.args[0].args[1] # type: ignore[union-attr] + + new_split_sections = list(first_split_sections) + new_split_sections[next_split_index : next_split_index + 1] = next_split_sections # type: ignore[operator, misc] + + first_split_dim = _get_dim(first_split) + + to_remove = [] + + with graph.inserting_before(first_split): # type: ignore[arg-type] + # Add the new split node + new_split = graph.call_function( + torch.split, + args=(first_split_input, new_split_sections), + kwargs={"dim": first_split_dim}, + ) + if is_node_meta_valid(first_split_input): + new_split.meta["example_value"] = torch.split( + first_split_input.meta["example_value"], + new_split_sections, + dim=first_split_dim, + ) + first_split_num_to_user = { + user.args[1]: user for user in first_split.users.keys() # type: ignore[union-attr] + } + + new_split_num = 0 + for split_num in range(len(first_split_sections)): + if split_num not in first_split_num_to_user: + new_split_num += 1 + continue + old_getitem = first_split_num_to_user[split_num] + if split_num != next_split_index: + old_getitem.update_arg(0, new_split) + old_getitem.update_arg(1, new_split_num) + new_split_num += 1 + else: + next_split_num_to_user = { + user.args[1]: user for user in node.users.keys() + } + # It is not necessary all getitems from the split node are used. + # We use the num of users to check the getitems to be merged. + for next_split_num in range(len(node.users.keys())): + with graph.inserting_after(new_split): + new_getitem = graph.call_function( + operator.getitem, args=(new_split, new_split_num) + ) + new_split_num += 1 + next_getitem = next_split_num_to_user[next_split_num] + new_getitem.meta.update(next_getitem.meta) + next_getitem.replace_all_uses_with(new_getitem) + to_remove.append(next_getitem) + to_remove.append(node) + to_remove.append(old_getitem) + + to_remove.append(first_split) # type: ignore[arg-type] + for node in to_remove: + graph.erase_node(node) + + counters["inductor"]["merge_splits_pass"] += 1 + + +class SplitCatSimplifier: + """ + Helper class to simplify split-cat pattern. In simple cases, both split and cat node can be removed in a "split->cat" + pattern. However, there are various cases where they can't and we need to simplify split/ add transforms before cat. + Some such cases are: + 1. Final node has additional args (not coming from the initial split) + 2. Shuffling of args between split/cat + 3. Some final nodes are non-(cat/stack) + 4. Split-dim != cat-dim (but equal split) + + Note that any combination of the above cases can happen. + + To deal with 1, 2, & 3 - we iterate over all users of split. And figure out common "ranges" that can be merged. + Then, we simplify the split accordingly. In the best case, split can be entirely removed. + + To deal with 4, we add some transformations (unflatten + movedim) (See `get_transform_params`). + + Finally, depending on final node being cat or stack, unsqueeze/flatten needs to be added. + + """ + + def simplify( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + split_sections: List[int], + ): + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + # Gather inputs of the next users. When inputs come from `split_node`, they are instead represented by + # a tuple indicating the split ranges. See `get_user_input_list` for more details + user_inputs_list = self.get_user_input_list(split_node, next_users) + # Simplify the split_sections based on user_inputs_list. In simpler cases, len(simplified_split_ranges) == 1 and + # we can simply replace the split node. Otherwise, we simplify it. + simplified_split_ranges = self.get_simplified_split_ranges( + split_sections, next_users, user_inputs_list + ) + if not simplified_split_ranges: # Simplification not possible + return + transform_params_list = self.get_transform_params( + split_node, next_users, user_inputs_list + ) + if not transform_params_list: + return + + # Start actual replacement + user_inputs_list_new = self.replace_split( + graph, split_node, split_sections, user_inputs_list, simplified_split_ranges + ) + self.replace_cat( + graph, split_node, next_users, user_inputs_list_new, transform_params_list # type: ignore[arg-type] + ) + self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type] + counters["inductor"]["unbind_stack_pass"] += 1 + + def get_user_input_list( + self, split_node: torch.fx.Node, next_users: List[torch.fx.Node] + ) -> List[List[Union[torch.fx.Node, _Range]]]: + """ + Returns list of inputs to the following user nodes, in order. The outer list represents the user node. The inner + list represents the inputs to that particular node. This list can either contain + - a tuple representing the ranges of get_items that should go into the cat (closed interval) + - torch.fx.Node representing "other" inputs (which are not coming from our split) + """ + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]] = [] + for user in next_users: + if user.target in {torch.cat, torch.stack}: + user_inputs_list.append(self.get_merged_user_inputs(split_node, user)) + else: + user_inputs_list.append(self.get_non_cat_node_input(split_node, user)) # type: ignore[arg-type] + return user_inputs_list + + def get_merged_user_inputs( + self, split_node: torch.fx.Node, cat_node: torch.fx.Node + ) -> List[Union[torch.fx.Node, _Range]]: + user_inputs = get_arg_value(cat_node, 0, "tensors") + simplified_user_inputs = [] + split_users = set(split_node.users.keys()) + for user_input in user_inputs: + if user_input not in split_users: + simplified_user_inputs.append(user_input) + else: + # Add which "getitem" cat depends on + simplified_user_inputs.append(user_input.args[1]) + return self.merge_consecutive_inputs(simplified_user_inputs) + + def get_non_cat_node_input( + self, split_node: torch.fx.Node, node: torch.fx.Node + ) -> List[_Range]: + """ + Get input for a non cat node in the same format as `get_merged_user_inputs` + """ + node_input = [] + split_users = set(split_node.users.keys()) + for node_arg in node.all_input_nodes: + if node_arg in split_users: + getitem_num = get_arg_value(node_arg, 1) + node_input.append((getitem_num, getitem_num)) + return node_input + + def merge_consecutive_inputs( + self, inputs: List[Union[torch.fx.Node, int]] + ) -> List[Union[torch.fx.Node, _Range]]: + """ + Merge consecutive inputs going into a user node. + + For e.g. + [arg0, 0, 1, 2, arg1] -> [arg0, (0, 2), arg1] + """ + merged_ranges = [] + cur_range = None + for input_ in inputs: + if isinstance(input_, int): + if not cur_range: + cur_range = [input_, input_] + elif input_ == cur_range[1] + 1: + cur_range[1] += 1 + else: + merged_ranges.append(tuple(cur_range)) + cur_range = [input_, input_] + else: + if cur_range: + merged_ranges.append(tuple(cur_range)) + cur_range = None + merged_ranges.append(input_) # type: ignore[arg-type] + if cur_range: + merged_ranges.append(tuple(cur_range)) + return merged_ranges # type: ignore[return-value] + + def get_simplified_split_ranges( + self, + split_sections, + next_users, + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + ) -> Optional[List[_Range]]: + ranges = set() + for user_node, user_inputs in zip(next_users, user_inputs_list): + ranges |= { + user_input + for user_input in user_inputs + if isinstance(user_input, tuple) + } + cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist() + split_ranges = sorted( + [(cumulative_sizes[r[0]], cumulative_sizes[r[1] + 1]) for r in ranges] + ) + + if not self.has_non_overlapping_ranges( + split_ranges, + ): # This need not be a strict condition + # However, we keep it now for simplicity. + return None + split_ranges = self.fill_gaps(split_ranges, 0, cumulative_sizes[-1]) + if len(split_sections) == len(split_ranges): # Simplification not possible + return None + counters["inductor"]["scmerge_split_sections_removed"] = len( + split_sections + ) - len(split_ranges) + return split_ranges + + def has_non_overlapping_ranges(self, ranges: List[_Range]) -> bool: + for range_, next_range in zip(ranges, ranges[1:]): + if range_[1] > next_range[0]: + return False + return True + + def fill_gaps(self, ranges: List[_Range], min_: int, max_: int) -> List[_Range]: + cur = min_ + filled_ranges = [] + for a, b in ranges: + if cur < a: + filled_ranges.append((cur, a)) + filled_ranges.append((a, b)) + cur = b + if filled_ranges[-1][1] < max_: + filled_ranges.append((filled_ranges[-1][1], max_)) + return filled_ranges + + def get_transform_params( + self, + split_node: torch.fx.Node, + next_users: List[torch.fx.Node], + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + ) -> Optional[List[List[_TransformParam]]]: + """ + Figure out what transforms are needed for each input to each cat node. + + We replace a split node with an unflatten followed by a movedim + """ + split_dim = _get_dim(split_node) + split_sections = split_node.args[1] + transform_params_list: List[List[_TransformParam]] = [] + + for user_node, user_inputs in zip(next_users, user_inputs_list): + if user_node.target not in {torch.cat, torch.stack}: + transform_params_list.append([]) + continue + + cat_dim = get_arg_value(user_node, 1, "dim") + transform_params: List[_TransformParam] = [] + for user_input in user_inputs: + if split_dim == cat_dim and user_node.target == torch.cat: + # No transform needed + transform_params.append((None, None, None, None)) + elif isinstance(user_input, tuple): # Split being simplified + # Verify equal split + subset_split_sections = split_sections[ # type: ignore[index] + user_input[0] : user_input[1] + 1 + ] + # All sections should be equal + if len(set(subset_split_sections)) != 1: + return None + + num_splits = len(subset_split_sections) + unflatten_params = (split_dim, (num_splits, -1)) + movedim_params = ( + (split_dim, cat_dim) if split_dim != cat_dim else None + ) + transform_params.append( + (unflatten_params, movedim_params, None, None) + ) + elif ( + user_node.target == torch.stack or split_dim != cat_dim + ): # We need to unsqueeze inputs not coming through split + transform_params.append((None, None, (cat_dim,), None)) + else: # Non-split inputs + transform_params.append((None, None, None, None)) + transform_params_list.append(transform_params) + return transform_params_list + + def replace_split( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + split_sections: List[int], + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + split_ranges: List[_Range], + ) -> List[List[torch.fx.Node]]: + """ + Replace the split node. It can either remove the split node if len(split_ranges) == 1, or simplify it + into a split with lesser sections if len(split_ranges) > 1. + + Returns the new `user_inputs_list`, with tuples replaced with new getitems from the newer split node. + """ + split_input = split_node.args[0] + split_dim = _get_dim(split_node) + if len(split_ranges) == 1: # We can completely eliminate the split node + split_items = [split_input] + else: + with graph.inserting_after(split_node): + new_split = graph.call_function( + torch.split, + args=( + split_input, + [r[1] - r[0] for r in split_ranges], + ), + kwargs={"dim": split_dim}, + ) + if is_node_meta_valid(split_input): # type: ignore[arg-type, union-attr] + new_split.meta["example_value"] = torch.split( + split_input.meta["example_value"], [r[1] - r[0] for r in split_ranges], dim=split_dim # type: ignore[union-attr] + ) + counters["inductor"]["scmerge_split_added"] += 1 + split_items = [] + with graph.inserting_after(new_split): + for i in range(len(split_ranges)): + getitem = graph.call_function(operator.getitem, args=(new_split, i)) + if is_node_meta_valid(new_split): + getitem.meta["example_value"] = new_split.meta["example_value"][ + i + ] + split_items.append(getitem) + # Now assign the right getitem to the right input + cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist() + new_user_inputs_list = [] + for user_inputs in user_inputs_list: + new_user_inputs = [] + for user_input in user_inputs: + if isinstance(user_input, tuple): + # Find the correct new getitem (present in split_items) + new_user_inputs.append( + split_items[ + split_ranges.index( + ( + cumulative_sizes[user_input[0]], + cumulative_sizes[user_input[1] + 1], + ) + ) + ] + ) + else: + new_user_inputs.append(user_input) + new_user_inputs_list.append(new_user_inputs) + return new_user_inputs_list # type: ignore[return-value] + + def replace_cat( + self, + graph: torch.fx.GraphModule, + split_node: torch.fx.Node, + next_users: List[torch.fx.Node], + user_inputs_list_new, + transform_params_list: List[List[_TransformParam]], + ): + split_dim = _get_dim(split_node) + split_users = split_node.users.keys() + new_cats = [] + for user_node, user_inputs_new, transform_params in zip( + next_users, user_inputs_list_new, transform_params_list + ): + if user_node.target not in {torch.cat, torch.stack}: + # Change the args and kwargs of non-cat/stack nodes. Replace old getitems (belonging to + # the original split node) with the newer getitems + next_cat_input = 0 + for input_node in user_node.all_input_nodes: + if input_node in split_users: + user_node.replace_input_with( + input_node, user_inputs_new[next_cat_input] + ) + next_cat_input += 1 + continue + + # Handle cat/stack user nodes + cat_dim = get_arg_value(user_node, 1, "dim") + user_inputs_new_transformed, user_inputs_new_transformed_meta = [], [] + # For `unsqueeze` transform, we will combine consecutive inputs with the same unsqueeze params, and stack them + to_stack, to_stack_meta = [], [] + stack_dim = None + with graph.inserting_before(user_node): + for user_input_new, transform_param in zip( + user_inputs_new, transform_params + ): + if not is_node_meta_valid(user_input_new): + log.debug("example value absent for node: %s", user_input_new) + return + # Apply transforms + ( + unflatten_params, + movedim_params, + unsqueeze_params, + flatten_params, + ) = transform_param + if unsqueeze_params and ( + stack_dim is None or stack_dim == unsqueeze_params[0] + ): + to_stack.append(user_input_new) + to_stack_meta.append(user_input_new.meta["example_value"]) + stack_dim = unsqueeze_params[0] + continue + elif to_stack: + stacked_input = graph.call_function( + torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} + ) + stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim) # type: ignore[arg-type, union-attr] + to_stack, to_stack_meta = [], [] + stack_dim = None + user_inputs_new_transformed.append(stacked_input) + user_inputs_new_transformed_meta.append( + stacked_input.meta["example_value"] + ) + if unsqueeze_params: + to_stack.append(user_input_new) + stack_dim = unsqueeze_params[0] + to_stack_meta.append(user_input_new.meta["example_value"]) + continue + + if unflatten_params: + user_input_new_meta = user_input_new.meta["example_value"] + user_input_new = graph.call_function( + torch.unflatten, args=(user_input_new, *unflatten_params) + ) + user_input_new.meta["example_value"] = torch.unflatten(user_input_new_meta, *unflatten_params) # type: ignore[arg-type, possibly-undefined, union-attr] + if movedim_params: + user_input_new_meta = user_input_new.meta["example_value"] + user_input_new = graph.call_function( + torch.movedim, args=(user_input_new, *movedim_params) + ) + user_input_new.meta["example_value"] = torch.movedim(user_input_new_meta, *movedim_params) # type: ignore[arg-type, possibly-undefined, union-attr] + if flatten_params: + user_input_new_meta = user_input_new.meta["example_value"] + user_input_new = graph.call_function( + torch.flatten, args=(user_input_new, *flatten_params) + ) + user_input_new.meta["example_value"] = torch.flatten(user_input_new_meta, *flatten_params) # type: ignore[arg-type, possibly-undefined, union-attr] + user_inputs_new_transformed.append(user_input_new) + user_inputs_new_transformed_meta.append( + user_input_new.meta["example_value"] + ) + if to_stack: + stacked_input = graph.call_function( + torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} + ) + stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim) # type: ignore[arg-type, union-attr] + user_inputs_new_transformed.append(stacked_input) + user_inputs_new_transformed_meta.append( + stacked_input.meta["example_value"] + ) + + with graph.inserting_after(user_node): + if len(user_inputs_new_transformed) > 1: + new_cat_node = graph.call_function( + torch.cat, + args=(user_inputs_new_transformed,), + kwargs={"dim": cat_dim}, + ) + new_cat_node.meta["example_value"] = torch.cat( + user_inputs_new_transformed_meta, dim=cat_dim + ) + counters["inductor"]["scmerge_cat_added"] += 1 + else: + new_cat_node = user_inputs_new_transformed[-1] + new_cat_node.meta[ + "example_value" + ] = user_inputs_new_transformed_meta[-1] + + if ( + user_node.target == torch.cat + and split_dim != cat_dim + and split_node.target == torch.split + ): + with graph.inserting_after(new_cat_node): + new_cat_node_meta = new_cat_node.meta["example_value"] + new_cat_node = graph.call_function( + torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1) + ) + new_cat_node.meta["example_value"] = torch.flatten(new_cat_node_meta, cat_dim, cat_dim + 1) # type: ignore[possibly-undefined, union-attr] + user_node.replace_all_uses_with(new_cat_node) + new_cats.append(new_cat_node) + + def erase_old_nodes( + self, + graph: torch.fx.GraphModule, + split_node: torch.fx.Node, + next_users: List[torch.fx.Node], + ): + to_remove = [split_node] + counters["inductor"]["scmerge_split_removed"] += 1 + to_remove.extend(split_node.users.keys()) + for next_user in next_users: + if next_user.target not in {torch.cat, torch.stack}: + continue + counters["inductor"]["scmerge_cat_removed"] += 1 + to_remove.append(next_user) + for node in reversed(to_remove): + if len(node.users.keys()) == 0: + graph.erase_node(node) + + +class UnbindCatRemover(SplitCatSimplifier): + """ + Helper class to merge Unbind->Cat/Stack. Many of the cases are similar to SplitCatSimplifier. + + Unbind can't be simplified like splits. So, we can only remove the unbind node. Other than this, + other cases like multiple users, additional args, dim mismatch are similar to `SplitCatSimplifier`, + hence we extend that class. + """ + + def remove_unbind( + self, + graph: torch.fx.Graph, + unbind_node: torch.fx.Node, + ): + if not is_node_meta_valid(unbind_node): + return + # we need to check if the getitem indices from unbind are consecutive and all go to the same cat node + # before we do the unbind remove, otherwise it will hit the error when we unbind part of them + getitem_indices = [] + for getitem_node in unbind_node.users.keys(): + getitem_indices.append(getitem_node.args[1]) + if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type] + getitem_indices + ) != len( + unbind_node.meta["example_value"] + ): + return + num_unbind = len(getitem_indices) + split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type] + + super().simplify(graph, unbind_node, split_sections) + + def get_simplified_split_ranges( + self, + split_sections: List[int], + next_users: List[torch.fx.Node], + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + ) -> Optional[List[_Range]]: + simplified_split_ranges = super().get_simplified_split_ranges( + split_sections, next_users, user_inputs_list + ) + if not simplified_split_ranges or len(simplified_split_ranges) != 1: + return None + return simplified_split_ranges + + def get_transform_params( + self, + split_node: torch.fx.Node, + next_users: List[torch.fx.Node], + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + ) -> Optional[List[List[_TransformParam]]]: + """ + Figure out what transforms are needed for each input to each cat node. + + Here is the rough transforms we apply: + + x -> unbind -> stack => x -> movedim + + x -> unbind -> cat => x -> movedim -> flatten + + When cat/stack nodes have additional args: + + addn ---| addn -> unsqueeze ---| + x -> unbind -> stack => x -> movedim -> cat + + addn ---| addn ---| + x -> unbind -> cat => x -> movedim -> flatten -> cat + + (Note application of these depends on the dims as well) + + + """ + split_dim = _get_dim(split_node) + transform_params_list: List[List[_TransformParam]] = [] + for user_node, user_inputs in zip(next_users, user_inputs_list): + cat_dim = get_arg_value(user_node, 1, "dim") or 0 + transform_params: List[_TransformParam] = [] + for user_input in user_inputs: + if isinstance(user_input, tuple): + # User input is coming from unbind + movedim_params = ( + (split_dim, cat_dim) if split_dim != cat_dim else None + ) + flatten_params = None + if user_node.target == torch.cat: + flatten_params = (cat_dim, cat_dim + 1) + transform_params.append( + (None, movedim_params, None, flatten_params) + ) + elif ( + user_node.target == torch.stack + ): # We need to unsqueeze inputs not coming through unbind into cat + transform_params.append((None, None, (cat_dim,), None)) + else: # Non-unbind inputs + transform_params.append((None, None, None, None)) + transform_params_list.append(transform_params) + return transform_params_list + + +class GetItem(CallFunction): + def __init__(self, arg, index, _users=1) -> None: + super().__init__(operator.getitem, arg, index, _users=_users) + + def find_anchor_nodes(self, ctx: MatchContext, searched: Set[torch.fx.Node]): + # We generally match GetItem with arg being an Arg(). So, we never return the anchor + # nodes as the stored node in ctx.pattern_to_node is returned. Here we override find_anchor_nodes + # to not use ctx.pattern_to_node + for pattern in self.flat_args_kwargs[0]: + if isinstance(pattern, PatternExpr): + for other_node in pattern.find_anchor_nodes(ctx, searched): + if not isinstance(other_node, torch.fx.Node): + continue + for node in other_node.users: + if node not in searched: + if self._match_fns(node): + yield node + searched.add(node) + + +@register_graph_pattern( + RepeatedExpr( + CallFunction( + torch.squeeze, + GetItem( + TorchSplit( + KeywordArg("split_input"), + KeywordArg("split_sizes"), + ), + Ignored(), + ), + KeywordArg("dim"), + _users=MULTIPLE, + ), + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +@register_graph_pattern( + RepeatedExpr( + CallFunction( + torch.squeeze, + GetItem( + TorchSplit( + KeywordArg("split_input"), + KeywordArg("split_sizes"), + ), + Ignored(), + ), + dim=KeywordArg("dim"), + _users=MULTIPLE, + ) + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +def merge_split_squeeze( + match: Match, split_input: torch.fx.Node, split_sizes: List[int], dim: int +): + graph = match.graph + split = next(node for node in match.nodes if node.target == torch.split) + if not all(s == 1 for s in split_sizes): + return + if isinstance(dim, Sequence): + return + next_users = find_next_users(split) + if not all(node.target == torch.squeeze for node in next_users): + return + with graph.inserting_before(match.output_node()): + unbind = graph.call_function( + torch.unbind, args=(split_input,), kwargs={"dim": dim} + ) + if is_node_meta_valid(split_input): + unbind.meta["example_value"] = torch.unbind( + split_input.meta["example_value"], dim=dim + ) + for item_index, getitem_node in sorted( + [ + (getitem_node.args[1], getitem_node) + for getitem_node in split.users.keys() + ] + ): + squeeze = next(iter(getitem_node.users.keys())) + new_get_item = graph.call_function( + operator.getitem, args=(unbind, item_index) + ) + squeeze.replace_all_uses_with(new_get_item) + new_get_item.meta.update(squeeze.meta) + graph.erase_node(squeeze) + graph.erase_node(getitem_node) + graph.erase_node(split) + counters["inductor"]["split_cat_pass"] += 1 + + +getitem_unbind = ListOf( + GetItem( + CallFunction( + torch.unbind, + KeywordArg("unbind_input"), + dim=KeywordArg("dim"), + _users=MULTIPLE, + ), + Ignored(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunction([torch.stack, torch.cat], getitem_unbind, Ignored(), _users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], getitem_unbind, dim=Ignored(), _users=MULTIPLE + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], tensors=getitem_unbind, dim=Ignored(), _users=MULTIPLE + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), +) +def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int): + unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + UnbindCatRemover().remove_unbind(match.graph, unbind_node) + + +getitem_split = ListOf( + CallFunction( + operator.getitem, + TorchSplit( + Ignored(), + KeywordArg("split_sections"), + ), + Ignored(), + _users=MULTIPLE, + ), + partial=True, +) + + +reshape_getitem_split = ListOf( + CallFunction( + torch.reshape, + CallFunction( + operator.getitem, + TorchSplit( + Ignored(), + KeywordArg("split_sections"), + ), + Ignored(), + _users=MULTIPLE, + ), + Arg(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + tensors=getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + getitem_split, + Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +def simplify_split_cat(match: Match, split_sections: List[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + split_node = next(node for node in match.nodes if node.target == torch.split) + SplitCatSimplifier().simplify(match.graph, split_node, split_sections) + + +# noqa: W605 +# ############pattern to be optimized is######### + +# split_node(dim=1) +# / \ ... / \ +# getitem getitem getitem getitem -> user=1 +# \ / \ / +# cat (user=mul, dim=1) cat(user=mul, dim=1) +# | \ | \ + +# ################after transformation############# + +# split_node(dim=1) +# / ... \ +# getitem getitem +# | \ | \ + + +def has_same_parent_node(node: torch.fx.Node): + # the input nodes of the node should come from the same parent + prev_node = None + for getitem in node.args[0]: # type: ignore[union-attr] + if getitem.target != operator.getitem: # type: ignore[union-attr] + return False + if prev_node is None: + prev_node = getitem.args[0] # type: ignore[union-attr] + else: + if getitem.args[0] != prev_node: + return False + return True + + +def remove_zeros(split_sections: List[int]): + """ + Remove zeros from the list and get the index mapping dict from getitem + in split node to getitem in new split node + """ + new_split_sections, index_mapping = [], {} + idx = 0 + for i in range(len(split_sections)): + if split_sections[i] > 0: + new_split_sections.append(split_sections[i]) + index_mapping[i] = idx + idx += 1 + + return new_split_sections, index_mapping + + +def is_sorted_and_consecutive(arr: List[int]) -> bool: + # check if the array is sorted + if arr == sorted(arr): + # check if the differences between adjacent elements are all 1 + return all(x[1] - x[0] == 1 for x in zip(arr, arr[1:])) + else: + return False + + +def calculate_fused_tensor_size(split_node: torch.fx.Node, indices: List[int]) -> int: + """ + Calculate the fused tensor size in the indices + """ + fused_tensor_size = 0 + for i in range(len(split_node.args[1])): # type: ignore[arg-type] + if i in indices: + fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] + return fused_tensor_size + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("merge_getitem_cat_pass"), +) +def merge_getitem_cat(match: Match, split_sections: List[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + graph = match.graph + split_node = next(node for node in match.nodes if node.target == torch.split) + split_input, split_size, split_dim = _get_split_args_default(split_node) + # if the cat and split have different dims, return + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + # 'immutable_list' object does not support mutation. Create a new copy of it + split_sections = list(split_sections) + for cat_user in next_users: + if cat_user.target == torch.cat: + cat_dim = get_arg_value(cat_user, 1, "dim") + # check the all getitems in the cat_user from the same node + # check the input of the cat has all getitem from the split + # check all getitem only has one single user + if ( + split_dim != cat_dim + or not has_same_parent_node(cat_user) + or not all(len(arg.users) == 1 for arg in cat_user.args[0]) # type: ignore[union-attr] + ): + continue + # find the index of getitems to be cated/stacked + indices = [] + for arg in cat_user.args[0]: # type: ignore[union-attr] + indices.append(arg.args[1]) # type: ignore[union-attr] + # the gettitems to be merged must be consecutive, otherwise + # returned sliced tensor could be wrong + if not is_sorted_and_consecutive(indices): + continue + # update the arg of cat user, only keep the first getitem + cat_user.update_arg(0, cat_user.args[0][0]) # type: ignore[index] + # calculate the fused tensor sizes in the indices + fused_tensor_size = 0 + for i in range(len(split_node.args[1])): # type: ignore[arg-type] + if i in indices: + fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] + # update the split sections + split_sections[indices[0]] = calculate_fused_tensor_size( + split_node, indices + ) + # padding others with zeros to keep the same dict size + for i in indices[1:]: + split_sections[i] = 0 + # remove all unused indexes in the split_node + new_split_sections, index_mapping = remove_zeros(split_sections) + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.split, + args=(split_input, split_sections), + kwargs={"dim": split_dim}, + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + # remove all unused getitem nodes + to_remove = [cat_user] + # dictionary keys changed during iteration + new_split_getitem_nodes = list(new_split_node.users.keys()) + for getitem_node in new_split_getitem_nodes: + if getitem_node.args[1] in indices[1:]: + to_remove.append(getitem_node) + # update meta data of getitem + elif getitem_node.args[1] == indices[0]: + cat_user.replace_all_uses_with(getitem_node) + getitem_node.meta.update(cat_user.meta) + else: + # update getitem index for new split node + getitem_node.update_arg(1, index_mapping[getitem_node.args[1]]) + graph.erase_node(split_node) + for getitem_node in to_remove: + graph.erase_node(getitem_node) + # update the split sections of new split node + new_split_node.update_arg(1, new_split_sections) + split_node = new_split_node + split_sections = new_split_sections + + counters["inductor"]["merge_getitem_cat_pass"] += 1 + + +# ############pattern to be optimized is######### + +# split_node(dim=1) -> user=multiple +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ \ / \ +# other_op /cat(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# split_node(dim=1) -> -> user=multiple +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ \ / \ +# other_op + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("mutate_cat_pass"), +) +def mutate_cat_node(match: Match, split_sections: List[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + graph = match.graph + split_node = next(node for node in match.nodes if node.target == torch.split) + split_input, split_size, split_dim = _get_split_args_default(split_node) + # if the cat and split have different dims, return + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + for cat_user in next_users: + if cat_user.target == torch.cat: + cat_dim = get_arg_value(cat_user, 1, "dim") or 0 + # check that all getitems in the cat_user from the same node + # check the input of the cat has all getitem from the split + if split_dim != cat_dim or not has_same_parent_node(cat_user): + continue + # find the index of getitems to be cat + indices, idx_to_getitem = [], {} + for getitem in cat_user.args[0]: # type: ignore[union-attr] + indices.append(getitem.args[1]) # type: ignore[union-attr] + idx_to_getitem[getitem.args[1]] = getitem # type: ignore[union-attr] + # the gettitems to be merged must be consecutive, otherwise + # returned sliced tensor could be wrong + if not is_sorted_and_consecutive(indices): + continue + # case 1: the cat uses all getitems from the split + if len(split_sections) == len(cat_user.args[0]): # type: ignore[arg-type] + # replace the users of the cat node to be the input of the split node + cat_user.replace_all_uses_with(split_node.args[0]) # type: ignore[arg-type] + # remove the cat node + graph.erase_node(cat_user) + counters["inductor"]["mutate_cat_pass"] += 1 + # case 2: the cat uses some getitems from the split + elif is_node_meta_valid(split_node.args[0]): # type: ignore[arg-type] + # check the split dim, and construct the slice tuple + start_fused_size = calculate_fused_tensor_size( + split_node, list(range(indices[0])) + ) + end_fused_size = start_fused_size + calculate_fused_tensor_size( + split_node, indices + ) + slice_list = [] + for i in range(len(split_node.args[0].meta["example_value"].shape)): # type: ignore[union-attr] + if i != split_dim: + slice_list.append(slice(None, None, None)) + else: + slice_list.append(slice(start_fused_size, end_fused_size, None)) + with graph.inserting_after(split_node): + slice_node = graph.call_function( + operator.getitem, + args=(split_node.args[0], tuple(slice_list)), + ) + cat_user.replace_all_uses_with(slice_node) + slice_node.meta.update(cat_user.meta) + + # remove the cat node + graph.erase_node(cat_user) + counters["inductor"]["mutate_cat_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.ops.aten.cat.default, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"), +) +def normalize_cat_default_aten(match: Match, *args, **kwargs): + cat_node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(cat_node, 0, "tensors") + cat_dim = get_arg_value(cat_node, 1, "dim") + if cat_dim is None: + cat_axis = cat_node.kwargs.get("axis") + if cat_axis is not None: + cat_dim = cat_axis + else: + cat_dim = 0 + if tensors is None or cat_dim is None: + log.debug("couldn't find cat args") + return + assert isinstance(tensors, (list, tuple)) + for tensor in itertools.chain([cat_node], tensors): + if "val" not in tensor.meta: + log.debug("val absent for node: %s", tensor) + return + + ndim = cat_node.meta["val"].dim() + + def is_empty_tensor(x: torch.fx.Node) -> bool: + # special case where torch.ops.aten.cat.default supports cat'ing with an empty tensor + x_shape = x.meta["val"].shape + return len(x_shape) == 1 and x_shape[0] == 0 + + assert all(ndim == x.meta["val"].dim() or is_empty_tensor(x) for x in tensors) + + if cat_dim < 0: # Normalize cat dim + cat_dim += ndim + + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.ops.aten.cat.default, + args=(tensors,), + kwargs={"dim": cat_dim}, + ) + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + graph.erase_node(cat_node) + counters["inductor"]["normalization_aten_pass"] += 1 + + +@register_graph_pattern( + CallFunction( + torch.ops.aten.cat, + ListOf(CallFunctionVarArgs(torch.ops.aten.unsqueeze)), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_aten_pass"), +) +def merge_unbind_stack_aten(match: Match, *args, **kwargs): + node = match.nodes[-1] + graph = match.graph + # pyre-fixme[6] + unsqueeze_nodes = list(node.args[0]) # type: ignore[arg-type] + cat_dim = get_arg_value(node, 1, "dim") + # check the unsqueeze nodes come from the select nodes + if not all( + get_arg_value(unsqueeze_node, 0, "input").target == torch.ops.aten.select + for unsqueeze_node in unsqueeze_nodes + ): + return + select_nodes = [ + get_arg_value(unsqueeze_node, 0, "input") for unsqueeze_node in unsqueeze_nodes + ] + parent_of_select_node = get_arg_value(select_nodes[0], 0, "input") + # check the target of select_nodes are the same + if not all( + select_node.target == torch.ops.aten.select for select_node in select_nodes + ): + return + # check the select nodes come from the same parent node + if not all( + get_arg_value(select_node, 0, "input") == parent_of_select_node + for select_node in select_nodes + ): + return + if len(unsqueeze_nodes) != len(select_nodes): + return + # check the select nodes have the same dim + if not all( + get_arg_value(select_node, 1, "dim") == cat_dim for select_node in select_nodes + ): + return + # check the select nodes have consecutive indices starting from 0 + if get_arg_value(select_nodes[0], 2, "index") != 0 or not is_sorted_and_consecutive( + [get_arg_value(select_node, 2, "index") for select_node in select_nodes] + ): + return + # check the users of parent of select node only from unsqueeze nodes that go to the cat node + # we simply check the number of users of the parent of select node + if len(parent_of_select_node.users.keys()) != len(node.args[0]): # type: ignore[arg-type] + return + node.replace_all_uses_with(parent_of_select_node) + graph.erase_node(node) + for unsqueeze_node in unsqueeze_nodes: + graph.erase_node(unsqueeze_node) + for select_node in select_nodes: + if len(select_node.users) == 0: + graph.erase_node(select_node) + counters["inductor"]["unbind_stack_aten_pass"] += 1 + + +def divide_into_consecutive_sublists(indices: List[int]) -> List[List[int]]: + n = len(indices) + if n <= 1: + return [indices] + + # Initialize the list of sublists + sublists = [] + + # Iterate over the indices + i = 0 + while i < n: + # Initialize the current sublist + sublist = [indices[i]] + + # Iterate over the remaining indices + j = i + 1 + while j < n and indices[j] == indices[j - 1] + 1: + # Add the next index to the current sublist + sublist.append(indices[j]) + j += 1 + + # Add the current sublist to the list of sublists + sublists.append(sublist) + # Move to the next index + i = j + + return sublists + + +def update_args_from_split_getitem( + graph: torch.fx.Graph, + node: torch.fx.Node, + getitem_indices: List[int], + parents_seen: List[torch.fx.Node], + new_cat_args: List[torch.fx.Node], + new_cat_args_meta: List[torch.fx.Node], + idx_to_getitems: Dict[int, torch.fx.Node], + threshold_to_cat: int = 2, +): + split_input, split_size, split_dim = _get_split_args_default(parents_seen[-1]) + # case 1: the number of getitems is the same as the split size, elimiate the split + if len(split_size) == len(getitem_indices) and is_sorted_and_consecutive( + getitem_indices + ): + # we can merge the getitems from the previous parent + new_cat_args.append(split_input) + new_cat_args_meta.append(split_input.meta["example_value"]) + else: + if len(getitem_indices) > 0: + # case 2: the number of getitems is smaller than the split size but larger than the threshold, and + # the indices of getitems are not all consecutive, we need to divide the indices into multiple groups + geitem_indices_sublist = divide_into_consecutive_sublists(getitem_indices) + for sublist in geitem_indices_sublist: + if len(sublist) >= threshold_to_cat: + # case 2: the number of getitems is smaller than the split size but larger than the threshold + # we need to slice the input of parent + start_fused_size = sum(split_size[: sublist[0]]) + end_fused_size = sum(split_size[: sublist[-1] + 1]) + slice_list = [] + for i in range(len(split_input.meta["example_value"].shape)): # type: ignore[union-attr] + if i != split_dim: + slice_list.append(slice(None, None, None)) + else: + slice_list.append( + slice(start_fused_size, end_fused_size, None) + ) + with graph.inserting_after(node): + slice_node = graph.call_function( + operator.getitem, + args=(split_input, tuple(slice_list)), + ) + slice_node.meta["example_value"] = split_input.meta[ + "example_value" + ][tuple(slice_list)] + new_cat_args.append(slice_node) + new_cat_args_meta.append(slice_node.meta["example_value"]) + else: + # case 3: the number of getitems is smaller than the threshold, no merge is done + # get the getitems based on the indexes + for i in sublist: + new_cat_args.append(idx_to_getitems[i]) + new_cat_args_meta.append( + idx_to_getitems[i].meta["example_value"] + ) + + +def reshape_cat_node( + graph: torch.fx.Graph, + cat_node: torch.fx.Node, + unbind_input: torch.fx.Node, + cat_dim: int, + unbind_dim: int, + cat_shape: torch.Size, +) -> torch.fx.Node: + if cat_dim != unbind_dim: + # construct the permute node args, which has the same shape as the slice node + # then it has the same dim as the unbind_input, i.e., shape of cat + 1 + with graph.inserting_after(cat_node): + permute_list = list(range(len(cat_shape) + 1)) + permute_list[unbind_dim], permute_list[cat_dim] = ( + permute_list[cat_dim], + permute_list[unbind_dim], + ) + permute_node = graph.call_function( + torch.permute, + args=(unbind_input, permute_list), + ) + permute_node.meta["example_value"] = torch.permute( + unbind_input.meta["example_value"], permute_list + ) # type: ignore[arg-type] + else: + permute_node = unbind_input + with graph.inserting_after(permute_node): + reshape_node = graph.call_function( + torch.reshape, args=(permute_node, tuple(cat_shape)) + ) + reshape_node.meta["example_value"] = torch.reshape( + permute_node.meta["example_value"], tuple(cat_shape) + ) # type: ignore[arg-type] + return reshape_node + + +def update_args_from_unbind_getitem( + graph: torch.fx.Graph, + node: torch.fx.Node, # cat or stack node + getitem_indices: List[int], + parents_seen: List[torch.fx.Node], + new_cat_args: List[torch.fx.Node], + new_cat_args_meta: List[torch.fx.Node], + idx_to_getitems: Dict[int, torch.fx.Node], + threshold_to_cat: int = 2, +): + unbind_input = get_arg_value(parents_seen[-1], 0, "input") # split or unbind input + unbind_dim = get_arg_value(parents_seen[-1], 1, "dim") # split or unbind dim + cat_dim = get_arg_value(node, 1, "dim") # cat or stack dim + # case 1: the number of getitems is the same as the split size, elimiate the split + size = list(unbind_input.meta["example_value"].shape)[unbind_dim] + if size == len(getitem_indices): + cat_shape = torch.cat( + [idx_to_getitems[i].meta["example_value"] for i in getitem_indices], + dim=cat_dim, + ).shape + # we can merge the getitems from the previous parent + reshape_node = reshape_cat_node( + graph, node, unbind_input, cat_dim, unbind_dim, cat_shape + ) + new_cat_args.append(reshape_node) + new_cat_args_meta.append(reshape_node.meta["example_value"]) + elif len(getitem_indices) >= threshold_to_cat and is_sorted_and_consecutive( + getitem_indices + ): + # case 2: the number of getitems is smaller than the split size but larger than the threshold + # we need to slice the input of parent + cat_shape = torch.cat( + [idx_to_getitems[i].meta["example_value"] for i in getitem_indices], + dim=cat_dim, + ).shape + slice_list = [] + for i in range(len(cat_shape) + 1): + if i != unbind_dim: + slice_list.append(slice(None, None, None)) # start, end, step + else: + slice_list.append( + slice(getitem_indices[0], getitem_indices[-1] + 1, None) + ) + with graph.inserting_after(node): + slice_node = graph.call_function( + operator.getitem, + args=(unbind_input, tuple(slice_list)), + ) + slice_node.meta["example_value"] = torch.narrow( + unbind_input.meta["example_value"], + unbind_dim, + getitem_indices[0], + getitem_indices[-1] - getitem_indices[0] + 1, + ) + reshape_node = reshape_cat_node( + graph, node, slice_node, cat_dim, unbind_dim, cat_shape + ) + new_cat_args.append(reshape_node) + new_cat_args_meta.append(reshape_node.meta["example_value"]) + else: + # case 3: the number of getitems is smaller than the threshold, no merge is done + # get the getitems based on the indexes + for i in getitem_indices: + new_cat_args.append(idx_to_getitems[i]) + new_cat_args_meta.append(idx_to_getitems[i].meta["example_value"]) + + +def construct_cat_args( + graph: torch.fx.Graph, + cat_or_stack_node: torch.fx.Node, + inputs: List[torch.fx.Node], + split_or_unbind_node: torch.fx.Node, + threshold_to_cat: int = 2, + run_update_func: Callable = update_args_from_split_getitem, # type: ignore[type-arg] +) -> Tuple[List[torch.fx.Node], List[torch.Tensor]]: + new_cat_args, parents_seen, getitem_indices, idx_to_getitems = [], [], [], {} # type: ignore[var-annotated] + new_cat_args_meta = [] # type: ignore[var-annotated] + for input in inputs: + if input.target != operator.getitem: + # update the last arg based on getitem_indices and parents_seens + if len(parents_seen) > 0: + run_update_func( # type: ignore[arg-type, union-attr] + graph, + cat_or_stack_node, + getitem_indices, + parents_seen, + new_cat_args, + new_cat_args_meta, + idx_to_getitems, # type: ignore[arg-type, union-attr] + threshold_to_cat, + ) + new_cat_args.append(input) + new_cat_args_meta.append(input.meta["example_value"]) + # reset the indices array + getitem_indices, idx_to_getitems = [], {} + else: + # get the parent node of the getitem input + parent, idx = input.args[0], input.args[1] # type: ignore[union-attr] + if parent.target != split_or_unbind_node.target: # type: ignore[union-attr] + new_cat_args.append(input) + new_cat_args_meta.append(input.meta["example_value"]) + continue + # cannot use parents_seen to check since the first item could be non getitem node + if len(parents_seen) == 0: + parents_seen.append(parent) + idx_to_getitems[idx] = input + getitem_indices.append(idx) + # case: we only have one getitem input, and it is in the last position + if input == inputs[-1]: + new_cat_args.append(input) + new_cat_args_meta.append(input.meta["example_value"]) + continue + # if it is the last input in the tensors, we also check if it can be optimized + if parent != parents_seen[-1] or input == inputs[-1]: + if input == inputs[-1]: + getitem_indices.append(idx) + idx_to_getitems[idx] = input + run_update_func( # type: ignore[arg-type, union-attr] + graph, + cat_or_stack_node, + getitem_indices, + parents_seen, + new_cat_args, + new_cat_args_meta, + idx_to_getitems, # type: ignore[arg-type, union-attr] + threshold_to_cat, + ) + # reset the indices array for the next parent + # remember to add the last element since it is the first + # item in this round of parent + # add the parent to the list of seen parents + parents_seen.append(parent) + getitem_indices, idx_to_getitems = [idx], {idx: input} + else: + getitem_indices.append(idx) + idx_to_getitems[idx] = input + return new_cat_args, new_cat_args_meta + + +def remove_split_unbind_children(graph: torch.fx.Graph, inputs: List[torch.fx.Node]): + nodes = set() + for input in inputs: + if input.target == operator.getitem: + nodes.add(input.args[0]) # type: ignore[union-attr] + if len(input.users.keys()) == 0: + graph.erase_node(input) + # check the split node to remove if it has no users + for node in nodes: + if len(node.users.keys()) == 0: # type: ignore[union-attr] + graph.erase_node(node) # type: ignore[arg-type] + + +# ############pattern to be optimized is######### + +# split_node(dim=1) -> user=multiple +# / \ ... / \ +# other inputs getitem getitem getitem -> user=multiple +# \ / \ +# cat(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# split_node(dim=1) other inputs -> -> user=multiple +# / \ +# cat (user=mul, dim=1, split_node) + + +@register_graph_pattern( + CallFunctionVarArgs(torch.cat, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("split_cat_to_slices_pass"), +) +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_to_slices_pass"), +) +def split_cat_to_slices(match: Match, split_sections: List[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + split_nodes = [node for node in match.nodes if node.target == torch.split] + if split_nodes: + split_node = next(node for node in split_nodes) + else: + # Handle the case where there are no nodes with a target of torch.split + return + split_dim = get_arg_value(split_node, 2, "dim") or 0 + graph = match.graph + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "split_cat_to_slices_pass" + ].get("threshold_to_cat", 10) + # get the cat_node and check its inputs and meta data + next_users = find_next_users(split_node) + for cat_node in next_users: + if cat_node.target != torch.cat or not is_node_meta_valid(cat_node): + continue + cat_inputs = get_arg_value(cat_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, _ = construct_cat_args( + graph, + cat_node, + cat_inputs, + split_node, + threshold_to_cat, + update_args_from_split_getitem, + ) + # At least one node would be in the returned new_cat_args + # case 1: if new cat args has length 1, we can remove the cat node + if len(new_cat_args) == 1: + cat_node.replace_all_uses_with(new_cat_args[0]) + # remove inputs of cat_node if they have no users + cat_inputs = cat_node.args[0] # type: ignore[union-attr] + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] + counters["inductor"]["split_cat_to_slices_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(cat_inputs): + new_args = (new_cat_args,) + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.cat, + args=new_args, + # split and cat have the same dim + kwargs={"dim": split_dim}, + ) + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + # remove the cat node + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) + counters["inductor"]["split_cat_to_slices_pass"] += 1 + + +# ############pattern to be optimized is######### + +# unbind(dim=0) -> user=multiple +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ / \ +# cat(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# input_of_unbind +# | \ +# slice +# | +# view +# | + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_unbind, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("unbind_cat_to_view_pass"), +) +def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int): + unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + graph = match.graph + # get the cat_node and check its inputs and meta data + next_users = find_next_users(unbind_node) + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "unbind_cat_to_view_pass" + ].get("threshold_to_cat", 10) + # get the cat_node and check its inputs and meta data + for cat_node in next_users: + if cat_node.target != torch.cat or not is_node_meta_valid(cat_node): + continue + inputs = get_arg_value(cat_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, new_cat_args_meta = construct_cat_args( + graph, + cat_node, + inputs, + unbind_node, + threshold_to_cat, + update_args_from_unbind_getitem, + ) + # get the view shape + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + cat_node.replace_all_uses_with(new_cat_args[0]) + # remove inputs of cat_node if they have no users + cat_inputs = cat_node.args[0] # type: ignore[union-attr] + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] + counters["inductor"]["unbind_cat_to_view_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + # get the view shape + cat_dim = get_arg_value(cat_node, 1, "dim") + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.cat, + args=(new_cat_args,), + kwargs={"dim": cat_dim}, + ) + new_cat_node.meta["example_value"] = torch.cat(new_cat_args_meta, dim=cat_dim) # type: ignore[arg-type] + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + # remove inputs of cat_node if they have no users + cat_inputs = cat_node.args[0] # type: ignore[union-attr] + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] + counters["inductor"]["unbind_cat_to_view_pass"] += 1 + + +def reshape_cat_node_to_stack( + graph: torch.fx.Graph, + cat_node: torch.fx.Node, + stack_node: torch.fx.Node, + split_or_unbind_dim: int, +) -> None: + # reshape the cat node to the stack node shape + stack_shape = stack_node.meta["example_value"].shape + stack_dim = _get_dim(stack_node) + if stack_dim != split_or_unbind_dim: + # case 1: the stack dim is not the same as the split dim + # we need to reshape the split input before we do the reshape + reshape_list = list(stack_shape) + reshape_list[stack_dim], reshape_list[split_or_unbind_dim] = ( + reshape_list[split_or_unbind_dim], + reshape_list[stack_dim], + ) + reshape_node = graph.call_function( + torch.reshape, + args=(cat_node, tuple(reshape_list)), + ) + reshape_node.meta["example_value"] = torch.reshape( + cat_node.meta["example_value"], tuple(reshape_list) + ) + permute_list = list(range(len(stack_shape))) + permute_list[stack_dim], permute_list[split_or_unbind_dim] = ( + permute_list[split_or_unbind_dim], + permute_list[stack_dim], + ) + permute_node = graph.call_function( + torch.permute, + args=(reshape_node, permute_list), + ) + permute_node.meta["example_value"] = torch.permute( + reshape_node.meta["example_value"], permute_list + ) + else: + # case 2: the stack dim is the same as the split dim + # we can directly reshape the split input + permute_node = cat_node + reshape_node = graph.call_function( + torch.Tensor.view, + args=(permute_node, *stack_shape), # type: ignore[arg-type] + ) + stack_node.replace_all_uses_with(reshape_node) + reshape_node.meta.update(stack_node.meta) + stack_inputs = stack_node.args[0] # type: ignore[union-attr] + # remove stack node + graph.erase_node(stack_node) + # check the input of stack node, and remove nodes that have no users + remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] + + +def convert_reshape_cat_arg_to_stack( + graph: torch.fx.Graph, + cat_node: torch.fx.Node, + stack_node: torch.fx.Node, + stack_node_shape: torch.Size, + stack_dim: int, + split_dim: int, +) -> torch.fx.Node: + # reshape the cat node to the stack node shape + cat_shape = cat_node.meta["example_value"].shape + if stack_dim != split_dim: + permute_list = list(range(len(cat_shape))) + permute_list[stack_dim], permute_list[split_dim] = ( + permute_list[split_dim], + permute_list[stack_dim], + ) + permute_node = graph.call_function( + torch.permute, + args=(cat_node, permute_list), + ) + permute_node.meta["example_value"] = torch.permute( + cat_node.meta["example_value"], permute_list + ) + else: + permute_node = cat_node + reshape_node = graph.call_function( + torch.Tensor.view, + args=(permute_node, tuple(stack_node_shape)), # type: ignore[arg-type] + ) + reshape_node.meta["example_value"] = torch.Tensor.view( + permute_node.meta["example_value"], tuple(stack_node_shape) # type: ignore[arg-type] + ) + return reshape_node + + +# ############pattern to be optimized is######### +# | | +# split split (dim=1) +# / \ / \ +# getitem ... getitem other ops +# \ | / / +# stack(user=mul, dim=1 or 2) -> can be different dim +# | + +# ################after transformation############# + +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ / +# cat(user=mul, dim=1) cat_other_opts +# \ / +# cat +# | +# view +# | + + +@register_graph_pattern( + CallFunction( + torch.stack, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_stack_to_cats_pass"), +) +def split_stack_to_cats(match: Match, split_sections: List[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + split_node = next(node for node in match.nodes if node.target == torch.split) + split_dim = get_arg_value(split_node, 2, "dim") or 0 + graph = match.graph + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "split_stack_to_cats_pass" + ].get("threshold_to_cat", 10) + # get the stack_node and check its inputs and meta data + next_users = find_next_users(split_node) + for stack_node in next_users: + if stack_node.target != torch.stack or not is_node_meta_valid(stack_node): + continue + inputs = get_arg_value(stack_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, new_cat_args_meta = construct_cat_args( + graph, + stack_node, + inputs, + split_node, + threshold_to_cat, + update_args_from_split_getitem, + ) + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + reshape_cat_node_to_stack(graph, new_cat_args[0], stack_node, split_dim) + counters["inductor"]["split_stack_to_cats_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + with graph.inserting_after(stack_node): + cat_node = graph.call_function( + torch.cat, + args=(new_cat_args,), + kwargs={"dim": split_dim}, + ) + cat_node.meta["example_value"] = torch.cat( # type: ignore[arg-type] + new_cat_args_meta, dim=split_dim + ) + reshape_cat_node_to_stack(graph, cat_node, stack_node, split_dim) + counters["inductor"]["split_stack_to_cats_pass"] += 1 + + +# ############pattern to be optimized is######### + +# unbind(dim=1) -> user=multiple +# \ ... / \ +# others getitem getitem getitem -> user=multiple +# \ \ / \ +# stack(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# input_of_unbind +# | \ +# slice +# | +# view others +# | / +# stack +# | + + +@register_graph_pattern( + CallFunction( + torch.stack, + getitem_unbind, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_to_slices_pass"), +) +def unbind_stack_to_slices(match: Match, unbind_input: torch.fx.Node, dim: int): + unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + graph = match.graph + # get the cat_node and check its inputs and meta data + next_users = find_next_users(unbind_node) + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "unbind_stack_to_slices_pass" + ].get("threshold_to_cat", 10) + # get the cat_node and check its inputs and meta data + for stack_node in next_users: + if stack_node.target != torch.stack or not is_node_meta_valid(stack_node): + continue + inputs = get_arg_value(stack_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, new_cat_args_meta = construct_cat_args( + graph, + stack_node, + inputs, + unbind_node, + threshold_to_cat, + update_args_from_unbind_getitem, + ) + unbind_dim = get_arg_value(unbind_node, 1, "dim") or 0 + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + reshape_cat_node_to_stack(graph, new_cat_args[0], stack_node, unbind_dim) + counters["inductor"]["unbind_stack_to_slices_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + # get the view shape + cat_dim = get_arg_value(stack_node, 1, "dim") + with graph.inserting_after(stack_node): + new_cat_node = graph.call_function( + torch.cat, + args=(new_cat_args,), + kwargs={"dim": cat_dim}, + ) + new_cat_node.meta["example_value"] = torch.cat( + new_cat_args_meta, dim=cat_dim + ) + reshape_cat_node_to_stack(graph, new_cat_node, stack_node, unbind_dim) + counters["inductor"]["unbind_stack_to_slices_pass"] += 1 + + +# ############pattern to be optimized is######### +# input +# | +# split(dim=1) -> user=multiple +# \ \ +# others getitem getitem +# \ \ / +# reshape reshape reshape other_op +# \ \ / / +# stack(user=mul, dim=0) +# | + +# ################after transformation############# +# input +# | +# permute +# | +# reshape others +# | / +# cat (dim=0) +# | + + +def get_view_shape_list(cat_arg: torch.fx.Node, stack_dim: int) -> List[int]: + # cat_arg must be the split input + view_shape_list = [] + for user in cat_arg.users.keys(): + if user.target == torch.split: + for getitem in user.users.keys(): + if getitem.target == operator.getitem: + reshape_user = [ + user + for user in getitem.users.keys() + if user.target == torch.reshape + ] + if len(reshape_user) > 0: + view_shape_list = list( + reshape_user[0] + .meta["example_value"] + .unsqueeze(stack_dim) + .shape + ) + view_shape_list[stack_dim] = -1 + return view_shape_list + return view_shape_list + + +@register_graph_pattern( + CallFunction( + torch.stack, + reshape_getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("move_reshape_out_of_split_stack_pass"), +) +def move_reshape_out_of_split_stack(match: Match, *args, **kwargs): + split_node = next(node for node in match.nodes if node.target == torch.split) + split_dim = _get_dim(split_node) + split_users = list(split_node.users.keys()) + stack_nodes = [node for node in match.nodes if node.target == torch.stack] + graph = match.graph + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "move_reshape_out_of_split_stack_pass" + ].get("threshold_to_cat", 10) + for stack_node in stack_nodes: + if not is_node_meta_valid(stack_node): + log.debug("example value absent for node: %s", stack_node) + continue + stack_dim = _get_dim(stack_node) + stack_inputs = get_arg_value(stack_node, 0, "tensors") # type: ignore[union-attr] + inputs = [] + for stack_input in stack_inputs: + if stack_input.target != torch.reshape: + inputs.append(stack_input) + else: + inputs.append(stack_input.args[0]) # type: ignore[union-attr] + new_cat_args, new_cat_args_meta = construct_cat_args( + graph, + stack_node, + inputs, + split_node, + threshold_to_cat, + update_args_from_split_getitem, + ) + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + reshape_node = convert_reshape_cat_arg_to_stack( + graph, + new_cat_args[0], + stack_node, + stack_node.meta["example_value"].shape, + stack_dim, + split_dim, + ) + stack_node.replace_all_uses_with(reshape_node) + # remove stack node + graph.erase_node(stack_node) + # check the input of stack node, and remove nodes that have no users + remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] + remove_split_unbind_children(graph, split_users) # type: ignore[arg-type] + counters["inductor"]["move_reshape_out_of_split_stack_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + # decompose the cat args into multiple stack nodes, i.e., we stack + # all the nodes exist in the stack inputs and reshape the rest followed by a cat + stack_node_input, stack_node_input_meta, cat_inputs = [], [], [] # type: ignore[var-annotated] + for cat_arg in new_cat_args: + if cat_arg not in stack_inputs: + if len(stack_node_input) > 0: + with graph.inserting_after(stack_node): + decomposed_stack_node = graph.call_function( + torch.stack, + args=(stack_node_input,), + kwargs={"dim": stack_dim}, + ) + decomposed_stack_node.meta["example_value"] = torch.stack( + stack_node_input_meta, dim=stack_dim + ) + cat_inputs.append(decomposed_stack_node) + # cat_arg must be the split input + view_shape_list = get_view_shape_list(cat_arg, stack_dim) + stack_node_shape = torch.reshape(cat_arg.meta["example_value"], tuple(view_shape_list)).shape # type: ignore[union-attr] + cat_inputs.append( + convert_reshape_cat_arg_to_stack( + graph, + cat_arg, + stack_node, + stack_node_shape, + stack_dim, + split_dim, + ) + ) + stack_node_input, stack_node_input_meta = [], [] + else: + stack_node_input.append(cat_arg) + stack_node_input_meta.append(cat_arg.meta["example_value"]) + + if len(stack_node_input) > 0: + with graph.inserting_after(stack_node): + decomposed_stack_node = graph.call_function( + torch.stack, + args=(stack_node_input,), + kwargs={"dim": stack_dim}, + ) + decomposed_stack_node.meta["example_value"] = torch.stack( + stack_node_input_meta, dim=stack_dim + ) + cat_inputs.append(decomposed_stack_node) + + with graph.inserting_after(stack_node): + cat_node = graph.call_function( + torch.cat, + args=(cat_inputs,), + kwargs={"dim": stack_dim}, + ) + stack_node.replace_all_uses_with(cat_node) + cat_node.meta.update(stack_node.meta) + graph.erase_node(stack_node) + remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] + remove_split_unbind_children(graph, split_users) # type: ignore[arg-type] + counters["inductor"]["move_reshape_out_of_split_stack_pass"] += 1