Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/exported_program.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/pass_base.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/verifier.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/wrappers.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/case.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py +231 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py +94 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_sym_size_ops_pass.py +18 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py +71 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/union.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/upgrade.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/schema.yaml +389 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/upgrade.py +201 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/wrappers.py +401 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/_numeric_suite_fx.py +1025 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/mappings.py +761 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/n_shadows_utils.py +1311 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/ns_types.py +64 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/pattern_utils.py +200 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/qconfig_multi_mapping.py +243 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_learnable_fake_quantize.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/observer.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/qconfig_mapping.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_jit.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/rewrite.py +600 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/_tensorboard_vis.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/config.py +6 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/recording.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/validator.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/recording.py +458 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/refinement_types.py +16 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/exported_program.cpython-311.pyc
ADDED
|
Binary file (1.57 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/pass_base.cpython-311.pyc
ADDED
|
Binary file (27.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (20.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/verifier.cpython-311.pyc
ADDED
|
Binary file (23.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/wrappers.cpython-311.pyc
ADDED
|
Binary file (7.31 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/case.cpython-311.pyc
ADDED
|
Binary file (9.12 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-311.pyc
ADDED
|
Binary file (1.52 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (333 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-311.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-311.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-311.pyc
ADDED
|
Binary file (1.65 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-311.pyc
ADDED
|
Binary file (7.57 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import operator
|
| 3 |
+
import traceback
|
| 4 |
+
from functools import partial
|
| 5 |
+
from typing import Callable, Dict, List, NamedTuple, Set
|
| 6 |
+
|
| 7 |
+
import sympy
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.fx
|
| 11 |
+
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, ProxyValue, PassResult
|
| 12 |
+
from torch.utils._sympy.value_ranges import ValueRanges
|
| 13 |
+
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
__all__ = ["InputDim"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class InputDim(NamedTuple):
|
| 20 |
+
input_name: str
|
| 21 |
+
dim: int
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _convert_to_int(val):
|
| 25 |
+
# Convert simple sympy Integers into concrete int
|
| 26 |
+
if val == sympy.oo:
|
| 27 |
+
return math.inf
|
| 28 |
+
if val == -sympy.oo:
|
| 29 |
+
return -math.inf
|
| 30 |
+
if isinstance(val, sympy.Integer):
|
| 31 |
+
return int(val)
|
| 32 |
+
raise RuntimeError(
|
| 33 |
+
"Export constraints cannot be non-integer expressions"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _convert_range_to_int(range: ValueRanges):
|
| 38 |
+
assert isinstance(range, ValueRanges)
|
| 39 |
+
min_val = _convert_to_int(range.lower)
|
| 40 |
+
max_val = _convert_to_int(range.upper)
|
| 41 |
+
return min_val, max_val
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class _AddRuntimeAssertionsForInlineConstraintsPass(_ExportPassBaseDeprecatedDoNotUse):
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
range_constraints: Dict[sympy.Symbol, ValueRanges],
|
| 48 |
+
):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints
|
| 51 |
+
self._asserts_generated_unbacked_symbols: Set[sympy.Symbol] = set()
|
| 52 |
+
self.counter = 0
|
| 53 |
+
|
| 54 |
+
def _assert_range_constraint(self, proxy, lower, upper, assert_msg):
|
| 55 |
+
if lower > -math.inf:
|
| 56 |
+
self._insert_assert_async(operator.ge, proxy, lower, assert_msg)
|
| 57 |
+
|
| 58 |
+
if upper < math.inf:
|
| 59 |
+
self._insert_assert_async(operator.le, proxy, upper, assert_msg)
|
| 60 |
+
|
| 61 |
+
def _insert_assert_async(self, operator, lower, upper, assert_msg):
|
| 62 |
+
"""
|
| 63 |
+
Inserts assert_async call_function nodes in the graph. This function is
|
| 64 |
+
called **during** the interpreter-based pass.
|
| 65 |
+
"""
|
| 66 |
+
self.counter += 1
|
| 67 |
+
cmp = super().call_operator(operator, (lower, upper), {}, self._create_dummy_node_metadata())
|
| 68 |
+
cmp_tensor = super().call_operator(torch.ops.aten.scalar_tensor.default, (cmp,), {}, self._create_dummy_node_metadata())
|
| 69 |
+
super().call_operator(
|
| 70 |
+
torch.ops.aten._assert_async.msg,
|
| 71 |
+
(cmp_tensor, assert_msg),
|
| 72 |
+
{},
|
| 73 |
+
self._create_dummy_node_metadata(),
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def call_operator(self, op, args, kwargs, meta) -> ProxyValue:
|
| 77 |
+
ret = super().call_operator(op, args, kwargs, meta)
|
| 78 |
+
if "val" not in meta:
|
| 79 |
+
return ret
|
| 80 |
+
|
| 81 |
+
val = meta["val"]
|
| 82 |
+
|
| 83 |
+
# In general, we may have to deal the case such as: ret[1].shape[0].
|
| 84 |
+
# We need first find out what symbols require assertion, then we need to follow the path
|
| 85 |
+
# from ret to the symbol, construct the proxies along the way and construct the messages
|
| 86 |
+
# piece-wise at the same time.
|
| 87 |
+
#
|
| 88 |
+
# We use post-order traversal to collect all the proxies callbacks needed, construct
|
| 89 |
+
# the error message callbacks, and at the top-level traversal tree we execute all the callbacks.
|
| 90 |
+
# We need the callbacks because, in order to call the function to create a proxy for shape[0], we
|
| 91 |
+
# need the proxy for shape, which further requires the proxy for ret[1], etc.
|
| 92 |
+
def add_assertions(val):
|
| 93 |
+
call_backs: List[Callable] = []
|
| 94 |
+
messages: List[str] = []
|
| 95 |
+
if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)):
|
| 96 |
+
symbol = val.node.expr
|
| 97 |
+
if symbol in self.existing_inline_assertions:
|
| 98 |
+
return call_backs, messages
|
| 99 |
+
if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols(symbol):
|
| 100 |
+
if symbol in self._asserts_generated_unbacked_symbols:
|
| 101 |
+
return call_backs, messages
|
| 102 |
+
# We only care about unbacked symints for these inline
|
| 103 |
+
# constraints, which are prefixed with 'u'
|
| 104 |
+
constraint = self.range_constraints[symbol]
|
| 105 |
+
min_val, max_val = _convert_range_to_int(constraint)
|
| 106 |
+
assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]."
|
| 107 |
+
call_backs.append(
|
| 108 |
+
partial(self._assert_range_constraint, lower=min_val, upper=max_val)
|
| 109 |
+
)
|
| 110 |
+
messages.append(assert_msg)
|
| 111 |
+
self._asserts_generated_unbacked_symbols.add(symbol)
|
| 112 |
+
|
| 113 |
+
elif isinstance(val, torch.Tensor):
|
| 114 |
+
for i, sym in enumerate(val.shape):
|
| 115 |
+
cbs, msgs = add_assertions(sym)
|
| 116 |
+
for cb, msg in zip(cbs, msgs):
|
| 117 |
+
def sym_size_cb(proxy, assert_msg, dim):
|
| 118 |
+
dim_proxy = super(
|
| 119 |
+
_AddRuntimeAssertionsForInlineConstraintsPass,
|
| 120 |
+
self
|
| 121 |
+
).call_operator(
|
| 122 |
+
torch.ops.aten.sym_size.int,
|
| 123 |
+
(proxy, dim),
|
| 124 |
+
{},
|
| 125 |
+
self._create_dummy_node_metadata(),
|
| 126 |
+
)
|
| 127 |
+
cb(proxy=dim_proxy, assert_msg=assert_msg)
|
| 128 |
+
call_backs.append(partial(sym_size_cb, dim=i))
|
| 129 |
+
messages.append(f".shape[{i}]" + msg)
|
| 130 |
+
return call_backs, messages
|
| 131 |
+
|
| 132 |
+
callbacks, messages = add_assertions(val)
|
| 133 |
+
for cb, msg in zip(callbacks, messages):
|
| 134 |
+
cb(proxy=ret, assert_msg=f"{ret.node}" + msg)
|
| 135 |
+
return ret
|
| 136 |
+
|
| 137 |
+
def call(self, graph_module):
|
| 138 |
+
self.existing_inline_assertions = _get_existing_inline_assertions(
|
| 139 |
+
graph_module, self.range_constraints
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Add runtime asserts for inline constraints
|
| 143 |
+
val = super().call(graph_module)
|
| 144 |
+
|
| 145 |
+
# Sometimes this pass would return a wrong graph where we have mismatched
|
| 146 |
+
# node names in signature. Before we fix it, let's just skip it.
|
| 147 |
+
if self.counter == 0 and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass:
|
| 148 |
+
return PassResult(graph_module, False)
|
| 149 |
+
|
| 150 |
+
# Populate the stack trace with dummy vals to respect IR
|
| 151 |
+
for node in val.graph_module.graph.nodes:
|
| 152 |
+
if not node.meta.get("stack_trace", None):
|
| 153 |
+
node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1))
|
| 154 |
+
|
| 155 |
+
return PassResult(val.graph_module, val.modified)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _get_existing_inline_assertions(
|
| 159 |
+
graph_module: torch.fx.GraphModule,
|
| 160 |
+
range_constraints: Dict[sympy.Symbol, ValueRanges],
|
| 161 |
+
) -> Dict[sympy.Symbol, ValueRanges]:
|
| 162 |
+
existing_inline_assertions: Dict[sympy.Symbol, ValueRanges] = {}
|
| 163 |
+
|
| 164 |
+
for module in graph_module.modules():
|
| 165 |
+
if not isinstance(module, torch.fx.GraphModule):
|
| 166 |
+
continue
|
| 167 |
+
|
| 168 |
+
# Find all the existing inline assertions. They will look something like:
|
| 169 |
+
# %_local_scalar_dense = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%arg1_1,), kwargs = {})
|
| 170 |
+
# %ge = call_function[target=operator.ge](args = (%_local_scalar_dense, 0), kwargs = {})
|
| 171 |
+
# %scalar_tensor = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%ge,), kwargs = {})
|
| 172 |
+
# %_assert_async = call_function[target=torch.ops.aten._assert_async.msg](args = (%scalar_tensor, "..."), kwargs = {})
|
| 173 |
+
for node in module.graph.nodes:
|
| 174 |
+
if node.target != torch.ops.aten._assert_async.msg:
|
| 175 |
+
continue
|
| 176 |
+
|
| 177 |
+
scalar_tensor_arg = node.args[0]
|
| 178 |
+
if not (
|
| 179 |
+
scalar_tensor_arg.op == "call_function" and
|
| 180 |
+
scalar_tensor_arg.target == torch.ops.aten.scalar_tensor.default
|
| 181 |
+
):
|
| 182 |
+
continue
|
| 183 |
+
|
| 184 |
+
compare_arg = scalar_tensor_arg.args[0]
|
| 185 |
+
if not (
|
| 186 |
+
compare_arg.op == "call_function" and
|
| 187 |
+
compare_arg.target in (operator.le, operator.ge) and
|
| 188 |
+
len(compare_arg.args) == 2
|
| 189 |
+
):
|
| 190 |
+
continue
|
| 191 |
+
|
| 192 |
+
compare_op = compare_arg.target
|
| 193 |
+
maybe_symint_arg, compare_int = compare_arg.args
|
| 194 |
+
|
| 195 |
+
# x >= 0 will sometimes be canonicalized to -x <= 0, so in some
|
| 196 |
+
# cases the operation before the comparison is to multiply by -1. We
|
| 197 |
+
# can undo the canonicalization here
|
| 198 |
+
if (
|
| 199 |
+
maybe_symint_arg.op == "call_function" and
|
| 200 |
+
maybe_symint_arg.target == operator.mul and
|
| 201 |
+
maybe_symint_arg.args[0] == -1
|
| 202 |
+
):
|
| 203 |
+
maybe_symint_arg = maybe_symint_arg.args[1]
|
| 204 |
+
compare_op = operator.ge
|
| 205 |
+
compare_int = -1 * compare_int
|
| 206 |
+
|
| 207 |
+
if not (
|
| 208 |
+
"val" in maybe_symint_arg.meta and
|
| 209 |
+
isinstance(maybe_symint_arg.meta["val"], torch.SymInt)
|
| 210 |
+
):
|
| 211 |
+
continue
|
| 212 |
+
|
| 213 |
+
symint = maybe_symint_arg.meta["val"].node.expr
|
| 214 |
+
if not isinstance(symint, sympy.Symbol):
|
| 215 |
+
continue
|
| 216 |
+
|
| 217 |
+
if symint not in range_constraints:
|
| 218 |
+
raise RuntimeError(f"Unable to find symint {symint} in {range_constraints}")
|
| 219 |
+
|
| 220 |
+
found_range = existing_inline_assertions.get(symint, ValueRanges(-math.inf, math.inf))
|
| 221 |
+
|
| 222 |
+
if compare_arg.target == operator.le:
|
| 223 |
+
existing_inline_assertions[symint] = ValueRanges(
|
| 224 |
+
lower=found_range.lower, upper=compare_int
|
| 225 |
+
)
|
| 226 |
+
elif compare_arg.target == operator.ge:
|
| 227 |
+
existing_inline_assertions[symint] = ValueRanges(
|
| 228 |
+
lower=compare_int, upper=found_range.upper
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
return existing_inline_assertions
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import Dict, Optional, Tuple, List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, PassResult, Argument
|
| 6 |
+
from torch._export.pass_infra.node_metadata import NodeMetadata
|
| 7 |
+
from torch._export.pass_infra.proxy_value import ProxyValue
|
| 8 |
+
from torch._ops import OpOverload
|
| 9 |
+
|
| 10 |
+
aten = torch.ops.aten
|
| 11 |
+
|
| 12 |
+
_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: Dict[OpOverload, OpOverload] = {
|
| 13 |
+
aten.sym_constrain_range.default: aten._functional_sym_constrain_range,
|
| 14 |
+
aten._assert_async.msg: aten._functional_assert_async.msg,
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse):
|
| 19 |
+
"""
|
| 20 |
+
Functionalize ops with side effect in graph module by replacing the op with
|
| 21 |
+
functional version of it. A new dependency token (`dep_token`) will be
|
| 22 |
+
created and propagated through functional ops to output.
|
| 23 |
+
For example:
|
| 24 |
+
```
|
| 25 |
+
def f(x):
|
| 26 |
+
sym_constrain_range(x.shape[0], min=1, max=3)
|
| 27 |
+
return x.add(3)
|
| 28 |
+
```
|
| 29 |
+
Will be transformed to:
|
| 30 |
+
```
|
| 31 |
+
def f(x):
|
| 32 |
+
dep_token0 = _make_dep_token()
|
| 33 |
+
dep_token1 = _functional_sym_constrain_range(
|
| 34 |
+
x.shape[0], min=1, max=3, dep_token=dep_token0
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
return x.add(3), dep_token1
|
| 38 |
+
```
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self) -> None:
|
| 42 |
+
super().__init__()
|
| 43 |
+
self._dep_token: Optional[ProxyValue] = None
|
| 44 |
+
self._next_dep_token_index: Optional[int] = None
|
| 45 |
+
|
| 46 |
+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
|
| 47 |
+
# Early return if no non-functional assertions.
|
| 48 |
+
if not any(
|
| 49 |
+
n.target in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS
|
| 50 |
+
for n in graph_module.graph.nodes
|
| 51 |
+
):
|
| 52 |
+
return PassResult(graph_module=graph_module, modified=False)
|
| 53 |
+
|
| 54 |
+
gm = copy.deepcopy(graph_module)
|
| 55 |
+
self._dep_token = None
|
| 56 |
+
self._next_dep_token_index = None
|
| 57 |
+
return super().call(gm)
|
| 58 |
+
|
| 59 |
+
def call_operator(
|
| 60 |
+
self,
|
| 61 |
+
op: OpOverload,
|
| 62 |
+
args: Tuple[Argument, ...],
|
| 63 |
+
kwargs: Dict[str, Argument],
|
| 64 |
+
meta: NodeMetadata,
|
| 65 |
+
) -> ProxyValue:
|
| 66 |
+
if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS:
|
| 67 |
+
return super().call_operator(op, args, kwargs, meta)
|
| 68 |
+
|
| 69 |
+
if self._dep_token is None:
|
| 70 |
+
self._dep_token = super().call_operator(
|
| 71 |
+
aten._make_dep_token,
|
| 72 |
+
args=(),
|
| 73 |
+
kwargs={},
|
| 74 |
+
meta=self._create_dummy_node_metadata(),
|
| 75 |
+
)
|
| 76 |
+
self._dep_token.node.name = "dep_token0"
|
| 77 |
+
self._next_dep_token_index = 1
|
| 78 |
+
|
| 79 |
+
self._dep_token = super().call_operator(
|
| 80 |
+
_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS[op],
|
| 81 |
+
args=args,
|
| 82 |
+
kwargs={**kwargs, "dep_token": self._dep_token},
|
| 83 |
+
meta=meta,
|
| 84 |
+
)
|
| 85 |
+
assert self._next_dep_token_index is not None
|
| 86 |
+
self._dep_token.node.name = f"dep_token{self._next_dep_token_index}"
|
| 87 |
+
self._next_dep_token_index += 1
|
| 88 |
+
|
| 89 |
+
return self._dep_token
|
| 90 |
+
|
| 91 |
+
def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
|
| 92 |
+
assert self._dep_token is not None
|
| 93 |
+
|
| 94 |
+
return super().output(results=(*results, self._dep_token), meta=meta) # type: ignore[arg-type]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_sym_size_ops_pass.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
replacements: Dict[torch._ops.OpOverloadPacket, torch._ops.OpOverload] = {
|
| 6 |
+
torch.ops.aten.sym_size: torch.ops.aten.sym_size.int,
|
| 7 |
+
torch.ops.aten.sym_stride: torch.ops.aten.sym_stride.int,
|
| 8 |
+
torch.ops.aten.sym_numel: torch.ops.aten.sym_numel.default,
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _replace_sym_size_ops_pass(gm: torch.fx.GraphModule):
|
| 13 |
+
for module in gm.modules():
|
| 14 |
+
if not isinstance(module, torch.fx.GraphModule):
|
| 15 |
+
continue
|
| 16 |
+
for node in module.graph.nodes:
|
| 17 |
+
if node.target in replacements:
|
| 18 |
+
node.target = replacements[node.target]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Optional, Set
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch._ops import OpOverload, OpOverloadPacket, HigherOrderOperator
|
| 5 |
+
from torch._export.error import InternalError
|
| 6 |
+
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
__all__ = ["ReplaceViewOpsWithViewCopyOpsPass"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: Dict[OpOverload, OpOverload] = {
|
| 13 |
+
torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default,
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
# TODO (tmanlaibaatar) remove this after https://github.com/pytorch/pytorch/pull/100749
|
| 17 |
+
_BLACK_LISTED_OPS: Set[OpOverloadPacket] = {
|
| 18 |
+
torch.ops.aten.sym_size,
|
| 19 |
+
torch.ops.aten.sym_stride,
|
| 20 |
+
torch.ops.aten.sym_numel,
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
def is_view_op(schema: torch._C.FunctionSchema) -> bool:
|
| 24 |
+
if len(schema.arguments) == 0:
|
| 25 |
+
return False
|
| 26 |
+
alias_info = schema.arguments[0].alias_info
|
| 27 |
+
return (alias_info is not None) and (not alias_info.is_write)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_view_copy_of_view_op(schema: torch._C.FunctionSchema) -> Optional[OpOverload]:
|
| 31 |
+
if is_view_op(schema) and schema.name.startswith("aten::"):
|
| 32 |
+
view_op_name = schema.name.split("::")[1]
|
| 33 |
+
view_op_overload = (
|
| 34 |
+
schema.overload_name
|
| 35 |
+
if schema.overload_name != ""
|
| 36 |
+
else "default"
|
| 37 |
+
)
|
| 38 |
+
view_copy_op_name = view_op_name + "_copy"
|
| 39 |
+
if not hasattr(torch.ops.aten, view_copy_op_name):
|
| 40 |
+
raise InternalError(f"{schema.name} is missing a view_copy variant")
|
| 41 |
+
|
| 42 |
+
view_copy_op_overload_packet = getattr(torch.ops.aten, view_copy_op_name)
|
| 43 |
+
|
| 44 |
+
if not hasattr(view_copy_op_overload_packet, view_op_overload):
|
| 45 |
+
raise InternalError(f"{schema.name} is missing a view_copy variant")
|
| 46 |
+
|
| 47 |
+
return getattr(view_copy_op_overload_packet, view_op_overload)
|
| 48 |
+
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ReplaceViewOpsWithViewCopyOpsPass(_ExportPassBaseDeprecatedDoNotUse):
|
| 53 |
+
"""
|
| 54 |
+
Our backend expects pure functional operators. For efficiency
|
| 55 |
+
purposes, we keep view ops around while functionalizing the exported
|
| 56 |
+
program. This pass replaces view ops with view copy ops for backends that
|
| 57 |
+
need AOT memory planning.
|
| 58 |
+
"""
|
| 59 |
+
def call_operator(self, op, args, kwargs, meta):
|
| 60 |
+
if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS:
|
| 61 |
+
return super().call_operator(
|
| 62 |
+
(_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op]), args, kwargs, meta
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
if op in _BLACK_LISTED_OPS or isinstance(op, HigherOrderOperator):
|
| 66 |
+
return super().call_operator(op, args, kwargs, meta)
|
| 67 |
+
|
| 68 |
+
if view_copy_op := get_view_copy_of_view_op(op._schema):
|
| 69 |
+
return super().call_operator(view_copy_op, args, kwargs, meta)
|
| 70 |
+
|
| 71 |
+
return super().call_operator(op, args, kwargs, meta)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema.cpython-311.pyc
ADDED
|
Binary file (15.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/union.cpython-311.pyc
ADDED
|
Binary file (5.63 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/upgrade.cpython-311.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/schema.yaml
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @generated by update_schema.py
|
| 2 |
+
# checksum<<4c9986f3aba283b1746995fff8fe7005b370c7e288adec65c03030349a4bab60>>
|
| 3 |
+
Argument:
|
| 4 |
+
kind: union
|
| 5 |
+
fields:
|
| 6 |
+
as_none:
|
| 7 |
+
type: Tuple[()]
|
| 8 |
+
as_tensor:
|
| 9 |
+
type: TensorArgument
|
| 10 |
+
as_tensors:
|
| 11 |
+
type: List[TensorArgument]
|
| 12 |
+
as_int:
|
| 13 |
+
type: int
|
| 14 |
+
as_ints:
|
| 15 |
+
type: List[int]
|
| 16 |
+
as_float:
|
| 17 |
+
type: float
|
| 18 |
+
as_floats:
|
| 19 |
+
type: List[float]
|
| 20 |
+
as_string:
|
| 21 |
+
type: str
|
| 22 |
+
as_strings:
|
| 23 |
+
type: List[str]
|
| 24 |
+
as_sym_int:
|
| 25 |
+
type: SymIntArgument
|
| 26 |
+
as_sym_ints:
|
| 27 |
+
type: List[SymIntArgument]
|
| 28 |
+
as_scalar_type:
|
| 29 |
+
type: ScalarType
|
| 30 |
+
as_memory_format:
|
| 31 |
+
type: MemoryFormat
|
| 32 |
+
as_layout:
|
| 33 |
+
type: Layout
|
| 34 |
+
as_device:
|
| 35 |
+
type: Device
|
| 36 |
+
as_bool:
|
| 37 |
+
type: bool
|
| 38 |
+
as_bools:
|
| 39 |
+
type: List[bool]
|
| 40 |
+
as_sym_bool:
|
| 41 |
+
type: SymBoolArgument
|
| 42 |
+
as_sym_bools:
|
| 43 |
+
type: List[SymBoolArgument]
|
| 44 |
+
as_graph:
|
| 45 |
+
type: GraphArgument
|
| 46 |
+
as_optional_tensors:
|
| 47 |
+
type: List[OptionalTensorArgument]
|
| 48 |
+
as_custom_obj:
|
| 49 |
+
type: CustomObjArgument
|
| 50 |
+
as_operator:
|
| 51 |
+
type: str
|
| 52 |
+
BufferMutationSpec:
|
| 53 |
+
kind: struct
|
| 54 |
+
fields:
|
| 55 |
+
arg:
|
| 56 |
+
type: TensorArgument
|
| 57 |
+
buffer_name:
|
| 58 |
+
type: str
|
| 59 |
+
CustomObjArgument:
|
| 60 |
+
kind: struct
|
| 61 |
+
fields:
|
| 62 |
+
name:
|
| 63 |
+
type: str
|
| 64 |
+
class_fqn:
|
| 65 |
+
type: str
|
| 66 |
+
Device:
|
| 67 |
+
kind: struct
|
| 68 |
+
fields:
|
| 69 |
+
type:
|
| 70 |
+
type: str
|
| 71 |
+
index:
|
| 72 |
+
type: Optional[int]
|
| 73 |
+
default: None
|
| 74 |
+
ExportedProgram:
|
| 75 |
+
kind: struct
|
| 76 |
+
fields:
|
| 77 |
+
graph_module:
|
| 78 |
+
type: GraphModule
|
| 79 |
+
opset_version:
|
| 80 |
+
type: Dict[str, int]
|
| 81 |
+
range_constraints:
|
| 82 |
+
type: Dict[str, RangeConstraint]
|
| 83 |
+
schema_version:
|
| 84 |
+
type: SchemaVersion
|
| 85 |
+
dialect:
|
| 86 |
+
type: str
|
| 87 |
+
GradientToParameterSpec:
|
| 88 |
+
kind: struct
|
| 89 |
+
fields:
|
| 90 |
+
arg:
|
| 91 |
+
type: TensorArgument
|
| 92 |
+
parameter_name:
|
| 93 |
+
type: str
|
| 94 |
+
GradientToUserInputSpec:
|
| 95 |
+
kind: struct
|
| 96 |
+
fields:
|
| 97 |
+
arg:
|
| 98 |
+
type: TensorArgument
|
| 99 |
+
user_input_name:
|
| 100 |
+
type: str
|
| 101 |
+
Graph:
|
| 102 |
+
kind: struct
|
| 103 |
+
fields:
|
| 104 |
+
inputs:
|
| 105 |
+
type: List[Argument]
|
| 106 |
+
outputs:
|
| 107 |
+
type: List[Argument]
|
| 108 |
+
nodes:
|
| 109 |
+
type: List[Node]
|
| 110 |
+
tensor_values:
|
| 111 |
+
type: Dict[str, TensorMeta]
|
| 112 |
+
sym_int_values:
|
| 113 |
+
type: Dict[str, SymInt]
|
| 114 |
+
sym_bool_values:
|
| 115 |
+
type: Dict[str, SymBool]
|
| 116 |
+
is_single_tensor_return:
|
| 117 |
+
type: bool
|
| 118 |
+
default: 'False'
|
| 119 |
+
custom_obj_values:
|
| 120 |
+
type: Dict[str, CustomObjArgument]
|
| 121 |
+
default: '{}'
|
| 122 |
+
GraphArgument:
|
| 123 |
+
kind: struct
|
| 124 |
+
fields:
|
| 125 |
+
name:
|
| 126 |
+
type: str
|
| 127 |
+
graph:
|
| 128 |
+
type: Graph
|
| 129 |
+
GraphModule:
|
| 130 |
+
kind: struct
|
| 131 |
+
fields:
|
| 132 |
+
graph:
|
| 133 |
+
type: Graph
|
| 134 |
+
signature:
|
| 135 |
+
type: GraphSignature
|
| 136 |
+
module_call_graph:
|
| 137 |
+
type: List[ModuleCallEntry]
|
| 138 |
+
GraphSignature:
|
| 139 |
+
kind: struct
|
| 140 |
+
fields:
|
| 141 |
+
input_specs:
|
| 142 |
+
type: List[InputSpec]
|
| 143 |
+
output_specs:
|
| 144 |
+
type: List[OutputSpec]
|
| 145 |
+
InputSpec:
|
| 146 |
+
kind: union
|
| 147 |
+
fields:
|
| 148 |
+
user_input:
|
| 149 |
+
type: UserInputSpec
|
| 150 |
+
parameter:
|
| 151 |
+
type: InputToParameterSpec
|
| 152 |
+
buffer:
|
| 153 |
+
type: InputToBufferSpec
|
| 154 |
+
tensor_constant:
|
| 155 |
+
type: InputToTensorConstantSpec
|
| 156 |
+
custom_obj:
|
| 157 |
+
type: InputToCustomObjSpec
|
| 158 |
+
InputToBufferSpec:
|
| 159 |
+
kind: struct
|
| 160 |
+
fields:
|
| 161 |
+
arg:
|
| 162 |
+
type: TensorArgument
|
| 163 |
+
buffer_name:
|
| 164 |
+
type: str
|
| 165 |
+
persistent:
|
| 166 |
+
type: bool
|
| 167 |
+
InputToCustomObjSpec:
|
| 168 |
+
kind: struct
|
| 169 |
+
fields:
|
| 170 |
+
arg:
|
| 171 |
+
type: CustomObjArgument
|
| 172 |
+
custom_obj_name:
|
| 173 |
+
type: str
|
| 174 |
+
InputToParameterSpec:
|
| 175 |
+
kind: struct
|
| 176 |
+
fields:
|
| 177 |
+
arg:
|
| 178 |
+
type: TensorArgument
|
| 179 |
+
parameter_name:
|
| 180 |
+
type: str
|
| 181 |
+
InputToTensorConstantSpec:
|
| 182 |
+
kind: struct
|
| 183 |
+
fields:
|
| 184 |
+
arg:
|
| 185 |
+
type: TensorArgument
|
| 186 |
+
tensor_constant_name:
|
| 187 |
+
type: str
|
| 188 |
+
Layout:
|
| 189 |
+
kind: enum
|
| 190 |
+
fields:
|
| 191 |
+
Unknown: 0
|
| 192 |
+
SparseCoo: 1
|
| 193 |
+
SparseCsr: 2
|
| 194 |
+
SparseCsc: 3
|
| 195 |
+
SparseBsr: 4
|
| 196 |
+
SparseBsc: 5
|
| 197 |
+
_mkldnn: 6
|
| 198 |
+
Strided: 7
|
| 199 |
+
LossOutputSpec:
|
| 200 |
+
kind: struct
|
| 201 |
+
fields:
|
| 202 |
+
arg:
|
| 203 |
+
type: TensorArgument
|
| 204 |
+
MemoryFormat:
|
| 205 |
+
kind: enum
|
| 206 |
+
fields:
|
| 207 |
+
Unknown: 0
|
| 208 |
+
ContiguousFormat: 1
|
| 209 |
+
ChannelsLast: 2
|
| 210 |
+
ChannelsLast3d: 3
|
| 211 |
+
PreserveFormat: 4
|
| 212 |
+
ModuleCallEntry:
|
| 213 |
+
kind: struct
|
| 214 |
+
fields:
|
| 215 |
+
fqn:
|
| 216 |
+
type: str
|
| 217 |
+
signature:
|
| 218 |
+
type: Optional[ModuleCallSignature]
|
| 219 |
+
default: None
|
| 220 |
+
ModuleCallSignature:
|
| 221 |
+
kind: struct
|
| 222 |
+
fields:
|
| 223 |
+
inputs:
|
| 224 |
+
type: List[Argument]
|
| 225 |
+
outputs:
|
| 226 |
+
type: List[Argument]
|
| 227 |
+
in_spec:
|
| 228 |
+
type: str
|
| 229 |
+
out_spec:
|
| 230 |
+
type: str
|
| 231 |
+
NamedArgument:
|
| 232 |
+
kind: struct
|
| 233 |
+
fields:
|
| 234 |
+
name:
|
| 235 |
+
type: str
|
| 236 |
+
arg:
|
| 237 |
+
type: Argument
|
| 238 |
+
Node:
|
| 239 |
+
kind: struct
|
| 240 |
+
fields:
|
| 241 |
+
target:
|
| 242 |
+
type: str
|
| 243 |
+
inputs:
|
| 244 |
+
type: List[NamedArgument]
|
| 245 |
+
outputs:
|
| 246 |
+
type: List[Argument]
|
| 247 |
+
metadata:
|
| 248 |
+
type: Dict[str, str]
|
| 249 |
+
OptionalTensorArgument:
|
| 250 |
+
kind: union
|
| 251 |
+
fields:
|
| 252 |
+
as_tensor:
|
| 253 |
+
type: str
|
| 254 |
+
as_none:
|
| 255 |
+
type: Tuple[()]
|
| 256 |
+
OutputSpec:
|
| 257 |
+
kind: union
|
| 258 |
+
fields:
|
| 259 |
+
user_output:
|
| 260 |
+
type: UserOutputSpec
|
| 261 |
+
loss_output:
|
| 262 |
+
type: LossOutputSpec
|
| 263 |
+
buffer_mutation:
|
| 264 |
+
type: BufferMutationSpec
|
| 265 |
+
gradient_to_parameter:
|
| 266 |
+
type: GradientToParameterSpec
|
| 267 |
+
gradient_to_user_input:
|
| 268 |
+
type: GradientToUserInputSpec
|
| 269 |
+
user_input_mutation:
|
| 270 |
+
type: UserInputMutationSpec
|
| 271 |
+
RangeConstraint:
|
| 272 |
+
kind: struct
|
| 273 |
+
fields:
|
| 274 |
+
min_val:
|
| 275 |
+
type: int
|
| 276 |
+
max_val:
|
| 277 |
+
type: int
|
| 278 |
+
ScalarType:
|
| 279 |
+
kind: enum
|
| 280 |
+
fields:
|
| 281 |
+
UNKNOWN: 0
|
| 282 |
+
BYTE: 1
|
| 283 |
+
CHAR: 2
|
| 284 |
+
SHORT: 3
|
| 285 |
+
INT: 4
|
| 286 |
+
LONG: 5
|
| 287 |
+
HALF: 6
|
| 288 |
+
FLOAT: 7
|
| 289 |
+
DOUBLE: 8
|
| 290 |
+
COMPLEXHALF: 9
|
| 291 |
+
COMPLEXFLOAT: 10
|
| 292 |
+
COMPLEXDOUBLE: 11
|
| 293 |
+
BOOL: 12
|
| 294 |
+
BFLOAT16: 13
|
| 295 |
+
SchemaVersion:
|
| 296 |
+
kind: struct
|
| 297 |
+
fields:
|
| 298 |
+
major:
|
| 299 |
+
type: int
|
| 300 |
+
minor:
|
| 301 |
+
type: int
|
| 302 |
+
SymBool:
|
| 303 |
+
kind: union
|
| 304 |
+
fields:
|
| 305 |
+
as_expr:
|
| 306 |
+
type: SymExpr
|
| 307 |
+
as_bool:
|
| 308 |
+
type: bool
|
| 309 |
+
SymBoolArgument:
|
| 310 |
+
kind: union
|
| 311 |
+
fields:
|
| 312 |
+
as_name:
|
| 313 |
+
type: str
|
| 314 |
+
as_bool:
|
| 315 |
+
type: bool
|
| 316 |
+
SymExpr:
|
| 317 |
+
kind: struct
|
| 318 |
+
fields:
|
| 319 |
+
expr_str:
|
| 320 |
+
type: str
|
| 321 |
+
hint:
|
| 322 |
+
type: Optional[SymExprHint]
|
| 323 |
+
default: None
|
| 324 |
+
SymExprHint:
|
| 325 |
+
kind: union
|
| 326 |
+
fields:
|
| 327 |
+
as_int:
|
| 328 |
+
type: int
|
| 329 |
+
as_float:
|
| 330 |
+
type: float
|
| 331 |
+
as_bool:
|
| 332 |
+
type: bool
|
| 333 |
+
SymInt:
|
| 334 |
+
kind: union
|
| 335 |
+
fields:
|
| 336 |
+
as_expr:
|
| 337 |
+
type: SymExpr
|
| 338 |
+
as_int:
|
| 339 |
+
type: int
|
| 340 |
+
SymIntArgument:
|
| 341 |
+
kind: union
|
| 342 |
+
fields:
|
| 343 |
+
as_name:
|
| 344 |
+
type: str
|
| 345 |
+
as_int:
|
| 346 |
+
type: int
|
| 347 |
+
TensorArgument:
|
| 348 |
+
kind: struct
|
| 349 |
+
fields:
|
| 350 |
+
name:
|
| 351 |
+
type: str
|
| 352 |
+
TensorMeta:
|
| 353 |
+
kind: struct
|
| 354 |
+
fields:
|
| 355 |
+
dtype:
|
| 356 |
+
type: ScalarType
|
| 357 |
+
sizes:
|
| 358 |
+
type: List[SymInt]
|
| 359 |
+
requires_grad:
|
| 360 |
+
type: bool
|
| 361 |
+
device:
|
| 362 |
+
type: Device
|
| 363 |
+
strides:
|
| 364 |
+
type: List[SymInt]
|
| 365 |
+
storage_offset:
|
| 366 |
+
type: SymInt
|
| 367 |
+
layout:
|
| 368 |
+
type: Layout
|
| 369 |
+
UserInputMutationSpec:
|
| 370 |
+
kind: struct
|
| 371 |
+
fields:
|
| 372 |
+
arg:
|
| 373 |
+
type: TensorArgument
|
| 374 |
+
user_input_name:
|
| 375 |
+
type: str
|
| 376 |
+
UserInputSpec:
|
| 377 |
+
kind: struct
|
| 378 |
+
fields:
|
| 379 |
+
arg:
|
| 380 |
+
type: Argument
|
| 381 |
+
UserOutputSpec:
|
| 382 |
+
kind: struct
|
| 383 |
+
fields:
|
| 384 |
+
arg:
|
| 385 |
+
type: Argument
|
| 386 |
+
SCHEMA_VERSION:
|
| 387 |
+
- 5
|
| 388 |
+
- 1
|
| 389 |
+
TREESPEC_VERSION: 1
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/upgrade.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from typing import Tuple, Dict, Optional, List
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.export import export
|
| 7 |
+
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
|
| 8 |
+
from torch._export.pass_infra.node_metadata import NodeMetadata
|
| 9 |
+
from torch._export.pass_infra.proxy_value import ProxyValue
|
| 10 |
+
from torch._subclasses import FakeTensor
|
| 11 |
+
from torch.fx.node import Target, Argument
|
| 12 |
+
from torch.library import Library
|
| 13 |
+
from torch.utils._pytree import tree_unflatten
|
| 14 |
+
import torch._export.exported_program as ep
|
| 15 |
+
import re
|
| 16 |
+
|
| 17 |
+
lib = Library("aten", "FRAGMENT")
|
| 18 |
+
impl_lib = Library("aten", "IMPL")
|
| 19 |
+
|
| 20 |
+
log = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_target_version(versioned_upgrader_name: str) -> int:
|
| 24 |
+
"""div_Scalar_0_3 is the name of the upgrader, meaning it applies to div.Scalar of version 0 to 3 and is
|
| 25 |
+
upgrading to version 4."""
|
| 26 |
+
if not re.match("^.*_[0-9]+_[0-9]+$", versioned_upgrader_name):
|
| 27 |
+
raise RuntimeError(f"Upgrader name {versioned_upgrader_name} is invalid")
|
| 28 |
+
|
| 29 |
+
return int(versioned_upgrader_name.split('_')[-1]) + 1
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_upgraders() -> Dict[str, Tuple[str, str]]:
|
| 33 |
+
"""Getting upgraders entry map and operator version map and merge them into one dict."""
|
| 34 |
+
upgraders = torch._C._get_upgraders_entry_map()
|
| 35 |
+
op_version_map = torch._C._get_operator_version_map()
|
| 36 |
+
output: Dict[str, Tuple[str, str]] = defaultdict(tuple) # type: ignore[arg-type]
|
| 37 |
+
for opname, entry_list in op_version_map.items():
|
| 38 |
+
if not entry_list:
|
| 39 |
+
raise RuntimeError(f"Op version map has an empty entry for opname {opname}")
|
| 40 |
+
entry = entry_list[0]
|
| 41 |
+
old_schema = entry.old_schema
|
| 42 |
+
upgrader_name = entry.upgrader_name
|
| 43 |
+
upgrader_str = upgraders.get(upgrader_name, None)
|
| 44 |
+
if not upgrader_str:
|
| 45 |
+
raise RuntimeError(f"Can't find upgrader for op {opname} and upgrader name {upgrader_name}")
|
| 46 |
+
output[upgrader_name] = (old_schema, upgrader_str)
|
| 47 |
+
return output
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class GraphModuleOpUpgrader:
|
| 51 |
+
"""This upgrader is able to upgrade the old version of ops in a given GraphModule, if all upgraders are available.
|
| 52 |
+
To use it, retrieve upgraders from somewhere (TorchScript API or new API) and pass it into this upgrader. In
|
| 53 |
+
__init__() it does the following:
|
| 54 |
+
1. parse the upgrader list and reorder for upgrading purpose.
|
| 55 |
+
2. register old versions of operators as custom ops.
|
| 56 |
+
3. prepare upgrader passes.
|
| 57 |
+
|
| 58 |
+
In `upgrade()` API run these upgrader passes.
|
| 59 |
+
|
| 60 |
+
An example of op_upgraders input:
|
| 61 |
+
{
|
| 62 |
+
"aten::div__Scalar_0_3": ( # versioned op name
|
| 63 |
+
"div._Scalar(self: Tensor, other: Scalar)", # old schema
|
| 64 |
+
'''
|
| 65 |
+
def div__Scalar_0_3(self: torch.Tensor, other) -> torch.Tensor: # upgrader in literal string
|
| 66 |
+
if (self.is_floating_point() or isinstance(other, float)):
|
| 67 |
+
return self.true_divide_(other)
|
| 68 |
+
return self.divide_(other, rounding_mode='trunc')
|
| 69 |
+
''',
|
| 70 |
+
),
|
| 71 |
+
},
|
| 72 |
+
|
| 73 |
+
Note that we require the upgrader function to be runnable in Python (which is a stricter requirement than the
|
| 74 |
+
original TorchScript upgrader).
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
class UpgraderPass(_ExportPassBaseDeprecatedDoNotUse):
|
| 78 |
+
def __init__(self, old_target: Target, new_target: Target):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.old_target = old_target
|
| 81 |
+
self.new_target = new_target
|
| 82 |
+
|
| 83 |
+
def call_operator(
|
| 84 |
+
self,
|
| 85 |
+
op,
|
| 86 |
+
args: Tuple[Argument, ...],
|
| 87 |
+
kwargs: Dict[str, Argument],
|
| 88 |
+
meta: NodeMetadata,
|
| 89 |
+
) -> ProxyValue:
|
| 90 |
+
if op == self.old_target:
|
| 91 |
+
return super().call_operator(self.new_target, args, kwargs, meta)
|
| 92 |
+
return super().call_operator(op, args, kwargs, meta)
|
| 93 |
+
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
compiler_opset_version: Optional[Dict[str, int]] = None,
|
| 97 |
+
model_opset_version: Optional[Dict[str, int]] = None,
|
| 98 |
+
op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None,
|
| 99 |
+
):
|
| 100 |
+
self.op_upgraders: Dict[str, Tuple[str, str]] = get_upgraders() if not op_upgraders else op_upgraders
|
| 101 |
+
self.compiler_opset_version = compiler_opset_version if compiler_opset_version else {}
|
| 102 |
+
self.model_opset_version = model_opset_version if model_opset_version else {}
|
| 103 |
+
self.upgrader_passes: List[GraphModuleOpUpgrader.UpgraderPass] = GraphModuleOpUpgrader._populate_passes(
|
| 104 |
+
self._parse_upgraders(self.op_upgraders))
|
| 105 |
+
|
| 106 |
+
def _parse_upgraders(self, op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None) -> List[Tuple[str, str]]:
|
| 107 |
+
"""Reorder op_upgraders by version number, return an ordered list of tuples, containing old op schema as well
|
| 108 |
+
as the upgrader function string literal."""
|
| 109 |
+
# TODO(larryliu0820): Add support for custom ops
|
| 110 |
+
op_namespace = "aten"
|
| 111 |
+
if not op_upgraders or op_namespace not in self.model_opset_version or op_namespace not in self.compiler_opset_version:
|
| 112 |
+
return []
|
| 113 |
+
model_ver = self.model_opset_version[op_namespace]
|
| 114 |
+
curr_ver = self.compiler_opset_version[op_namespace]
|
| 115 |
+
|
| 116 |
+
# key is the target version. div__Scalar_0_3 should have a key of 4.
|
| 117 |
+
versioned_upgraders: Dict[int, Tuple[str, str]] = {get_target_version(name): v for name, v in
|
| 118 |
+
op_upgraders.items()}
|
| 119 |
+
target_upgraders: List[Tuple[str, str]] = []
|
| 120 |
+
# we need all upgraders from model_ver + 1 to curr_ver, inclusively
|
| 121 |
+
for ver in range(model_ver + 1, curr_ver + 1):
|
| 122 |
+
if ver in versioned_upgraders:
|
| 123 |
+
target_upgraders.append(versioned_upgraders[ver])
|
| 124 |
+
else:
|
| 125 |
+
# we may be able to get away with missing upgraders, if that operator is missing from given graph
|
| 126 |
+
# module.
|
| 127 |
+
log.warning("Missing an upgrader to upgrade to version {ver}.", extra={"ver": ver})
|
| 128 |
+
|
| 129 |
+
return target_upgraders
|
| 130 |
+
|
| 131 |
+
@staticmethod
|
| 132 |
+
def _populate_passes(upgraders: List[Tuple[str, str]]) -> List[UpgraderPass]:
|
| 133 |
+
"""Given a list of upgraders, loop through it from lower version to higher version and create passes for all
|
| 134 |
+
upgraders. se torch.Library API to register old ops. Op name will be
|
| 135 |
+
<name>_<valid_from_ver>_<valid_till_ver>. Register upgraders as CompositeImplicitAutograd kernels. For example:
|
| 136 |
+
|
| 137 |
+
lib = Library("aten", "FRAGMENT")
|
| 138 |
+
lib.define(old_schema)
|
| 139 |
+
|
| 140 |
+
impl_lib = Library("aten", "IMPL")
|
| 141 |
+
impl_lib.impl("div__Scalar_0_3", div__Scalar_0_3, "CompositeImplicitAutograd")
|
| 142 |
+
|
| 143 |
+
@:var upgraders: a list of tuples. The first element of the tuple is the old schema and the second is the
|
| 144 |
+
upgrader function literal text.
|
| 145 |
+
@:return upgrader passes, order matters
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
upgrader_passes = []
|
| 149 |
+
|
| 150 |
+
def register_old_op(name: str, schema: str, impl_str: str):
|
| 151 |
+
"""Registers an old version operator using impl_name as old op name."""
|
| 152 |
+
lib.define(schema)
|
| 153 |
+
try:
|
| 154 |
+
exec(impl_str)
|
| 155 |
+
except Exception as e:
|
| 156 |
+
raise RuntimeError(f"Invalid upgrader string: {impl_str}") from e
|
| 157 |
+
impl_lib.impl(name, locals()[name], "CompositeImplicitAutograd")
|
| 158 |
+
|
| 159 |
+
for (schema, upgrader_str) in upgraders:
|
| 160 |
+
upgrader_name = upgrader_str.split('(')[0].split(' ')[-1]
|
| 161 |
+
op_name = schema.split('(')[0].split("::")[-1]
|
| 162 |
+
schema = schema.replace(op_name, upgrader_name)
|
| 163 |
+
try:
|
| 164 |
+
register_old_op(name=upgrader_name, schema=schema, impl_str=upgrader_str)
|
| 165 |
+
except RuntimeError as e:
|
| 166 |
+
if "with the same name and overload name multiple times" in str(e):
|
| 167 |
+
print(f"Registering {upgrader_name} multiple times")
|
| 168 |
+
else:
|
| 169 |
+
raise RuntimeError from e
|
| 170 |
+
old_op_target = getattr(torch.ops.aten, upgrader_name).default
|
| 171 |
+
# for example, the operator instance of "aten::div" is torch.op.aten.div.default. We need to append the
|
| 172 |
+
# "default" at the end.
|
| 173 |
+
op_name, overload_name = (op_name, "default") if "." not in op_name else tuple(op_name.split(".")[:2])
|
| 174 |
+
new_op_target = getattr(getattr(torch.ops.aten, op_name), overload_name)
|
| 175 |
+
# Note that the graph will have op names in the graph, but actually they are of old versions.
|
| 176 |
+
upgrader_passes.append(
|
| 177 |
+
GraphModuleOpUpgrader.UpgraderPass(old_target=new_op_target, new_target=old_op_target))
|
| 178 |
+
|
| 179 |
+
return upgrader_passes
|
| 180 |
+
|
| 181 |
+
def upgrade(self, exported_program: ep.ExportedProgram) -> ep.ExportedProgram:
|
| 182 |
+
"""Run each upgrader pass and then retrace to decompose it. Each upgrader pass replaces the old version of
|
| 183 |
+
operators with a custom operator. The custom operator contains a CompositeImplicitAutograd kernel (the
|
| 184 |
+
upgrading function itself). After retrace, this custom operator will be decomposed into the ops used in the
|
| 185 |
+
upgrader. After all passes are applied, the exported program will be upgraded to the target version."""
|
| 186 |
+
if not self.upgrader_passes:
|
| 187 |
+
return exported_program
|
| 188 |
+
|
| 189 |
+
args = [n.meta.get("val", None) for n in exported_program.graph.nodes if n.op == "placeholder"]
|
| 190 |
+
args_real_tensors = [torch.ones(tuple(arg.size()), dtype=arg.dtype) if isinstance(arg, FakeTensor) else arg for
|
| 191 |
+
arg in args]
|
| 192 |
+
assert exported_program.call_spec.in_spec is not None
|
| 193 |
+
args, kwargs = tree_unflatten(args_real_tensors, exported_program.call_spec.in_spec)
|
| 194 |
+
assert kwargs == {}
|
| 195 |
+
|
| 196 |
+
for _pass in self.upgrader_passes:
|
| 197 |
+
upgraded_program = exported_program._transform_do_not_use(_pass)
|
| 198 |
+
# NB: we have to retrace the graph_module instead of ep because of some failure.
|
| 199 |
+
exported_program = export(upgraded_program.module(), args, kwargs)
|
| 200 |
+
|
| 201 |
+
return exported_program
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/wrappers.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import warnings
|
| 3 |
+
from functools import wraps
|
| 4 |
+
from itertools import chain
|
| 5 |
+
|
| 6 |
+
from typing import Callable, NamedTuple, Optional, overload, Sequence, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._prims_common as utils
|
| 10 |
+
from torch._prims_common import (
|
| 11 |
+
CustomOutParamAnnotation,
|
| 12 |
+
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
| 13 |
+
Number,
|
| 14 |
+
NumberType,
|
| 15 |
+
ShapeType,
|
| 16 |
+
TensorLike,
|
| 17 |
+
TensorLikeType,
|
| 18 |
+
)
|
| 19 |
+
from torch.utils import _pytree as pytree
|
| 20 |
+
from torch.utils._pytree import tree_flatten, tree_unflatten
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@overload
|
| 24 |
+
def _maybe_convert_to_dtype(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@overload
|
| 29 |
+
def _maybe_convert_to_dtype(a: NumberType, dtype: torch.dtype) -> NumberType:
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@overload
|
| 34 |
+
def _maybe_convert_to_dtype(a: Sequence, dtype: torch.dtype) -> Sequence:
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@overload
|
| 39 |
+
def _maybe_convert_to_dtype(a: None, dtype: torch.dtype) -> None:
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# TODO: implement ref.cast with an option to enforce safe casting
|
| 44 |
+
def _maybe_convert_to_dtype(a, dtype):
|
| 45 |
+
if isinstance(a, TensorLike):
|
| 46 |
+
if a.dtype != dtype:
|
| 47 |
+
return a.to(dtype)
|
| 48 |
+
return a
|
| 49 |
+
if isinstance(a, Number):
|
| 50 |
+
return utils.dtype_to_type_ctor(dtype)(a) # type: ignore[arg-type]
|
| 51 |
+
if isinstance(a, Sequence):
|
| 52 |
+
return tuple(_maybe_convert_to_dtype(x, dtype) for x in a)
|
| 53 |
+
# Passthrough None because some functions wrapped with type promotion
|
| 54 |
+
# wrapper might have optional args
|
| 55 |
+
if a is None:
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
raise ValueError(f"Received type {type(a)} that is neither a tensor or a number!")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType:
|
| 62 |
+
if not isinstance(a, Number):
|
| 63 |
+
msg = f"Found unknown type {type(a)} when trying to convert scalars!"
|
| 64 |
+
raise ValueError(msg)
|
| 65 |
+
if not utils.is_weakly_lesser_type(type(a), typ):
|
| 66 |
+
msg = f"Scalar {a} of type {type(a)} cannot be safely cast to type {typ}!"
|
| 67 |
+
raise ValueError(msg)
|
| 68 |
+
|
| 69 |
+
return typ(a)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _annotation_has_type(*, typ, annotation):
|
| 73 |
+
if hasattr(annotation, "__args__"):
|
| 74 |
+
for a in annotation.__args__:
|
| 75 |
+
if _annotation_has_type(typ=typ, annotation=a):
|
| 76 |
+
return True
|
| 77 |
+
return False
|
| 78 |
+
|
| 79 |
+
return typ is annotation
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class elementwise_type_promotion_wrapper:
|
| 83 |
+
"""
|
| 84 |
+
Adds elementwise type promotion to a Python reference implementation.
|
| 85 |
+
|
| 86 |
+
Takes two kwargs, type_promoting_args and type_promotion_kind.
|
| 87 |
+
|
| 88 |
+
type_promoting_args must be a string Sequence specifiying the argument names of all
|
| 89 |
+
arguments that participate in type promotion (and should be type promoted). If the
|
| 90 |
+
arg specifies a Sequence-type then every element of the Sequence will participate in
|
| 91 |
+
type promotion.
|
| 92 |
+
|
| 93 |
+
type_promotion_kind must be one of the kinds specified by ELEMENTWISE_TYPE_PROMOTION_KIND.
|
| 94 |
+
See its documentation for details.
|
| 95 |
+
|
| 96 |
+
The return_dtype will be coerced to the wrapped function's dtype arg if it is available and
|
| 97 |
+
not None.
|
| 98 |
+
|
| 99 |
+
Other type promotion behavior, like validating the Python type of scalar arguments, must
|
| 100 |
+
be handled separately.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
*,
|
| 106 |
+
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
|
| 107 |
+
type_promoting_args: Optional[Sequence[str]] = None,
|
| 108 |
+
):
|
| 109 |
+
self.type_promoting_arg_names = type_promoting_args
|
| 110 |
+
self.type_promotion_kind = type_promotion_kind
|
| 111 |
+
|
| 112 |
+
def __call__(self, fn: Callable) -> Callable:
|
| 113 |
+
sig = inspect.signature(fn)
|
| 114 |
+
|
| 115 |
+
@wraps(fn)
|
| 116 |
+
def _fn(*args, **kwargs):
|
| 117 |
+
bound = sig.bind(*args, **kwargs)
|
| 118 |
+
type_promoting_args = tuple(
|
| 119 |
+
bound.arguments[x]
|
| 120 |
+
for x in self.type_promoting_arg_names # type: ignore[union-attr]
|
| 121 |
+
if x in bound.arguments.keys()
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
flattened_type_promoting_args = pytree.arg_tree_leaves(*type_promoting_args)
|
| 125 |
+
compute_dtype, result_dtype = utils.elementwise_dtypes(
|
| 126 |
+
*flattened_type_promoting_args,
|
| 127 |
+
type_promotion_kind=self.type_promotion_kind,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
promoted_args = {
|
| 131 |
+
x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype)
|
| 132 |
+
for x in self.type_promoting_arg_names # type: ignore[union-attr]
|
| 133 |
+
if x in bound.arguments.keys()
|
| 134 |
+
}
|
| 135 |
+
bound.arguments.update(promoted_args)
|
| 136 |
+
|
| 137 |
+
result = fn(**bound.arguments)
|
| 138 |
+
|
| 139 |
+
# Override the return_dtype if a dtype arg is present and not None
|
| 140 |
+
if "dtype" in bound.arguments:
|
| 141 |
+
maybe_dtype = bound.arguments["dtype"]
|
| 142 |
+
if maybe_dtype: # dtype cannot be None
|
| 143 |
+
result_dtype = maybe_dtype
|
| 144 |
+
|
| 145 |
+
if isinstance(result, TensorLike):
|
| 146 |
+
return _maybe_convert_to_dtype(result, result_dtype)
|
| 147 |
+
if isinstance(result, Sequence):
|
| 148 |
+
return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result)
|
| 149 |
+
raise AssertionError(f"Unhandled result type: {type(result)}")
|
| 150 |
+
|
| 151 |
+
_fn.__signature__ = sig # type: ignore[attr-defined]
|
| 152 |
+
return _fn
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# Returns True if resize is necessary
|
| 156 |
+
def _resize_output_check(out: TensorLikeType, shape: ShapeType):
|
| 157 |
+
# If the shapes are correct there's nothing to do
|
| 158 |
+
if utils.same_shape(out.shape, shape):
|
| 159 |
+
return False
|
| 160 |
+
if out.numel() != 0:
|
| 161 |
+
msg = (
|
| 162 |
+
f"An output with one or more elements was resized since it had shape {str(out.shape)} "
|
| 163 |
+
"which does not match the required output shape {str(shape)}. "
|
| 164 |
+
"This behavior is deprecated, and in a future PyTorch release outputs will not "
|
| 165 |
+
"be resized unless they have zero elements. "
|
| 166 |
+
"You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)."
|
| 167 |
+
)
|
| 168 |
+
warnings.warn(msg)
|
| 169 |
+
return True
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# TODO: handle tuples of tensors
|
| 173 |
+
def _maybe_resize_out(out: TensorLikeType, shape: ShapeType):
|
| 174 |
+
if _resize_output_check(out, shape):
|
| 175 |
+
return out.resize_(shape)
|
| 176 |
+
else:
|
| 177 |
+
return out
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _safe_copy_out(
|
| 181 |
+
*, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False
|
| 182 |
+
):
|
| 183 |
+
# Checks same device
|
| 184 |
+
if copy_from.device != copy_to.device:
|
| 185 |
+
msg = "Attempting to copy from device {} to device {}, but cross-device copies are not allowed!".format(
|
| 186 |
+
copy_from.device, copy_to.device
|
| 187 |
+
)
|
| 188 |
+
raise RuntimeError(msg)
|
| 189 |
+
|
| 190 |
+
# Checks safe cast
|
| 191 |
+
if exact_dtype:
|
| 192 |
+
torch._check(
|
| 193 |
+
copy_from.dtype == copy_to.dtype,
|
| 194 |
+
lambda: f"Expected out tensor to have dtype {copy_from.dtype} "
|
| 195 |
+
f"but got {copy_to.dtype} instead",
|
| 196 |
+
)
|
| 197 |
+
else:
|
| 198 |
+
torch._check(
|
| 199 |
+
utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype),
|
| 200 |
+
lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, "
|
| 201 |
+
"but this can't be cast because it is not safe!",
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
return copy_to.copy_(copy_from)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def out_wrapper(*out_names: str, exact_dtype: bool = False, pass_is_out: bool = False):
|
| 208 |
+
# The wrapped function needs to convert the output parameters to ensure
|
| 209 |
+
# compatibility between the Python API (which always uses "out" as the
|
| 210 |
+
# parameter name and may be a tuple) and the Aten API (which may have
|
| 211 |
+
# multiple output parameters and use different parameter names such as
|
| 212 |
+
# "grad_input", "indices" or "values".)
|
| 213 |
+
|
| 214 |
+
default_out_names = ("out",)
|
| 215 |
+
if len(out_names) == 0:
|
| 216 |
+
# Use default in out name
|
| 217 |
+
out_names = default_out_names
|
| 218 |
+
|
| 219 |
+
is_tensor = len(out_names) == 1
|
| 220 |
+
|
| 221 |
+
def _out_wrapper(fn: Callable) -> Callable:
|
| 222 |
+
"""
|
| 223 |
+
Adds the out parameter to a Python reference.
|
| 224 |
+
"""
|
| 225 |
+
out_type = (
|
| 226 |
+
TensorLikeType
|
| 227 |
+
if is_tensor
|
| 228 |
+
else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))]
|
| 229 |
+
)
|
| 230 |
+
return_type = (
|
| 231 |
+
TensorLikeType
|
| 232 |
+
if is_tensor
|
| 233 |
+
else NamedTuple(
|
| 234 |
+
f"return_types_{fn.__name__}", [(o, TensorLikeType) for o in out_names]
|
| 235 |
+
)
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
sig = inspect.signature(fn)
|
| 239 |
+
factory_kwargs = ("device", "dtype")
|
| 240 |
+
is_factory_fn = all(p in sig.parameters for p in factory_kwargs)
|
| 241 |
+
|
| 242 |
+
@wraps(fn)
|
| 243 |
+
def _fn(*args, out=None, **kwargs):
|
| 244 |
+
if is_factory_fn and out is not None:
|
| 245 |
+
for k in factory_kwargs:
|
| 246 |
+
out_attr = getattr(out, k)
|
| 247 |
+
if k not in kwargs:
|
| 248 |
+
kwargs[k] = out_attr
|
| 249 |
+
if pass_is_out:
|
| 250 |
+
result = fn(*args, is_out=(out is not None), **kwargs)
|
| 251 |
+
else:
|
| 252 |
+
result = fn(*args, **kwargs)
|
| 253 |
+
assert (
|
| 254 |
+
isinstance(result, TensorLike)
|
| 255 |
+
and is_tensor
|
| 256 |
+
or isinstance(result, Tuple) # type: ignore[arg-type]
|
| 257 |
+
and len(result) == len(out_names)
|
| 258 |
+
)
|
| 259 |
+
if out is not None:
|
| 260 |
+
# Naively you might expect this assert to be true, but
|
| 261 |
+
# it's not:
|
| 262 |
+
#
|
| 263 |
+
# assert type(out) == type(result)
|
| 264 |
+
#
|
| 265 |
+
# The reason is that functions under this wrapper can
|
| 266 |
+
# get registered to the Meta dispatch key, and that
|
| 267 |
+
# means they can be executed in a context where tensor
|
| 268 |
+
# subclasses are disabled (with no_dispatch), which is a
|
| 269 |
+
# handy way for an is-a tensor subclass (e.g.,
|
| 270 |
+
# FakeTensor) to have the normal meta backend create a
|
| 271 |
+
# meta tensor, to be wrapped once it gets returned.
|
| 272 |
+
# In this situation, you will get a FakeTensor as
|
| 273 |
+
# the output tensor, but not the result--which will
|
| 274 |
+
# be a normal meta tensor, but this is perfectly
|
| 275 |
+
# harmless.
|
| 276 |
+
if is_tensor:
|
| 277 |
+
assert isinstance(out, TensorLike)
|
| 278 |
+
# These two operations are done in-place
|
| 279 |
+
_maybe_resize_out(out, result.shape)
|
| 280 |
+
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
|
| 281 |
+
else:
|
| 282 |
+
assert isinstance(out, Tuple) # type: ignore[arg-type]
|
| 283 |
+
torch._check_type(
|
| 284 |
+
len(out) == len(result),
|
| 285 |
+
lambda: f"expected tuple of {len(result)} elements but got {len(out)}",
|
| 286 |
+
)
|
| 287 |
+
for r, o in zip(result, out):
|
| 288 |
+
# These two operations are done in-place
|
| 289 |
+
_maybe_resize_out(o, r.shape)
|
| 290 |
+
_safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type]
|
| 291 |
+
else:
|
| 292 |
+
out = result
|
| 293 |
+
# mypy does not see through the definition of out_type given that it's in a different scope
|
| 294 |
+
return out if is_tensor else return_type(*out) # type: ignore[operator]
|
| 295 |
+
|
| 296 |
+
out_param = inspect.Parameter(
|
| 297 |
+
"out",
|
| 298 |
+
kind=inspect.Parameter.KEYWORD_ONLY,
|
| 299 |
+
default=None,
|
| 300 |
+
annotation=out_type,
|
| 301 |
+
)
|
| 302 |
+
# Mark that the function now returns a tuple
|
| 303 |
+
assert isinstance(sig.return_annotation, str) or sig.return_annotation in (
|
| 304 |
+
sig.empty,
|
| 305 |
+
out_type,
|
| 306 |
+
)
|
| 307 |
+
params = chain(sig.parameters.values(), (out_param,))
|
| 308 |
+
_fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
|
| 309 |
+
parameters=params, return_annotation=return_type # type: ignore[arg-type]
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
_fn.__annotations__ = fn.__annotations__
|
| 313 |
+
_fn.__annotations__["out"] = out_type
|
| 314 |
+
_fn.__annotations__["return"] = return_type
|
| 315 |
+
|
| 316 |
+
# In the special case of having a single tensor out parameter with a
|
| 317 |
+
# name other than out, add a special annotation to name the parameter
|
| 318 |
+
if is_tensor and out_names != default_out_names:
|
| 319 |
+
_fn.__annotations__[CustomOutParamAnnotation] = out_names[0]
|
| 320 |
+
|
| 321 |
+
# Add an indicator attribute that can be used in special cases
|
| 322 |
+
# where having a function wrapped by `out_wrapper` is not desirable e.g.
|
| 323 |
+
# jit
|
| 324 |
+
_fn._torch_decompositions_out_wrapper = f"This function is wrapped by {out_wrapper.__module__}.out_wrapper" # type: ignore[attr-defined]
|
| 325 |
+
|
| 326 |
+
return _fn
|
| 327 |
+
|
| 328 |
+
return _out_wrapper
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def _maybe_remove_out_wrapper(fn: Callable):
|
| 332 |
+
return inspect.unwrap(
|
| 333 |
+
fn,
|
| 334 |
+
stop=lambda f: not hasattr(f, "_torch_decompositions_out_wrapper"),
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def backwards_not_supported(prim):
|
| 339 |
+
def redispatch_prim(args, kwargs):
|
| 340 |
+
with torch._C._AutoDispatchBelowAutograd():
|
| 341 |
+
old = torch._C._dispatch_tls_is_dispatch_key_excluded(
|
| 342 |
+
torch._C.DispatchKey.ADInplaceOrView
|
| 343 |
+
)
|
| 344 |
+
return prim(*args, **kwargs)
|
| 345 |
+
|
| 346 |
+
class BackwardsNotSupported(torch.autograd.Function):
|
| 347 |
+
@staticmethod
|
| 348 |
+
def forward(ctx, args_spec, *flat_args):
|
| 349 |
+
args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type]
|
| 350 |
+
return redispatch_prim(args, kwargs)
|
| 351 |
+
|
| 352 |
+
@staticmethod
|
| 353 |
+
def backward(ctx, *args):
|
| 354 |
+
raise RuntimeError("backwards not supported on prim")
|
| 355 |
+
|
| 356 |
+
@wraps(prim)
|
| 357 |
+
def _autograd_impl(*args, **kwargs):
|
| 358 |
+
flat_args, args_spec = tree_flatten((args, kwargs))
|
| 359 |
+
if torch.is_grad_enabled() and any(
|
| 360 |
+
a.requires_grad for a in flat_args if isinstance(a, torch.Tensor)
|
| 361 |
+
):
|
| 362 |
+
# TODO: There is a subtle bug here: prims like copy_to
|
| 363 |
+
# return their input argument after mutating it; and custom
|
| 364 |
+
# autograd function will incorrectly turn the result into
|
| 365 |
+
# a view which will fail test_python_ref_executor tests.
|
| 366 |
+
# At the moment, we sidestep this by observing that the
|
| 367 |
+
# unit tests don't ever try to run the executor with
|
| 368 |
+
# autograd, so we don't exercise the buggy case, but if
|
| 369 |
+
# you ever want to feed autograd through this, be aware
|
| 370 |
+
# of it! We need a way of properly implementing autograd
|
| 371 |
+
# for mutating operations in Python to do this.
|
| 372 |
+
return BackwardsNotSupported.apply(args_spec, *flat_args)
|
| 373 |
+
else:
|
| 374 |
+
return redispatch_prim(args, kwargs)
|
| 375 |
+
|
| 376 |
+
return _autograd_impl
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
# TODO: when tracing this will add torch tensors and not TensorMeta objects
|
| 380 |
+
# to the trace -- we should fix this by adding a tracing context and NumberMeta classes
|
| 381 |
+
# TODO: this wrapper is currently untested
|
| 382 |
+
def elementwise_unary_scalar_wrapper(fn: Callable) -> Callable:
|
| 383 |
+
"""
|
| 384 |
+
Allows unary operators that accept tensors to work with Python numbers.
|
| 385 |
+
"""
|
| 386 |
+
sig = inspect.signature(fn)
|
| 387 |
+
|
| 388 |
+
@wraps(fn)
|
| 389 |
+
def _fn(*args, **kwargs):
|
| 390 |
+
if len(args) > 0 and isinstance(args[0], Number):
|
| 391 |
+
dtype = utils.type_to_dtype(type(args[0]))
|
| 392 |
+
args_ = list(args)
|
| 393 |
+
args_[0] = torch.tensor(args[0], dtype=dtype)
|
| 394 |
+
result = fn(*args_, **kwargs)
|
| 395 |
+
assert isinstance(result, torch.Tensor)
|
| 396 |
+
return result.item()
|
| 397 |
+
|
| 398 |
+
return fn(*args, **kwargs)
|
| 399 |
+
|
| 400 |
+
_fn.__signature__ = sig # type: ignore[attr-defined]
|
| 401 |
+
return _fn
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/_numeric_suite_fx.py
ADDED
|
@@ -0,0 +1,1025 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This module contains tooling to compare weights and activations
|
| 3 |
+
across models. Example usage::
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
import torch
|
| 7 |
+
import torch.ao.quantization.quantize_fx as quantize_fx
|
| 8 |
+
import torch.ao.ns._numeric_suite_fx as ns
|
| 9 |
+
|
| 10 |
+
m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)).eval()
|
| 11 |
+
mp = quantize_fx.prepare_fx(m, {'': torch.ao.quantization.default_qconfig})
|
| 12 |
+
# We convert a copy because we need the original prepared model
|
| 13 |
+
# to be available for comparisons, and `quantize_fx.convert_fx` is inplace.
|
| 14 |
+
mq = quantize_fx.convert_fx(copy.deepcopy(mp))
|
| 15 |
+
|
| 16 |
+
#
|
| 17 |
+
# Comparing weights
|
| 18 |
+
#
|
| 19 |
+
|
| 20 |
+
# extract weight pairs
|
| 21 |
+
weight_comparison = ns.extract_weights('a', mp, 'b', mq)
|
| 22 |
+
|
| 23 |
+
# add SQNR for each comparison, inplace
|
| 24 |
+
ns.extend_logger_results_with_comparison(
|
| 25 |
+
weight_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
|
| 26 |
+
'sqnr')
|
| 27 |
+
|
| 28 |
+
# weight_comparison contains the weights from `mp` and `mq` stored
|
| 29 |
+
# in pairs, and can be used for further analysis.
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
#
|
| 33 |
+
# Comparing activations, with error propagation
|
| 34 |
+
#
|
| 35 |
+
|
| 36 |
+
# add loggers
|
| 37 |
+
mp_ns, mq_ns = ns.add_loggers(
|
| 38 |
+
'a', copy.deepcopy(mp),
|
| 39 |
+
'b', copy.deepcopy(mq),
|
| 40 |
+
ns.OutputLogger)
|
| 41 |
+
|
| 42 |
+
# send an example datum to capture intermediate activations
|
| 43 |
+
datum = torch.randn(1, 1, 1, 1)
|
| 44 |
+
mp_ns(datum)
|
| 45 |
+
mq_ns(datum)
|
| 46 |
+
|
| 47 |
+
# extract intermediate activations
|
| 48 |
+
act_comparison = ns.extract_logger_info(
|
| 49 |
+
mp_ns, mq_ns, ns.OutputLogger, 'b')
|
| 50 |
+
|
| 51 |
+
# add SQNR for each comparison, inplace
|
| 52 |
+
ns.extend_logger_results_with_comparison(
|
| 53 |
+
act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
|
| 54 |
+
'sqnr')
|
| 55 |
+
|
| 56 |
+
# act_comparison contains the activations from `mp_ns` and `mq_ns` stored
|
| 57 |
+
# in pairs, and can be used for further analysis.
|
| 58 |
+
|
| 59 |
+
#
|
| 60 |
+
# Comparing activations, without error propagation
|
| 61 |
+
#
|
| 62 |
+
|
| 63 |
+
# create shadow model
|
| 64 |
+
mp_shadows_mq = ns.add_shadow_loggers(
|
| 65 |
+
'a', copy.deepcopy(mp),
|
| 66 |
+
'b', copy.deepcopy(mq),
|
| 67 |
+
ns.OutputLogger)
|
| 68 |
+
|
| 69 |
+
# send an example datum to capture intermediate activations
|
| 70 |
+
datum = torch.randn(1, 1, 1, 1)
|
| 71 |
+
mp_shadows_mq(datum)
|
| 72 |
+
|
| 73 |
+
# extract intermediate activations
|
| 74 |
+
shadow_act_comparison = ns.extract_shadow_logger_info(
|
| 75 |
+
mp_shadows_mq, ns.OutputLogger, 'b')
|
| 76 |
+
|
| 77 |
+
# add SQNR for each comparison, inplace
|
| 78 |
+
ns.extend_logger_results_with_comparison(
|
| 79 |
+
shadow_act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
|
| 80 |
+
'sqnr')
|
| 81 |
+
|
| 82 |
+
# shadow_act_comparison contains the activations from `mp_ns` and `mq_ns` stored
|
| 83 |
+
# in pairs, and can be used for further analysis.
|
| 84 |
+
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
import collections
|
| 88 |
+
|
| 89 |
+
import torch
|
| 90 |
+
import torch.nn as nn
|
| 91 |
+
import torch.ao.quantization.quantize_fx as quantize_fx
|
| 92 |
+
from torch.fx import GraphModule
|
| 93 |
+
from torch.fx.graph import Node
|
| 94 |
+
from torch.ao.ns.fx.mappings import (
|
| 95 |
+
get_base_name_to_sets_of_related_ops,
|
| 96 |
+
)
|
| 97 |
+
from torch.ao.ns.fx.graph_matcher import (
|
| 98 |
+
get_matching_subgraph_pairs,
|
| 99 |
+
get_type_a_related_to_b,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
from .fx.weight_utils import (
|
| 103 |
+
extract_weight_from_node,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
from .fx.graph_passes import (
|
| 107 |
+
add_loggers_to_model,
|
| 108 |
+
create_a_shadows_b,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
from .fx.utils import (
|
| 112 |
+
rekey_logger_info_on_node_name_of_model,
|
| 113 |
+
maybe_add_missing_fqns,
|
| 114 |
+
get_target_type_str,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
from .fx.ns_types import (
|
| 118 |
+
NSSingleResultValuesType,
|
| 119 |
+
NSResultsType,
|
| 120 |
+
NSNodeTargetType,
|
| 121 |
+
)
|
| 122 |
+
from torch.ao.quantization.backend_config.utils import get_fusion_pattern_to_root_node_getter
|
| 123 |
+
from torch.ao.quantization.backend_config import BackendConfig
|
| 124 |
+
from torch.ao.quantization.fx.match_utils import _find_matches
|
| 125 |
+
from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr
|
| 126 |
+
from torch.ao.quantization.fx.qconfig_mapping_utils import _generate_node_name_to_qconfig
|
| 127 |
+
from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
|
| 128 |
+
from torch.ao.quantization.qconfig import QConfigAny
|
| 129 |
+
from torch.ao.quantization import QConfigMapping
|
| 130 |
+
from torch.ao.ns.fx.n_shadows_utils import (
|
| 131 |
+
OutputProp,
|
| 132 |
+
_get_dedup_subgraphs,
|
| 133 |
+
SHADOW_WRAPPER_NODE_NAME_PREFIX,
|
| 134 |
+
group_results_by_subgraph,
|
| 135 |
+
create_results_comparison,
|
| 136 |
+
print_n_shadows_summary,
|
| 137 |
+
create_n_transformed_and_logged_copies_of_subgraph,
|
| 138 |
+
create_add_loggers_graph,
|
| 139 |
+
extract_weight_comparison,
|
| 140 |
+
)
|
| 141 |
+
from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping
|
| 142 |
+
|
| 143 |
+
from typing import Dict, Tuple, Callable, List, Optional, Set, Any, Type
|
| 144 |
+
|
| 145 |
+
RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
| 146 |
+
|
| 147 |
+
class OutputLogger(nn.Module):
|
| 148 |
+
"""
|
| 149 |
+
Base class for capturing intermediate values.
|
| 150 |
+
"""
|
| 151 |
+
stats: List[torch.Tensor]
|
| 152 |
+
stats_rnn: List[RNNReturnType]
|
| 153 |
+
|
| 154 |
+
# Mark as impure so that calls to it will not be removed during DCE.
|
| 155 |
+
_is_impure = True
|
| 156 |
+
|
| 157 |
+
def __init__(
|
| 158 |
+
self,
|
| 159 |
+
ref_node_name: str,
|
| 160 |
+
prev_node_name: str,
|
| 161 |
+
model_name: str,
|
| 162 |
+
ref_name: str,
|
| 163 |
+
prev_node_target_type: str,
|
| 164 |
+
ref_node_target_type: str,
|
| 165 |
+
results_type: str,
|
| 166 |
+
index_within_arg: int,
|
| 167 |
+
index_of_arg: int,
|
| 168 |
+
fqn: Optional[str],
|
| 169 |
+
qconfig_str: Optional[str] = '',
|
| 170 |
+
):
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.stats: List[torch.Tensor] = []
|
| 173 |
+
self.stats_rnn: List[RNNReturnType] = []
|
| 174 |
+
|
| 175 |
+
# name of the node which was responsible for adding this logger
|
| 176 |
+
# Note:
|
| 177 |
+
# - if we are logging node outputs, this is the same as prev_node_name
|
| 178 |
+
# - if we are logging node inputs, this is the name of the node
|
| 179 |
+
# whose input this logger is logging.
|
| 180 |
+
#
|
| 181 |
+
# example, where logger1 is logging input of op1 and logger2 is logging
|
| 182 |
+
# the output of op1:
|
| 183 |
+
#
|
| 184 |
+
# x1 -> logger1 -> op1 -> logger2 -> x2
|
| 185 |
+
#
|
| 186 |
+
# in this example,
|
| 187 |
+
# - logger1's prev_node_name is x1 and ref_node_name is op1
|
| 188 |
+
# - logger2's prev_node_name is op1 and ref_node_name is op1
|
| 189 |
+
self.ref_node_name = ref_node_name
|
| 190 |
+
# name of the node whose output this Logger is capturing
|
| 191 |
+
self.prev_node_name = prev_node_name
|
| 192 |
+
|
| 193 |
+
# name of the model from which the node originated from
|
| 194 |
+
self.model_name = model_name
|
| 195 |
+
# reference name, used to match loggers from separate models
|
| 196 |
+
# to each other
|
| 197 |
+
self.ref_name = ref_name
|
| 198 |
+
# type of the target of the node whose output this logger is logging
|
| 199 |
+
self.prev_node_target_type = prev_node_target_type
|
| 200 |
+
# type of the target of the node which was responsible for adding this
|
| 201 |
+
# logger
|
| 202 |
+
self.ref_node_target_type = ref_node_target_type
|
| 203 |
+
# what kind of values are inside of stats
|
| 204 |
+
self.results_type = results_type
|
| 205 |
+
# index of this node within the arg of the input/output node
|
| 206 |
+
# for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
|
| 207 |
+
self.index_within_arg = index_within_arg
|
| 208 |
+
# index of this node within the args of the input/output node
|
| 209 |
+
# for example, in add(x1, x2), x2 would have index_of_arg == 1
|
| 210 |
+
self.index_of_arg = index_of_arg
|
| 211 |
+
# fully qualified name
|
| 212 |
+
self.fqn = fqn
|
| 213 |
+
# if loggers are added before prepare_fx, but we do not want
|
| 214 |
+
# collect results of calibration, only results after convert_fx
|
| 215 |
+
# so, we add a flag to control whether this logger collects data
|
| 216 |
+
self.enabled = True
|
| 217 |
+
# string representation of qconfig
|
| 218 |
+
self.qconfig_str = qconfig_str
|
| 219 |
+
# this can be turned off to reduce memory usage during calibration
|
| 220 |
+
self.save_activations = True
|
| 221 |
+
|
| 222 |
+
# Note: cannot annotate the type of x because TorchScript does not support
|
| 223 |
+
# the Union type.
|
| 224 |
+
def forward(self, x):
|
| 225 |
+
"""
|
| 226 |
+
""" # blank docblock to make autodoc happy
|
| 227 |
+
# TODO(future PR): consider designing this better, as the difference
|
| 228 |
+
# between these two flags is subtle and not obvious.
|
| 229 |
+
if not self.enabled:
|
| 230 |
+
return x
|
| 231 |
+
if not self.save_activations:
|
| 232 |
+
return x
|
| 233 |
+
# TODO(future PR): consider refactoring this to better reuse the parent
|
| 234 |
+
# class
|
| 235 |
+
if isinstance(x, torch.Tensor):
|
| 236 |
+
self.stats.append(x.detach())
|
| 237 |
+
elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
|
| 238 |
+
new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
|
| 239 |
+
self.stats_rnn.append(new_res)
|
| 240 |
+
return x
|
| 241 |
+
|
| 242 |
+
def __repr__(self):
|
| 243 |
+
clean_dict = {
|
| 244 |
+
k: v
|
| 245 |
+
for k, v in self.__dict__.items()
|
| 246 |
+
# skip nn.Module keys
|
| 247 |
+
if (k != 'training') and not k.startswith('_')
|
| 248 |
+
}
|
| 249 |
+
return f"OutputLogger({clean_dict})"
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class OutputComparisonLogger(OutputLogger):
|
| 253 |
+
"""
|
| 254 |
+
Same as OutputLogger, but also requires the original activation
|
| 255 |
+
in order to calculate the comparison at calibration time
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
def __init__(self, *args, **kwargs):
|
| 259 |
+
super().__init__(*args, **kwargs)
|
| 260 |
+
# TODO(future PR): make the comparison function configurable
|
| 261 |
+
self.comparison_fn = torch.ao.ns.fx.utils.compute_sqnr
|
| 262 |
+
self.comparison_fn_name = 'sqnr'
|
| 263 |
+
# precalculated comparisons of logger output versus reference
|
| 264 |
+
self.comparisons = []
|
| 265 |
+
# precalculated comparisons function
|
| 266 |
+
|
| 267 |
+
def forward(self, x, x_ref):
|
| 268 |
+
"""
|
| 269 |
+
""" # blank docblock to make autodoc happy
|
| 270 |
+
if not self.enabled:
|
| 271 |
+
return x
|
| 272 |
+
assert isinstance(x, torch.Tensor), 'non-tensor inputs not yet supported'
|
| 273 |
+
if self.save_activations:
|
| 274 |
+
# save the activation, for debugging
|
| 275 |
+
self.stats.append(x.detach())
|
| 276 |
+
# save the comparison
|
| 277 |
+
self.comparisons.append(self.comparison_fn(x, x_ref))
|
| 278 |
+
return x
|
| 279 |
+
|
| 280 |
+
def __repr__(self):
|
| 281 |
+
clean_dict = {
|
| 282 |
+
k: v
|
| 283 |
+
for k, v in self.__dict__.items()
|
| 284 |
+
# skip nn.Module keys
|
| 285 |
+
if (k != 'training') and not k.startswith('_')
|
| 286 |
+
}
|
| 287 |
+
return f"OutputComparisonLogger({clean_dict})"
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class NSTracer(quantize_fx.QuantizationTracer):
|
| 291 |
+
"""
|
| 292 |
+
Just like a regular FX quantization tracer, but treats observers and fake_quantize
|
| 293 |
+
modules as leaf modules.
|
| 294 |
+
"""
|
| 295 |
+
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
|
| 296 |
+
"""
|
| 297 |
+
""" # blank docblock to make autodoc happy
|
| 298 |
+
if isinstance(m, torch.ao.quantization.ObserverBase):
|
| 299 |
+
return True
|
| 300 |
+
elif isinstance(m, torch.ao.quantization.FakeQuantizeBase):
|
| 301 |
+
return True
|
| 302 |
+
return super().is_leaf_module(m, module_qualified_name)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _extract_weights_one_model(
|
| 306 |
+
model_name: str,
|
| 307 |
+
model: GraphModule,
|
| 308 |
+
nodes_and_names_to_instrument: List[Tuple[Node, str]],
|
| 309 |
+
results: NSResultsType,
|
| 310 |
+
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
|
| 311 |
+
) -> None:
|
| 312 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model")
|
| 313 |
+
for node, ref_name in nodes_and_names_to_instrument:
|
| 314 |
+
res_type = NSSingleResultValuesType.WEIGHT.value
|
| 315 |
+
extracted_weight = extract_weight_from_node(
|
| 316 |
+
node, model, op_to_type_to_weight_extraction_fn)
|
| 317 |
+
if extracted_weight:
|
| 318 |
+
if ref_name not in results:
|
| 319 |
+
results[ref_name] = {res_type: {}}
|
| 320 |
+
results[ref_name][res_type][model_name] = [extracted_weight]
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def _extract_weights_impl(
|
| 324 |
+
model_name_a: str,
|
| 325 |
+
gm_a: GraphModule,
|
| 326 |
+
model_name_b: str,
|
| 327 |
+
gm_b: GraphModule,
|
| 328 |
+
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 329 |
+
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 330 |
+
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
|
| 331 |
+
) -> NSResultsType:
|
| 332 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl")
|
| 333 |
+
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
| 334 |
+
gm_a, gm_b, base_name_to_sets_of_related_ops,
|
| 335 |
+
unmatchable_types_map)
|
| 336 |
+
|
| 337 |
+
# split the subgraph pairs into one data structure for each model
|
| 338 |
+
nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
|
| 339 |
+
nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
|
| 340 |
+
for match_name, match in matched_subgraph_pairs.items():
|
| 341 |
+
subgraph_a, subgraph_b = match
|
| 342 |
+
nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
|
| 343 |
+
nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
|
| 344 |
+
|
| 345 |
+
# populate the results, one model at a time
|
| 346 |
+
results: NSResultsType = {}
|
| 347 |
+
_extract_weights_one_model(
|
| 348 |
+
model_name_a, gm_a, nodes_and_names_to_instrument_a, results,
|
| 349 |
+
op_to_type_to_weight_extraction_fn)
|
| 350 |
+
_extract_weights_one_model(
|
| 351 |
+
model_name_b, gm_b, nodes_and_names_to_instrument_b, results,
|
| 352 |
+
op_to_type_to_weight_extraction_fn)
|
| 353 |
+
|
| 354 |
+
# fill in missing fqn entries
|
| 355 |
+
maybe_add_missing_fqns(results)
|
| 356 |
+
|
| 357 |
+
# rekey on names of nodes in gm_b
|
| 358 |
+
results = rekey_logger_info_on_node_name_of_model(results, model_name_b)
|
| 359 |
+
|
| 360 |
+
return results
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def extract_weights(
|
| 364 |
+
model_name_a: str,
|
| 365 |
+
model_a: nn.Module,
|
| 366 |
+
model_name_b: str,
|
| 367 |
+
model_b: nn.Module,
|
| 368 |
+
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 369 |
+
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 370 |
+
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
|
| 371 |
+
) -> NSResultsType:
|
| 372 |
+
"""
|
| 373 |
+
Extract weights from model A and model B, and return a comparison.
|
| 374 |
+
|
| 375 |
+
Args:
|
| 376 |
+
model_name_a: string name of model A to use in results
|
| 377 |
+
model_a: model A
|
| 378 |
+
model_name_b: string name of model B to use in results
|
| 379 |
+
model_b: model B
|
| 380 |
+
base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
|
| 381 |
+
unmatchable_types_map: optional override of unmatchable types, subject to change
|
| 382 |
+
op_to_type_to_weight_extraction_fn: optional override of function which extracts weight
|
| 383 |
+
from a type, subject to change
|
| 384 |
+
|
| 385 |
+
Return:
|
| 386 |
+
NSResultsType, containing the weight comparisons
|
| 387 |
+
"""
|
| 388 |
+
|
| 389 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
|
| 390 |
+
if base_name_to_sets_of_related_ops is None:
|
| 391 |
+
base_name_to_sets_of_related_ops = \
|
| 392 |
+
get_base_name_to_sets_of_related_ops()
|
| 393 |
+
type_a_related_to_b = \
|
| 394 |
+
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
|
| 395 |
+
|
| 396 |
+
# TODO(future PR): expose these
|
| 397 |
+
skipped_module_names: List[str] = []
|
| 398 |
+
skipped_module_classes: List[Callable] = []
|
| 399 |
+
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
| 400 |
+
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
| 401 |
+
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
| 402 |
+
maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(model_a, 'node_name_to_scope')
|
| 403 |
+
if maybe_model_a_node_name_to_scope is not None:
|
| 404 |
+
gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
|
| 405 |
+
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
| 406 |
+
maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(model_b, 'node_name_to_scope')
|
| 407 |
+
if maybe_model_b_node_name_to_scope is not None:
|
| 408 |
+
gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
|
| 409 |
+
return _extract_weights_impl(
|
| 410 |
+
model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops,
|
| 411 |
+
unmatchable_types_map, op_to_type_to_weight_extraction_fn)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def _add_loggers_one_model(
|
| 415 |
+
model_name: str,
|
| 416 |
+
model: GraphModule,
|
| 417 |
+
nodes_and_names_to_instrument_inputs: List[Tuple[Node, str, str]],
|
| 418 |
+
nodes_and_names_to_instrument_outputs: List[Tuple[Node, str, str]],
|
| 419 |
+
logger_cls: Callable,
|
| 420 |
+
) -> nn.Module:
|
| 421 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_one_model")
|
| 422 |
+
|
| 423 |
+
# TODO(future PR): do not observe nodes we do not care
|
| 424 |
+
# about (both fp32, denylist, etc)
|
| 425 |
+
node_to_instrument_inputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
|
| 426 |
+
node_to_instrument_outputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
|
| 427 |
+
for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs:
|
| 428 |
+
node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type)
|
| 429 |
+
for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs:
|
| 430 |
+
node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type)
|
| 431 |
+
|
| 432 |
+
model = add_loggers_to_model(
|
| 433 |
+
model, node_to_instrument_inputs_to_ref_name,
|
| 434 |
+
node_to_instrument_outputs_to_ref_name, logger_cls, model_name)
|
| 435 |
+
return model
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def _add_loggers_impl(
|
| 439 |
+
name_a: str,
|
| 440 |
+
gm_a: GraphModule,
|
| 441 |
+
name_b: str,
|
| 442 |
+
gm_b: GraphModule,
|
| 443 |
+
logger_cls: Callable,
|
| 444 |
+
should_log_inputs: bool,
|
| 445 |
+
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 446 |
+
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 447 |
+
) -> Tuple[nn.Module, nn.Module]:
|
| 448 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl")
|
| 449 |
+
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
| 450 |
+
gm_a, gm_b,
|
| 451 |
+
base_name_to_sets_of_related_ops, unmatchable_types_map)
|
| 452 |
+
nodes_and_names_to_instrument_inputs_a = []
|
| 453 |
+
nodes_and_names_to_instrument_inputs_b = []
|
| 454 |
+
nodes_and_names_to_instrument_outputs_a = []
|
| 455 |
+
nodes_and_names_to_instrument_outputs_b = []
|
| 456 |
+
for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
|
| 457 |
+
ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
|
| 458 |
+
ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
|
| 459 |
+
# Note: for matching inputs we use start_node, such as observing
|
| 460 |
+
# the input of linear in linear-relu
|
| 461 |
+
if should_log_inputs:
|
| 462 |
+
nodes_and_names_to_instrument_inputs_a.append(
|
| 463 |
+
(subgraph_a.start_node, match_name, ref_node_type_a))
|
| 464 |
+
nodes_and_names_to_instrument_inputs_b.append(
|
| 465 |
+
(subgraph_b.start_node, match_name, ref_node_type_b))
|
| 466 |
+
# Note: for matching activations we always use end_node,
|
| 467 |
+
# such as observing the output of relu in linear-relu
|
| 468 |
+
nodes_and_names_to_instrument_outputs_a.append(
|
| 469 |
+
(subgraph_a.end_node, match_name, ref_node_type_a))
|
| 470 |
+
nodes_and_names_to_instrument_outputs_b.append(
|
| 471 |
+
(subgraph_b.end_node, match_name, ref_node_type_b))
|
| 472 |
+
|
| 473 |
+
new_model_a = _add_loggers_one_model(
|
| 474 |
+
name_a, gm_a, nodes_and_names_to_instrument_inputs_a,
|
| 475 |
+
nodes_and_names_to_instrument_outputs_a, logger_cls)
|
| 476 |
+
new_model_b = _add_loggers_one_model(
|
| 477 |
+
name_b, gm_b, nodes_and_names_to_instrument_inputs_b,
|
| 478 |
+
nodes_and_names_to_instrument_outputs_b, logger_cls)
|
| 479 |
+
return (new_model_a, new_model_b)
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def add_loggers(
|
| 483 |
+
name_a: str,
|
| 484 |
+
model_a: nn.Module,
|
| 485 |
+
name_b: str,
|
| 486 |
+
model_b: nn.Module,
|
| 487 |
+
logger_cls: Callable,
|
| 488 |
+
should_log_inputs : bool = False,
|
| 489 |
+
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 490 |
+
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 491 |
+
) -> Tuple[nn.Module, nn.Module]:
|
| 492 |
+
"""
|
| 493 |
+
Instrument model A and model B with loggers.
|
| 494 |
+
|
| 495 |
+
Args:
|
| 496 |
+
name_a: string name of model A to use in results
|
| 497 |
+
model_a: model A
|
| 498 |
+
name_b: string name of model B to use in results
|
| 499 |
+
model_b: model B
|
| 500 |
+
logger_cls: class of Logger to use
|
| 501 |
+
base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
|
| 502 |
+
unmatchable_types_map: optional override of unmatchable types, subject to change
|
| 503 |
+
|
| 504 |
+
Return:
|
| 505 |
+
Returns a tuple of (model_a_with_loggers, model_b_with_loggers). Modifies both models inplace.
|
| 506 |
+
"""
|
| 507 |
+
|
| 508 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers")
|
| 509 |
+
# TODO(future PR): expose these
|
| 510 |
+
skipped_module_names: List[str] = []
|
| 511 |
+
skipped_module_classes: List[Callable] = []
|
| 512 |
+
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
| 513 |
+
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
| 514 |
+
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
| 515 |
+
maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(model_a, 'node_name_to_scope')
|
| 516 |
+
if maybe_model_a_node_name_to_scope is not None:
|
| 517 |
+
gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
|
| 518 |
+
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
| 519 |
+
maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(model_b, 'node_name_to_scope')
|
| 520 |
+
if maybe_model_b_node_name_to_scope is not None:
|
| 521 |
+
gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
|
| 522 |
+
return _add_loggers_impl(
|
| 523 |
+
name_a, gm_a, name_b, gm_b, logger_cls,
|
| 524 |
+
should_log_inputs=should_log_inputs,
|
| 525 |
+
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
|
| 526 |
+
unmatchable_types_map=unmatchable_types_map)
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def _extract_logger_info_one_model(
|
| 530 |
+
model: nn.Module,
|
| 531 |
+
results: NSResultsType,
|
| 532 |
+
logger_cls: Callable,
|
| 533 |
+
) -> None:
|
| 534 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_logger_info_one_model")
|
| 535 |
+
for gm_name, mod in model.named_modules():
|
| 536 |
+
# TODO(future PR): better check when scripted
|
| 537 |
+
is_logger = (
|
| 538 |
+
isinstance(mod, logger_cls) # type: ignore[arg-type]
|
| 539 |
+
or (
|
| 540 |
+
isinstance(mod, torch.jit.RecursiveScriptModule)
|
| 541 |
+
and mod.original_name == 'OutputLogger'
|
| 542 |
+
)
|
| 543 |
+
)
|
| 544 |
+
if is_logger:
|
| 545 |
+
key = mod.ref_name
|
| 546 |
+
if key not in results:
|
| 547 |
+
results[key] = {}
|
| 548 |
+
assert mod.model_name not in results[key], \
|
| 549 |
+
f"{mod.model_name} is already present in results"
|
| 550 |
+
if mod.results_type not in results[key]:
|
| 551 |
+
results[key][mod.results_type] = {}
|
| 552 |
+
if mod.model_name not in results[key][mod.results_type]:
|
| 553 |
+
results[key][mod.results_type][mod.model_name] = []
|
| 554 |
+
stats_to_use = mod.stats
|
| 555 |
+
if len(mod.stats_rnn) > 0:
|
| 556 |
+
stats_to_use = mod.stats_rnn
|
| 557 |
+
data = {
|
| 558 |
+
'type': mod.results_type,
|
| 559 |
+
'values': stats_to_use,
|
| 560 |
+
'ref_node_name': mod.ref_node_name,
|
| 561 |
+
'ref_node_target_type': mod.ref_node_target_type,
|
| 562 |
+
'prev_node_name': mod.prev_node_name,
|
| 563 |
+
'prev_node_target_type': mod.prev_node_target_type,
|
| 564 |
+
'index_within_arg': mod.index_within_arg,
|
| 565 |
+
'index_of_arg': mod.index_of_arg,
|
| 566 |
+
'fqn': mod.fqn,
|
| 567 |
+
'qconfig_str': mod.qconfig_str,
|
| 568 |
+
}
|
| 569 |
+
if hasattr(mod, 'comparisons'):
|
| 570 |
+
data['comparisons'] = mod.comparisons
|
| 571 |
+
data['comparison_fn_name'] = mod.comparison_fn_name
|
| 572 |
+
else:
|
| 573 |
+
data['comparisons'] = []
|
| 574 |
+
data['comparison_fn_name'] = ''
|
| 575 |
+
results[key][mod.results_type][mod.model_name].append(data)
|
| 576 |
+
# ensure the list stays sorted
|
| 577 |
+
results[key][mod.results_type][mod.model_name].sort(
|
| 578 |
+
key=lambda res:
|
| 579 |
+
f"{res['index_of_arg']}:{res['index_within_arg']}"
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
# TODO(future PR): align on naming
|
| 584 |
+
# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
|
| 585 |
+
def extract_logger_info(
|
| 586 |
+
model_a: nn.Module,
|
| 587 |
+
model_b: nn.Module,
|
| 588 |
+
logger_cls: Callable,
|
| 589 |
+
model_name_to_use_for_layer_names: str,
|
| 590 |
+
) -> NSResultsType:
|
| 591 |
+
"""
|
| 592 |
+
Traverse all loggers in `model_a` and `model_b`, and extract the logged
|
| 593 |
+
information.
|
| 594 |
+
|
| 595 |
+
Args:
|
| 596 |
+
model_a: model A
|
| 597 |
+
model_b: model B
|
| 598 |
+
logger_cls: class of Logger to use
|
| 599 |
+
model_name_to_use_for_layer_names: string name of model to use for
|
| 600 |
+
layer names in the output
|
| 601 |
+
|
| 602 |
+
Return:
|
| 603 |
+
NSResultsType, containing the logged comparisons
|
| 604 |
+
"""
|
| 605 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_logger_info")
|
| 606 |
+
results: NSResultsType = {}
|
| 607 |
+
for model in (model_a, model_b):
|
| 608 |
+
_extract_logger_info_one_model(model, results, logger_cls)
|
| 609 |
+
# fill in missing fqn entries
|
| 610 |
+
maybe_add_missing_fqns(results)
|
| 611 |
+
# rekey on the name of model b
|
| 612 |
+
results = rekey_logger_info_on_node_name_of_model(
|
| 613 |
+
results, model_name_to_use_for_layer_names)
|
| 614 |
+
return results
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
def _add_shadow_loggers_impl(
|
| 618 |
+
name_a: str,
|
| 619 |
+
gm_a: GraphModule,
|
| 620 |
+
name_b: str,
|
| 621 |
+
gm_b: GraphModule,
|
| 622 |
+
logger_cls: Callable,
|
| 623 |
+
should_log_inputs: bool,
|
| 624 |
+
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 625 |
+
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 626 |
+
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 627 |
+
) -> nn.Module:
|
| 628 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_shadow_loggers_impl")
|
| 629 |
+
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
| 630 |
+
gm_a, gm_b, base_name_to_sets_of_related_ops,
|
| 631 |
+
unmatchable_types_map)
|
| 632 |
+
gm_a_shadows_b = create_a_shadows_b(
|
| 633 |
+
name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls,
|
| 634 |
+
should_log_inputs=should_log_inputs,
|
| 635 |
+
node_type_to_io_type_map=node_type_to_io_type_map)
|
| 636 |
+
return gm_a_shadows_b
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def add_shadow_loggers(
|
| 640 |
+
name_a: str,
|
| 641 |
+
model_a: nn.Module,
|
| 642 |
+
name_b: str,
|
| 643 |
+
model_b: nn.Module,
|
| 644 |
+
logger_cls: Callable,
|
| 645 |
+
should_log_inputs: bool = False,
|
| 646 |
+
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 647 |
+
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 648 |
+
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 649 |
+
) -> nn.Module:
|
| 650 |
+
"""
|
| 651 |
+
Instrument model A and model B with shadow loggers.
|
| 652 |
+
|
| 653 |
+
Args:
|
| 654 |
+
name_a: string name of model A to use in results
|
| 655 |
+
model_a: model A
|
| 656 |
+
name_b: string name of model B to use in results
|
| 657 |
+
model_b: model B
|
| 658 |
+
logger_cls: class of Logger to use
|
| 659 |
+
should_log_inputs: whether to log inputs
|
| 660 |
+
base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
|
| 661 |
+
unmatchable_types_map: optional override of unmatchable types, subject to change
|
| 662 |
+
"""
|
| 663 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_shadow_loggers")
|
| 664 |
+
# TODO(future PR): expose these
|
| 665 |
+
skipped_module_names: List[str] = []
|
| 666 |
+
skipped_module_classes: List[Callable] = []
|
| 667 |
+
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
| 668 |
+
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
| 669 |
+
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
| 670 |
+
maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(model_a, 'node_name_to_scope')
|
| 671 |
+
if maybe_model_a_node_name_to_scope is not None:
|
| 672 |
+
gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
|
| 673 |
+
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
| 674 |
+
maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(model_b, 'node_name_to_scope')
|
| 675 |
+
if maybe_model_b_node_name_to_scope is not None:
|
| 676 |
+
gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
|
| 677 |
+
return _add_shadow_loggers_impl(
|
| 678 |
+
name_a, gm_a, name_b, gm_b, logger_cls,
|
| 679 |
+
should_log_inputs=should_log_inputs,
|
| 680 |
+
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
|
| 681 |
+
node_type_to_io_type_map=node_type_to_io_type_map,
|
| 682 |
+
unmatchable_types_map=unmatchable_types_map)
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
def extract_shadow_logger_info(
|
| 686 |
+
model_a_shadows_b: nn.Module,
|
| 687 |
+
logger_cls: Callable,
|
| 688 |
+
model_name_to_use_for_layer_names: str,
|
| 689 |
+
) -> NSResultsType:
|
| 690 |
+
"""
|
| 691 |
+
Traverse all loggers in a shadow model, and extract the logged
|
| 692 |
+
information.
|
| 693 |
+
|
| 694 |
+
Args:
|
| 695 |
+
model_a_shadows_b: shadow model
|
| 696 |
+
logger_cls: class of Logger to use
|
| 697 |
+
model_name_to_use_for_layer_names: string name of model to use for
|
| 698 |
+
layer names in the output
|
| 699 |
+
|
| 700 |
+
Return:
|
| 701 |
+
NSResultsType, containing the logged comparisons
|
| 702 |
+
"""
|
| 703 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_shadow_logger_info")
|
| 704 |
+
results: NSResultsType = collections.defaultdict(dict)
|
| 705 |
+
_extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
|
| 706 |
+
# fill in missing fqn entries
|
| 707 |
+
maybe_add_missing_fqns(results)
|
| 708 |
+
# rekey on the name of model b
|
| 709 |
+
results = rekey_logger_info_on_node_name_of_model(
|
| 710 |
+
results, model_name_to_use_for_layer_names)
|
| 711 |
+
return dict(results)
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
def extend_logger_results_with_comparison(
|
| 715 |
+
results: NSResultsType,
|
| 716 |
+
model_name_1: str,
|
| 717 |
+
model_name_2: str,
|
| 718 |
+
comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
| 719 |
+
comparison_name: str,
|
| 720 |
+
) -> None:
|
| 721 |
+
"""
|
| 722 |
+
Compares the logged values from `model_name_2` against the corresponding
|
| 723 |
+
values in `model_name_1`, using `comparison_fn`. Records the result
|
| 724 |
+
in `model_name_2`'s results under `comparison_name`. Modifies `results` inplace.
|
| 725 |
+
|
| 726 |
+
Args:
|
| 727 |
+
results: the result data structure from `extract_logger_info` or
|
| 728 |
+
`extract_shadow_logger_info`.
|
| 729 |
+
model_name_1: string name of model 1
|
| 730 |
+
model_name_2: string name of model 2
|
| 731 |
+
comparison_fn: function to compare two Tensors
|
| 732 |
+
comparison_name: string name of model to use for
|
| 733 |
+
layer names in the output
|
| 734 |
+
"""
|
| 735 |
+
for results_type_to_results in results.values():
|
| 736 |
+
for model_name_to_results in results_type_to_results.values():
|
| 737 |
+
assert model_name_1 in model_name_to_results, \
|
| 738 |
+
f"{model_name_1} not found in results"
|
| 739 |
+
assert model_name_2 in model_name_to_results, \
|
| 740 |
+
f"{model_name_2} not found in results"
|
| 741 |
+
|
| 742 |
+
results_1 = model_name_to_results[model_name_1]
|
| 743 |
+
results_2 = model_name_to_results[model_name_2]
|
| 744 |
+
|
| 745 |
+
for result_2 in results_2:
|
| 746 |
+
index_within_arg_2 = result_2['index_within_arg']
|
| 747 |
+
index_of_arg_2 = result_2['index_of_arg']
|
| 748 |
+
# find corresponding result_1
|
| 749 |
+
result_1 = None
|
| 750 |
+
for cur_result_1 in results_1:
|
| 751 |
+
index_within_arg_1 = cur_result_1['index_within_arg']
|
| 752 |
+
index_of_arg_1 = cur_result_1['index_of_arg']
|
| 753 |
+
if (
|
| 754 |
+
(index_within_arg_1 == index_within_arg_2) and
|
| 755 |
+
(index_of_arg_1 == index_of_arg_2)
|
| 756 |
+
):
|
| 757 |
+
result_1 = cur_result_1
|
| 758 |
+
break
|
| 759 |
+
assert result_1 is not None
|
| 760 |
+
|
| 761 |
+
values_1 = result_1['values']
|
| 762 |
+
values_2 = result_2['values']
|
| 763 |
+
result_2[comparison_name] = []
|
| 764 |
+
for value_1, value_2 in zip(values_1, values_2):
|
| 765 |
+
comparison_result = comparison_fn(value_1, value_2)
|
| 766 |
+
result_2[comparison_name].append(comparison_result)
|
| 767 |
+
|
| 768 |
+
def prepare_n_shadows_model(
|
| 769 |
+
model: torch.nn.Module,
|
| 770 |
+
example_inputs: Any,
|
| 771 |
+
qconfig_multi_mapping: QConfigMultiMapping,
|
| 772 |
+
backend_config: BackendConfig,
|
| 773 |
+
custom_prepare_fn: Optional[Callable] = None,
|
| 774 |
+
custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
|
| 775 |
+
custom_tracer: Any = None,
|
| 776 |
+
) -> GraphModule:
|
| 777 |
+
"""
|
| 778 |
+
Given a model with a graph with M ops such as
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
args_kwargs_m -> op_m -> output_m
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
And a set of N qconfigs for each op, creates a new model, with
|
| 785 |
+
each of the subgraph of `op_m` transformed into
|
| 786 |
+
|
| 787 |
+
.. code::
|
| 788 |
+
|
| 789 |
+
|---------> op_m_n -> log_m_n
|
| 790 |
+
| /
|
| 791 |
+
args_kwargs_m ---------> op_m -> log_m_0
|
| 792 |
+
|
| 793 |
+
Where op_m_n is op_m wrapped in a submodule and transformed with
|
| 794 |
+
qconfig_n, and its inner graph looks like
|
| 795 |
+
|
| 796 |
+
.. code::
|
| 797 |
+
|
| 798 |
+
args_m -------- op_m_prepared_with_qconfig_n -> out_m_n
|
| 799 |
+
/
|
| 800 |
+
kwargs_m ---
|
| 801 |
+
|
| 802 |
+
This is useful for testing different quantization of multiple layers in
|
| 803 |
+
a single pass through the model.
|
| 804 |
+
|
| 805 |
+
High level TODOs for future PRs:
|
| 806 |
+
* figure out a better way to name the output structure
|
| 807 |
+
* return a results data structure instead of printing it out
|
| 808 |
+
* add examples to docblocks
|
| 809 |
+
"""
|
| 810 |
+
|
| 811 |
+
if custom_tracer is None:
|
| 812 |
+
tracer = quantize_fx.QuantizationTracer([], [])
|
| 813 |
+
else:
|
| 814 |
+
tracer = custom_tracer
|
| 815 |
+
mt = torch.fx.GraphModule(model, tracer.trace(model))
|
| 816 |
+
# this is necessary to ensure logger FQNs get populated
|
| 817 |
+
mt._node_name_to_scope = tracer.node_name_to_scope
|
| 818 |
+
|
| 819 |
+
# run example input propagation, we need this to call prepare_fx on
|
| 820 |
+
# individual subgraphs
|
| 821 |
+
output_prop = OutputProp(mt)
|
| 822 |
+
output_prop.propagate(*example_inputs)
|
| 823 |
+
|
| 824 |
+
# Find the set of subgraphs in the original graph which we need to
|
| 825 |
+
# consider.
|
| 826 |
+
modules = dict(mt.named_modules(remove_duplicate=False))
|
| 827 |
+
patterns = _get_pattern_to_quantize_handlers(backend_config)
|
| 828 |
+
root_node_getter_mapping = \
|
| 829 |
+
get_fusion_pattern_to_root_node_getter(backend_config)
|
| 830 |
+
standalone_module_names: List[str] = []
|
| 831 |
+
standalone_module_classes: List[Type] = []
|
| 832 |
+
custom_module_classes: List[Type] = []
|
| 833 |
+
matches = _find_matches(
|
| 834 |
+
mt.graph, modules, patterns, root_node_getter_mapping,
|
| 835 |
+
standalone_module_names, standalone_module_classes, custom_module_classes)
|
| 836 |
+
subgraphs_dedup: Dict[str, List[Node]] = \
|
| 837 |
+
_get_dedup_subgraphs(matches)
|
| 838 |
+
|
| 839 |
+
# generate node to qconfig for each subgraph
|
| 840 |
+
# TODO(future PR): deduplicate repeating entries
|
| 841 |
+
list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]] = []
|
| 842 |
+
for qconfig_mapping in qconfig_multi_mapping.qconfig_mappings_list:
|
| 843 |
+
node_name_to_qconfig = _generate_node_name_to_qconfig(
|
| 844 |
+
mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope)
|
| 845 |
+
list_of_node_name_to_qconfig.append(node_name_to_qconfig)
|
| 846 |
+
|
| 847 |
+
# For each region in the model, do the following:
|
| 848 |
+
# For each qconfig for that region, do the following:
|
| 849 |
+
# 1. create a copy of the region wrapped in a module
|
| 850 |
+
# 2. pass original args, original kwargs, and expected output to module
|
| 851 |
+
# 3. add an output comparison logger and hook it up to compare
|
| 852 |
+
# actual output to expected output
|
| 853 |
+
# 4. run `prepare_fx` on the module
|
| 854 |
+
for (subgraph_idx, (match_name, nodes_in_this_subgraph)) in \
|
| 855 |
+
enumerate(subgraphs_dedup.items()):
|
| 856 |
+
create_n_transformed_and_logged_copies_of_subgraph(
|
| 857 |
+
mt, subgraph_idx, match_name, nodes_in_this_subgraph,
|
| 858 |
+
qconfig_multi_mapping.qconfig_mappings_list, list_of_node_name_to_qconfig,
|
| 859 |
+
custom_prepare_fn, custom_prepare_kwargs # type: ignore[arg-type]
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
return mt
|
| 863 |
+
|
| 864 |
+
# TODO(future PR): we should rethink the names of all the PNP APIs
|
| 865 |
+
def _prepare_n_shadows_add_loggers_model(
|
| 866 |
+
model: torch.nn.Module,
|
| 867 |
+
example_inputs: Any,
|
| 868 |
+
qconfig_mapping: QConfigMapping,
|
| 869 |
+
backend_config: BackendConfig,
|
| 870 |
+
) -> torch.nn.Module:
|
| 871 |
+
r"""
|
| 872 |
+
Note: this API is not recommended for wide usage, it is only
|
| 873 |
+
provided for customers who need to migrate from the `add_loggers`
|
| 874 |
+
API.
|
| 875 |
+
|
| 876 |
+
This creates a model which provides logging for the following
|
| 877 |
+
problem: if we quantize `model` with `qconfig_mapping` and feed
|
| 878 |
+
the same input through both models, log the comparisons of
|
| 879 |
+
corresponding intermediate layers.
|
| 880 |
+
|
| 881 |
+
The problem is solved with a single model. Specifically, we
|
| 882 |
+
partition `model` into N subgraphs, create a copy of each relevant
|
| 883 |
+
subgraph, wrap it in a module, apply the quantization API to that
|
| 884 |
+
module, and hook up loggers to measure the comparisons.
|
| 885 |
+
|
| 886 |
+
Example starting graph:
|
| 887 |
+
|
| 888 |
+
x0 -> op0 -> x1 -> op1 -> x2
|
| 889 |
+
|
| 890 |
+
Example config: quantize op0 to int8, do nothing to op1.
|
| 891 |
+
The following graph will be created:
|
| 892 |
+
|
| 893 |
+
.. code::
|
| 894 |
+
|
| 895 |
+
x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log
|
| 896 |
+
\ \ \ # noqa: W605
|
| 897 |
+
---> op0_1 -> x1_1 ----> clog -> op1_0 -> x2_1 ----> clog
|
| 898 |
+
|
| 899 |
+
Where op0_0 is op0, op0_1 is op0 wrapped in a submodule and quantized
|
| 900 |
+
to int8, op1_0 is op1 (appearing in the graph twice), log is a logger,
|
| 901 |
+
and clog is a comparison logger.
|
| 902 |
+
"""
|
| 903 |
+
|
| 904 |
+
tracer = quantize_fx.QuantizationTracer([], [])
|
| 905 |
+
mt = torch.fx.GraphModule(model, tracer.trace(model))
|
| 906 |
+
# this is necessary to ensure logger FQNs get populated
|
| 907 |
+
mt._node_name_to_scope = tracer.node_name_to_scope
|
| 908 |
+
|
| 909 |
+
# run example input propagation, we need this to call prepare_fx on
|
| 910 |
+
# individual subgraphs
|
| 911 |
+
output_prop = OutputProp(mt)
|
| 912 |
+
output_prop.propagate(*example_inputs)
|
| 913 |
+
|
| 914 |
+
# Find the set of subgraphs in the original graph which we need to
|
| 915 |
+
# consider.
|
| 916 |
+
modules = dict(mt.named_modules(remove_duplicate=False))
|
| 917 |
+
patterns = _get_pattern_to_quantize_handlers(backend_config)
|
| 918 |
+
root_node_getter_mapping = \
|
| 919 |
+
get_fusion_pattern_to_root_node_getter(backend_config)
|
| 920 |
+
standalone_module_names: List[str] = []
|
| 921 |
+
standalone_module_classes: List[Type] = []
|
| 922 |
+
custom_module_classes: List[Type] = []
|
| 923 |
+
matches = _find_matches(
|
| 924 |
+
mt.graph, modules, patterns, root_node_getter_mapping,
|
| 925 |
+
standalone_module_names, standalone_module_classes, custom_module_classes)
|
| 926 |
+
subgraphs_dedup: Dict[str, List[Node]] = \
|
| 927 |
+
_get_dedup_subgraphs(matches)
|
| 928 |
+
|
| 929 |
+
# generate node to qconfig for each subgraph
|
| 930 |
+
node_name_to_qconfig = _generate_node_name_to_qconfig(
|
| 931 |
+
mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope)
|
| 932 |
+
|
| 933 |
+
# Now, mutate the graph to be the add_loggers graph with propagation
|
| 934 |
+
# error.
|
| 935 |
+
create_add_loggers_graph(
|
| 936 |
+
mt, subgraphs_dedup, qconfig_mapping, node_name_to_qconfig)
|
| 937 |
+
|
| 938 |
+
return mt
|
| 939 |
+
|
| 940 |
+
# TODO(future PR): we should rethink the names of all the PNP APIs
|
| 941 |
+
def _n_shadows_compare_weights(
|
| 942 |
+
model: torch.nn.Module,
|
| 943 |
+
example_inputs: Any,
|
| 944 |
+
qconfig_mapping: QConfigMapping,
|
| 945 |
+
backend_config: BackendConfig,
|
| 946 |
+
) -> NSResultsType:
|
| 947 |
+
"""
|
| 948 |
+
Note: this API is not recommended for wide usage, it is only
|
| 949 |
+
provided for customers who need to migrate from the `add_loggers`
|
| 950 |
+
API.
|
| 951 |
+
"""
|
| 952 |
+
qconfig_multi_mapping = \
|
| 953 |
+
QConfigMultiMapping.from_list_qconfig_mapping([qconfig_mapping])
|
| 954 |
+
mp = prepare_n_shadows_model(
|
| 955 |
+
model, example_inputs, qconfig_multi_mapping, backend_config)
|
| 956 |
+
# passing inputs through the model is necessary to populate
|
| 957 |
+
# observers which observe weights with real values
|
| 958 |
+
mp(*example_inputs)
|
| 959 |
+
mq = convert_n_shadows_model(mp)
|
| 960 |
+
weight_comparison = extract_weight_comparison(mq)
|
| 961 |
+
return weight_comparison
|
| 962 |
+
|
| 963 |
+
# TODO(future PR): consider aligning API signature with other similar quantization
|
| 964 |
+
# functions (enable_fake_quant, etc)
|
| 965 |
+
def loggers_set_enabled(model: torch.nn.Module, enabled: bool) -> None:
|
| 966 |
+
"""
|
| 967 |
+
Sets the `enabled` setting on a `model`'s loggers
|
| 968 |
+
"""
|
| 969 |
+
for name, child in model.named_modules():
|
| 970 |
+
if isinstance(child, OutputLogger):
|
| 971 |
+
child.enabled = enabled
|
| 972 |
+
|
| 973 |
+
# TODO(future PR): consider aligning API signature with other similar quantization
|
| 974 |
+
# functions (enable_fake_quant, etc)
|
| 975 |
+
def loggers_set_save_activations(
|
| 976 |
+
model: torch.nn.Module,
|
| 977 |
+
save_activations: bool,
|
| 978 |
+
) -> None:
|
| 979 |
+
"""
|
| 980 |
+
Sets the `save_activations` setting on a `model`'s loggers
|
| 981 |
+
"""
|
| 982 |
+
for name, child in model.named_modules():
|
| 983 |
+
if isinstance(child, OutputLogger):
|
| 984 |
+
child.save_activations = save_activations
|
| 985 |
+
|
| 986 |
+
def convert_n_shadows_model(
|
| 987 |
+
model: GraphModule,
|
| 988 |
+
custom_convert_fn: Optional[Callable] = None,
|
| 989 |
+
custom_convert_kwargs: Optional[Dict[str, Any]] = None
|
| 990 |
+
) -> GraphModule:
|
| 991 |
+
"""
|
| 992 |
+
Given a model from `prepare_n_shadows_model`, runs `convert_fx`
|
| 993 |
+
on each shadow submodule.
|
| 994 |
+
"""
|
| 995 |
+
for node in model.graph.nodes:
|
| 996 |
+
# TODO(future PR): consider matching in a safer way than
|
| 997 |
+
# node name string match
|
| 998 |
+
if node.name.startswith(SHADOW_WRAPPER_NODE_NAME_PREFIX):
|
| 999 |
+
orig_mod = getattr(model, node.name)
|
| 1000 |
+
if custom_convert_fn is None:
|
| 1001 |
+
converted_mod = torch.ao.quantization.quantize_fx.convert_fx(
|
| 1002 |
+
orig_mod)
|
| 1003 |
+
else:
|
| 1004 |
+
if custom_convert_kwargs is None:
|
| 1005 |
+
custom_convert_kwargs = {}
|
| 1006 |
+
converted_mod = custom_convert_fn(orig_mod, **custom_convert_kwargs)
|
| 1007 |
+
setattr(model, node.name, converted_mod)
|
| 1008 |
+
|
| 1009 |
+
return model
|
| 1010 |
+
|
| 1011 |
+
def extract_results_n_shadows_model(model: torch.nn.Module) -> NSResultsType:
|
| 1012 |
+
"""
|
| 1013 |
+
Extracts logger results from `model`.
|
| 1014 |
+
"""
|
| 1015 |
+
results: NSResultsType = {}
|
| 1016 |
+
_extract_logger_info_one_model(model, results, OutputLogger)
|
| 1017 |
+
return results
|
| 1018 |
+
|
| 1019 |
+
def print_comparisons_n_shadows_model(results: NSResultsType) -> None:
|
| 1020 |
+
"""
|
| 1021 |
+
Prints a summary of extracted `results`.
|
| 1022 |
+
"""
|
| 1023 |
+
results_grouped = group_results_by_subgraph(results)
|
| 1024 |
+
results_comparison = create_results_comparison(results_grouped)
|
| 1025 |
+
print_n_shadows_summary(results_comparison)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/mappings.py
ADDED
|
@@ -0,0 +1,761 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import operator
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
toq = torch.ops.quantized
|
| 7 |
+
|
| 8 |
+
import torch.ao.nn.quantized as nnq
|
| 9 |
+
import torch.ao.nn.quantized.dynamic as nnqd
|
| 10 |
+
import torch.ao.nn.intrinsic.quantized as nniq
|
| 11 |
+
import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
|
| 12 |
+
import torch.ao.nn.intrinsic.qat as nniqat
|
| 13 |
+
import torch.ao.nn.intrinsic as nni
|
| 14 |
+
import torch.ao.nn.qat as nnqat
|
| 15 |
+
import torch.ao.nn.qat.dynamic as nnqatd
|
| 16 |
+
from torch.ao.quantization.backend_config import get_native_backend_config
|
| 17 |
+
import torch.ao.quantization.fx._lower_to_native_backend as \
|
| 18 |
+
_lower_to_native_backend
|
| 19 |
+
import torch.ao.quantization.quantization_mappings as quantization_mappings
|
| 20 |
+
|
| 21 |
+
from .ns_types import NSNodeTargetType
|
| 22 |
+
|
| 23 |
+
from typing import Callable, Dict, List, Optional, Set, Tuple
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
|
| 27 |
+
# note: this set is modified below by items from backend_config
|
| 28 |
+
sets_of_related_ops: List[Set[NSNodeTargetType]] = [
|
| 29 |
+
# conv modules
|
| 30 |
+
{
|
| 31 |
+
nn.Conv1d,
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
nn.Conv2d,
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
nn.Conv3d,
|
| 38 |
+
},
|
| 39 |
+
# conv functionals
|
| 40 |
+
{
|
| 41 |
+
F.conv1d,
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
F.conv2d,
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
F.conv3d,
|
| 48 |
+
},
|
| 49 |
+
# linear modules
|
| 50 |
+
{
|
| 51 |
+
nn.Linear,
|
| 52 |
+
},
|
| 53 |
+
# linear functionals
|
| 54 |
+
{
|
| 55 |
+
F.linear,
|
| 56 |
+
},
|
| 57 |
+
# average pool
|
| 58 |
+
{
|
| 59 |
+
nn.AvgPool1d,
|
| 60 |
+
torch.avg_pool1d,
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
nn.AvgPool2d,
|
| 64 |
+
torch._C._nn.avg_pool2d,
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
nn.AvgPool3d,
|
| 68 |
+
torch._C._nn.avg_pool3d,
|
| 69 |
+
},
|
| 70 |
+
# adaptive average pool
|
| 71 |
+
{
|
| 72 |
+
nn.AdaptiveAvgPool1d,
|
| 73 |
+
F.adaptive_avg_pool1d,
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
nn.AdaptiveAvgPool2d,
|
| 77 |
+
F.adaptive_avg_pool2d,
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
nn.AdaptiveAvgPool3d,
|
| 81 |
+
F.adaptive_avg_pool3d,
|
| 82 |
+
},
|
| 83 |
+
# LSTM
|
| 84 |
+
{
|
| 85 |
+
nn.LSTM,
|
| 86 |
+
},
|
| 87 |
+
# add
|
| 88 |
+
{
|
| 89 |
+
torch.add,
|
| 90 |
+
operator.add, # x + y
|
| 91 |
+
},
|
| 92 |
+
# cat
|
| 93 |
+
{
|
| 94 |
+
torch.cat,
|
| 95 |
+
},
|
| 96 |
+
# mul
|
| 97 |
+
{
|
| 98 |
+
torch.mul,
|
| 99 |
+
operator.mul,
|
| 100 |
+
},
|
| 101 |
+
# relu
|
| 102 |
+
{
|
| 103 |
+
F.relu,
|
| 104 |
+
nn.ReLU,
|
| 105 |
+
'relu',
|
| 106 |
+
'relu_',
|
| 107 |
+
torch.relu,
|
| 108 |
+
},
|
| 109 |
+
# maxpool
|
| 110 |
+
{
|
| 111 |
+
nn.MaxPool1d,
|
| 112 |
+
F.max_pool1d,
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
nn.MaxPool2d,
|
| 116 |
+
F.max_pool2d,
|
| 117 |
+
},
|
| 118 |
+
{
|
| 119 |
+
nn.MaxPool3d,
|
| 120 |
+
F.max_pool3d,
|
| 121 |
+
},
|
| 122 |
+
# sigmoid
|
| 123 |
+
{
|
| 124 |
+
torch.sigmoid,
|
| 125 |
+
'sigmoid',
|
| 126 |
+
'sigmoid_',
|
| 127 |
+
nn.Sigmoid,
|
| 128 |
+
F.sigmoid,
|
| 129 |
+
},
|
| 130 |
+
# BatchNorm
|
| 131 |
+
{
|
| 132 |
+
nn.BatchNorm2d,
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
nn.BatchNorm3d,
|
| 136 |
+
},
|
| 137 |
+
# ConvTranspose
|
| 138 |
+
{
|
| 139 |
+
nn.ConvTranspose1d,
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
nn.ConvTranspose2d,
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
nn.ConvTranspose3d,
|
| 146 |
+
},
|
| 147 |
+
# functional transposed conv
|
| 148 |
+
{
|
| 149 |
+
F.conv_transpose1d,
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
F.conv_transpose2d,
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
F.conv_transpose3d,
|
| 156 |
+
},
|
| 157 |
+
# ELU
|
| 158 |
+
{
|
| 159 |
+
nn.ELU,
|
| 160 |
+
},
|
| 161 |
+
# Embedding
|
| 162 |
+
{
|
| 163 |
+
nn.Embedding,
|
| 164 |
+
},
|
| 165 |
+
# EmbeddingBag
|
| 166 |
+
{
|
| 167 |
+
nn.EmbeddingBag,
|
| 168 |
+
},
|
| 169 |
+
# GroupNorm
|
| 170 |
+
{
|
| 171 |
+
nn.GroupNorm,
|
| 172 |
+
},
|
| 173 |
+
# Hardswish
|
| 174 |
+
{
|
| 175 |
+
nn.Hardswish,
|
| 176 |
+
},
|
| 177 |
+
# InstanceNorm
|
| 178 |
+
{
|
| 179 |
+
nn.InstanceNorm1d,
|
| 180 |
+
},
|
| 181 |
+
{
|
| 182 |
+
nn.InstanceNorm2d,
|
| 183 |
+
},
|
| 184 |
+
{
|
| 185 |
+
nn.InstanceNorm3d,
|
| 186 |
+
},
|
| 187 |
+
# LayerNorm
|
| 188 |
+
{
|
| 189 |
+
nn.LayerNorm,
|
| 190 |
+
},
|
| 191 |
+
# LeakyReLU
|
| 192 |
+
{
|
| 193 |
+
nn.LeakyReLU,
|
| 194 |
+
},
|
| 195 |
+
# ReLU6
|
| 196 |
+
{
|
| 197 |
+
nn.ReLU6,
|
| 198 |
+
F.relu6,
|
| 199 |
+
},
|
| 200 |
+
# F.elu
|
| 201 |
+
{
|
| 202 |
+
F.elu,
|
| 203 |
+
},
|
| 204 |
+
# F.hardswish
|
| 205 |
+
{
|
| 206 |
+
F.hardswish,
|
| 207 |
+
},
|
| 208 |
+
# F.group_norm
|
| 209 |
+
{
|
| 210 |
+
F.group_norm,
|
| 211 |
+
},
|
| 212 |
+
# F.instance_norm
|
| 213 |
+
{
|
| 214 |
+
F.instance_norm,
|
| 215 |
+
},
|
| 216 |
+
# F.layer_norm
|
| 217 |
+
{
|
| 218 |
+
F.layer_norm,
|
| 219 |
+
},
|
| 220 |
+
# F.leaky_relu
|
| 221 |
+
{
|
| 222 |
+
F.leaky_relu,
|
| 223 |
+
},
|
| 224 |
+
# F.silu
|
| 225 |
+
{
|
| 226 |
+
nn.SiLU,
|
| 227 |
+
F.silu,
|
| 228 |
+
},
|
| 229 |
+
# F.mish
|
| 230 |
+
{
|
| 231 |
+
nn.Mish,
|
| 232 |
+
F.mish,
|
| 233 |
+
},
|
| 234 |
+
# F.tanh
|
| 235 |
+
{
|
| 236 |
+
nn.Tanh,
|
| 237 |
+
F.tanh,
|
| 238 |
+
torch.tanh,
|
| 239 |
+
'tanh_',
|
| 240 |
+
'tanh',
|
| 241 |
+
},
|
| 242 |
+
# F.hardsigmoid
|
| 243 |
+
{
|
| 244 |
+
'hardsigmoid_',
|
| 245 |
+
'hardsigmoid',
|
| 246 |
+
F.hardsigmoid,
|
| 247 |
+
nn.Hardsigmoid,
|
| 248 |
+
},
|
| 249 |
+
# F.hardtanh
|
| 250 |
+
{
|
| 251 |
+
nn.Hardtanh,
|
| 252 |
+
F.hardtanh,
|
| 253 |
+
F.hardtanh_,
|
| 254 |
+
},
|
| 255 |
+
# floordiv
|
| 256 |
+
{
|
| 257 |
+
operator.floordiv,
|
| 258 |
+
},
|
| 259 |
+
# unsqueeze
|
| 260 |
+
{
|
| 261 |
+
torch.unsqueeze,
|
| 262 |
+
},
|
| 263 |
+
# stack
|
| 264 |
+
{
|
| 265 |
+
torch.stack,
|
| 266 |
+
},
|
| 267 |
+
# squeeze
|
| 268 |
+
{
|
| 269 |
+
torch.squeeze,
|
| 270 |
+
},
|
| 271 |
+
# sort
|
| 272 |
+
{
|
| 273 |
+
torch.sort,
|
| 274 |
+
},
|
| 275 |
+
# repeat_interleave
|
| 276 |
+
{
|
| 277 |
+
torch.repeat_interleave,
|
| 278 |
+
},
|
| 279 |
+
# min
|
| 280 |
+
{
|
| 281 |
+
torch.min,
|
| 282 |
+
},
|
| 283 |
+
# mean
|
| 284 |
+
{
|
| 285 |
+
torch.mean,
|
| 286 |
+
},
|
| 287 |
+
# max
|
| 288 |
+
{
|
| 289 |
+
torch.max,
|
| 290 |
+
},
|
| 291 |
+
# transpose
|
| 292 |
+
{
|
| 293 |
+
torch.transpose,
|
| 294 |
+
},
|
| 295 |
+
# flatten
|
| 296 |
+
{
|
| 297 |
+
torch.flatten,
|
| 298 |
+
},
|
| 299 |
+
# clamp
|
| 300 |
+
{
|
| 301 |
+
torch.clamp,
|
| 302 |
+
},
|
| 303 |
+
# chunk
|
| 304 |
+
{
|
| 305 |
+
torch.chunk,
|
| 306 |
+
},
|
| 307 |
+
# interpolate
|
| 308 |
+
{
|
| 309 |
+
torch.nn.functional.interpolate,
|
| 310 |
+
},
|
| 311 |
+
# dropout
|
| 312 |
+
{
|
| 313 |
+
nn.Dropout,
|
| 314 |
+
},
|
| 315 |
+
# F.dropout
|
| 316 |
+
{
|
| 317 |
+
F.dropout,
|
| 318 |
+
},
|
| 319 |
+
# matmul
|
| 320 |
+
{
|
| 321 |
+
torch.matmul,
|
| 322 |
+
},
|
| 323 |
+
# Softmax
|
| 324 |
+
{
|
| 325 |
+
nn.Softmax,
|
| 326 |
+
},
|
| 327 |
+
# PReLU
|
| 328 |
+
{
|
| 329 |
+
nn.PReLU,
|
| 330 |
+
nnq.PReLU,
|
| 331 |
+
},
|
| 332 |
+
# F.prelu
|
| 333 |
+
{
|
| 334 |
+
F.prelu,
|
| 335 |
+
toq.prelu,
|
| 336 |
+
},
|
| 337 |
+
# pixel shuffle
|
| 338 |
+
{
|
| 339 |
+
nn.PixelShuffle,
|
| 340 |
+
},
|
| 341 |
+
{
|
| 342 |
+
F.pixel_shuffle,
|
| 343 |
+
},
|
| 344 |
+
# pixel unshuffle
|
| 345 |
+
{
|
| 346 |
+
nn.PixelUnshuffle,
|
| 347 |
+
},
|
| 348 |
+
{
|
| 349 |
+
F.pixel_unshuffle,
|
| 350 |
+
},
|
| 351 |
+
# narrow
|
| 352 |
+
{
|
| 353 |
+
torch.narrow,
|
| 354 |
+
},
|
| 355 |
+
]
|
| 356 |
+
|
| 357 |
+
# for each floating point op, add versions of the op added by
|
| 358 |
+
# backend_config
|
| 359 |
+
backend_config = get_native_backend_config()
|
| 360 |
+
|
| 361 |
+
new_connections: List[Tuple[Callable, Callable]] = [
|
| 362 |
+
# technical debt edge case
|
| 363 |
+
(nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear),
|
| 364 |
+
]
|
| 365 |
+
|
| 366 |
+
for pattern, config in backend_config._pattern_complex_format_to_config.items():
|
| 367 |
+
|
| 368 |
+
# pattern format: (c, (b, a))
|
| 369 |
+
first_element = pattern
|
| 370 |
+
# look from the end, because pattern is in reverse order
|
| 371 |
+
while isinstance(first_element, (list, tuple)):
|
| 372 |
+
first_element = first_element[-1]
|
| 373 |
+
|
| 374 |
+
if config.fused_module is not None:
|
| 375 |
+
# case 1: pattern fuses a pattern of ops into an op
|
| 376 |
+
# example: nn.Conv1d, nn.ReLU fused into nni.ConvReLU1d
|
| 377 |
+
new_connections.append((first_element, config.fused_module))
|
| 378 |
+
|
| 379 |
+
if config.qat_module is not None:
|
| 380 |
+
# case 2: pattern swaps a module into a QAT module
|
| 381 |
+
# example: nni.ConvReLU1d swapped into nniqat.ConvReLU1d
|
| 382 |
+
new_connections.append((first_element, config.qat_module))
|
| 383 |
+
|
| 384 |
+
if config.reference_quantized_module is not None:
|
| 385 |
+
# case 3: reference version of floating point module, such as
|
| 386 |
+
# nn.Conv2d and nnqr.Conv2d
|
| 387 |
+
new_connections.append((first_element, config.reference_quantized_module))
|
| 388 |
+
|
| 389 |
+
#
|
| 390 |
+
# Add reference module swaps from default lowering path
|
| 391 |
+
#
|
| 392 |
+
|
| 393 |
+
for source_to_target in (
|
| 394 |
+
_lower_to_native_backend.STATIC_LOWER_MODULE_MAP,
|
| 395 |
+
_lower_to_native_backend.DYNAMIC_LOWER_MODULE_MAP,
|
| 396 |
+
_lower_to_native_backend.WEIGHT_ONLY_LOWER_MODULE_MAP,
|
| 397 |
+
_lower_to_native_backend.SPECIAL_PATTERN_LOWER_MODULE_MAP,
|
| 398 |
+
):
|
| 399 |
+
for source, target in source_to_target.items(): # type: ignore[attr-defined]
|
| 400 |
+
new_connections.append((source, target))
|
| 401 |
+
|
| 402 |
+
for source_to_double_target in (
|
| 403 |
+
_lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_MAP,
|
| 404 |
+
_lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP,
|
| 405 |
+
_lower_to_native_backend.DYNAMIC_LOWER_FUSED_MODULE_MAP,
|
| 406 |
+
):
|
| 407 |
+
for source, (target1, target2) in source_to_double_target.items(): # type: ignore[attr-defined]
|
| 408 |
+
new_connections.append((source, target1))
|
| 409 |
+
new_connections.append((source, target2))
|
| 410 |
+
|
| 411 |
+
#
|
| 412 |
+
# Add function swaps from default lowering path
|
| 413 |
+
#
|
| 414 |
+
|
| 415 |
+
for source, (target1, target2) in \
|
| 416 |
+
_lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
|
| 417 |
+
new_connections.append((source, target1))
|
| 418 |
+
new_connections.append((source, target2))
|
| 419 |
+
|
| 420 |
+
for source_to_target in (
|
| 421 |
+
_lower_to_native_backend.QBIN_OP_MAPPING,
|
| 422 |
+
_lower_to_native_backend.QBIN_RELU_OP_MAPPING,
|
| 423 |
+
quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
|
| 424 |
+
):
|
| 425 |
+
for source, target in source_to_target.items():
|
| 426 |
+
new_connections.append((source, target))
|
| 427 |
+
|
| 428 |
+
#
|
| 429 |
+
# Add other swaps, ideally in the future this could be removed
|
| 430 |
+
# after the lowering code stops using these.
|
| 431 |
+
#
|
| 432 |
+
for source_to_target in (
|
| 433 |
+
quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS,
|
| 434 |
+
):
|
| 435 |
+
for source, target in source_to_target.items():
|
| 436 |
+
new_connections.append((source, target))
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
# add the new connections from backend_config
|
| 440 |
+
for item1, item2 in new_connections:
|
| 441 |
+
for set_of_related_ops in sets_of_related_ops:
|
| 442 |
+
if item1 in set_of_related_ops or item2 in set_of_related_ops:
|
| 443 |
+
set_of_related_ops.add(item1)
|
| 444 |
+
set_of_related_ops.add(item2)
|
| 445 |
+
break
|
| 446 |
+
|
| 447 |
+
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]] = {}
|
| 448 |
+
|
| 449 |
+
counter = 0
|
| 450 |
+
for set_of_related_ops in sets_of_related_ops:
|
| 451 |
+
base_name = str(counter)
|
| 452 |
+
counter += 1
|
| 453 |
+
base_name_to_sets_of_related_ops[base_name] = set_of_related_ops
|
| 454 |
+
|
| 455 |
+
return base_name_to_sets_of_related_ops
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def get_base_name_for_op(
|
| 459 |
+
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
|
| 460 |
+
op: NSNodeTargetType,
|
| 461 |
+
) -> Optional[str]:
|
| 462 |
+
for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items():
|
| 463 |
+
if op in set_of_related_ops:
|
| 464 |
+
return base_name
|
| 465 |
+
return None
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def add_op_to_sets_of_related_ops(
|
| 469 |
+
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
|
| 470 |
+
op: NSNodeTargetType,
|
| 471 |
+
related_op: Optional[NSNodeTargetType],
|
| 472 |
+
) -> None:
|
| 473 |
+
if related_op is not None:
|
| 474 |
+
for set_of_related_ops in base_name_to_sets_of_related_ops.values():
|
| 475 |
+
if related_op in set_of_related_ops:
|
| 476 |
+
set_of_related_ops.add(op)
|
| 477 |
+
return
|
| 478 |
+
# if we got here, related_op was not found
|
| 479 |
+
raise AssertionError(f"{related_op} was not found")
|
| 480 |
+
else:
|
| 481 |
+
counter = 0
|
| 482 |
+
while str(counter) in base_name_to_sets_of_related_ops:
|
| 483 |
+
counter += 1
|
| 484 |
+
base_name_to_sets_of_related_ops[str(counter)] = {op}
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
# TODO(future PR): clean this up
|
| 488 |
+
def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
|
| 489 |
+
FUNS_IO_TYPE_FP32: Set[NSNodeTargetType] = {
|
| 490 |
+
F.linear,
|
| 491 |
+
F.conv1d,
|
| 492 |
+
F.conv2d,
|
| 493 |
+
F.conv3d,
|
| 494 |
+
torch.cat,
|
| 495 |
+
F.elu,
|
| 496 |
+
F.hardswish,
|
| 497 |
+
F.instance_norm,
|
| 498 |
+
F.layer_norm,
|
| 499 |
+
F.leaky_relu,
|
| 500 |
+
F.dropout,
|
| 501 |
+
F.silu,
|
| 502 |
+
F.mish,
|
| 503 |
+
operator.add,
|
| 504 |
+
torch.add,
|
| 505 |
+
operator.mul,
|
| 506 |
+
torch.mul,
|
| 507 |
+
torch.sum,
|
| 508 |
+
F.prelu,
|
| 509 |
+
}
|
| 510 |
+
|
| 511 |
+
FUNS_IO_TYPE_FP16: Set[NSNodeTargetType] = set()
|
| 512 |
+
|
| 513 |
+
FUNS_IO_TYPE_INT8: Set[NSNodeTargetType] = {
|
| 514 |
+
toq.linear,
|
| 515 |
+
toq.linear_relu,
|
| 516 |
+
toq.conv1d,
|
| 517 |
+
toq.conv1d_relu,
|
| 518 |
+
toq.conv2d,
|
| 519 |
+
toq.conv2d_relu,
|
| 520 |
+
toq.conv3d,
|
| 521 |
+
toq.conv3d_relu,
|
| 522 |
+
toq.cat,
|
| 523 |
+
toq.elu,
|
| 524 |
+
toq.hardswish,
|
| 525 |
+
toq.instance_norm,
|
| 526 |
+
toq.layer_norm,
|
| 527 |
+
toq.leaky_relu,
|
| 528 |
+
toq.dropout,
|
| 529 |
+
toq.prelu,
|
| 530 |
+
# TODO(future PR): implement shadowing for binary ops and
|
| 531 |
+
# uncomment below
|
| 532 |
+
# toq.add,
|
| 533 |
+
# toq.mul,
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
FUNS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
|
| 537 |
+
F.relu,
|
| 538 |
+
F.tanh,
|
| 539 |
+
torch.tanh,
|
| 540 |
+
F.sigmoid,
|
| 541 |
+
torch.sigmoid,
|
| 542 |
+
F.hardsigmoid,
|
| 543 |
+
operator.floordiv,
|
| 544 |
+
torch.adaptive_avg_pool1d,
|
| 545 |
+
F.adaptive_avg_pool2d,
|
| 546 |
+
F.adaptive_avg_pool3d,
|
| 547 |
+
F.dropout,
|
| 548 |
+
F.hardtanh,
|
| 549 |
+
F.hardtanh_,
|
| 550 |
+
F.interpolate,
|
| 551 |
+
F.max_pool1d,
|
| 552 |
+
F.max_pool2d,
|
| 553 |
+
F.max_pool3d,
|
| 554 |
+
F.relu6,
|
| 555 |
+
F.pixel_shuffle,
|
| 556 |
+
F.pixel_unshuffle,
|
| 557 |
+
torch.avg_pool1d,
|
| 558 |
+
torch._C._nn.avg_pool2d,
|
| 559 |
+
torch._C._nn.avg_pool3d,
|
| 560 |
+
torch.cat,
|
| 561 |
+
torch.chunk,
|
| 562 |
+
torch.clamp,
|
| 563 |
+
torch.flatten,
|
| 564 |
+
torch.transpose,
|
| 565 |
+
torch.max,
|
| 566 |
+
torch.mean,
|
| 567 |
+
torch.min,
|
| 568 |
+
torch.narrow,
|
| 569 |
+
torch.repeat_interleave,
|
| 570 |
+
torch.sort,
|
| 571 |
+
torch.squeeze,
|
| 572 |
+
torch.stack,
|
| 573 |
+
torch.unsqueeze,
|
| 574 |
+
operator.add,
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
MODS_IO_TYPE_FP32: Set[NSNodeTargetType] = {
|
| 578 |
+
nn.Linear,
|
| 579 |
+
nnqat.Linear,
|
| 580 |
+
nnqatd.Linear,
|
| 581 |
+
nnqd.Linear,
|
| 582 |
+
torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
|
| 583 |
+
nn.Conv1d,
|
| 584 |
+
nn.Conv2d,
|
| 585 |
+
nn.Conv3d,
|
| 586 |
+
nnqat.Conv1d,
|
| 587 |
+
nnqat.Conv2d,
|
| 588 |
+
nnqat.Conv3d,
|
| 589 |
+
nnqat.Embedding,
|
| 590 |
+
nnqat.EmbeddingBag,
|
| 591 |
+
nn.LSTM,
|
| 592 |
+
# note: nnqd.Linear is an instance of nnq.Linear, so this
|
| 593 |
+
# check has to happen before the int8 module check
|
| 594 |
+
nnqd.LSTM,
|
| 595 |
+
nn.BatchNorm2d,
|
| 596 |
+
nn.BatchNorm3d,
|
| 597 |
+
nn.Dropout,
|
| 598 |
+
nn.ConvTranspose1d,
|
| 599 |
+
nn.ConvTranspose2d,
|
| 600 |
+
nn.ConvTranspose3d,
|
| 601 |
+
nn.ELU,
|
| 602 |
+
nn.GroupNorm,
|
| 603 |
+
nn.InstanceNorm1d,
|
| 604 |
+
nn.InstanceNorm2d,
|
| 605 |
+
nn.InstanceNorm3d,
|
| 606 |
+
nn.LayerNorm,
|
| 607 |
+
nn.Hardswish,
|
| 608 |
+
nn.LeakyReLU,
|
| 609 |
+
nn.ReLU6,
|
| 610 |
+
nn.SiLU,
|
| 611 |
+
nn.Mish,
|
| 612 |
+
nn.Softmax,
|
| 613 |
+
nn.PReLU,
|
| 614 |
+
nni.BNReLU2d,
|
| 615 |
+
nni.BNReLU3d,
|
| 616 |
+
nni.ConvReLU1d,
|
| 617 |
+
nni.ConvReLU2d,
|
| 618 |
+
nni.ConvReLU3d,
|
| 619 |
+
nni.LinearReLU,
|
| 620 |
+
nni.LinearBn1d,
|
| 621 |
+
nni.ConvBn1d,
|
| 622 |
+
nni.ConvBn2d,
|
| 623 |
+
nni.ConvBn3d,
|
| 624 |
+
nniqat.ConvBn1d,
|
| 625 |
+
nniqat.ConvBn2d,
|
| 626 |
+
nniqat.ConvBn3d,
|
| 627 |
+
nniqat.ConvBnReLU1d,
|
| 628 |
+
nniqat.ConvBnReLU2d,
|
| 629 |
+
nniqat.ConvBnReLU3d,
|
| 630 |
+
nniqat.ConvReLU1d,
|
| 631 |
+
nniqat.ConvReLU2d,
|
| 632 |
+
nniqat.ConvReLU3d,
|
| 633 |
+
nniqat.LinearReLU,
|
| 634 |
+
nniqat.LinearBn1d,
|
| 635 |
+
nniqd.LinearReLU,
|
| 636 |
+
nni.LinearLeakyReLU,
|
| 637 |
+
nni.LinearTanh,
|
| 638 |
+
nni.ConvAdd2d,
|
| 639 |
+
nni.ConvAddReLU2d,
|
| 640 |
+
}
|
| 641 |
+
|
| 642 |
+
MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = {
|
| 643 |
+
nnq.Linear,
|
| 644 |
+
nnq.Conv1d,
|
| 645 |
+
nnq.Conv2d,
|
| 646 |
+
nnq.Conv3d,
|
| 647 |
+
nnq.BatchNorm2d,
|
| 648 |
+
nnq.BatchNorm3d,
|
| 649 |
+
nnq.Dropout,
|
| 650 |
+
nnq.ConvTranspose1d,
|
| 651 |
+
nnq.ConvTranspose2d,
|
| 652 |
+
nnq.ELU,
|
| 653 |
+
nnq.InstanceNorm1d,
|
| 654 |
+
nnq.InstanceNorm2d,
|
| 655 |
+
nnq.InstanceNorm3d,
|
| 656 |
+
nnq.LayerNorm,
|
| 657 |
+
nnq.Hardswish,
|
| 658 |
+
nnq.LeakyReLU,
|
| 659 |
+
nnq.Embedding,
|
| 660 |
+
nnq.EmbeddingBag,
|
| 661 |
+
nnq.Dropout,
|
| 662 |
+
nnq.Softmax,
|
| 663 |
+
nnq.PReLU,
|
| 664 |
+
nniq.BNReLU2d,
|
| 665 |
+
nniq.BNReLU3d,
|
| 666 |
+
nniq.ConvReLU1d,
|
| 667 |
+
nniq.ConvReLU2d,
|
| 668 |
+
nniq.ConvReLU3d,
|
| 669 |
+
nniq.LinearReLU,
|
| 670 |
+
nniq.LinearLeakyReLU,
|
| 671 |
+
nniq.LinearTanh,
|
| 672 |
+
nniq.ConvAdd2d,
|
| 673 |
+
nniq.ConvAddReLU2d,
|
| 674 |
+
}
|
| 675 |
+
|
| 676 |
+
MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
|
| 677 |
+
nn.ReLU,
|
| 678 |
+
nn.Tanh,
|
| 679 |
+
nn.Sigmoid,
|
| 680 |
+
nn.Hardsigmoid,
|
| 681 |
+
nn.AdaptiveAvgPool1d,
|
| 682 |
+
nn.AdaptiveAvgPool2d,
|
| 683 |
+
nn.AdaptiveAvgPool3d,
|
| 684 |
+
nn.AvgPool1d,
|
| 685 |
+
nn.AvgPool2d,
|
| 686 |
+
nn.AvgPool3d,
|
| 687 |
+
nn.Dropout,
|
| 688 |
+
nn.Hardtanh,
|
| 689 |
+
nn.Identity,
|
| 690 |
+
nn.MaxPool1d,
|
| 691 |
+
nn.MaxPool2d,
|
| 692 |
+
nn.MaxPool3d,
|
| 693 |
+
nn.PixelShuffle,
|
| 694 |
+
nn.PixelUnshuffle,
|
| 695 |
+
nn.ReLU6,
|
| 696 |
+
}
|
| 697 |
+
|
| 698 |
+
METHS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
|
| 699 |
+
'sigmoid_',
|
| 700 |
+
'sigmoid',
|
| 701 |
+
'tanh_',
|
| 702 |
+
'tanh',
|
| 703 |
+
'hardsigmoid_',
|
| 704 |
+
'hardsigmoid',
|
| 705 |
+
'relu_',
|
| 706 |
+
'relu',
|
| 707 |
+
}
|
| 708 |
+
|
| 709 |
+
return {
|
| 710 |
+
'funs_io_type_fp32': FUNS_IO_TYPE_FP32,
|
| 711 |
+
'funs_io_type_fp16': FUNS_IO_TYPE_FP16,
|
| 712 |
+
'funs_io_type_int8': FUNS_IO_TYPE_INT8,
|
| 713 |
+
'funs_io_type_fp32_or_int8': FUNS_IO_TYPE_FP32_OR_INT8,
|
| 714 |
+
'mods_io_type_fp32': MODS_IO_TYPE_FP32,
|
| 715 |
+
'mods_io_type_int8': MODS_IO_TYPE_INT8,
|
| 716 |
+
'mods_io_type_fp32_or_int8': MODS_IO_TYPE_FP32_OR_INT8,
|
| 717 |
+
'meths_io_type_fp32_or_int8': METHS_IO_TYPE_FP32_OR_INT8,
|
| 718 |
+
}
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
def get_unmatchable_types_map() -> Dict[str, Set[NSNodeTargetType]]:
|
| 722 |
+
|
| 723 |
+
FUNS_UNMATCHABLE: Set[NSNodeTargetType] = {
|
| 724 |
+
torch.quantize_per_tensor,
|
| 725 |
+
operator.getitem,
|
| 726 |
+
}
|
| 727 |
+
|
| 728 |
+
MODS_UNMATCHABLE: Set[NSNodeTargetType] = {
|
| 729 |
+
nn.Identity,
|
| 730 |
+
}
|
| 731 |
+
|
| 732 |
+
METHS_UNMATCHABLE: Set[NSNodeTargetType] = {
|
| 733 |
+
'to',
|
| 734 |
+
'dequantize',
|
| 735 |
+
'reshape',
|
| 736 |
+
'view',
|
| 737 |
+
'unsqueeze_',
|
| 738 |
+
'unsqueeze',
|
| 739 |
+
'transpose',
|
| 740 |
+
'squeeze_',
|
| 741 |
+
'squeeze',
|
| 742 |
+
'size',
|
| 743 |
+
'shape',
|
| 744 |
+
'resize_',
|
| 745 |
+
'repeat_interleave',
|
| 746 |
+
'repeat',
|
| 747 |
+
'permute',
|
| 748 |
+
'numel',
|
| 749 |
+
'mean',
|
| 750 |
+
'detach_',
|
| 751 |
+
'detach',
|
| 752 |
+
'contiguous',
|
| 753 |
+
'clamp',
|
| 754 |
+
'chunk',
|
| 755 |
+
}
|
| 756 |
+
|
| 757 |
+
return {
|
| 758 |
+
'funs_unmatchable': FUNS_UNMATCHABLE,
|
| 759 |
+
'mods_unmatchable': MODS_UNMATCHABLE,
|
| 760 |
+
'meths_unmatchable': METHS_UNMATCHABLE,
|
| 761 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/n_shadows_utils.py
ADDED
|
@@ -0,0 +1,1311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.fx
|
| 3 |
+
from torch.fx import (
|
| 4 |
+
Node,
|
| 5 |
+
GraphModule,
|
| 6 |
+
Graph,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
from torch.ao.ns.fx.utils import (
|
| 10 |
+
# TODO(future PR): make this work correctly for methods
|
| 11 |
+
get_target_type_str,
|
| 12 |
+
get_normalized_nth_input,
|
| 13 |
+
)
|
| 14 |
+
from torch.ao.ns.fx.ns_types import (
|
| 15 |
+
NSSingleResultValuesType,
|
| 16 |
+
NSResultsType,
|
| 17 |
+
)
|
| 18 |
+
from torch.ao.ns.fx.graph_passes import _maybe_get_fqn
|
| 19 |
+
from torch.ao.quantization import QConfigMapping
|
| 20 |
+
from torch.ao.quantization.qconfig import QConfigAny
|
| 21 |
+
from torch.ao.quantization.utils import getattr_from_fqn
|
| 22 |
+
from torch.ao.quantization.fx.match_utils import _MatchResult
|
| 23 |
+
from torch.utils._pytree import tree_map
|
| 24 |
+
|
| 25 |
+
import collections
|
| 26 |
+
import copy
|
| 27 |
+
from typing import List, Dict, Set, Tuple, Callable, Any, Optional
|
| 28 |
+
import operator
|
| 29 |
+
|
| 30 |
+
SHADOW_NODE_NAME_PREFIX = 'shadow'
|
| 31 |
+
SHADOW_WRAPPER_NODE_NAME_PREFIX = 'shadow_wrapper'
|
| 32 |
+
|
| 33 |
+
# TODO(future PR): reuse existing mapping instead of creating a new one
|
| 34 |
+
BINARY_FUNCTIONS = {
|
| 35 |
+
torch.add,
|
| 36 |
+
torch.Tensor.add,
|
| 37 |
+
operator.add,
|
| 38 |
+
torch.mul,
|
| 39 |
+
torch.Tensor.mul,
|
| 40 |
+
operator.mul,
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
def _get_attr_name(subgraph_idx, subgraph_candidate_idx):
|
| 44 |
+
return f"{SHADOW_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}"
|
| 45 |
+
|
| 46 |
+
def _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx):
|
| 47 |
+
return f"{SHADOW_WRAPPER_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class OutputProp:
|
| 51 |
+
"""
|
| 52 |
+
Output propagation (modeled from shape propagation).
|
| 53 |
+
|
| 54 |
+
Given a GraphModule and an example input, saves the output flowing
|
| 55 |
+
through each node on `node.traced_result`.
|
| 56 |
+
|
| 57 |
+
Code based on the example from
|
| 58 |
+
https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern
|
| 59 |
+
"""
|
| 60 |
+
def __init__(self, mod):
|
| 61 |
+
self.mod = mod
|
| 62 |
+
self.graph = mod.graph
|
| 63 |
+
self.modules = dict(self.mod.named_modules())
|
| 64 |
+
|
| 65 |
+
def propagate(self, *args):
|
| 66 |
+
args_iter = iter(args)
|
| 67 |
+
env : Dict[str, Node] = {}
|
| 68 |
+
|
| 69 |
+
def load_arg(a):
|
| 70 |
+
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
|
| 71 |
+
|
| 72 |
+
def fetch_attr(target : str):
|
| 73 |
+
target_atoms = target.split('.')
|
| 74 |
+
attr_itr = self.mod
|
| 75 |
+
for i, atom in enumerate(target_atoms):
|
| 76 |
+
if not hasattr(attr_itr, atom):
|
| 77 |
+
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
|
| 78 |
+
attr_itr = getattr(attr_itr, atom)
|
| 79 |
+
return attr_itr
|
| 80 |
+
|
| 81 |
+
for node in self.graph.nodes:
|
| 82 |
+
if node.op == 'placeholder':
|
| 83 |
+
result = next(args_iter)
|
| 84 |
+
elif node.op == 'get_attr':
|
| 85 |
+
result = fetch_attr(node.target)
|
| 86 |
+
elif node.op == 'call_function':
|
| 87 |
+
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
|
| 88 |
+
elif node.op == 'call_method':
|
| 89 |
+
self_obj, *args = load_arg(node.args)
|
| 90 |
+
kwargs = load_arg(node.kwargs)
|
| 91 |
+
result = getattr(self_obj, node.target)(*args, **kwargs)
|
| 92 |
+
elif node.op == 'call_module':
|
| 93 |
+
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
|
| 94 |
+
|
| 95 |
+
if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined]
|
| 96 |
+
node.traced_result = result
|
| 97 |
+
|
| 98 |
+
env[node.name] = result
|
| 99 |
+
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
def _get_dedup_subgraphs(
|
| 103 |
+
matches: Dict[str, _MatchResult]
|
| 104 |
+
) -> Dict[str, List[Node]]:
|
| 105 |
+
# the original matches variable is unique by node, make it unique by subgraph
|
| 106 |
+
# instead
|
| 107 |
+
seen_nodes = set()
|
| 108 |
+
subgraphs_dedup = {}
|
| 109 |
+
|
| 110 |
+
# Dict items are not reversible until Python 3.8, so we hack it
|
| 111 |
+
# to be compatible with previous Python versions
|
| 112 |
+
# TODO(future PR): try reversed(list(matches.items()))
|
| 113 |
+
matches_items_reversed: List[Tuple[str, _MatchResult]] = []
|
| 114 |
+
for name, cur_match in matches.items():
|
| 115 |
+
matches_items_reversed.insert(0, (name, cur_match))
|
| 116 |
+
|
| 117 |
+
# Note: the order is important. `matches` currently provides the matches
|
| 118 |
+
# in reverse order. We would like to process the matches in non-reverse
|
| 119 |
+
# order, so that we can create an intuitive naming scheme, such as
|
| 120 |
+
# naming the first op's submodules `shadow_0_0` through `shadow_0_(n-1)`
|
| 121 |
+
for name, cur_match in matches_items_reversed: # type: ignore[call-overload]
|
| 122 |
+
was_seen = False
|
| 123 |
+
for node_or_tuple in cur_match[1]:
|
| 124 |
+
|
| 125 |
+
# Cur_match[1] has an unusual type. It says that it's a `List[Node]`,
|
| 126 |
+
# but it is really not. Furthermore, the contents of this field
|
| 127 |
+
# can change from match results of multiple nodes of the same pattern
|
| 128 |
+
#
|
| 129 |
+
# For example, for conv -> bn -> relu, we see
|
| 130 |
+
# match_results = {
|
| 131 |
+
# 'conv': (relu, [(bn, conv), relu], ...),
|
| 132 |
+
# 'bn': (relu, [(bn, conv), relu], ...),
|
| 133 |
+
# 'relu': (relu, [(bn, conv), relu], ...),
|
| 134 |
+
# }
|
| 135 |
+
#
|
| 136 |
+
# Ideally we should clean up the `find_matches` function to make
|
| 137 |
+
# this more intuitive. For the purposes of this prototype, we hack
|
| 138 |
+
# around it.
|
| 139 |
+
|
| 140 |
+
if isinstance(node_or_tuple, Node):
|
| 141 |
+
if node_or_tuple in seen_nodes:
|
| 142 |
+
was_seen = True
|
| 143 |
+
seen_nodes.add(node_or_tuple)
|
| 144 |
+
|
| 145 |
+
else:
|
| 146 |
+
assert isinstance(node_or_tuple, tuple)
|
| 147 |
+
for node in node_or_tuple:
|
| 148 |
+
assert isinstance(node, Node)
|
| 149 |
+
if node in seen_nodes:
|
| 150 |
+
was_seen = True
|
| 151 |
+
seen_nodes.add(node)
|
| 152 |
+
|
| 153 |
+
if was_seen:
|
| 154 |
+
continue
|
| 155 |
+
|
| 156 |
+
# Start with the unusual type, convert it to [op_0, ..., op_n]
|
| 157 |
+
list_of_nodes = []
|
| 158 |
+
|
| 159 |
+
if len(cur_match[1]) == 1:
|
| 160 |
+
list_of_nodes = cur_match[1]
|
| 161 |
+
else:
|
| 162 |
+
assert len(cur_match[1]) == 2
|
| 163 |
+
# either (a, b), or ((a, b), c) or (c, (a, b))
|
| 164 |
+
# cannot make any assumptions on order, not clear what the
|
| 165 |
+
# _find_matches function is doing to populate this
|
| 166 |
+
# TODO(future PR): make this code less confusing, see discussion
|
| 167 |
+
# in https://github.com/pytorch/pytorch/pull/80521/files#r975918836
|
| 168 |
+
|
| 169 |
+
def _order_nodes(node_a, node_b, node_c) -> List[Node]:
|
| 170 |
+
nodes = [node_a, node_b, node_c]
|
| 171 |
+
first_node = None
|
| 172 |
+
mid_node = None
|
| 173 |
+
last_node = None
|
| 174 |
+
for n in nodes:
|
| 175 |
+
prev_n = n.args[0]
|
| 176 |
+
next_n = next(iter(n.users))
|
| 177 |
+
if prev_n not in nodes:
|
| 178 |
+
first_node = n
|
| 179 |
+
elif next_n not in nodes:
|
| 180 |
+
last_node = n
|
| 181 |
+
else:
|
| 182 |
+
mid_node = n
|
| 183 |
+
assert first_node is not None and mid_node is not None and \
|
| 184 |
+
last_node is not None
|
| 185 |
+
assert mid_node.args[0] is first_node
|
| 186 |
+
assert last_node.args[0] is mid_node
|
| 187 |
+
return [last_node, mid_node, first_node]
|
| 188 |
+
|
| 189 |
+
if isinstance(cur_match[1][0], Node) and isinstance(cur_match[1][1], Node):
|
| 190 |
+
# (a, b)
|
| 191 |
+
list_of_nodes = cur_match[1]
|
| 192 |
+
elif isinstance(cur_match[1][0], tuple):
|
| 193 |
+
# ((a, b), c)
|
| 194 |
+
node_a, node_b = cur_match[1][0]
|
| 195 |
+
node_c = cur_match[1][1]
|
| 196 |
+
list_of_nodes = _order_nodes(node_a, node_b, node_c)
|
| 197 |
+
elif isinstance(cur_match[1][1], tuple):
|
| 198 |
+
# (a, (b, c))
|
| 199 |
+
node_a, node_b = cur_match[1][1]
|
| 200 |
+
node_c = cur_match[1][0]
|
| 201 |
+
list_of_nodes = _order_nodes(node_a, node_b, node_c)
|
| 202 |
+
|
| 203 |
+
# [node_n, ..., node_0], note that the order is reversed
|
| 204 |
+
# to make it chronological for simple subgraphs
|
| 205 |
+
list_of_nodes.reverse()
|
| 206 |
+
subgraphs_dedup[name] = list_of_nodes
|
| 207 |
+
|
| 208 |
+
return subgraphs_dedup
|
| 209 |
+
|
| 210 |
+
def _get_logger_for_subgraph(
|
| 211 |
+
model: GraphModule,
|
| 212 |
+
first_node: Node,
|
| 213 |
+
last_node: Node,
|
| 214 |
+
subgraph_idx: int,
|
| 215 |
+
subgraph_candidate_idx: int,
|
| 216 |
+
qconfig_str: str,
|
| 217 |
+
logger_cls: Callable,
|
| 218 |
+
fqn: Optional[str],
|
| 219 |
+
) -> torch.nn.Module:
|
| 220 |
+
"""
|
| 221 |
+
Given a model and a linear subgraph starting from `first_node` and
|
| 222 |
+
ending with `last_node`, creates a logger for the end of this
|
| 223 |
+
subgraph.
|
| 224 |
+
"""
|
| 225 |
+
if fqn is None:
|
| 226 |
+
fqn = ''
|
| 227 |
+
logger_mod_orig = logger_cls(
|
| 228 |
+
first_node.name, # ref_node_name
|
| 229 |
+
last_node.name, # prev_node_name
|
| 230 |
+
f'subgraph_{subgraph_idx}_{subgraph_candidate_idx}', # model_name
|
| 231 |
+
'model', # ref_name
|
| 232 |
+
get_target_type_str(last_node, model), # prev_node_target_type
|
| 233 |
+
get_target_type_str(first_node, model), # ref_node_target_type
|
| 234 |
+
NSSingleResultValuesType.NODE_OUTPUT.value, # results_type
|
| 235 |
+
0, # index_within_arg
|
| 236 |
+
0, # index_of_arg
|
| 237 |
+
fqn, # fqn
|
| 238 |
+
qconfig_str,
|
| 239 |
+
)
|
| 240 |
+
# Usually we expect the user to add loggers, then calibrate, then convert,
|
| 241 |
+
# and then populate loggers. This is why the loggers start disabled.
|
| 242 |
+
# TODO(future PR): reconsider the design to make this more intuitive.
|
| 243 |
+
logger_mod_orig.enabled = False
|
| 244 |
+
return logger_mod_orig
|
| 245 |
+
|
| 246 |
+
def create_submodule_from_subgraph(
|
| 247 |
+
model: torch.nn.Module,
|
| 248 |
+
first_node: Node,
|
| 249 |
+
last_node: Node,
|
| 250 |
+
) -> GraphModule:
|
| 251 |
+
"""
|
| 252 |
+
Input: a model, and a linear subgraph within the model from first_node to
|
| 253 |
+
last_node.
|
| 254 |
+
|
| 255 |
+
Output: a new submodule containing a copy of the subgraph, with the inputs
|
| 256 |
+
to the first node becoming the inputs to the submodule, and all other
|
| 257 |
+
nodes in the subgraph being copied.
|
| 258 |
+
|
| 259 |
+
Example inputs:
|
| 260 |
+
|
| 261 |
+
`model`: a module with graph
|
| 262 |
+
|
| 263 |
+
x0 -> op1 -> x1 -> op2 -> x2
|
| 264 |
+
|
|
| 265 |
+
arg1
|
| 266 |
+
|
| 267 |
+
`first_node`: op1
|
| 268 |
+
`last_node`: op2
|
| 269 |
+
|
| 270 |
+
Example output: a new module with graph
|
| 271 |
+
|
| 272 |
+
input1 -> op1_copy -> x1 -> op2_copy -> output1
|
| 273 |
+
|
|
| 274 |
+
arg1
|
| 275 |
+
"""
|
| 276 |
+
|
| 277 |
+
#
|
| 278 |
+
# create a blank GraphModule with an empty graph
|
| 279 |
+
#
|
| 280 |
+
|
| 281 |
+
class M(torch.nn.Module):
|
| 282 |
+
def forward(self, x):
|
| 283 |
+
pass
|
| 284 |
+
|
| 285 |
+
m = M()
|
| 286 |
+
gm = torch.fx.symbolic_trace(m)
|
| 287 |
+
g = gm.graph
|
| 288 |
+
for node in reversed(gm.graph.nodes):
|
| 289 |
+
g.erase_node(node)
|
| 290 |
+
|
| 291 |
+
#
|
| 292 |
+
# modify the graph to have a copy of our subgraph
|
| 293 |
+
#
|
| 294 |
+
|
| 295 |
+
cur_node_orig = first_node
|
| 296 |
+
cur_args_orig = cur_node_orig.args
|
| 297 |
+
cur_kwargs_orig = cur_node_orig.kwargs
|
| 298 |
+
|
| 299 |
+
cur_name_idx = 0
|
| 300 |
+
|
| 301 |
+
iteration_limit = 100
|
| 302 |
+
cur_iteration = 0
|
| 303 |
+
|
| 304 |
+
while True:
|
| 305 |
+
if cur_node_orig is first_node:
|
| 306 |
+
# we are at the first node, we need to set up graph inputs
|
| 307 |
+
# TODO(future): some graphs could have placeholders which are unrelated
|
| 308 |
+
# to the first node, need to handle this
|
| 309 |
+
cur_args_copy = []
|
| 310 |
+
cur_kwargs_copy = {}
|
| 311 |
+
seen_names: Set[str] = set()
|
| 312 |
+
old_name_to_new_node: Dict[str, Node] = {}
|
| 313 |
+
|
| 314 |
+
def _add_placeholder(
|
| 315 |
+
g: Graph, node: Node, seen_names, old_name_to_new_node
|
| 316 |
+
):
|
| 317 |
+
# note: for graphs starting with patterns such as `y = x + x`, we
|
| 318 |
+
# need to ensure we do not add multiple placeholders with the
|
| 319 |
+
# same name
|
| 320 |
+
counter = 0
|
| 321 |
+
while node.name + '_' + str(counter) in seen_names:
|
| 322 |
+
counter += 1
|
| 323 |
+
cur_name = node.name + '_' + str(counter)
|
| 324 |
+
seen_names.add(cur_name)
|
| 325 |
+
placeholder = g.placeholder(cur_name)
|
| 326 |
+
old_name_to_new_node[node.name] = placeholder
|
| 327 |
+
return placeholder
|
| 328 |
+
|
| 329 |
+
for arg in cur_node_orig.args:
|
| 330 |
+
if isinstance(arg, Node):
|
| 331 |
+
p = _add_placeholder(
|
| 332 |
+
g, arg, seen_names, old_name_to_new_node)
|
| 333 |
+
cur_args_copy.append(p)
|
| 334 |
+
elif isinstance(arg, (list, tuple)):
|
| 335 |
+
new_arg = []
|
| 336 |
+
for inner_arg in arg:
|
| 337 |
+
if isinstance(inner_arg, Node):
|
| 338 |
+
new_arg.append(_add_placeholder(
|
| 339 |
+
g, inner_arg, seen_names, old_name_to_new_node))
|
| 340 |
+
else:
|
| 341 |
+
new_arg.append(inner_arg)
|
| 342 |
+
cur_args_copy.append(new_arg)
|
| 343 |
+
else:
|
| 344 |
+
cur_args_copy.append(arg)
|
| 345 |
+
|
| 346 |
+
# TODO(future PR): handle non-normalized kwargs
|
| 347 |
+
for kwarg_name, kwarg in cur_node_orig.kwargs.items():
|
| 348 |
+
if isinstance(kwarg, Node):
|
| 349 |
+
cur_kwargs_copy[kwarg_name] = _add_placeholder(
|
| 350 |
+
g, kwarg, seen_names, old_name_to_new_node)
|
| 351 |
+
elif isinstance(kwarg, (list, tuple)):
|
| 352 |
+
new_kwarg = []
|
| 353 |
+
for inner_kwarg in kwarg:
|
| 354 |
+
p = _add_placeholder(
|
| 355 |
+
g, inner_kwarg, seen_names, old_name_to_new_node)
|
| 356 |
+
new_kwarg.append(p)
|
| 357 |
+
cur_kwargs_copy[kwarg_name] = new_kwarg
|
| 358 |
+
else:
|
| 359 |
+
cur_kwargs_copy[kwarg_name] = kwarg
|
| 360 |
+
|
| 361 |
+
cur_args_copy = tuple(cur_args_copy) # type: ignore[assignment]
|
| 362 |
+
else:
|
| 363 |
+
# we are not at first node, first arg is from the previous node,
|
| 364 |
+
# and all other args are copied
|
| 365 |
+
|
| 366 |
+
# the current implementation is simplistic and cannot handle
|
| 367 |
+
# ops with two or more arguments which need to be passed from
|
| 368 |
+
# the previous op, so we assert them out
|
| 369 |
+
assert cur_node_orig.target not in BINARY_FUNCTIONS
|
| 370 |
+
|
| 371 |
+
# at this point in the code, cur_node_copy is pointing to the copy
|
| 372 |
+
# of the previous node
|
| 373 |
+
# TODO(future PR): this is not handling complicated graphs correctly, need to
|
| 374 |
+
# look at actual relationships instead of assuming sequential graph
|
| 375 |
+
# TODO(future PR): this is ignoring kwargs, will need to support kwargs
|
| 376 |
+
# for any fusion pattern which has them for a node that is not the
|
| 377 |
+
# first node.
|
| 378 |
+
cur_args_copy = [cur_node_copy] # type: ignore[has-type, possibly-undefined] # noqa: F821
|
| 379 |
+
|
| 380 |
+
if len(cur_node_orig.args) > 1:
|
| 381 |
+
for arg in cur_node_orig.args[1:]:
|
| 382 |
+
if isinstance(arg, torch.nn.Parameter):
|
| 383 |
+
new_arg = arg.clone().detach() # type: ignore[assignment]
|
| 384 |
+
mod_name = f"mod_{cur_name_idx}"
|
| 385 |
+
cur_name_idx += 1
|
| 386 |
+
setattr(gm, mod_name, new_arg)
|
| 387 |
+
new_arg_placeholder = gm.placeholder(mod_name)
|
| 388 |
+
cur_args_copy.append(new_arg_placeholder)
|
| 389 |
+
elif isinstance(arg, (float, int, torch.dtype)):
|
| 390 |
+
cur_args_copy.append(arg)
|
| 391 |
+
else:
|
| 392 |
+
raise AssertionError(f'arg of type {type(arg)} not handled yet')
|
| 393 |
+
cur_args_copy = tuple(cur_args_copy) # type: ignore[assignment]
|
| 394 |
+
|
| 395 |
+
# copy the node
|
| 396 |
+
if cur_node_orig.op == 'call_module':
|
| 397 |
+
orig_mod = getattr_from_fqn(model, cur_node_orig.target) # type: ignore[arg-type]
|
| 398 |
+
orig_mod_copy = copy.deepcopy(orig_mod)
|
| 399 |
+
mod_name = f"mod_{cur_name_idx}"
|
| 400 |
+
setattr(gm, mod_name, orig_mod_copy)
|
| 401 |
+
cur_name_idx += 1
|
| 402 |
+
cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined]
|
| 403 |
+
|
| 404 |
+
elif cur_node_orig.op == 'call_function':
|
| 405 |
+
cur_node_copy = g.call_function(
|
| 406 |
+
cur_node_orig.target, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined]
|
| 407 |
+
|
| 408 |
+
elif cur_node_orig.op == 'call_method':
|
| 409 |
+
cur_node_copy = g.call_method(
|
| 410 |
+
cur_node_orig.target, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined]
|
| 411 |
+
|
| 412 |
+
else:
|
| 413 |
+
raise AssertionError(f'{cur_node_orig.op} not supported yet')
|
| 414 |
+
|
| 415 |
+
if cur_node_orig is last_node:
|
| 416 |
+
break
|
| 417 |
+
|
| 418 |
+
# go to next node
|
| 419 |
+
assert len(cur_node_orig.users.keys()) == 1, \
|
| 420 |
+
f'{cur_node_orig} has more than 1 users, not supported yet'
|
| 421 |
+
cur_node_orig = next(iter(cur_node_orig.users.keys()))
|
| 422 |
+
cur_args_orig = cur_node_orig.args
|
| 423 |
+
cur_kwargs_orig = cur_node_orig.kwargs
|
| 424 |
+
|
| 425 |
+
cur_iteration += 1
|
| 426 |
+
if cur_iteration > iteration_limit:
|
| 427 |
+
raise AssertionError('iteration limit exceeded')
|
| 428 |
+
|
| 429 |
+
# set up outputs
|
| 430 |
+
g.output(cur_node_copy)
|
| 431 |
+
|
| 432 |
+
gm.recompile()
|
| 433 |
+
return gm
|
| 434 |
+
|
| 435 |
+
def create_one_transformed_and_logged_copy_of_subgraph(
|
| 436 |
+
mt: GraphModule,
|
| 437 |
+
subgraph_idx: int,
|
| 438 |
+
subgraph_candidate_idx: int,
|
| 439 |
+
first_node: Node,
|
| 440 |
+
last_node: Node,
|
| 441 |
+
fqn: Optional[str],
|
| 442 |
+
list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]],
|
| 443 |
+
example_inputs: Any,
|
| 444 |
+
last_added_shadow_node_list: List[Optional[Node]],
|
| 445 |
+
custom_prepare_fn: Optional[Callable] = None,
|
| 446 |
+
custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
|
| 447 |
+
) -> None:
|
| 448 |
+
"""
|
| 449 |
+
Given a subgraph in `mt` and a subgraph candidate idx, inserts the
|
| 450 |
+
subgraph candidate copy and instruments it with loggers.
|
| 451 |
+
|
| 452 |
+
If subgraph_candidate_idx is 0, this is the baseline fp32 subgraph and we just
|
| 453 |
+
add a logger to the end.
|
| 454 |
+
|
| 455 |
+
If subgraph_candidate_idx is not 0, we create a copy of the subgraph and
|
| 456 |
+
prepare it with `prepare_fx`.
|
| 457 |
+
"""
|
| 458 |
+
|
| 459 |
+
# TODO(future PR): move logger classes to utils to remove circular dependency
|
| 460 |
+
from torch.ao.ns._numeric_suite_fx import OutputLogger, OutputComparisonLogger
|
| 461 |
+
|
| 462 |
+
if subgraph_candidate_idx == 0:
|
| 463 |
+
# idx = 0 is the floating point (original) version of the subgraph
|
| 464 |
+
# We keep the subgraph as is, and add a logger at the end
|
| 465 |
+
|
| 466 |
+
qconfig_str = ''
|
| 467 |
+
logger_mod_orig = _get_logger_for_subgraph(
|
| 468 |
+
mt, first_node, last_node, subgraph_idx, subgraph_candidate_idx,
|
| 469 |
+
qconfig_str, OutputLogger, fqn)
|
| 470 |
+
|
| 471 |
+
attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
|
| 472 |
+
assert not hasattr(mt, attr_name)
|
| 473 |
+
setattr(mt, attr_name, logger_mod_orig)
|
| 474 |
+
with mt.graph.inserting_after(last_node):
|
| 475 |
+
new_node = mt.graph.call_module(attr_name, args=(last_node,), kwargs={})
|
| 476 |
+
last_added_shadow_node_list[0] = new_node
|
| 477 |
+
|
| 478 |
+
else:
|
| 479 |
+
# idx > 0 means we have a candidate qconfig to try, so we need
|
| 480 |
+
# to make a copy of the subgraph, feed it with the right inputs,
|
| 481 |
+
# and add a logger at the end
|
| 482 |
+
|
| 483 |
+
# get the qconfig
|
| 484 |
+
# subtract one because the first candidate is the floating point
|
| 485 |
+
# version of the subgraph
|
| 486 |
+
node_name_to_qconfig = \
|
| 487 |
+
list_of_node_name_to_qconfig[subgraph_candidate_idx - 1]
|
| 488 |
+
qconfig = node_name_to_qconfig[first_node.name]
|
| 489 |
+
|
| 490 |
+
# if no quantization is requested, skip
|
| 491 |
+
# TODO(future PR): deduplicate equivalent qconfigs that come from
|
| 492 |
+
# different qconfig mapping objects
|
| 493 |
+
if qconfig is None:
|
| 494 |
+
return
|
| 495 |
+
|
| 496 |
+
qconfig_mapping = QConfigMapping().set_global(qconfig)
|
| 497 |
+
|
| 498 |
+
# create a copy of the submodule, wrapped in a separate module
|
| 499 |
+
orig_mod_copy_wrapped = create_submodule_from_subgraph(
|
| 500 |
+
mt, first_node, last_node)
|
| 501 |
+
|
| 502 |
+
# add a call to prepare_fx on the wrapper module
|
| 503 |
+
if custom_prepare_fn is None:
|
| 504 |
+
orig_mod_copy_wrapped = torch.ao.quantization.quantize_fx.prepare_fx(
|
| 505 |
+
orig_mod_copy_wrapped, qconfig_mapping, example_inputs=example_inputs)
|
| 506 |
+
else:
|
| 507 |
+
if custom_prepare_kwargs is None:
|
| 508 |
+
custom_prepare_kwargs = {}
|
| 509 |
+
for kwarg_name in ["example_inputs", "prepare_custom_config", "qconfig_mapping"]:
|
| 510 |
+
assert kwarg_name not in custom_prepare_kwargs, f"cannot specify {kwarg_name} in custom_prepare_kwargs"
|
| 511 |
+
prepare_kwargs: Dict[str, Any] = {
|
| 512 |
+
"example_inputs": example_inputs,
|
| 513 |
+
"qconfig_mapping": qconfig_mapping
|
| 514 |
+
}
|
| 515 |
+
prepare_kwargs.update(custom_prepare_kwargs)
|
| 516 |
+
orig_mod_copy_wrapped = custom_prepare_fn(
|
| 517 |
+
orig_mod_copy_wrapped,
|
| 518 |
+
**prepare_kwargs)
|
| 519 |
+
|
| 520 |
+
# attach the wrapper to the model
|
| 521 |
+
attr_name = _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx)
|
| 522 |
+
assert not hasattr(mt, attr_name)
|
| 523 |
+
setattr(mt, attr_name, orig_mod_copy_wrapped)
|
| 524 |
+
|
| 525 |
+
# add a call to the wrapper module from the parent graph
|
| 526 |
+
insert_after_node = last_added_shadow_node_list[0]
|
| 527 |
+
with mt.graph.inserting_after(insert_after_node):
|
| 528 |
+
# TODO(future PR): handle fusion patterns where non-first nodes
|
| 529 |
+
# need inputs
|
| 530 |
+
|
| 531 |
+
# pass in all node args and kwargs
|
| 532 |
+
|
| 533 |
+
new_args = []
|
| 534 |
+
for arg in first_node.args:
|
| 535 |
+
if isinstance(arg, Node):
|
| 536 |
+
new_args.append(arg)
|
| 537 |
+
elif isinstance(arg, (list, tuple)) and len(arg) and isinstance(arg[0], Node):
|
| 538 |
+
for inner_arg in arg:
|
| 539 |
+
if isinstance(inner_arg, Node):
|
| 540 |
+
new_args.append(inner_arg)
|
| 541 |
+
|
| 542 |
+
new_kwargs = {}
|
| 543 |
+
for name, old_kwarg in first_node.kwargs.items():
|
| 544 |
+
if isinstance(old_kwarg, Node):
|
| 545 |
+
new_kwargs[name] = old_kwarg
|
| 546 |
+
elif isinstance(old_kwarg, (list, tuple)) and len(old_kwarg):
|
| 547 |
+
# TODO(future PR): clarify why we are adding kwargs to args
|
| 548 |
+
new_args.extend(old_kwarg)
|
| 549 |
+
|
| 550 |
+
new_args = tuple(new_args) # type: ignore[assignment]
|
| 551 |
+
|
| 552 |
+
new_node = mt.graph.call_module(
|
| 553 |
+
attr_name, args=new_args, kwargs=new_kwargs)
|
| 554 |
+
|
| 555 |
+
# add a logger to parent graph to observe the shadow wrapper
|
| 556 |
+
logger_mod_orig = _get_logger_for_subgraph(
|
| 557 |
+
mt, first_node, last_node, subgraph_idx, subgraph_candidate_idx,
|
| 558 |
+
str(qconfig), OutputComparisonLogger, fqn)
|
| 559 |
+
|
| 560 |
+
attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
|
| 561 |
+
assert not hasattr(mt, attr_name)
|
| 562 |
+
setattr(mt, attr_name, logger_mod_orig)
|
| 563 |
+
with mt.graph.inserting_after(new_node):
|
| 564 |
+
logger = mt.graph.call_module(attr_name, args=(new_node, last_node), kwargs={})
|
| 565 |
+
last_added_shadow_node_list[0] = logger
|
| 566 |
+
|
| 567 |
+
mt.recompile()
|
| 568 |
+
|
| 569 |
+
def create_n_transformed_and_logged_copies_of_subgraph(
|
| 570 |
+
mt: GraphModule,
|
| 571 |
+
subgraph_idx: int,
|
| 572 |
+
match_name: str,
|
| 573 |
+
nodes_in_this_subgraph: List[Any],
|
| 574 |
+
qconfig_mappings: List[QConfigMapping],
|
| 575 |
+
list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]],
|
| 576 |
+
custom_prepare_fn: Optional[Callable] = None,
|
| 577 |
+
custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
|
| 578 |
+
) -> None:
|
| 579 |
+
"""
|
| 580 |
+
Given a model `mt` and a subgraph_idx, creates the needed copies
|
| 581 |
+
of the subgraph for all qconfigs, and instruments them with loggers.
|
| 582 |
+
"""
|
| 583 |
+
# for now, assume that
|
| 584 |
+
# 1. the first node has one input
|
| 585 |
+
# 2. the last node has one output
|
| 586 |
+
|
| 587 |
+
# for now, ignore all subgraphs that contain non-nodes (tuples, etc)
|
| 588 |
+
# TODO(future PR): implement this
|
| 589 |
+
if any(
|
| 590 |
+
not isinstance(node, Node)
|
| 591 |
+
for node in nodes_in_this_subgraph
|
| 592 |
+
):
|
| 593 |
+
return
|
| 594 |
+
|
| 595 |
+
first_node = nodes_in_this_subgraph[0]
|
| 596 |
+
last_node = nodes_in_this_subgraph[-1]
|
| 597 |
+
# We used output propagation to populate example values on each
|
| 598 |
+
# node. Use the example values from the previous node as the input
|
| 599 |
+
# to the current node.
|
| 600 |
+
prev_node = get_normalized_nth_input(first_node, mt, 0)
|
| 601 |
+
if isinstance(prev_node, list):
|
| 602 |
+
example_inputs = [x.traced_result for x in prev_node]
|
| 603 |
+
elif isinstance(prev_node, tuple):
|
| 604 |
+
example_inputs = (x.traced_result for x in prev_node) # type: ignore[assignment]
|
| 605 |
+
else:
|
| 606 |
+
# currently some customer models do not have a traced_result in
|
| 607 |
+
# every node, so we have to guard for this case since we cannot
|
| 608 |
+
# quantize without an example input
|
| 609 |
+
# TODO(future PR): add a test case for this once we have an easy
|
| 610 |
+
# repro, see https://github.com/pytorch/pytorch/pull/80521/files#r975940489
|
| 611 |
+
# for additional context
|
| 612 |
+
if hasattr(prev_node, 'traced_result'):
|
| 613 |
+
example_inputs = (prev_node.traced_result,) # type: ignore[attr-defined, assignment]
|
| 614 |
+
else:
|
| 615 |
+
print(
|
| 616 |
+
'unable to get example input for node ' +
|
| 617 |
+
f'{first_node.format_node()}, skipping')
|
| 618 |
+
return
|
| 619 |
+
|
| 620 |
+
# If there are no quantization configs for this subgraph, skip adding
|
| 621 |
+
# loggers. This reduces memory usage for models where not all layers are
|
| 622 |
+
# quantized.
|
| 623 |
+
# TODO(future): consider making this configurable
|
| 624 |
+
found_at_least_one_qconfig = False
|
| 625 |
+
for subgraph_candidate_idx in range(len(qconfig_mappings) + 1):
|
| 626 |
+
|
| 627 |
+
if subgraph_candidate_idx == 0:
|
| 628 |
+
# fp32 baseline does not need a qconfig
|
| 629 |
+
continue
|
| 630 |
+
|
| 631 |
+
# a. we have N shadows, so len(qconfig_mappings) is N
|
| 632 |
+
# b. we will have the fp32 layer + N shadows, so overall number of
|
| 633 |
+
# (original_op) + (*shadows) will be N+1
|
| 634 |
+
# c. since `subgraph_candidate_idx` represents (b), we need
|
| 635 |
+
# to subtract 1 to query from (a)
|
| 636 |
+
node_name_to_qconfig = \
|
| 637 |
+
list_of_node_name_to_qconfig[subgraph_candidate_idx - 1]
|
| 638 |
+
qconfig = node_name_to_qconfig[first_node.name]
|
| 639 |
+
if qconfig is not None:
|
| 640 |
+
found_at_least_one_qconfig = True
|
| 641 |
+
break
|
| 642 |
+
if not found_at_least_one_qconfig:
|
| 643 |
+
print('unable to find at least one qconfig for node ' +
|
| 644 |
+
f'{first_node.format_node()}, skipping')
|
| 645 |
+
return
|
| 646 |
+
|
| 647 |
+
fqn = _maybe_get_fqn(first_node, mt)
|
| 648 |
+
|
| 649 |
+
# We want the results to contain the subgraphs in natural order,
|
| 650 |
+
# and the graph to also contain shadow wrappers and shadow loggers
|
| 651 |
+
# in natural order.
|
| 652 |
+
# If we just iterate in reverse, the graph will be in natural
|
| 653 |
+
# order but the eventual results will be in reverse order.
|
| 654 |
+
# So, we keep track of the last shadow logger we added and
|
| 655 |
+
# always insert after it.
|
| 656 |
+
last_added_shadow_node_list: List[Optional[Node]] = [None]
|
| 657 |
+
for subgraph_candidate_idx in range(len(qconfig_mappings) + 1):
|
| 658 |
+
|
| 659 |
+
create_one_transformed_and_logged_copy_of_subgraph(
|
| 660 |
+
mt, subgraph_idx, subgraph_candidate_idx, first_node,
|
| 661 |
+
last_node, fqn, list_of_node_name_to_qconfig,
|
| 662 |
+
example_inputs, last_added_shadow_node_list, custom_prepare_fn,
|
| 663 |
+
custom_prepare_kwargs)
|
| 664 |
+
|
| 665 |
+
def create_add_loggers_graph(
|
| 666 |
+
model: GraphModule,
|
| 667 |
+
subgraphs_dedup: Dict[str, List[Node]],
|
| 668 |
+
qconfig_mapping: QConfigMapping,
|
| 669 |
+
node_name_to_qconfig: Dict[str, QConfigAny],
|
| 670 |
+
) -> None:
|
| 671 |
+
r"""
|
| 672 |
+
Given a model, a model graph partition (currently a set of matched
|
| 673 |
+
subgraphs) and instructions how to transform each subgraph
|
| 674 |
+
(currently quantizing it according to qconfig_mapping), modifies
|
| 675 |
+
the model graph to create an alternate path through the original graph,
|
| 676 |
+
with each of the subgraphs quantized. This is useful to compare
|
| 677 |
+
propagation error of a transformation such as quantization.
|
| 678 |
+
|
| 679 |
+
For example, given layer op0 and op1, there are four cases when handling op1:
|
| 680 |
+
1. op0 and op1 quantized
|
| 681 |
+
2. op0 and op1 unquantized
|
| 682 |
+
3. op0 quantized, op1 unquantized
|
| 683 |
+
4. op0 unquantized, op1 quantized
|
| 684 |
+
|
| 685 |
+
Example input, case 1:
|
| 686 |
+
|
| 687 |
+
.. code::
|
| 688 |
+
|
| 689 |
+
x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log
|
| 690 |
+
\ \ \ \ # noqa: W605
|
| 691 |
+
---> op0_1 -> x1_1 ----> clog op1_1 -> x2_1 ----> clog
|
| 692 |
+
|
| 693 |
+
Example output, case 1:
|
| 694 |
+
|
| 695 |
+
.. code::
|
| 696 |
+
|
| 697 |
+
x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log
|
| 698 |
+
\ \ \ # noqa: W605
|
| 699 |
+
---> op0_1 -> x1_1 ----> clog -> op1_1 -> x2_1 ----> clog
|
| 700 |
+
|
| 701 |
+
"""
|
| 702 |
+
# TODO(future PR): move logger classes to utils to remove circular dependency
|
| 703 |
+
from torch.ao.ns._numeric_suite_fx import OutputLogger, OutputComparisonLogger
|
| 704 |
+
|
| 705 |
+
def _get_subgraph_containing_node(node, subgraphs_dedup):
|
| 706 |
+
for subgraph in subgraphs_dedup.values():
|
| 707 |
+
if node in subgraph:
|
| 708 |
+
return subgraph
|
| 709 |
+
return None
|
| 710 |
+
|
| 711 |
+
# First, we need to create shadow branches, going from
|
| 712 |
+
#
|
| 713 |
+
# x0 -> op0 -> x1 -> ...
|
| 714 |
+
#
|
| 715 |
+
#
|
| 716 |
+
# to
|
| 717 |
+
#
|
| 718 |
+
# x0 -> op0_0 -> x1_0 -> log -> ...
|
| 719 |
+
# \ \
|
| 720 |
+
# -> op0_1 -> x1_1 -> clog
|
| 721 |
+
#
|
| 722 |
+
# Later, the outputs of each shadow will be rerouted to calculate
|
| 723 |
+
# propagation error.
|
| 724 |
+
|
| 725 |
+
# Note: we cannot iterate over matched subgraphs because some nodes
|
| 726 |
+
# may not be matched. So, we iterate over nodes in the graph, and
|
| 727 |
+
# associate them to matched subgraphs if possible.
|
| 728 |
+
|
| 729 |
+
nodes_to_skip = set()
|
| 730 |
+
# for each subgraph, save a mapping from first node of subgraph
|
| 731 |
+
# to first and last node of the shadow of this subgraph
|
| 732 |
+
orig_first_node_to_shadow_in_node = {}
|
| 733 |
+
orig_first_node_to_shadow_out_node = {}
|
| 734 |
+
# need to record original list because we will mutate the graph as we go
|
| 735 |
+
orig_nodes = list(model.graph.nodes) # type: ignore[union-attr, arg-type]
|
| 736 |
+
cur_subgraph_idx = 0
|
| 737 |
+
for n in orig_nodes:
|
| 738 |
+
if n.op in ('placeholder', 'get_attr', 'output') or n in nodes_to_skip:
|
| 739 |
+
continue
|
| 740 |
+
|
| 741 |
+
maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup)
|
| 742 |
+
insert_submodule_copy = False
|
| 743 |
+
if maybe_subgraph is not None:
|
| 744 |
+
first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
|
| 745 |
+
for node_to_skip in maybe_subgraph:
|
| 746 |
+
nodes_to_skip.add(node_to_skip)
|
| 747 |
+
qconfig = node_name_to_qconfig[first_node.name]
|
| 748 |
+
if qconfig is not None:
|
| 749 |
+
insert_submodule_copy = True
|
| 750 |
+
else:
|
| 751 |
+
first_node, last_node = n, n
|
| 752 |
+
|
| 753 |
+
if insert_submodule_copy:
|
| 754 |
+
match_name = first_node.name
|
| 755 |
+
create_n_transformed_and_logged_copies_of_subgraph(
|
| 756 |
+
model, cur_subgraph_idx, match_name, maybe_subgraph,
|
| 757 |
+
[qconfig_mapping], [node_name_to_qconfig],
|
| 758 |
+
None, None # type: ignore[arg-type]
|
| 759 |
+
)
|
| 760 |
+
# find the created shadow module and record it so we
|
| 761 |
+
# can find it easily in step 2
|
| 762 |
+
expected_shadow_target = f"shadow_wrapper_{cur_subgraph_idx}_1"
|
| 763 |
+
new_shadow_mod = None
|
| 764 |
+
for maybe_shadow_mod in model.graph.nodes:
|
| 765 |
+
if maybe_shadow_mod.op == 'call_module' and \
|
| 766 |
+
maybe_shadow_mod.target == expected_shadow_target:
|
| 767 |
+
new_shadow_mod = maybe_shadow_mod
|
| 768 |
+
break
|
| 769 |
+
assert new_shadow_mod is not None
|
| 770 |
+
orig_first_node_to_shadow_in_node[first_node] = new_shadow_mod
|
| 771 |
+
orig_first_node_to_shadow_out_node[first_node] = new_shadow_mod
|
| 772 |
+
|
| 773 |
+
else:
|
| 774 |
+
# create a copy of the subgraph by only copying FX nodes
|
| 775 |
+
# but not copying any parameters, to minimize memory usage
|
| 776 |
+
subgraph_to_use = maybe_subgraph if maybe_subgraph is not None \
|
| 777 |
+
else [first_node]
|
| 778 |
+
|
| 779 |
+
# add a regular logger after last_node
|
| 780 |
+
qconfig_str = ''
|
| 781 |
+
subgraph_candidate_idx = 0
|
| 782 |
+
fqn = _maybe_get_fqn(first_node, model)
|
| 783 |
+
logger_mod_orig = _get_logger_for_subgraph(
|
| 784 |
+
model, first_node, last_node, cur_subgraph_idx, subgraph_candidate_idx,
|
| 785 |
+
qconfig_str, OutputLogger, fqn)
|
| 786 |
+
attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
|
| 787 |
+
assert not hasattr(model, attr_name)
|
| 788 |
+
setattr(model, attr_name, logger_mod_orig)
|
| 789 |
+
insertion_point = last_node
|
| 790 |
+
with model.graph.inserting_after(insertion_point):
|
| 791 |
+
logger = model.graph.call_module(
|
| 792 |
+
attr_name, args=(last_node,), kwargs={})
|
| 793 |
+
insertion_point = logger
|
| 794 |
+
|
| 795 |
+
# create a copy of the subgraph
|
| 796 |
+
cur_node_orig = first_node
|
| 797 |
+
cur_node_copy = None
|
| 798 |
+
first_node_copy = None
|
| 799 |
+
while cur_node_orig in subgraph_to_use:
|
| 800 |
+
# TODO(future PR): make this support all possible args/kwargs
|
| 801 |
+
if cur_node_orig is first_node:
|
| 802 |
+
new_args = cur_node_orig.args
|
| 803 |
+
new_kwargs = cur_node_orig.kwargs
|
| 804 |
+
else:
|
| 805 |
+
first_arg_for_copy = cur_node_copy
|
| 806 |
+
new_args = tuple([first_arg_for_copy, *cur_node_orig.args[1:]]) # noqa: C409
|
| 807 |
+
new_kwargs = cur_node_orig.kwargs
|
| 808 |
+
# make a copy of cur_node_orig
|
| 809 |
+
with model.graph.inserting_after(insertion_point):
|
| 810 |
+
cur_node_copy = model.graph.create_node(
|
| 811 |
+
cur_node_orig.op,
|
| 812 |
+
cur_node_orig.target,
|
| 813 |
+
new_args,
|
| 814 |
+
new_kwargs,
|
| 815 |
+
# cur_node_orig.name, # TODO(future PR): set name explicitly
|
| 816 |
+
)
|
| 817 |
+
if first_node_copy is None:
|
| 818 |
+
first_node_copy = cur_node_copy
|
| 819 |
+
# since now only linear subgraphs are supported, all nodes
|
| 820 |
+
# except the last one must have only one user
|
| 821 |
+
if cur_node_orig != last_node:
|
| 822 |
+
assert len(cur_node_orig.users.keys()) == 1
|
| 823 |
+
cur_node_orig = next(iter(cur_node_orig.users.keys()))
|
| 824 |
+
assert not cur_node_orig.name.startswith(SHADOW_NODE_NAME_PREFIX)
|
| 825 |
+
insertion_point = cur_node_copy
|
| 826 |
+
|
| 827 |
+
# add a comparison logger after last_node's copy
|
| 828 |
+
subgraph_candidate_idx = 1
|
| 829 |
+
logger_mod_orig = _get_logger_for_subgraph(
|
| 830 |
+
model, first_node, last_node, cur_subgraph_idx, subgraph_candidate_idx,
|
| 831 |
+
qconfig_str, OutputComparisonLogger, fqn)
|
| 832 |
+
attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
|
| 833 |
+
assert not hasattr(model, attr_name)
|
| 834 |
+
setattr(model, attr_name, logger_mod_orig)
|
| 835 |
+
with model.graph.inserting_after(insertion_point):
|
| 836 |
+
logger = model.graph.call_module(
|
| 837 |
+
attr_name, args=(cur_node_copy, last_node), kwargs={})
|
| 838 |
+
|
| 839 |
+
# save the final node so we can use it in step 2
|
| 840 |
+
orig_first_node_to_shadow_in_node[first_node] = first_node_copy
|
| 841 |
+
orig_first_node_to_shadow_out_node[first_node] = cur_node_copy
|
| 842 |
+
|
| 843 |
+
cur_subgraph_idx += 1
|
| 844 |
+
|
| 845 |
+
model.recompile()
|
| 846 |
+
|
| 847 |
+
# Now, we go from
|
| 848 |
+
#
|
| 849 |
+
# x0 -> op0_0 -> x1_0 -> log -> x1 -> op1_0 -> ...
|
| 850 |
+
# \ \ \
|
| 851 |
+
# -> op0_1 -> x1_1 -> clog -> op1_1 -> ...
|
| 852 |
+
#
|
| 853 |
+
# to
|
| 854 |
+
#
|
| 855 |
+
# x0 -> op0_0 -> x1_0 -> log --> x1_0 -> op1_0 -> ...
|
| 856 |
+
# \ \
|
| 857 |
+
# -> op0_1 -> x1_1 -> clog -> x1_1 -> op1_1 -> ...
|
| 858 |
+
#
|
| 859 |
+
# sample values of key internal variables for the example above:
|
| 860 |
+
#
|
| 861 |
+
# orig_first_node_to_shadow_in_node = {op0_0: op0_1, op1_0: op1_1}
|
| 862 |
+
# orig_first_node_to_shadow_out_node = {op0_0: op0_1, op1_0: op1_1}
|
| 863 |
+
#
|
| 864 |
+
# note: for subgraphs with more than one node, in_node will be different
|
| 865 |
+
# compared to out_node
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
nodes_to_skip = set()
|
| 869 |
+
for n in orig_nodes:
|
| 870 |
+
if n.op in ('placeholder', 'get_attr', 'output') or n in nodes_to_skip:
|
| 871 |
+
continue
|
| 872 |
+
|
| 873 |
+
maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup)
|
| 874 |
+
if maybe_subgraph is not None:
|
| 875 |
+
first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
|
| 876 |
+
for node_to_skip in maybe_subgraph:
|
| 877 |
+
nodes_to_skip.add(node_to_skip)
|
| 878 |
+
else:
|
| 879 |
+
first_node, last_node = n, n
|
| 880 |
+
|
| 881 |
+
def maybe_remap_node_to_shadow(node):
|
| 882 |
+
"""
|
| 883 |
+
If unshadowed `node` has a shadow version, return that. If not,
|
| 884 |
+
return `node`.
|
| 885 |
+
"""
|
| 886 |
+
if not isinstance(node, Node):
|
| 887 |
+
# handle scalars
|
| 888 |
+
return node
|
| 889 |
+
|
| 890 |
+
if node.op in ('placeholder', 'get_attr'):
|
| 891 |
+
return node
|
| 892 |
+
|
| 893 |
+
# Find the shadowed version of this arg from the previous
|
| 894 |
+
# subgraph. For this, we need to:
|
| 895 |
+
# 1. navigate to the first node of the previous subgraph
|
| 896 |
+
# 2. get the output of the shadow wrapper which has (1) as an input
|
| 897 |
+
|
| 898 |
+
# For now, assume the arg is in matched subgraphs. In the
|
| 899 |
+
# future we may have to handle the case where this is not true.
|
| 900 |
+
prev_subgraph = _get_subgraph_containing_node(
|
| 901 |
+
node, subgraphs_dedup)
|
| 902 |
+
if prev_subgraph is None:
|
| 903 |
+
prev_subgraph = [node]
|
| 904 |
+
prev_first_node = prev_subgraph[0]
|
| 905 |
+
prev_shadow_output = \
|
| 906 |
+
orig_first_node_to_shadow_out_node[prev_first_node]
|
| 907 |
+
return prev_shadow_output
|
| 908 |
+
|
| 909 |
+
cur_shadow_input = \
|
| 910 |
+
orig_first_node_to_shadow_in_node[first_node]
|
| 911 |
+
assert cur_shadow_input is not None
|
| 912 |
+
cur_shadow_input.args = tree_map(
|
| 913 |
+
maybe_remap_node_to_shadow, cur_shadow_input.args)
|
| 914 |
+
cur_shadow_input.kwargs = tree_map(
|
| 915 |
+
maybe_remap_node_to_shadow, cur_shadow_input.kwargs)
|
| 916 |
+
|
| 917 |
+
model.recompile()
|
| 918 |
+
|
| 919 |
+
def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module):
|
| 920 |
+
# input: shadow wrapper module
|
| 921 |
+
# output if shadow wrapper module has a weighted op:
|
| 922 |
+
# (quantize_fn, (quantize_fn_args))
|
| 923 |
+
# output if shadow wrapper module doesn't have a weighted op:
|
| 924 |
+
# None
|
| 925 |
+
|
| 926 |
+
# For now, assume that the weight is the second input
|
| 927 |
+
# to the shadow module. If that changes, we can fix it later.
|
| 928 |
+
placeholders_seen = 0
|
| 929 |
+
for shadow_n in shadow_wrapper.graph.nodes: # type: ignore[union-attr]
|
| 930 |
+
if shadow_n.op != 'placeholder':
|
| 931 |
+
continue
|
| 932 |
+
|
| 933 |
+
placeholders_seen += 1
|
| 934 |
+
if placeholders_seen != 2:
|
| 935 |
+
continue
|
| 936 |
+
|
| 937 |
+
# the subgraph looks like
|
| 938 |
+
#
|
| 939 |
+
# _input_scale_1 = self._input_scale_1
|
| 940 |
+
# _input_zero_point_1 = self._input_zero_point_1
|
| 941 |
+
# quantize_per_channel = torch.quantize_per_channel(
|
| 942 |
+
# w2_0, _input_scale_1, _input_zero_point_1,
|
| 943 |
+
# 0, torch.qint8)
|
| 944 |
+
#
|
| 945 |
+
# we have `w2_0`, and are navigating this subgraph
|
| 946 |
+
# to get `_input_scale_1` and `_input_zero_point_1`
|
| 947 |
+
|
| 948 |
+
assert len(shadow_n.users) == 1
|
| 949 |
+
quant_node = next(iter(shadow_n.users.keys()))
|
| 950 |
+
new_args: Any = None
|
| 951 |
+
if quant_node.target == torch.quantize_per_channel:
|
| 952 |
+
_weight, scale_node, zp_node, axis, dtype = quant_node.args
|
| 953 |
+
scale_val = getattr_from_fqn(
|
| 954 |
+
shadow_wrapper, scale_node.target)
|
| 955 |
+
zp_val = getattr_from_fqn(
|
| 956 |
+
shadow_wrapper, zp_node.target)
|
| 957 |
+
new_args = (scale_val, zp_val, axis, dtype)
|
| 958 |
+
else:
|
| 959 |
+
assert quant_node.target == torch.quantize_per_tensor
|
| 960 |
+
_weight, scale_node, zp_node, dtype = quant_node.args
|
| 961 |
+
scale_val = getattr_from_fqn(
|
| 962 |
+
shadow_wrapper, scale_node.target)
|
| 963 |
+
zp_val = getattr_from_fqn(
|
| 964 |
+
shadow_wrapper, zp_node.target)
|
| 965 |
+
new_args = (scale_val, zp_val, dtype)
|
| 966 |
+
return (quant_node.target, new_args)
|
| 967 |
+
|
| 968 |
+
return None
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
def extract_weight_comparison(m: GraphModule) -> NSResultsType:
|
| 972 |
+
|
| 973 |
+
# example graph:
|
| 974 |
+
#
|
| 975 |
+
# w1 = self.w1
|
| 976 |
+
# b1 = self.b1
|
| 977 |
+
# linear = torch._C._nn.linear(x, w1, b1)
|
| 978 |
+
# shadow_0_0 = self.shadow_0_0(linear)
|
| 979 |
+
# shadow_wrapper_0_1 = self.shadow_wrapper_0_1(x, w1, b1)
|
| 980 |
+
# shadow_0_1 = self.shadow_0_1(shadow_wrapper_0_1, linear)
|
| 981 |
+
#
|
| 982 |
+
# algorithm:
|
| 983 |
+
# 1. for each call_function node matching our allowlist:
|
| 984 |
+
# 2. if corresponding shadow wrapper exists, extract the weight pair
|
| 985 |
+
#
|
| 986 |
+
# Note: this is not super robust, but that's ok because this is
|
| 987 |
+
# just for legacy customers who depend on the previous two-model version
|
| 988 |
+
# of this API. TBD if we need to make this robust.
|
| 989 |
+
# Note: modules are not supported, since existing customers only
|
| 990 |
+
# use functions.
|
| 991 |
+
|
| 992 |
+
# TODO(future PR): move this to config
|
| 993 |
+
weighted_ops = {
|
| 994 |
+
torch.nn.functional.linear,
|
| 995 |
+
}
|
| 996 |
+
|
| 997 |
+
results: NSResultsType = {
|
| 998 |
+
'model': {NSSingleResultValuesType.WEIGHT.value: {}}
|
| 999 |
+
}
|
| 1000 |
+
|
| 1001 |
+
for n in m.graph.nodes: # type: ignore[union-attr]
|
| 1002 |
+
if not (n.op == 'call_function' and n.target in weighted_ops):
|
| 1003 |
+
continue
|
| 1004 |
+
|
| 1005 |
+
# Check if we have a corresponding shadow wrapper
|
| 1006 |
+
# TODO(future PR, if needed): support kwargs
|
| 1007 |
+
# TODO(future PR, if needed): support multiple shadow users
|
| 1008 |
+
first_arg = n.args[0]
|
| 1009 |
+
shadow_wrapper_node = None
|
| 1010 |
+
for user in first_arg.users:
|
| 1011 |
+
# TODO(before land): fix string match
|
| 1012 |
+
if user.op == 'call_module' and \
|
| 1013 |
+
user.target.startswith('shadow_wrapper'):
|
| 1014 |
+
shadow_wrapper_node = user
|
| 1015 |
+
break
|
| 1016 |
+
|
| 1017 |
+
if shadow_wrapper_node is None:
|
| 1018 |
+
continue
|
| 1019 |
+
|
| 1020 |
+
shadow_wrapper = getattr_from_fqn(
|
| 1021 |
+
m, shadow_wrapper_node.target) # type: ignore[arg-type]
|
| 1022 |
+
weight_info = _get_weight_info_from_shadow_wrapper(
|
| 1023 |
+
shadow_wrapper)
|
| 1024 |
+
if weight_info is None:
|
| 1025 |
+
continue
|
| 1026 |
+
|
| 1027 |
+
# get weight
|
| 1028 |
+
w_node = n.args[1]
|
| 1029 |
+
w_obj = getattr_from_fqn(m, w_node.target).detach()
|
| 1030 |
+
|
| 1031 |
+
# get a quantized version of weight
|
| 1032 |
+
quant_fn, quant_fn_args_except_first = weight_info
|
| 1033 |
+
new_args = (w_obj, *quant_fn_args_except_first)
|
| 1034 |
+
w_obj_q = quant_fn(*new_args)
|
| 1035 |
+
|
| 1036 |
+
# add a comparison
|
| 1037 |
+
ref_node_name = n.name
|
| 1038 |
+
prev_node_name = n.name
|
| 1039 |
+
ref_node_type = get_target_type_str(n, m)
|
| 1040 |
+
prev_node_type = ref_node_type
|
| 1041 |
+
fqn = None
|
| 1042 |
+
if hasattr(m, '_node_name_to_scope'):
|
| 1043 |
+
fqn = m._node_name_to_scope[n.name][0] # type: ignore[index]
|
| 1044 |
+
comparison = torch.ao.ns.fx.utils.compute_sqnr(w_obj, w_obj_q)
|
| 1045 |
+
result_fp32 = {
|
| 1046 |
+
'res_type': NSSingleResultValuesType.WEIGHT.value,
|
| 1047 |
+
'values': [w_obj],
|
| 1048 |
+
'prev_node_name': prev_node_name,
|
| 1049 |
+
'prev_node_target_type': prev_node_type,
|
| 1050 |
+
'ref_node_name': ref_node_name,
|
| 1051 |
+
'ref_node_target_type': ref_node_type,
|
| 1052 |
+
'index_within_arg': 0,
|
| 1053 |
+
'index_of_arg': 0,
|
| 1054 |
+
'fqn': fqn,
|
| 1055 |
+
'qconfig_str': '',
|
| 1056 |
+
'comparisons': [comparison],
|
| 1057 |
+
'comparison_fn_name': 'sqnr',
|
| 1058 |
+
}
|
| 1059 |
+
result_q = {
|
| 1060 |
+
'res_type': NSSingleResultValuesType.WEIGHT.value,
|
| 1061 |
+
'values': [w_obj_q],
|
| 1062 |
+
'prev_node_name': prev_node_name,
|
| 1063 |
+
'prev_node_target_type': prev_node_type,
|
| 1064 |
+
'ref_node_name': ref_node_name,
|
| 1065 |
+
'ref_node_target_type': ref_node_type,
|
| 1066 |
+
'index_within_arg': 0,
|
| 1067 |
+
'index_of_arg': 0,
|
| 1068 |
+
'fqn': fqn,
|
| 1069 |
+
'qconfig_str': '',
|
| 1070 |
+
'comparisons': [comparison],
|
| 1071 |
+
'comparison_fn_name': 'sqnr',
|
| 1072 |
+
}
|
| 1073 |
+
|
| 1074 |
+
# go from subgraph_n_1 to subgraph_n_0
|
| 1075 |
+
_1, _2, node_idx, _3 = shadow_wrapper_node.target.split('_')
|
| 1076 |
+
name_fp32 = f"subgraph_{node_idx}_0"
|
| 1077 |
+
name_q = f"subgraph_{node_idx}_1"
|
| 1078 |
+
|
| 1079 |
+
results['model'][NSSingleResultValuesType.WEIGHT.value][name_fp32] = \
|
| 1080 |
+
[result_fp32]
|
| 1081 |
+
results['model'][NSSingleResultValuesType.WEIGHT.value][name_q] = \
|
| 1082 |
+
[result_q]
|
| 1083 |
+
|
| 1084 |
+
return results
|
| 1085 |
+
|
| 1086 |
+
# TODO(future PR): redesign this to make it easier to consume outputs
|
| 1087 |
+
def group_results_by_subgraph(results: NSResultsType) -> Any:
|
| 1088 |
+
"""
|
| 1089 |
+
Creates a comparison of results
|
| 1090 |
+
|
| 1091 |
+
Input:
|
| 1092 |
+
|
| 1093 |
+
{
|
| 1094 |
+
'model': {
|
| 1095 |
+
'node_output': {
|
| 1096 |
+
'subgraph_0_0': [
|
| 1097 |
+
'values': [torch.tensor(...), ...], ...
|
| 1098 |
+
'ref_node_name': ...,
|
| 1099 |
+
'ref_node_target_type': ...,
|
| 1100 |
+
'qconfig_str': ...,
|
| 1101 |
+
'comparisons': [], ...
|
| 1102 |
+
'comparison_fn_name': '',
|
| 1103 |
+
'fqn': '...',
|
| 1104 |
+
],
|
| 1105 |
+
'subgraph_0_1': [
|
| 1106 |
+
'values': [torch.tensor(...), ...], ...
|
| 1107 |
+
'ref_node_name': ...,
|
| 1108 |
+
'ref_node_target_type': ...,
|
| 1109 |
+
'qconfig_str': ...,
|
| 1110 |
+
'comparisons': [torch.tensor(...), ...], ...
|
| 1111 |
+
'comparison_fn_name': '...',
|
| 1112 |
+
'fqn': '...',
|
| 1113 |
+
],
|
| 1114 |
+
...
|
| 1115 |
+
},
|
| 1116 |
+
},
|
| 1117 |
+
}
|
| 1118 |
+
|
| 1119 |
+
Output:
|
| 1120 |
+
{
|
| 1121 |
+
'subgraph_0': {
|
| 1122 |
+
'0': {
|
| 1123 |
+
'ref_node_name': '...',
|
| 1124 |
+
'ref_node_target_type': ...,
|
| 1125 |
+
'values': [torch.tensor(...), ...],
|
| 1126 |
+
'qconfig_str': None,
|
| 1127 |
+
'comparisons': [torch.tensor(...), ...], ...
|
| 1128 |
+
'comparison_fn_name': '...',
|
| 1129 |
+
'fqn': '...',
|
| 1130 |
+
},
|
| 1131 |
+
'1': {
|
| 1132 |
+
'ref_node_name': '...',
|
| 1133 |
+
'ref_node_target_type': ...,
|
| 1134 |
+
'values': [torch.tensor(...), ...],
|
| 1135 |
+
'qconfig_str': '...',
|
| 1136 |
+
'comparisons': [torch.tensor(...), ...], ...
|
| 1137 |
+
'comparison_fn_name': '...',
|
| 1138 |
+
'fqn': '...',
|
| 1139 |
+
},
|
| 1140 |
+
},
|
| 1141 |
+
}
|
| 1142 |
+
|
| 1143 |
+
"""
|
| 1144 |
+
subgraph_name_to_subgraph_results: Any = collections.defaultdict(dict)
|
| 1145 |
+
|
| 1146 |
+
# node_output or weight
|
| 1147 |
+
key_to_use = next(iter(results['model'].keys()))
|
| 1148 |
+
|
| 1149 |
+
for subgraph_name_with_idx, subgraph_candidate_results in \
|
| 1150 |
+
results['model'][key_to_use].items():
|
| 1151 |
+
|
| 1152 |
+
# convert from `subgraph_m_n` to `subgraph_m` and `n`
|
| 1153 |
+
subgraph_str, subgraph_idx, subgraph_candidate_idx = \
|
| 1154 |
+
subgraph_name_with_idx.split('_')
|
| 1155 |
+
subgraph_name = f'{subgraph_str}_{subgraph_idx}'
|
| 1156 |
+
|
| 1157 |
+
subgraph_results = {
|
| 1158 |
+
'ref_node_name': subgraph_candidate_results[0]['ref_node_name'],
|
| 1159 |
+
'ref_node_target_type': subgraph_candidate_results[0]['ref_node_target_type'],
|
| 1160 |
+
'fqn': subgraph_candidate_results[0]['fqn'],
|
| 1161 |
+
'values': subgraph_candidate_results[0]['values'],
|
| 1162 |
+
'qconfig_str': subgraph_candidate_results[0]['qconfig_str'],
|
| 1163 |
+
'comparisons': subgraph_candidate_results[0]['comparisons'],
|
| 1164 |
+
'comparison_fn_name': subgraph_candidate_results[0]['comparison_fn_name'],
|
| 1165 |
+
}
|
| 1166 |
+
|
| 1167 |
+
subgraph_name_to_subgraph_results[subgraph_name][subgraph_candidate_idx] = \
|
| 1168 |
+
subgraph_results
|
| 1169 |
+
|
| 1170 |
+
return dict(subgraph_name_to_subgraph_results)
|
| 1171 |
+
|
| 1172 |
+
# TODO(future PR): redesign this to make it easier to consume outputs
|
| 1173 |
+
def create_results_comparison(
|
| 1174 |
+
results_grouped,
|
| 1175 |
+
) -> Any:
|
| 1176 |
+
"""
|
| 1177 |
+
Input:
|
| 1178 |
+
|
| 1179 |
+
{
|
| 1180 |
+
'subgraph_0': {
|
| 1181 |
+
'0': {
|
| 1182 |
+
'ref_node_name': '...',
|
| 1183 |
+
'ref_node_target_type': ...,
|
| 1184 |
+
'values': [torch.tensor(...), ...],
|
| 1185 |
+
'qconfig_str': '',
|
| 1186 |
+
'comparisons': [],
|
| 1187 |
+
'comparison_fn_name': '',
|
| 1188 |
+
'fqn': '...',
|
| 1189 |
+
},
|
| 1190 |
+
'1': {
|
| 1191 |
+
'ref_node_name': '...',
|
| 1192 |
+
'ref_node_target_type': ...,
|
| 1193 |
+
'values': [torch.tensor(...), ...],
|
| 1194 |
+
'qconfig_str': '...',
|
| 1195 |
+
'comparisons': [torch.tensor(...), ...],
|
| 1196 |
+
'comparison_fn_name': 'sqnr',
|
| 1197 |
+
'fqn': '...',
|
| 1198 |
+
},
|
| 1199 |
+
},
|
| 1200 |
+
}
|
| 1201 |
+
|
| 1202 |
+
Output:
|
| 1203 |
+
{
|
| 1204 |
+
'subgraph_0': {
|
| 1205 |
+
'ref_node_name': '...',
|
| 1206 |
+
'ref_node_target_type': '...',
|
| 1207 |
+
'fqn': '...',
|
| 1208 |
+
'candidates': {
|
| 1209 |
+
'1': {
|
| 1210 |
+
'qconfig_str': ...,
|
| 1211 |
+
'comparison_fn_name': 'sqnr',
|
| 1212 |
+
'cmp_raw': [..., ...],
|
| 1213 |
+
'cmp_mean': ...,
|
| 1214 |
+
},
|
| 1215 |
+
...,
|
| 1216 |
+
},
|
| 1217 |
+
},
|
| 1218 |
+
}
|
| 1219 |
+
"""
|
| 1220 |
+
|
| 1221 |
+
results_comparison = {}
|
| 1222 |
+
|
| 1223 |
+
for subgraph_name, subgraph_results in results_grouped.items():
|
| 1224 |
+
|
| 1225 |
+
candidates = {}
|
| 1226 |
+
for subgraph_inner_name, subgraph_inner_result in subgraph_results.items():
|
| 1227 |
+
# skip comparing baseline to baseline
|
| 1228 |
+
if subgraph_inner_name == '0':
|
| 1229 |
+
continue
|
| 1230 |
+
|
| 1231 |
+
# we expect the comparisons to be precalculated from
|
| 1232 |
+
# calibration, so we just fetch them here
|
| 1233 |
+
cmp_raw = subgraph_inner_result['comparisons']
|
| 1234 |
+
cmp_raw_tensor = torch.stack(cmp_raw)
|
| 1235 |
+
|
| 1236 |
+
candidates[subgraph_inner_name] = {
|
| 1237 |
+
'qconfig_str': subgraph_inner_result['qconfig_str'],
|
| 1238 |
+
'comparison_fn_name': subgraph_inner_result['comparison_fn_name'],
|
| 1239 |
+
'cmp_raw': cmp_raw_tensor,
|
| 1240 |
+
'cmp_mean': torch.mean(cmp_raw_tensor),
|
| 1241 |
+
}
|
| 1242 |
+
|
| 1243 |
+
results_comparison[subgraph_name] = {
|
| 1244 |
+
'ref_node_name': subgraph_results['0']['ref_node_name'],
|
| 1245 |
+
'ref_node_target_type': subgraph_results['0']['ref_node_target_type'],
|
| 1246 |
+
'fqn': subgraph_results['0']['fqn'],
|
| 1247 |
+
'candidates': candidates,
|
| 1248 |
+
}
|
| 1249 |
+
|
| 1250 |
+
return results_comparison
|
| 1251 |
+
|
| 1252 |
+
# TODO(future PR): redesign this to make it easier to consume outputs
|
| 1253 |
+
def print_n_shadows_summary(
|
| 1254 |
+
results_comparison,
|
| 1255 |
+
) -> None:
|
| 1256 |
+
"""
|
| 1257 |
+
Input:
|
| 1258 |
+
|
| 1259 |
+
{
|
| 1260 |
+
'subgraph_0': {
|
| 1261 |
+
'ref_node_name': 'linear1',
|
| 1262 |
+
'ref_node_target_type': '...',
|
| 1263 |
+
'fqn': '...',
|
| 1264 |
+
'candidates': {
|
| 1265 |
+
'1': {
|
| 1266 |
+
'qconfig_str': ...,
|
| 1267 |
+
'comparison_fn_name': ...,
|
| 1268 |
+
'cmp_raw': [45.0, 55.0],
|
| 1269 |
+
'cmp_mean': 50.0,
|
| 1270 |
+
},
|
| 1271 |
+
...,
|
| 1272 |
+
},
|
| 1273 |
+
},
|
| 1274 |
+
}
|
| 1275 |
+
|
| 1276 |
+
Prints:
|
| 1277 |
+
|
| 1278 |
+
node_name | node_type | fqn | 0 | 1 | ...
|
| 1279 |
+
linear1 | ... | ... | 45.0 | 50.0 | ...
|
| 1280 |
+
"""
|
| 1281 |
+
|
| 1282 |
+
try:
|
| 1283 |
+
from tabulate import tabulate
|
| 1284 |
+
except ImportError:
|
| 1285 |
+
print("`print_tabular` relies on the library `tabulate`, "
|
| 1286 |
+
"which could not be found on this machine. Run `pip "
|
| 1287 |
+
"install tabulate` to install the library.")
|
| 1288 |
+
return
|
| 1289 |
+
|
| 1290 |
+
results = []
|
| 1291 |
+
for subgraph_data in results_comparison.values():
|
| 1292 |
+
mean_all_candidates = [
|
| 1293 |
+
candidate['cmp_mean']
|
| 1294 |
+
for candidate_name, candidate in subgraph_data['candidates'].items()
|
| 1295 |
+
]
|
| 1296 |
+
|
| 1297 |
+
data_row = [
|
| 1298 |
+
subgraph_data['ref_node_name'],
|
| 1299 |
+
subgraph_data['ref_node_target_type'],
|
| 1300 |
+
subgraph_data['fqn'],
|
| 1301 |
+
*mean_all_candidates,
|
| 1302 |
+
]
|
| 1303 |
+
results.append(data_row)
|
| 1304 |
+
|
| 1305 |
+
max_candidate_idx_len = -1
|
| 1306 |
+
for data_row in results:
|
| 1307 |
+
max_candidate_idx_len = max(max_candidate_idx_len, len(data_row[1]))
|
| 1308 |
+
candidate_idx_headers = [str(x) for x in range(max_candidate_idx_len)]
|
| 1309 |
+
|
| 1310 |
+
headers = ['node_name', 'node_type', 'fqn', *candidate_idx_headers]
|
| 1311 |
+
print(tabulate(results, headers=headers))
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/ns_types.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import enum
|
| 2 |
+
from typing import NamedTuple
|
| 3 |
+
|
| 4 |
+
from torch.fx.graph import Node
|
| 5 |
+
|
| 6 |
+
from typing import Dict, Any, List, Union, Callable
|
| 7 |
+
|
| 8 |
+
class NSSingleResultValuesType(str, enum.Enum):
|
| 9 |
+
WEIGHT = 'weight'
|
| 10 |
+
NODE_OUTPUT = 'node_output'
|
| 11 |
+
NODE_INPUT = 'node_input'
|
| 12 |
+
|
| 13 |
+
class NSSubgraph(NamedTuple):
|
| 14 |
+
start_node: Node
|
| 15 |
+
end_node: Node
|
| 16 |
+
base_op_node: Node
|
| 17 |
+
|
| 18 |
+
# TODO(future PR): see if we can use typing_extensions's TypedDict instead
|
| 19 |
+
# to properly type the various keys
|
| 20 |
+
# {
|
| 21 |
+
# # one of NSSingleResultValuesType
|
| 22 |
+
# 'type': 'weight',
|
| 23 |
+
# # the values of type specified above
|
| 24 |
+
# 'values': [torch.tensor(...), ...],
|
| 25 |
+
# # name of the node directly before the logger
|
| 26 |
+
# 'prev_node_name': 'linear1',
|
| 27 |
+
# # type of the underlying function or module
|
| 28 |
+
# 'prev_node_target_type': torch.nn.functional.linear # or torch.nn.Linear, etc
|
| 29 |
+
# # name of the node responsible for adding this logger
|
| 30 |
+
# # Note: this may differ from prev_node_name if we are logging inputs
|
| 31 |
+
# 'ref_node_name': 'linear1',
|
| 32 |
+
# # index of this node within the arg of the input/output node
|
| 33 |
+
# # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
|
| 34 |
+
# 'index_within_arg': 0,
|
| 35 |
+
# # index of this node within the args of the input/output node
|
| 36 |
+
# # for example, in add(x1, x2), x2 would have index_of_arg == 1
|
| 37 |
+
# 'index_of_arg': 0,
|
| 38 |
+
# # precomputed comparisons of logger values to reference values
|
| 39 |
+
# 'comparisons': [torch.tensor(...), ...]
|
| 40 |
+
# # name of function used for precomputed comparisons
|
| 41 |
+
# 'comparison_fn_name': 'sqnr',
|
| 42 |
+
# # string representation of qconfig responsible for creating this logger
|
| 43 |
+
# 'qconfig_str': 'QConfig(...)',
|
| 44 |
+
# }
|
| 45 |
+
NSSingleResultType = Dict[str, Any]
|
| 46 |
+
|
| 47 |
+
# {
|
| 48 |
+
# 'layer_name_1': { # subgraph name
|
| 49 |
+
# 'node_output': { # results type (node_output, node_input, weight)
|
| 50 |
+
# 'model_name_a': # model name
|
| 51 |
+
# [NSSingleResultType, ...], # results, ordered by index_within_arg
|
| 52 |
+
# 'model_name_b':
|
| 53 |
+
# [NSSingleResultType, ...],
|
| 54 |
+
# },
|
| 55 |
+
# },
|
| 56 |
+
# }
|
| 57 |
+
#
|
| 58 |
+
NSResultsType = Dict[str, Dict[str, Dict[str, List[NSSingleResultType]]]]
|
| 59 |
+
|
| 60 |
+
# Defines the underlying target type of a node, for example:
|
| 61 |
+
# `F.conv1d` for a `call_function` conv node
|
| 62 |
+
# `nn.Conv1d` for a `call_module` node calling the forward of a `nn.Conv1d` module
|
| 63 |
+
# `'sigmoid'` for a `call_method` node calling `x.sigmoid()`
|
| 64 |
+
NSNodeTargetType = Union[Callable, str]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/pattern_utils.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
toq = torch.ops.quantized
|
| 5 |
+
|
| 6 |
+
from torch.fx import GraphModule
|
| 7 |
+
from torch.fx.graph import Node
|
| 8 |
+
|
| 9 |
+
from torch.ao.quantization.backend_config import get_native_backend_config
|
| 10 |
+
from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
|
| 11 |
+
from torch.ao.quantization.utils import getattr_from_fqn
|
| 12 |
+
from .ns_types import NSNodeTargetType
|
| 13 |
+
from torch.ao.quantization import (
|
| 14 |
+
ObserverBase,
|
| 15 |
+
FakeQuantizeBase,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
from typing import Dict, Tuple, Set, Callable, Any, Union, List
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_type_a_related_to_b(
|
| 22 |
+
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
|
| 23 |
+
) -> Set[Tuple[NSNodeTargetType, NSNodeTargetType]]:
|
| 24 |
+
# TODO(future PR): allow customizations
|
| 25 |
+
# TODO(future PR): reuse existing quantization mappings
|
| 26 |
+
# TODO(future PR): add the rest of modules and ops here
|
| 27 |
+
type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]] = set()
|
| 28 |
+
|
| 29 |
+
for s in base_name_to_sets_of_related_ops.values():
|
| 30 |
+
s_list = list(s)
|
| 31 |
+
# add every bidirectional pair
|
| 32 |
+
for idx_0 in range(0, len(s_list)):
|
| 33 |
+
for idx_1 in range(idx_0, len(s_list)):
|
| 34 |
+
type_a_related_to_b.add((s_list[idx_0], s_list[idx_1]))
|
| 35 |
+
type_a_related_to_b.add((s_list[idx_1], s_list[idx_0]))
|
| 36 |
+
|
| 37 |
+
return type_a_related_to_b
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
NSFusionElType = Union[
|
| 41 |
+
Callable, # call_function or call_module type, example: F.linear or nn.Conv2d
|
| 42 |
+
str, # call_method name, example: "dequantize"
|
| 43 |
+
Tuple[str, Any], # call_method name and first argument, example: ("to", torch.float16)
|
| 44 |
+
]
|
| 45 |
+
NSFusionType = Union[
|
| 46 |
+
Tuple[NSFusionElType, NSFusionElType],
|
| 47 |
+
Tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType],
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]:
|
| 51 |
+
"""
|
| 52 |
+
Set of potential fusions, in reverse order. The order is reversed
|
| 53 |
+
to match how fusion patterns are defined in quantization code.
|
| 54 |
+
|
| 55 |
+
Fusion format:
|
| 56 |
+
((fusion_op_0, fusion_op_1), base_op_idx)
|
| 57 |
+
|
| 58 |
+
Where base_op_idx is the idx of the op we should use to match other related
|
| 59 |
+
ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx
|
| 60 |
+
of 0 represents the first op in regular (non-reverse) order, 1 represents the
|
| 61 |
+
second op, etc.
|
| 62 |
+
"""
|
| 63 |
+
results: List[Tuple[NSFusionType, int]] = []
|
| 64 |
+
|
| 65 |
+
# Possible syntaxes:
|
| 66 |
+
# * single op: torch.nn.Conv2d
|
| 67 |
+
# * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d)
|
| 68 |
+
# For fusions, we only care about patterns composed of multiple ops.
|
| 69 |
+
# TODO(future PR): allow customizations from default patterns.
|
| 70 |
+
all_quant_patterns = _get_pattern_to_quantize_handlers(get_native_backend_config())
|
| 71 |
+
|
| 72 |
+
default_base_op_idx = 0
|
| 73 |
+
for quant_pattern in all_quant_patterns.keys():
|
| 74 |
+
# TODO: this is a temporary hack to flatten the patterns from quantization so
|
| 75 |
+
# that it works with the ns matcher function, maybe we should use `_is_match`
|
| 76 |
+
# in torch.ao.quantization.fx.match_utils to match the patterns
|
| 77 |
+
if isinstance(quant_pattern, tuple) and len(quant_pattern) == 2 and \
|
| 78 |
+
isinstance(quant_pattern[1], tuple) and len(quant_pattern[1]) == 2:
|
| 79 |
+
# flatten the pattern with form (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))
|
| 80 |
+
quant_pattern = (quant_pattern[0], quant_pattern[1][0], quant_pattern[1][1])
|
| 81 |
+
|
| 82 |
+
# Only patterns of multiple ops are fusions, ignore
|
| 83 |
+
# patterns which contain a single ops (they get matched
|
| 84 |
+
# without caring about fusions).
|
| 85 |
+
if isinstance(quant_pattern, tuple):
|
| 86 |
+
results.append((quant_pattern, default_base_op_idx)) # type: ignore[arg-type]
|
| 87 |
+
|
| 88 |
+
# For each pattern, add additional patterns with observers and
|
| 89 |
+
# fake quants at the end.
|
| 90 |
+
# TODO(future PR): if needed, implement matching for a node
|
| 91 |
+
# having multiple output observers.
|
| 92 |
+
for cls in (ObserverBase, FakeQuantizeBase):
|
| 93 |
+
if isinstance(quant_pattern, tuple):
|
| 94 |
+
new_pattern = (cls, *quant_pattern)
|
| 95 |
+
else:
|
| 96 |
+
new_pattern = (cls, quant_pattern)
|
| 97 |
+
results.append((new_pattern, default_base_op_idx)) # type: ignore[arg-type]
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# After this point, results contains values such as
|
| 101 |
+
# [..., ((torch.nn.Relu, torch.nn.Conv2d), 0), ...]
|
| 102 |
+
|
| 103 |
+
# Patterns for matching fp16 emulation are not specified in the quantization
|
| 104 |
+
# fusion mappings. For now, define them here.
|
| 105 |
+
fp16_em_base_op_idx = 1
|
| 106 |
+
patterns_to_add = [
|
| 107 |
+
# linear-relu fp16 emulation:
|
| 108 |
+
# fp16_to_fp32 -> linear -> relu -> fp32_to_fp16
|
| 109 |
+
((("to", torch.float16), F.relu, F.linear, "dequantize"), fp16_em_base_op_idx,),
|
| 110 |
+
# Conv-BN fusion (this happens outside of quantization patterns,
|
| 111 |
+
# which is why it is defined separately here).
|
| 112 |
+
((nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
|
| 113 |
+
((nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
|
| 114 |
+
((nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
|
| 115 |
+
((nn.ReLU, nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
|
| 116 |
+
((nn.ReLU, nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
|
| 117 |
+
((nn.ReLU, nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
|
| 118 |
+
]
|
| 119 |
+
for p in patterns_to_add:
|
| 120 |
+
results.append(p) # type: ignore[arg-type]
|
| 121 |
+
results.append(((ObserverBase, *p[0]), p[1])) # type: ignore[arg-type]
|
| 122 |
+
results.append(((FakeQuantizeBase, *p[0]), p[1])) # type: ignore[arg-type]
|
| 123 |
+
|
| 124 |
+
return results
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def end_node_matches_reversed_fusion(
|
| 128 |
+
end_node: Node,
|
| 129 |
+
reversed_fusion: NSFusionType,
|
| 130 |
+
gm: GraphModule,
|
| 131 |
+
seen_nodes: Set[Node],
|
| 132 |
+
) -> bool:
|
| 133 |
+
"""
|
| 134 |
+
Returns true if a pattern ending with `end_node` matches
|
| 135 |
+
the fusion pattern.
|
| 136 |
+
"""
|
| 137 |
+
cur_node = end_node
|
| 138 |
+
for fusion_idx in range(len(reversed_fusion)):
|
| 139 |
+
# each node can only belong to one matched pattern
|
| 140 |
+
if cur_node in seen_nodes:
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
cur_fusion_el = reversed_fusion[fusion_idx]
|
| 144 |
+
|
| 145 |
+
if cur_node.op == 'call_function':
|
| 146 |
+
fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and \
|
| 147 |
+
(not isinstance(cur_fusion_el, type))
|
| 148 |
+
if fusion_el_is_fun:
|
| 149 |
+
if cur_node.target != cur_fusion_el:
|
| 150 |
+
return False
|
| 151 |
+
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
|
| 152 |
+
cur_node = cur_node.args[0]
|
| 153 |
+
else:
|
| 154 |
+
return False
|
| 155 |
+
else:
|
| 156 |
+
return False
|
| 157 |
+
|
| 158 |
+
elif cur_node.op == 'call_module':
|
| 159 |
+
fusion_el_is_mod = isinstance(cur_fusion_el, type)
|
| 160 |
+
if fusion_el_is_mod:
|
| 161 |
+
assert isinstance(cur_node.target, str)
|
| 162 |
+
target_mod = getattr_from_fqn(gm, cur_node.target)
|
| 163 |
+
if not isinstance(cur_fusion_el, type):
|
| 164 |
+
return False
|
| 165 |
+
if not isinstance(target_mod, cur_fusion_el):
|
| 166 |
+
return False
|
| 167 |
+
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
|
| 168 |
+
cur_node = cur_node.args[0]
|
| 169 |
+
else:
|
| 170 |
+
return False
|
| 171 |
+
else:
|
| 172 |
+
return False
|
| 173 |
+
|
| 174 |
+
elif cur_node.op == 'call_method':
|
| 175 |
+
fusion_el_is_meth_with_second_arg = \
|
| 176 |
+
isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2
|
| 177 |
+
fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str)
|
| 178 |
+
if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg:
|
| 179 |
+
if fusion_el_is_meth_without_args:
|
| 180 |
+
if cur_node.target != cur_fusion_el:
|
| 181 |
+
return False
|
| 182 |
+
else:
|
| 183 |
+
assert isinstance(cur_fusion_el, tuple)
|
| 184 |
+
if cur_node.target != cur_fusion_el[0]:
|
| 185 |
+
return False
|
| 186 |
+
elif len(cur_node.args) < 2:
|
| 187 |
+
return False
|
| 188 |
+
elif cur_node.args[1] != cur_fusion_el[1]:
|
| 189 |
+
return False
|
| 190 |
+
|
| 191 |
+
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
|
| 192 |
+
cur_node = cur_node.args[0]
|
| 193 |
+
else:
|
| 194 |
+
return False
|
| 195 |
+
else:
|
| 196 |
+
return False
|
| 197 |
+
else:
|
| 198 |
+
return False
|
| 199 |
+
|
| 200 |
+
return True
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/qconfig_multi_mapping.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
from typing import Any, Callable, Dict, List, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.ao.quantization import QConfigMapping
|
| 8 |
+
from torch.ao.quantization.qconfig_mapping import _QCONFIG_STYLE_ORDER
|
| 9 |
+
from torch.ao.quantization.qconfig import QConfigAny
|
| 10 |
+
|
| 11 |
+
__all__ = ["QConfigMultiMapping"]
|
| 12 |
+
|
| 13 |
+
_QCONFIG_STYLE_TO_METHOD: Dict[str, str] = {
|
| 14 |
+
"global_qconfig": "set_global",
|
| 15 |
+
"object_type_qconfigs": "set_object_type",
|
| 16 |
+
"module_name_regex_qconfigs": "set_module_name_regex",
|
| 17 |
+
"module_name_qconfigs": "set_module_name",
|
| 18 |
+
"module_name_object_type_order_qconfigs": "set_module_name_object_type_order",
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
def _remove_duplicates_and_none(qconfig_list: List[QConfigAny]) -> None:
|
| 22 |
+
to_remove = []
|
| 23 |
+
for index, cur_qconfig in enumerate(qconfig_list):
|
| 24 |
+
if cur_qconfig is None:
|
| 25 |
+
to_remove.append(index)
|
| 26 |
+
break
|
| 27 |
+
for checked_qconfig in qconfig_list[:index]:
|
| 28 |
+
if torch.ao.quantization.qconfig_equals(cur_qconfig, checked_qconfig):
|
| 29 |
+
to_remove.append(index)
|
| 30 |
+
break
|
| 31 |
+
for index in to_remove[::-1]:
|
| 32 |
+
qconfig_list.pop(index)
|
| 33 |
+
|
| 34 |
+
class QConfigMultiMapping:
|
| 35 |
+
"""
|
| 36 |
+
This class, used with the prepare_n_shadows_model API, stores a list of :class:`torch.ao.quantization.QConfigMapping`s
|
| 37 |
+
so that multiple QConfigs can be specified for each QConfig matching style.
|
| 38 |
+
|
| 39 |
+
The user can specify QConfigs using the following methods (in increasing match priority):
|
| 40 |
+
|
| 41 |
+
``set_global`` : sets the global (default) QConfigs
|
| 42 |
+
|
| 43 |
+
``set_object_type`` : sets the QConfigs for a given module type, function, or method name
|
| 44 |
+
|
| 45 |
+
``set_module_name_regex`` : sets the QConfigs for modules matching the given regex string
|
| 46 |
+
|
| 47 |
+
``set_module_name`` : sets the QConfigs for modules matching the given module name
|
| 48 |
+
|
| 49 |
+
``set_module_name_object_type_order`` : sets the QConfigs for modules matching a combination
|
| 50 |
+
of the given module name, object type, and the index at which the module appears
|
| 51 |
+
|
| 52 |
+
Note: Usage of set methods is the same as in QConfigMapping except with a passed in list of QConfigs rather than a
|
| 53 |
+
single QConfig.
|
| 54 |
+
|
| 55 |
+
Example usage::
|
| 56 |
+
|
| 57 |
+
qconfig_mapping = QConfigMultiMapping()
|
| 58 |
+
.set_global([qconfig1, qconfig2])
|
| 59 |
+
.set_object_type(torch.nn.Linear, [qconfig2, qconfig3])
|
| 60 |
+
.set_object_type(torch.nn.ReLU, [qconfig1])
|
| 61 |
+
.set_module_name_regex("foo.*bar.*conv[0-9]+", [qconfig2])
|
| 62 |
+
.set_module_name_regex("foo.*", [qconfig1, qconfig2, qconfig3])
|
| 63 |
+
.set_module_name("module1", [None])
|
| 64 |
+
.set_module_name("module2", [qconfig2])
|
| 65 |
+
.set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, [qconfig3])
|
| 66 |
+
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(self):
|
| 70 |
+
# initialize this with 1 QConfigMapping to avoid corner cases
|
| 71 |
+
self.qconfig_mappings_list: List[QConfigMapping] = [QConfigMapping()]
|
| 72 |
+
|
| 73 |
+
def _handle_list_size_mismatch(
|
| 74 |
+
self, qconfig_list: List[QConfigAny], style: str
|
| 75 |
+
) -> None:
|
| 76 |
+
# this method handles cases where the size of qconfig_list does not match
|
| 77 |
+
# the size of qconfig_mappings_list.
|
| 78 |
+
# Issue: Consider a user inserting global_qconfig A and B first, then inserting
|
| 79 |
+
# qconfig C as an object_type_qconfig for conv ops. If we internally store
|
| 80 |
+
# 1 QConfigMapping with A and C and another with just B, then the
|
| 81 |
+
# second QConfigMapping will match B to conv ops (which is not wanted), since B is global.
|
| 82 |
+
|
| 83 |
+
# we avoid this by maintaining the invariant that if any QConfigMapping
|
| 84 |
+
# has a qconfig style+key with a qconfig in it, all QConfigMappings must
|
| 85 |
+
# have either a qconfig or None for that same style+key. In the above
|
| 86 |
+
# example, a None qconfig would prevent the unwanted match in the
|
| 87 |
+
# second QConfigMapping
|
| 88 |
+
|
| 89 |
+
if len(qconfig_list) > len(self.qconfig_mappings_list):
|
| 90 |
+
# Case: we have more qconfigs (in qconfig_list) than QConfigMappings
|
| 91 |
+
|
| 92 |
+
# Add new QConfigMappings (initialized so we maintain the `invariant`)
|
| 93 |
+
|
| 94 |
+
new_qconfig_mapping = QConfigMapping()
|
| 95 |
+
# searches other QConfigMappings for qconfig style+keys
|
| 96 |
+
# that need to be inserted as `None` into the new QConfigMapping
|
| 97 |
+
for qconfig_mapping in self.qconfig_mappings_list:
|
| 98 |
+
|
| 99 |
+
# global_qconfig has None by default
|
| 100 |
+
for check_style in _QCONFIG_STYLE_ORDER[1:]:
|
| 101 |
+
qconfigs_dict = getattr(qconfig_mapping, check_style)
|
| 102 |
+
target_qconfigs_dict = getattr(new_qconfig_mapping, check_style)
|
| 103 |
+
for key in qconfigs_dict:
|
| 104 |
+
target_qconfigs_dict[key] = None
|
| 105 |
+
break
|
| 106 |
+
|
| 107 |
+
# insert copies of this new QConfigMapping until all entires
|
| 108 |
+
# in qconfig_list can fit among the QConfigMappings
|
| 109 |
+
while len(qconfig_list) > len(self.qconfig_mappings_list):
|
| 110 |
+
self.qconfig_mappings_list.append(copy.deepcopy(new_qconfig_mapping))
|
| 111 |
+
else:
|
| 112 |
+
# Case: we have fewer qconfigs in qconfig_list than QConfigMappings
|
| 113 |
+
|
| 114 |
+
# pad qconfig_list with `None` until length is same
|
| 115 |
+
while len(qconfig_list) < len(self.qconfig_mappings_list):
|
| 116 |
+
qconfig_list.append(None)
|
| 117 |
+
|
| 118 |
+
# this function applies the insertion method across each QConfigMapping
|
| 119 |
+
def _insert_qconfig_list(
|
| 120 |
+
self,
|
| 121 |
+
style: str,
|
| 122 |
+
args: List[Union[str, int, Callable]],
|
| 123 |
+
qconfig_list: List[QConfigAny],
|
| 124 |
+
) -> None:
|
| 125 |
+
|
| 126 |
+
# we remove duplicates and None to make the ordering of qconfigs
|
| 127 |
+
# deterministic upon insertion.
|
| 128 |
+
_remove_duplicates_and_none(qconfig_list)
|
| 129 |
+
|
| 130 |
+
self._handle_list_size_mismatch(qconfig_list, style)
|
| 131 |
+
method_name = _QCONFIG_STYLE_TO_METHOD[style]
|
| 132 |
+
for qconfig_mapping, qconfig in zip(self.qconfig_mappings_list, qconfig_list):
|
| 133 |
+
# uses QConfigMapping set method to insert qconfig
|
| 134 |
+
set_method = getattr(qconfig_mapping, method_name)
|
| 135 |
+
set_method(*args, qconfig)
|
| 136 |
+
|
| 137 |
+
def set_global(self, global_qconfig_list: List[QConfigAny]) -> QConfigMultiMapping:
|
| 138 |
+
"""
|
| 139 |
+
Set global QConfigs
|
| 140 |
+
see :func:`~torch.ao.quantization.QConfigMapping.set_global()` for more info
|
| 141 |
+
"""
|
| 142 |
+
self._insert_qconfig_list("global_qconfig", [], global_qconfig_list)
|
| 143 |
+
return self
|
| 144 |
+
|
| 145 |
+
def set_object_type(
|
| 146 |
+
self, object_type: Union[Callable, str], qconfig_list: List[QConfigAny]
|
| 147 |
+
) -> QConfigMultiMapping:
|
| 148 |
+
"""
|
| 149 |
+
Set object type QConfigs
|
| 150 |
+
see :func:`~torch.ao.quantization.QConfigMapping.set_object_type()` for more info
|
| 151 |
+
"""
|
| 152 |
+
self._insert_qconfig_list("object_type_qconfigs", [object_type], qconfig_list)
|
| 153 |
+
return self
|
| 154 |
+
|
| 155 |
+
def set_module_name_regex(
|
| 156 |
+
self, module_name_regex: str, qconfig_list: List[QConfigAny]
|
| 157 |
+
) -> QConfigMultiMapping:
|
| 158 |
+
"""
|
| 159 |
+
Set module_name_regex QConfigs
|
| 160 |
+
see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_regex()` for more info
|
| 161 |
+
"""
|
| 162 |
+
self._insert_qconfig_list(
|
| 163 |
+
"module_name_regex_qconfigs", [module_name_regex], qconfig_list
|
| 164 |
+
)
|
| 165 |
+
return self
|
| 166 |
+
|
| 167 |
+
def set_module_name(
|
| 168 |
+
self, module_name: str, qconfig_list: List[QConfigAny]
|
| 169 |
+
) -> QConfigMultiMapping:
|
| 170 |
+
"""
|
| 171 |
+
Set module_name QConfigs
|
| 172 |
+
see :func:`~torch.ao.quantization.QConfigMapping.set_module_name()` for more info
|
| 173 |
+
"""
|
| 174 |
+
self._insert_qconfig_list("module_name_qconfigs", [module_name], qconfig_list)
|
| 175 |
+
return self
|
| 176 |
+
|
| 177 |
+
def set_module_name_object_type_order(
|
| 178 |
+
self,
|
| 179 |
+
module_name: str,
|
| 180 |
+
object_type: Callable,
|
| 181 |
+
index: int,
|
| 182 |
+
qconfig_list: List[QConfigAny],
|
| 183 |
+
) -> QConfigMultiMapping:
|
| 184 |
+
"""
|
| 185 |
+
Set module_name QConfigs
|
| 186 |
+
see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_object_type_order()` for more info
|
| 187 |
+
"""
|
| 188 |
+
self._insert_qconfig_list(
|
| 189 |
+
"module_name_object_type_order_qconfigs",
|
| 190 |
+
[module_name, object_type, index],
|
| 191 |
+
qconfig_list,
|
| 192 |
+
)
|
| 193 |
+
return self
|
| 194 |
+
|
| 195 |
+
def __repr__(self):
|
| 196 |
+
return (
|
| 197 |
+
self.__class__.__name__ +
|
| 198 |
+
" [" +
|
| 199 |
+
"".join(f"\n{qconfig_mapping.__repr__()}," for qconfig_mapping in self.qconfig_mappings_list) +
|
| 200 |
+
"\n]"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
@classmethod
|
| 204 |
+
def from_list_qconfig_mapping(
|
| 205 |
+
cls, qconfig_mapping_list: List[QConfigMapping]
|
| 206 |
+
) -> QConfigMultiMapping:
|
| 207 |
+
"""
|
| 208 |
+
Creates a QConfigMultiMapping from a list of QConfigMappings
|
| 209 |
+
"""
|
| 210 |
+
new_qconfig_multi_mapping = cls()
|
| 211 |
+
|
| 212 |
+
new_qconfig_multi_mapping.qconfig_mappings_list = copy.deepcopy(
|
| 213 |
+
qconfig_mapping_list
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# we need to avoid the issue described in _handle_list_size_mismatch,
|
| 217 |
+
# so we reinsert all the qconfigs using the QConfigMultiMapping
|
| 218 |
+
# set methods
|
| 219 |
+
|
| 220 |
+
# go through all qconfig styles
|
| 221 |
+
# note: global can be ignored since it is None by default
|
| 222 |
+
for style in _QCONFIG_STYLE_ORDER[1:]:
|
| 223 |
+
|
| 224 |
+
# gather all key+qconfigs for current style
|
| 225 |
+
# into qconfig_dict_list
|
| 226 |
+
qconfig_dict_list: Dict[Any, List[QConfigAny]] = {}
|
| 227 |
+
for qconfig_mapping in qconfig_mapping_list:
|
| 228 |
+
qconfig_dict = getattr(qconfig_mapping, style)
|
| 229 |
+
for key, qconfig in qconfig_dict.items():
|
| 230 |
+
if key not in qconfig_dict_list:
|
| 231 |
+
qconfig_dict_list[key] = []
|
| 232 |
+
qconfig_dict_list[key].append(qconfig)
|
| 233 |
+
|
| 234 |
+
# reinsert all gathered key+qconfigs
|
| 235 |
+
set_method_name = _QCONFIG_STYLE_TO_METHOD[style]
|
| 236 |
+
set_method = getattr(new_qconfig_multi_mapping, set_method_name)
|
| 237 |
+
for key, qconfig_list in qconfig_dict_list.items():
|
| 238 |
+
if isinstance(key, tuple):
|
| 239 |
+
set_method(*key, qconfig_list)
|
| 240 |
+
else:
|
| 241 |
+
set_method(key, qconfig_list)
|
| 242 |
+
|
| 243 |
+
return new_qconfig_multi_mapping
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_learnable_fake_quantize.cpython-311.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-311.pyc
ADDED
|
Binary file (7.83 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/observer.cpython-311.pyc
ADDED
|
Binary file (75 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/qconfig_mapping.cpython-311.pyc
ADDED
|
Binary file (16 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-311.pyc
ADDED
|
Binary file (31.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_jit.cpython-311.pyc
ADDED
|
Binary file (18.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (227 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (350 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/rewrite.py
ADDED
|
@@ -0,0 +1,600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.fx import GraphModule
|
| 3 |
+
from ..export_utils import _WrapperModule
|
| 4 |
+
from ..utils import (
|
| 5 |
+
get_aten_graph_module,
|
| 6 |
+
remove_tensor_overload_for_qdq_ops,
|
| 7 |
+
_replace_literals_with_new_placeholders,
|
| 8 |
+
_replace_literals_with_existing_placeholders,
|
| 9 |
+
)
|
| 10 |
+
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
| 11 |
+
from torch.fx.subgraph_rewriter import replace_pattern
|
| 12 |
+
from torch._higher_order_ops.out_dtype import out_dtype
|
| 13 |
+
from typing import Optional, Callable, Tuple, Any
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
|
| 16 |
+
from functools import partial
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"reference_representation_rewrite",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
|
| 24 |
+
torch.randint(-128, 127, (2, 5), dtype=torch.int8),
|
| 25 |
+
torch.randn(1, dtype=torch.float),
|
| 26 |
+
torch.zeros(1, dtype=torch.int),
|
| 27 |
+
torch.tensor([-128], dtype=torch.int),
|
| 28 |
+
torch.tensor([127], dtype=torch.int),
|
| 29 |
+
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
| 30 |
+
torch.randn(1, dtype=torch.float),
|
| 31 |
+
torch.zeros(1, dtype=torch.int),
|
| 32 |
+
torch.tensor([-127], dtype=torch.int),
|
| 33 |
+
torch.tensor([127], dtype=torch.int),
|
| 34 |
+
torch.randn(1, dtype=torch.float),
|
| 35 |
+
torch.randn(1, dtype=torch.float),
|
| 36 |
+
torch.zeros(1, dtype=torch.int),
|
| 37 |
+
torch.tensor([-128], dtype=torch.int),
|
| 38 |
+
torch.tensor([127], dtype=torch.int),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def _qdq_quantized_linear(
|
| 42 |
+
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
|
| 43 |
+
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
|
| 44 |
+
bias_fp32,
|
| 45 |
+
out_scale, out_zero_point, out_quant_min, out_quant_max
|
| 46 |
+
):
|
| 47 |
+
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
| 48 |
+
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
|
| 49 |
+
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
| 50 |
+
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
|
| 51 |
+
out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
|
| 52 |
+
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
|
| 53 |
+
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
|
| 54 |
+
return out_i8
|
| 55 |
+
|
| 56 |
+
def _reference_quantized_linear(
|
| 57 |
+
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
|
| 58 |
+
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
|
| 59 |
+
bias_fp32,
|
| 60 |
+
out_scale, out_zero_point, out_quant_min, out_quant_max
|
| 61 |
+
):
|
| 62 |
+
# without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
|
| 63 |
+
# This results in failure to match the pattern.
|
| 64 |
+
# Therefore, we call a torch.ops.aten.clamp here
|
| 65 |
+
x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max)
|
| 66 |
+
weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
|
| 67 |
+
|
| 68 |
+
x_i16 = x_i8.to(torch.int16)
|
| 69 |
+
weight_i16 = weight_i8.to(torch.int16)
|
| 70 |
+
# always set bias to None so that the same representation can work for the case
|
| 71 |
+
# no matter if bias_scale == x_scale * weight_scale or not
|
| 72 |
+
acc_i32 = out_dtype(
|
| 73 |
+
torch.ops.aten.linear.default,
|
| 74 |
+
torch.int32,
|
| 75 |
+
x_i16 - x_zero_point,
|
| 76 |
+
weight_i16 - weight_zero_point,
|
| 77 |
+
None)
|
| 78 |
+
# TODO: change to mul.Scalar
|
| 79 |
+
# Note: we are quantizing bias with these scales without signal from user, but it might be OK
|
| 80 |
+
bias_scale = x_scale * weight_scale
|
| 81 |
+
bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
|
| 82 |
+
acc_i32 = acc_i32 + bias_i32
|
| 83 |
+
# TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values
|
| 84 |
+
acc_i32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, acc_i32, x_scale * weight_scale / out_scale) + out_zero_point
|
| 85 |
+
out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8)
|
| 86 |
+
return out_i8
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
|
| 90 |
+
torch.randn((2, 5), dtype=torch.float),
|
| 91 |
+
-128,
|
| 92 |
+
127,
|
| 93 |
+
torch.finfo(torch.float32).eps,
|
| 94 |
+
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
| 95 |
+
torch.randn(1, dtype=torch.float),
|
| 96 |
+
torch.zeros(1, dtype=torch.int),
|
| 97 |
+
torch.tensor([-127], dtype=torch.int),
|
| 98 |
+
torch.tensor([127], dtype=torch.int),
|
| 99 |
+
torch.randn(1, dtype=torch.float),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _qdq_dynamic_quantized_linear(
|
| 104 |
+
x_fp32, x_quant_min, x_quant_max, x_eps,
|
| 105 |
+
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
|
| 106 |
+
bias_fp32,
|
| 107 |
+
):
|
| 108 |
+
x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8)
|
| 109 |
+
x_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
|
| 110 |
+
x_fp32, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
|
| 111 |
+
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
| 112 |
+
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
|
| 113 |
+
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
| 114 |
+
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
|
| 115 |
+
out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
|
| 116 |
+
return out_fp32
|
| 117 |
+
|
| 118 |
+
def _reference_dynamic_quantized_linear(
|
| 119 |
+
x_fp32, x_quant_min, x_quant_max, x_eps,
|
| 120 |
+
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
|
| 121 |
+
bias_fp32,
|
| 122 |
+
):
|
| 123 |
+
x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8)
|
| 124 |
+
# decomposed representation for quantize_per_tensor
|
| 125 |
+
# TODO: use out_dtype(mul, ...) here when the op is ready
|
| 126 |
+
x_fp32 = x_fp32 / x_scale # fp32
|
| 127 |
+
# round modes might be different here
|
| 128 |
+
# pytorch is rounding to even, which is also common for most of the backends
|
| 129 |
+
x_fp32 = torch.round(x_fp32) # fp32
|
| 130 |
+
x_i32 = x_fp32.to(dtype=torch.int32) # int32
|
| 131 |
+
x_i32 = x_i32 + x_zero_point # int32
|
| 132 |
+
# clamp works for fp32, int32 and int8 dtypes
|
| 133 |
+
x_i32 = torch.clamp(x_i32, x_quant_min, x_quant_max) # int32
|
| 134 |
+
x_i8 = x_i32.to(dtype=torch.int8)
|
| 135 |
+
|
| 136 |
+
weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
|
| 137 |
+
|
| 138 |
+
x_i16 = x_i8.to(torch.int16)
|
| 139 |
+
weight_i16 = weight_i8.to(torch.int16)
|
| 140 |
+
# always set bias to None so that the same representation can work for the case
|
| 141 |
+
# no matter if bias_scale == x_scale * weight_scale or not
|
| 142 |
+
acc_i32 = out_dtype(
|
| 143 |
+
torch.ops.aten.linear.default,
|
| 144 |
+
torch.int32,
|
| 145 |
+
x_i16 - x_zero_point,
|
| 146 |
+
weight_i16 - weight_zero_point,
|
| 147 |
+
None)
|
| 148 |
+
bias_scale = x_scale * weight_scale
|
| 149 |
+
bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
|
| 150 |
+
acc_i32 = acc_i32 + bias_i32
|
| 151 |
+
out_fp32 = acc_i32 * (x_scale * weight_scale)
|
| 152 |
+
return out_fp32
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
_QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
|
| 156 |
+
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
| 157 |
+
torch.randn(1, dtype=torch.float),
|
| 158 |
+
torch.zeros(1, dtype=torch.int),
|
| 159 |
+
torch.tensor([-128], dtype=torch.int),
|
| 160 |
+
torch.tensor([127], dtype=torch.int),
|
| 161 |
+
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
| 162 |
+
torch.randn(1, dtype=torch.float),
|
| 163 |
+
torch.zeros(1, dtype=torch.int),
|
| 164 |
+
torch.tensor([-127], dtype=torch.int),
|
| 165 |
+
torch.tensor([127], dtype=torch.int),
|
| 166 |
+
torch.randn(1, dtype=torch.float),
|
| 167 |
+
torch.randn(1, dtype=torch.float),
|
| 168 |
+
torch.zeros(1, dtype=torch.int),
|
| 169 |
+
torch.tensor([-128], dtype=torch.int),
|
| 170 |
+
torch.tensor([127], dtype=torch.int),
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def _qdq_quantized_conv2d(
|
| 174 |
+
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
|
| 175 |
+
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
|
| 176 |
+
bias_fp32,
|
| 177 |
+
out_scale, out_zero_point, out_quant_min, out_quant_max
|
| 178 |
+
):
|
| 179 |
+
stride = [1, 1]
|
| 180 |
+
padding = [0, 0]
|
| 181 |
+
dilation = [1, 1]
|
| 182 |
+
transposed = False
|
| 183 |
+
output_padding = [0, 0]
|
| 184 |
+
groups = 1
|
| 185 |
+
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
| 186 |
+
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
|
| 187 |
+
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
| 188 |
+
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
|
| 189 |
+
out_fp32 = torch.ops.aten.convolution.default(
|
| 190 |
+
x_fp32, weight_fp32, bias_fp32, stride, padding, dilation, transposed, output_padding, groups)
|
| 191 |
+
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
|
| 192 |
+
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
|
| 193 |
+
return out_i8
|
| 194 |
+
|
| 195 |
+
def _reference_quantized_conv2d(
|
| 196 |
+
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
|
| 197 |
+
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
|
| 198 |
+
bias_fp32,
|
| 199 |
+
out_scale, out_zero_point, out_quant_min, out_quant_max
|
| 200 |
+
):
|
| 201 |
+
stride = [1, 1]
|
| 202 |
+
padding = [0, 0]
|
| 203 |
+
dilation = [1, 1]
|
| 204 |
+
transposed = False
|
| 205 |
+
output_padding = [0, 0]
|
| 206 |
+
groups = 1
|
| 207 |
+
# without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
|
| 208 |
+
# This results in failure to match the pattern.
|
| 209 |
+
# Therefore, we call a torch.ops.aten.clamp here
|
| 210 |
+
x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max)
|
| 211 |
+
weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
|
| 212 |
+
|
| 213 |
+
x_i16 = x_i8.to(torch.int16)
|
| 214 |
+
weight_i16 = weight_i8.to(torch.int16)
|
| 215 |
+
# always set bias to None so that the same representation can work for the case
|
| 216 |
+
# no matter if bias_scale == x_scale * weight_scale or not
|
| 217 |
+
acc_i32 = out_dtype(
|
| 218 |
+
torch.ops.aten.convolution.default,
|
| 219 |
+
torch.int32,
|
| 220 |
+
x_i16 - x_zero_point,
|
| 221 |
+
weight_i16 - weight_zero_point,
|
| 222 |
+
None, stride, padding, dilation, transposed, output_padding, groups)
|
| 223 |
+
# Note: we are quantizing bias with these scales without signal from user, but it might be OK
|
| 224 |
+
bias_scale = x_scale * weight_scale
|
| 225 |
+
# bias quantization to int32 uses bias_scale = x_scale * weight_scale due to:
|
| 226 |
+
# Take linear calculation for example
|
| 227 |
+
# Out_(i, j)_fp32 = Sum_(over k)[X_(i, k)_fp32 * W_(i, k)_fp32] + bias_(i)_fp32
|
| 228 |
+
# Represent X, W fp32 as their dequant transforms
|
| 229 |
+
# A_fp32 = (A_q - A_zero_point)/A_scale
|
| 230 |
+
# Out_(i, j)_fp32 = Sum_(over k)[(X_(i, k)_fp32 - X_zp) * X_scale * (W_(i, k)_fp32 - W_zp) * W_scale] + bias_(i)_fp32
|
| 231 |
+
# Factor out X_scale and W_scale
|
| 232 |
+
# Out_(i, j)_fp32 = ((X_scale * W_scale) * Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)]) + bias_(i)_fp32
|
| 233 |
+
# In order to addition of bias_(i)_fp32 inside, we must do
|
| 234 |
+
# Out_(i, j)_fp32 = (X_scale * W_scale) * (Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)] + (1 / (X_scale * W_scale)) * bias_(i)_fp32)W_scale # noqa: B950
|
| 235 |
+
# Note we had to multiply bias_fp32 qith X_scale * W_scale = bias_scale
|
| 236 |
+
# Thus bias quantization to int32 must be with X_scale * W_scale
|
| 237 |
+
|
| 238 |
+
bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
|
| 239 |
+
# Unsqueeze to match broadcast dims
|
| 240 |
+
# Unfortnuately I cannot do bias_i32.unsqueeze(0) due to literal matching nightmare
|
| 241 |
+
# in graph pattern replacement
|
| 242 |
+
bias_i32 = bias_i32.unsqueeze(-1)
|
| 243 |
+
bias_i32 = bias_i32.unsqueeze(-1)
|
| 244 |
+
acc_i32 = acc_i32 + bias_i32
|
| 245 |
+
# TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values
|
| 246 |
+
acc_i32 = out_dtype(
|
| 247 |
+
torch.ops.aten.mul.Tensor, torch.int32, acc_i32, x_scale * weight_scale / out_scale) + out_zero_point
|
| 248 |
+
out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8)
|
| 249 |
+
return out_i8
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = (
|
| 253 |
+
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
| 254 |
+
torch.randn(1, dtype=torch.float),
|
| 255 |
+
torch.zeros(1, dtype=torch.int),
|
| 256 |
+
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
| 257 |
+
torch.randn(1, dtype=torch.float),
|
| 258 |
+
torch.zeros(1, dtype=torch.int),
|
| 259 |
+
torch.randn(1, dtype=torch.float),
|
| 260 |
+
torch.zeros(1, dtype=torch.int),
|
| 261 |
+
torch.tensor([-128], dtype=torch.int),
|
| 262 |
+
torch.tensor([127], dtype=torch.int),
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
def _qdq_quantized_add_relu(
|
| 266 |
+
x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point,
|
| 267 |
+
out_scale, out_zero_point, quant_min, quant_max
|
| 268 |
+
):
|
| 269 |
+
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8)
|
| 270 |
+
y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8)
|
| 271 |
+
out_fp32 = x_fp32 + y_fp32
|
| 272 |
+
out_fp32 = torch.ops.aten.relu(out_fp32)
|
| 273 |
+
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
|
| 274 |
+
out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8
|
| 275 |
+
)
|
| 276 |
+
return out_i8
|
| 277 |
+
|
| 278 |
+
def _reference_quantized_add_relu(
|
| 279 |
+
x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point,
|
| 280 |
+
out_scale, out_zero_point, quant_min, quant_max
|
| 281 |
+
):
|
| 282 |
+
"""
|
| 283 |
+
See comments for `_reference_quantized_add` for more information on
|
| 284 |
+
how to derive the formula for out_i8 based on x_i8 and y_i8
|
| 285 |
+
"""
|
| 286 |
+
x_i32 = x_i8.to(torch.int32)
|
| 287 |
+
y_i32 = y_i8.to(torch.int32)
|
| 288 |
+
# TODO: change this to mul.Scalar?
|
| 289 |
+
x_i32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, (x_i32 - x_zero_point), (x_scale / out_scale))
|
| 290 |
+
y_i32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, (y_i32 - y_zero_point), (y_scale / out_scale))
|
| 291 |
+
out_i32 = x_i32 + y_i32 + out_zero_point
|
| 292 |
+
# out_i32 = torch.ops.aten.clamp(out_i32, out_zero_point)
|
| 293 |
+
out_i8 = torch.ops.aten.clamp(out_i32, out_zero_point, quant_max).to(torch.int8)
|
| 294 |
+
return out_i8
|
| 295 |
+
|
| 296 |
+
def _qdq_quantized_add(x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point, out_scale, out_zero_point, quant_min, quant_max):
|
| 297 |
+
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8)
|
| 298 |
+
y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8)
|
| 299 |
+
out_fp32 = x_fp32 + y_fp32
|
| 300 |
+
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
|
| 301 |
+
out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8
|
| 302 |
+
)
|
| 303 |
+
return out_i8
|
| 304 |
+
|
| 305 |
+
def _reference_quantized_add(
|
| 306 |
+
x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point,
|
| 307 |
+
out_scale, out_zero_point, quant_min, quant_max
|
| 308 |
+
):
|
| 309 |
+
"""
|
| 310 |
+
# How to Derive the formula for out_i8 based on x_i8 and y_i8
|
| 311 |
+
# (since quantized add takes x_i8, y_i8 and their quantization parameters, and produce an out_i8)
|
| 312 |
+
|
| 313 |
+
# out_i8 is quantized output, we can write down the formula for it first:
|
| 314 |
+
out_i8 = out_f32 / out_scale + out_zero_point (1)
|
| 315 |
+
|
| 316 |
+
# then out_fp32 is computed from x_f32 + y_f32, and the x_fp32 and y_fp32 are the dequantized x_i8 and y_i8
|
| 317 |
+
out_f32 = x_f32 + y_f32 (2)
|
| 318 |
+
x_fp32 = (x_i8 - x_zero_point) * x_scale (3)
|
| 319 |
+
y_fp32 = (y_i8 - y_zero_point) * y_scale (4)
|
| 320 |
+
|
| 321 |
+
# applying the above fomula to the out_i8 equation we can get the following:
|
| 322 |
+
out_i8 = out_fp32 / out_scale + out_zero_point # (1)
|
| 323 |
+
= (x_f32 + y_f32) / out_scale + out_zero_point # applying (2) to substitute out_fp32 with x_fp32 + y_fp32
|
| 324 |
+
= ((x_i8 - x_zero_point) * x_scale + (y_i8 - y_zero_point) * y_scale) / out_scale + out_zero_point # apply (3) and (4)
|
| 325 |
+
"""
|
| 326 |
+
x_i32 = x_i8.to(torch.int32)
|
| 327 |
+
y_i32 = y_i8.to(torch.int32)
|
| 328 |
+
# TODO: use out_dtype op
|
| 329 |
+
x_i32 = torch.round((x_scale / out_scale) * (x_i32 - x_zero_point)).to(torch.int32)
|
| 330 |
+
y_i32 = torch.round((y_scale / out_scale) * (y_i32 - y_zero_point)).to(torch.int32)
|
| 331 |
+
out_i32 = x_i32 + y_i32 + out_zero_point
|
| 332 |
+
quant_min = -128
|
| 333 |
+
quant_max = 127
|
| 334 |
+
out_i8 = torch.ops.aten.clamp(out_i32, quant_min, quant_max).to(torch.int8)
|
| 335 |
+
return out_i8
|
| 336 |
+
|
| 337 |
+
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = (
|
| 338 |
+
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
| 339 |
+
torch.randn(1, dtype=torch.float),
|
| 340 |
+
torch.zeros(1, dtype=torch.int),
|
| 341 |
+
torch.tensor([-128], dtype=torch.int),
|
| 342 |
+
torch.tensor([127], dtype=torch.int),
|
| 343 |
+
torch.randn(1, dtype=torch.float),
|
| 344 |
+
torch.zeros(1, dtype=torch.int),
|
| 345 |
+
torch.tensor([-128], dtype=torch.int),
|
| 346 |
+
torch.tensor([127], dtype=torch.int),
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
def _qdq_quantized_max_pool2d(
|
| 350 |
+
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, out_scale, out_zero_point, out_quant_min, out_quant_max):
|
| 351 |
+
kernel_size = 1
|
| 352 |
+
stride = 1
|
| 353 |
+
padding = 0
|
| 354 |
+
dilation = 1
|
| 355 |
+
ceil_mode = False
|
| 356 |
+
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
|
| 357 |
+
out_fp32, _ = torch.ops.aten.max_pool2d_with_indices.default(x_fp32, kernel_size, stride, padding, dilation, ceil_mode)
|
| 358 |
+
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
|
| 359 |
+
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
|
| 360 |
+
return out_i8
|
| 361 |
+
|
| 362 |
+
def _reference_quantized_max_pool2d(
|
| 363 |
+
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, out_scale, out_zero_point, out_quant_min, out_quant_max):
|
| 364 |
+
kernel_size = 1
|
| 365 |
+
stride = 1
|
| 366 |
+
padding = 0
|
| 367 |
+
dilation = 1
|
| 368 |
+
ceil_mode = False
|
| 369 |
+
# to preserve x_quant_min, x_quant_max in the graph for pattern matching
|
| 370 |
+
x_i8 = torch.clamp(x_i8, x_quant_min, x_quant_max)
|
| 371 |
+
x_i32 = x_i8.to(torch.int32)
|
| 372 |
+
out_i32, _ = torch.ops.aten.max_pool2d_with_indices.default(
|
| 373 |
+
x_i32 - x_zero_point,
|
| 374 |
+
kernel_size,
|
| 375 |
+
stride,
|
| 376 |
+
padding,
|
| 377 |
+
dilation,
|
| 378 |
+
ceil_mode
|
| 379 |
+
)
|
| 380 |
+
out_fp32 = out_i32 * (x_scale / out_scale) + out_zero_point
|
| 381 |
+
out_fp32 = torch.clamp(out_fp32, out_quant_min, out_quant_max)
|
| 382 |
+
out_i8 = out_fp32.to(torch.int8)
|
| 383 |
+
return out_i8
|
| 384 |
+
|
| 385 |
+
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
|
| 386 |
+
torch.randn(1, 3, 3, 3, dtype=torch.float),
|
| 387 |
+
torch.randn(1, dtype=torch.float),
|
| 388 |
+
torch.zeros(1, dtype=torch.int),
|
| 389 |
+
torch.tensor([-128], dtype=torch.int),
|
| 390 |
+
torch.tensor([127], dtype=torch.int),
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max):
|
| 394 |
+
x = torch.ops.quantized_decomposed.quantize_per_tensor(x_fp32, scale, zero_point, quant_min, quant_max, torch.int8)
|
| 395 |
+
return x
|
| 396 |
+
|
| 397 |
+
def _reference_quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max):
|
| 398 |
+
# TODO: use out_dtype(mul, ...) here when the op is ready
|
| 399 |
+
x = x_fp32 / scale # fp32
|
| 400 |
+
# round modes might be different here
|
| 401 |
+
# pytorch is rounding to even, which is also common for most of the backends
|
| 402 |
+
x = torch.round(x) # fp32
|
| 403 |
+
x = x.to(dtype=torch.int32) # int32
|
| 404 |
+
x = x + zero_point # int32
|
| 405 |
+
# clamp works for fp32, int32 and int8 dtypes
|
| 406 |
+
x = torch.clamp(x, quant_min, quant_max) # int32
|
| 407 |
+
x = x.to(dtype=torch.int8)
|
| 408 |
+
return x
|
| 409 |
+
|
| 410 |
+
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
|
| 411 |
+
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
| 412 |
+
torch.randn(1, dtype=torch.float),
|
| 413 |
+
torch.zeros(1, dtype=torch.int),
|
| 414 |
+
torch.tensor([-128], dtype=torch.int),
|
| 415 |
+
torch.tensor([127], dtype=torch.int),
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max):
|
| 419 |
+
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max, torch.int8)
|
| 420 |
+
return x_fp32
|
| 421 |
+
|
| 422 |
+
def _reference_dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max):
|
| 423 |
+
# without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
|
| 424 |
+
# This results in failure to match the pattern.
|
| 425 |
+
# Therefore, we call a torch.ops.aten.clamp here
|
| 426 |
+
x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max)
|
| 427 |
+
# TODO: use out_dtype op
|
| 428 |
+
# note: x_i8.to(torch.int32) does not work here
|
| 429 |
+
# TODO: debug the implementation later when torchdynamo time out issue is resolved
|
| 430 |
+
return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
|
| 431 |
+
|
| 432 |
+
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
|
| 433 |
+
torch.randn(1, 3, 3, 3, dtype=torch.float),
|
| 434 |
+
torch.randn(3, dtype=torch.float),
|
| 435 |
+
torch.zeros(3, dtype=torch.int),
|
| 436 |
+
1,
|
| 437 |
+
-128,
|
| 438 |
+
127,
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
def _quantize_per_channel_int8(x_fp32, scales, zero_points, ch_axis, quant_min, quant_max):
|
| 442 |
+
out_i8 = torch.ops.quantized_decomposed.quantize_per_channel(
|
| 443 |
+
x_fp32, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8
|
| 444 |
+
)
|
| 445 |
+
return out_i8
|
| 446 |
+
|
| 447 |
+
def _reference_quantize_per_channel_int8(x_fp32, scales, zero_points, ch_axis, quant_min, quant_max):
|
| 448 |
+
x_fp32 = torch.transpose(x_fp32, ch_axis, -1)
|
| 449 |
+
out_i32 = torch.ops.aten.clamp(torch.round(x_fp32 / scales).to(torch.int32) + zero_points, quant_min, quant_max)
|
| 450 |
+
out_i32 = torch.transpose(out_i32, ch_axis, -1)
|
| 451 |
+
return out_i32.to(torch.int8)
|
| 452 |
+
|
| 453 |
+
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
|
| 454 |
+
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
| 455 |
+
torch.randn(3, dtype=torch.float),
|
| 456 |
+
torch.zeros(3, dtype=torch.int),
|
| 457 |
+
1,
|
| 458 |
+
-128,
|
| 459 |
+
127,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
def _dequantize_per_channel_int8(x_i8, scales, zero_points, ch_axis, quant_min, quant_max):
|
| 463 |
+
# the following will be replaced as placeholders
|
| 464 |
+
out_fp32 = torch.ops.quantized_decomposed.dequantize_per_channel(
|
| 465 |
+
x_i8, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8
|
| 466 |
+
)
|
| 467 |
+
return out_fp32
|
| 468 |
+
|
| 469 |
+
def _reference_dequantize_per_channel_int8(x_i8, scales, zero_points, ch_axis, quant_min, quant_max):
|
| 470 |
+
# the following will be replaced as placeholders
|
| 471 |
+
# in order to preserve the quant_min/quant_max args for pattern matching (e.g. matching for int4 quantized ops)
|
| 472 |
+
# we call a torch.ops.aten.clamp here
|
| 473 |
+
x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max)
|
| 474 |
+
x_i8 = torch.transpose(x_i8, ch_axis, -1)
|
| 475 |
+
x_i32 = x_i8.to(torch.int32)
|
| 476 |
+
out_fp32 = (x_i32 - zero_points).to(torch.float) * scales
|
| 477 |
+
out_fp32 = torch.transpose(out_fp32, ch_axis, -1)
|
| 478 |
+
return out_fp32
|
| 479 |
+
|
| 480 |
+
def _replace_ph_qdq_per_channel_replacement(gm: torch.fx.GraphModule):
|
| 481 |
+
return _replace_literals_with_existing_placeholders(
|
| 482 |
+
gm,
|
| 483 |
+
exclude_literals=[-1],
|
| 484 |
+
literal_to_ph_idx={1: 3, -128: 4, 127: 5}
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
@dataclass
|
| 489 |
+
class _RewriteInfo:
|
| 490 |
+
"""Data needed for rewrite, this includes example inputs, pattern and replacement functions
|
| 491 |
+
and post transformation functions for the exported pattern and replacement GraphModule
|
| 492 |
+
"""
|
| 493 |
+
|
| 494 |
+
# example inputs used for exporting the pattern into GraphModule
|
| 495 |
+
example_inputs: Tuple[Any, ...]
|
| 496 |
+
pattern: Callable
|
| 497 |
+
replacement: Callable
|
| 498 |
+
# post transformation on the exported pattern and replacement GraphModule
|
| 499 |
+
pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
|
| 500 |
+
replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
|
| 501 |
+
|
| 502 |
+
_REWRITE_INFO_LIST = [
|
| 503 |
+
_RewriteInfo(
|
| 504 |
+
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
|
| 505 |
+
_WrapperModule(_qdq_dynamic_quantized_linear),
|
| 506 |
+
_WrapperModule(_reference_dynamic_quantized_linear),
|
| 507 |
+
partial(
|
| 508 |
+
_replace_literals_with_existing_placeholders,
|
| 509 |
+
literal_to_ph_idx={
|
| 510 |
+
-128: 1,
|
| 511 |
+
127: 2,
|
| 512 |
+
torch.finfo(torch.float32).eps: 3
|
| 513 |
+
}
|
| 514 |
+
),
|
| 515 |
+
partial(
|
| 516 |
+
_replace_literals_with_existing_placeholders,
|
| 517 |
+
literal_to_ph_idx={
|
| 518 |
+
-128: 1,
|
| 519 |
+
127: 2,
|
| 520 |
+
torch.finfo(torch.float32).eps: 3
|
| 521 |
+
}
|
| 522 |
+
),
|
| 523 |
+
),
|
| 524 |
+
_RewriteInfo(
|
| 525 |
+
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
|
| 526 |
+
_WrapperModule(_qdq_quantized_linear),
|
| 527 |
+
_WrapperModule(_reference_quantized_linear),
|
| 528 |
+
_replace_literals_with_new_placeholders,
|
| 529 |
+
_replace_literals_with_new_placeholders,
|
| 530 |
+
),
|
| 531 |
+
_RewriteInfo(
|
| 532 |
+
_QUANTIZED_CONV2d_EXAMPLE_INPUTS,
|
| 533 |
+
_WrapperModule(_qdq_quantized_conv2d),
|
| 534 |
+
_WrapperModule(_reference_quantized_conv2d),
|
| 535 |
+
partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
|
| 536 |
+
partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
|
| 537 |
+
),
|
| 538 |
+
_RewriteInfo(
|
| 539 |
+
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
|
| 540 |
+
_WrapperModule(_qdq_quantized_add_relu),
|
| 541 |
+
_WrapperModule(_reference_quantized_add_relu),
|
| 542 |
+
),
|
| 543 |
+
_RewriteInfo(
|
| 544 |
+
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
|
| 545 |
+
_WrapperModule(_qdq_quantized_add),
|
| 546 |
+
_WrapperModule(_reference_quantized_add),
|
| 547 |
+
),
|
| 548 |
+
_RewriteInfo(
|
| 549 |
+
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS,
|
| 550 |
+
_WrapperModule(_qdq_quantized_max_pool2d),
|
| 551 |
+
_WrapperModule(_reference_quantized_max_pool2d),
|
| 552 |
+
_replace_literals_with_new_placeholders,
|
| 553 |
+
_replace_literals_with_new_placeholders
|
| 554 |
+
),
|
| 555 |
+
_RewriteInfo(
|
| 556 |
+
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
|
| 557 |
+
_WrapperModule(_quantize_per_tensor_int8),
|
| 558 |
+
_WrapperModule(_reference_quantize_per_tensor_int8),
|
| 559 |
+
),
|
| 560 |
+
_RewriteInfo(
|
| 561 |
+
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
|
| 562 |
+
_WrapperModule(_dequantize_per_tensor_int8),
|
| 563 |
+
_WrapperModule(_reference_dequantize_per_tensor_int8),
|
| 564 |
+
),
|
| 565 |
+
_RewriteInfo(
|
| 566 |
+
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
|
| 567 |
+
_WrapperModule(_quantize_per_channel_int8),
|
| 568 |
+
_WrapperModule(_reference_quantize_per_channel_int8),
|
| 569 |
+
_replace_ph_qdq_per_channel_replacement,
|
| 570 |
+
_replace_ph_qdq_per_channel_replacement
|
| 571 |
+
),
|
| 572 |
+
_RewriteInfo(
|
| 573 |
+
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
|
| 574 |
+
_WrapperModule(_dequantize_per_channel_int8),
|
| 575 |
+
_WrapperModule(_reference_dequantize_per_channel_int8),
|
| 576 |
+
_replace_ph_qdq_per_channel_replacement,
|
| 577 |
+
_replace_ph_qdq_per_channel_replacement
|
| 578 |
+
),
|
| 579 |
+
]
|
| 580 |
+
|
| 581 |
+
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
| 582 |
+
remove_tensor_overload_for_qdq_ops(model)
|
| 583 |
+
for rewrite_info in _REWRITE_INFO_LIST:
|
| 584 |
+
example_inputs = rewrite_info.example_inputs
|
| 585 |
+
pattern = rewrite_info.pattern
|
| 586 |
+
replacement = rewrite_info.replacement
|
| 587 |
+
pattern_post_trans = rewrite_info.pattern_post_trans
|
| 588 |
+
replacement_post_trans = rewrite_info.replacement_post_trans
|
| 589 |
+
pattern = get_aten_graph_module(pattern, example_inputs) # type: ignore[arg-type, assignment]
|
| 590 |
+
remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type]
|
| 591 |
+
replacement = get_aten_graph_module(replacement, example_inputs) # type: ignore[arg-type, assignment]
|
| 592 |
+
remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type]
|
| 593 |
+
if pattern_post_trans:
|
| 594 |
+
pattern = pattern_post_trans(pattern)
|
| 595 |
+
if replacement_post_trans:
|
| 596 |
+
replacement = replacement_post_trans(replacement)
|
| 597 |
+
pattern.recompile() # type: ignore[attr-defined]
|
| 598 |
+
replacement.recompile() # type: ignore[attr-defined]
|
| 599 |
+
matches = replace_pattern(model, pattern, replacement)
|
| 600 |
+
return model
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (214 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/_tensorboard_vis.cpython-311.pyc
ADDED
|
Binary file (9.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/config.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Whether to disable showing progress on compilation passes
|
| 2 |
+
# Need to add a new config otherwise wil get a circular import if dynamo config is imported here
|
| 3 |
+
disable_progress = True
|
| 4 |
+
|
| 5 |
+
# If True this also shows the node names in each pass, for small models this is great but larger models it's quite noisy
|
| 6 |
+
verbose_progress = False
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-311.pyc
ADDED
|
Binary file (1.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-311.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-311.pyc
ADDED
|
Binary file (62.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/recording.cpython-311.pyc
ADDED
|
Binary file (16.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/validator.cpython-311.pyc
ADDED
|
Binary file (39.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/recording.py
ADDED
|
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import itertools
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.utils._pytree as pytree
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"ShapeEnvEvent",
|
| 12 |
+
"record_shapeenv_event",
|
| 13 |
+
"replay_shape_env_events",
|
| 14 |
+
"FakeTensorMeta",
|
| 15 |
+
"shape_env_check_state_equal",
|
| 16 |
+
"NotEqualError",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
# [Note: Recording ShapeEnv Events]
|
| 20 |
+
# =================================
|
| 21 |
+
#
|
| 22 |
+
# What is a ShapeEnv event?
|
| 23 |
+
# -------------------------
|
| 24 |
+
# We consider a ShapeEnv event every function call (ShapeEnv method or
|
| 25 |
+
# independent function) that modifies the state of the ShapeEnv instance.
|
| 26 |
+
# Such calls are recorded alongside their positional and keyword arguments,
|
| 27 |
+
# so that it may be replayed over a different ShapeEnv instance.
|
| 28 |
+
#
|
| 29 |
+
# See [Note: ShapeEnv State Equality] for what is considered the state
|
| 30 |
+
# of a ShapeEnv instance.
|
| 31 |
+
#
|
| 32 |
+
# What is it for?
|
| 33 |
+
# ---------------
|
| 34 |
+
# ShapeEnv events recording is used for reconstructing the ShapeEnv in an
|
| 35 |
+
# arbitrary state in time.
|
| 36 |
+
#
|
| 37 |
+
# Being able to arbitrarily replay events like so is useful, mainly for
|
| 38 |
+
# translation validation bisection. i.e. if a ValidationException has been
|
| 39 |
+
# raised, find the earliest point in time where the translation validation
|
| 40 |
+
# fails.
|
| 41 |
+
#
|
| 42 |
+
# Besides that, it also allows us to inspect the given instance and,
|
| 43 |
+
# for example, check the guards that would actually be issued at that point.
|
| 44 |
+
#
|
| 45 |
+
# What kind of arguments can be stored in an event?
|
| 46 |
+
# -------------------------------------------------
|
| 47 |
+
# There's no specific rule for what cannot be used as an argument.
|
| 48 |
+
# That said, pay special attention to the following cases:
|
| 49 |
+
#
|
| 50 |
+
# 1. Tensor inputs: there are some tests that check whether the inputs
|
| 51 |
+
# were garbage collected after execution. These will fail if there's
|
| 52 |
+
# an event that is holding a reference to those inputs.
|
| 53 |
+
#
|
| 54 |
+
# 2. ShapeEnv arguments: if there is an argument of ShapeEnv type, that
|
| 55 |
+
# will be automatically replaced by the new given ShapeEnv instance.
|
| 56 |
+
#
|
| 57 |
+
# 3. SymTypes arguments: they also hold references to ShapeEnv. So,
|
| 58 |
+
# whenever we see them, we create a new instance, replacing the
|
| 59 |
+
# ShapeEnv reference.
|
| 60 |
+
#
|
| 61 |
+
# 4. FX nodes: specifically, FX nodes from the FX graph for symbolic
|
| 62 |
+
# shapes. That argument must be replaced when replaying the event at
|
| 63 |
+
# ShapeEnvEvent.run, since it has to reference a node from the given
|
| 64 |
+
# instance, and not from the recorded instance.
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# Event class for reconstructing ShapeEnv at arbitrary time.
|
| 68 |
+
#
|
| 69 |
+
# Represents a method call that mutates ShapeEnv in a way that affects the
|
| 70 |
+
# issued guards, when ShapeEnv.produce_guards is called.
|
| 71 |
+
@dataclass
|
| 72 |
+
class ShapeEnvEvent:
|
| 73 |
+
# ShapeEnv method.
|
| 74 |
+
f: Callable
|
| 75 |
+
|
| 76 |
+
# Arguments and keyword arguments called with.
|
| 77 |
+
args: Optional[List[Any]] = None
|
| 78 |
+
kwargs: Optional[Dict[str, Any]] = None
|
| 79 |
+
|
| 80 |
+
# List of tracked_fakes at the time the method was called.
|
| 81 |
+
tracked_fakes: Optional[List[Any]] = None
|
| 82 |
+
|
| 83 |
+
# Name of the captured event.
|
| 84 |
+
# Used for special handling of particular methods.
|
| 85 |
+
name: Optional[str] = None
|
| 86 |
+
|
| 87 |
+
# Replay itself, but using shape_env as self.
|
| 88 |
+
def run(self, shape_env=None) -> Any:
|
| 89 |
+
from torch.fx.experimental.symbolic_shapes import (
|
| 90 |
+
is_symbolic,
|
| 91 |
+
ShapeEnv,
|
| 92 |
+
SymTypes,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Special handling for the constructor event.
|
| 96 |
+
if self.f is ShapeEnv:
|
| 97 |
+
assert shape_env is None and self.args is None and self.kwargs is not None
|
| 98 |
+
return ShapeEnv(**self.kwargs)
|
| 99 |
+
|
| 100 |
+
assert shape_env is not None
|
| 101 |
+
args = list(self.args or list())
|
| 102 |
+
kwargs = dict(self.kwargs or dict())
|
| 103 |
+
|
| 104 |
+
# Replace any argument of type ShapeEnv by the given one.
|
| 105 |
+
args, kwargs = pytree.tree_map_only(
|
| 106 |
+
ShapeEnv, lambda _: shape_env, (args, kwargs)
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Replace any argument of type SymTypes by a new instance,
|
| 110 |
+
# replacing its ShapeEnv reference.
|
| 111 |
+
args, kwargs = pytree.tree_map_only(
|
| 112 |
+
lambda x: isinstance(x, SymTypes) and is_symbolic(x),
|
| 113 |
+
lambda a: type(a)(a.node.with_shape_env(shape_env)),
|
| 114 |
+
(args, kwargs),
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Converts FX nodes using the mapping argument.
|
| 118 |
+
def maybe_convert_node(x: Any) -> Any:
|
| 119 |
+
if not isinstance(x, torch.fx.Node):
|
| 120 |
+
# Don't do anything to x if it's not an FX node.
|
| 121 |
+
return x
|
| 122 |
+
|
| 123 |
+
# If, at some point, we created an FX node, it means that translation validation is on.
|
| 124 |
+
# It also means we are building an FX graph for symbolic shapes at shape_env.graph, and
|
| 125 |
+
# we are tracking node names at shape_env.name_to_node.
|
| 126 |
+
assert hasattr(shape_env, "name_to_node")
|
| 127 |
+
name_to_node = shape_env.name_to_node # type: ignore[attr-defined]
|
| 128 |
+
assert x.name in name_to_node
|
| 129 |
+
return name_to_node[x.name]
|
| 130 |
+
|
| 131 |
+
# Replaces the value of an specific argument by the result of fn.
|
| 132 |
+
def replacearg(index: int, key: str, fn: Callable):
|
| 133 |
+
if index < len(args):
|
| 134 |
+
args[index] = fn(args[index])
|
| 135 |
+
if key in kwargs:
|
| 136 |
+
kwargs[key] = fn(kwargs[key])
|
| 137 |
+
|
| 138 |
+
if self.is_create_fx_call_function():
|
| 139 |
+
# ShapeEnv.create_fx_call_function:
|
| 140 |
+
# "args" parameter is a tuple of FX nodes from the FX graph of the old ShapeEnv.
|
| 141 |
+
# They must be replaced, since a "call_function" FX node with this tuple as argument
|
| 142 |
+
# will be added to the FX graph of the new shape_env.
|
| 143 |
+
replacearg(
|
| 144 |
+
index=2,
|
| 145 |
+
key="args",
|
| 146 |
+
fn=lambda args: tuple(maybe_convert_node(a) for a in args),
|
| 147 |
+
)
|
| 148 |
+
if self.is_evaluate_expr() or self.is_defer_runtime_assert():
|
| 149 |
+
# ShapeEnv.evaluate_expr and ShapeEnv.defer_runtime_assert:
|
| 150 |
+
# "fx_node" parameter is an (optional) FX node that represents the evaluate expression.
|
| 151 |
+
# They must be replaced, since it will be part of a "call_function" FX node for
|
| 152 |
+
# torch._assert, which will be added to the FX graph of the new shape_env.
|
| 153 |
+
replacearg(index=3, key="fx_node", fn=maybe_convert_node)
|
| 154 |
+
|
| 155 |
+
# Actually call the method with the converted arguments.
|
| 156 |
+
return self.f(*args, **kwargs)
|
| 157 |
+
|
| 158 |
+
def __str__(self) -> str:
|
| 159 |
+
name = self.name if self.name is not None else self.f.__name__
|
| 160 |
+
return f"event: {name} ({self.args}, {self.kwargs})"
|
| 161 |
+
|
| 162 |
+
def is_create_fx_call_function(self) -> bool:
|
| 163 |
+
return self.name == "_create_fx_call_function"
|
| 164 |
+
|
| 165 |
+
def is_evaluate_expr(self) -> bool:
|
| 166 |
+
return self.name == "evaluate_expr"
|
| 167 |
+
|
| 168 |
+
def is_defer_runtime_assert(self) -> bool:
|
| 169 |
+
return self.name == "defer_runtime_assert"
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# Extracts a ShapeEnv instance inside args and kwargs.
|
| 173 |
+
# Specifically, it looks for:
|
| 174 |
+
# 1. ShapeEnv arguments
|
| 175 |
+
# 2. SymInt, SymFloat, or SymBool arguments
|
| 176 |
+
# If we find more than one object of any of the above types, we
|
| 177 |
+
# also check that the ShapeEnv instance is the same for all of them.
|
| 178 |
+
def _extract_shape_env_and_assert_equal(args, kwargs):
|
| 179 |
+
from torch.fx.experimental.symbolic_shapes import is_symbolic, ShapeEnv, SymTypes
|
| 180 |
+
|
| 181 |
+
def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv:
|
| 182 |
+
if old is not None:
|
| 183 |
+
assert old is new, "call with different ShapeEnv"
|
| 184 |
+
return new
|
| 185 |
+
|
| 186 |
+
shape_env = None
|
| 187 |
+
for val in itertools.chain(args, kwargs.values()):
|
| 188 |
+
if isinstance(val, ShapeEnv):
|
| 189 |
+
shape_env = assert_equal(shape_env, val)
|
| 190 |
+
if isinstance(val, SymTypes) and is_symbolic(val):
|
| 191 |
+
shape_env = assert_equal(shape_env, val.node.shape_env)
|
| 192 |
+
|
| 193 |
+
return shape_env
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# Decorator for recording the given function as a replayable event.
|
| 197 |
+
#
|
| 198 |
+
# This decorator should be used at every function that mutates the state of
|
| 199 |
+
# ShapeEnv in some way that affects the resulting issued guards (i.e. when
|
| 200 |
+
# ShapeEnv.produce_guards is called).
|
| 201 |
+
#
|
| 202 |
+
# save_tracked_fakes: saves a snapshot of the TrackedFake list.
|
| 203 |
+
# This is used when calling ShapeEnv.produce_guards at arbitrary points in time.
|
| 204 |
+
#
|
| 205 |
+
# When to save the list of TrackedFake?
|
| 206 |
+
# =====================================
|
| 207 |
+
# We should save the list of TrackedFake whenever the translation validation
|
| 208 |
+
# bisection may actually stop and call the produce_guards method at the moment
|
| 209 |
+
# right after the recorded function was played. In other words, since the
|
| 210 |
+
# bisection bisects through torch._assert calls, we should save in all methods
|
| 211 |
+
# that adds a torch._assert call to the symbolic shapes FX graph.
|
| 212 |
+
#
|
| 213 |
+
# At the moment, there are 2 methods that save the list:
|
| 214 |
+
# - ShapeEnv.evaluate_expr
|
| 215 |
+
# - ShapeEnv.defer_runtime_assert
|
| 216 |
+
def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
|
| 217 |
+
def decorator(fn: Callable) -> Callable:
|
| 218 |
+
assert callable(fn)
|
| 219 |
+
name = fn.__name__
|
| 220 |
+
|
| 221 |
+
@functools.wraps(fn)
|
| 222 |
+
def wrapper(*args, **kwargs):
|
| 223 |
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
| 224 |
+
|
| 225 |
+
if isinstance(args[0], ShapeEnv) and args[0].is_recording: # type: ignore[has-type]
|
| 226 |
+
# If ShapeEnv is already recording an event, call the wrapped
|
| 227 |
+
# function directly.
|
| 228 |
+
#
|
| 229 |
+
# NB: here, we skip the check of whether all ShapeEnv instances
|
| 230 |
+
# are equal, in favor of a faster dispatch.
|
| 231 |
+
return fn(*args, **kwargs)
|
| 232 |
+
|
| 233 |
+
# Retrieve an instance of ShapeEnv.
|
| 234 |
+
# Assumption: the collection of args and kwargs may not reference
|
| 235 |
+
# different ShapeEnv instances.
|
| 236 |
+
self = _extract_shape_env_and_assert_equal(args, kwargs)
|
| 237 |
+
|
| 238 |
+
# If we are calling this function without any ShapeEnv instance
|
| 239 |
+
# alive in its arguments, we don't record and call the original.
|
| 240 |
+
if self is None:
|
| 241 |
+
return fn(*args, **kwargs)
|
| 242 |
+
|
| 243 |
+
# Otherwise, start recording and call the function.
|
| 244 |
+
with self._recording():
|
| 245 |
+
# Take a snapshot of the current tracked_fakes.
|
| 246 |
+
tracked_fakes = (
|
| 247 |
+
self._snapshot_tracked_fakes() if save_tracked_fakes else None
|
| 248 |
+
)
|
| 249 |
+
# Record the event for 'fn'.
|
| 250 |
+
event = ShapeEnvEvent(
|
| 251 |
+
fn, list(args), kwargs, tracked_fakes, name=fn.__name__
|
| 252 |
+
)
|
| 253 |
+
self.events.append(event)
|
| 254 |
+
# Play the event on this ShapeEnv.
|
| 255 |
+
return event.run(self)
|
| 256 |
+
|
| 257 |
+
return wrapper
|
| 258 |
+
|
| 259 |
+
return decorator
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
# Replays the ShapeEnvEvents list.
|
| 263 |
+
# It assumes the first event is the constructor call.
|
| 264 |
+
#
|
| 265 |
+
# fn: transforms an old FX node into one corresponding to the newly created ShapeEnv.
|
| 266 |
+
def replay_shape_env_events(events):
|
| 267 |
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
| 268 |
+
|
| 269 |
+
constructor_event = events[0]
|
| 270 |
+
assert constructor_event.f == ShapeEnv
|
| 271 |
+
|
| 272 |
+
# Constructs the new ShapeEnv.
|
| 273 |
+
shape_env = constructor_event.run()
|
| 274 |
+
|
| 275 |
+
for event in events[1:]:
|
| 276 |
+
try:
|
| 277 |
+
# Actually replays each event.
|
| 278 |
+
# We need to call create_mapping_fn every time, since the node list might
|
| 279 |
+
# change after each event is replayed.
|
| 280 |
+
event.run(shape_env)
|
| 281 |
+
except Exception as e:
|
| 282 |
+
raise RuntimeError(f"failed when running event: {event}") from e
|
| 283 |
+
|
| 284 |
+
return shape_env
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
# FakeTensor metadata.
|
| 288 |
+
# This is to be used in place of FakeTensor placeholders when calling
|
| 289 |
+
# ShapeEnv.produce_guards.
|
| 290 |
+
@dataclass
|
| 291 |
+
class FakeTensorMeta:
|
| 292 |
+
tensor_size: Tuple[Union[int, torch.SymInt], ...]
|
| 293 |
+
tensor_stride: Tuple[Union[int, torch.SymInt], ...]
|
| 294 |
+
tensor_storage_offset: Union[int, torch.SymInt]
|
| 295 |
+
is_nested: bool
|
| 296 |
+
|
| 297 |
+
def size(self) -> Tuple[Union[int, torch.SymInt], ...]:
|
| 298 |
+
return self.tensor_size
|
| 299 |
+
|
| 300 |
+
def stride(self) -> Tuple[Union[int, torch.SymInt], ...]:
|
| 301 |
+
return self.tensor_stride
|
| 302 |
+
|
| 303 |
+
def storage_offset(self) -> Union[int, torch.SymInt]:
|
| 304 |
+
return self.tensor_storage_offset
|
| 305 |
+
|
| 306 |
+
def dim(self) -> int:
|
| 307 |
+
return len(self.tensor_size)
|
| 308 |
+
|
| 309 |
+
@staticmethod
|
| 310 |
+
def from_fake(fake) -> "FakeTensorMeta":
|
| 311 |
+
return FakeTensorMeta(
|
| 312 |
+
fake.size(), fake.stride(), fake.storage_offset(), fake.is_nested
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
# [Note: ShapeEnv State Equality]
|
| 317 |
+
# ===============================
|
| 318 |
+
#
|
| 319 |
+
# What is considered ShapeEnv state?
|
| 320 |
+
# ----------------------------------
|
| 321 |
+
# We consider to be the state of a ShapeEnv instance everything that
|
| 322 |
+
# is not in the inline tuple inside remove_nonstate_variables function.
|
| 323 |
+
# That is: the fields within ShapeEnv that modify the flow of execution
|
| 324 |
+
# of the program.
|
| 325 |
+
#
|
| 326 |
+
# So, for example: the replacements field might influence on how an
|
| 327 |
+
# expression is simplified. That, in turn, may result in a guard being
|
| 328 |
+
# statically known (i.e. not added).
|
| 329 |
+
#
|
| 330 |
+
# On the other hand, var_to_stack serves only changes what is printed
|
| 331 |
+
# in the screen, i.e. used only for debugging purposes. Therefore, we
|
| 332 |
+
# should not consider it when comparing states.
|
| 333 |
+
#
|
| 334 |
+
# What to do on NotEqualError?
|
| 335 |
+
# ----------------------------
|
| 336 |
+
# Here are a few possible causes for getting a NotEqualError raised:
|
| 337 |
+
#
|
| 338 |
+
# 1. New field that does not belong in the ShapeEnv state.
|
| 339 |
+
# For example: log field of type ShapeEnvLoggerAdapter. Different
|
| 340 |
+
# ShapeEnv instances will always have different ShapeEnvLoggerAdapter
|
| 341 |
+
# instances, i.e. equality comparison would fail.
|
| 342 |
+
# Solution: add it to the inlined tuple inside remove_nonstate_variables
|
| 343 |
+
# function inside check_equal method.
|
| 344 |
+
#
|
| 345 |
+
# 2. New field that is not directly comparable across instances.
|
| 346 |
+
# For example: guards field of type List[ShapeGuard]. More specifically,
|
| 347 |
+
# the ShapeGuard type holds an expression and a stack information
|
| 348 |
+
# for debugging purposes. When replaying the even on a new ShapeEnv
|
| 349 |
+
# instance, the stack would be different, which would trigger this error.
|
| 350 |
+
# Solution: add a special case to the map_value function inside
|
| 351 |
+
# check_equal function.
|
| 352 |
+
#
|
| 353 |
+
# 3. Mutation of ShapeEnv on some not recorded function.
|
| 354 |
+
# If a mutation of the state of ShapeEnv happens inside a function
|
| 355 |
+
# that is not recorded (or that no caller in the stack is recorded),
|
| 356 |
+
# then, the replayed ShapeEnv won't catch that.
|
| 357 |
+
# Solution: decorate the function with record_shape_env_event.
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
# Checks whether the state of two ShapeEnv are equal w.r.t. the guards
|
| 361 |
+
# returned by ShapeEnv.produce_guards.
|
| 362 |
+
def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value):
|
| 363 |
+
# Collect and remove variables that don't necessarily represent the state
|
| 364 |
+
# of a ShapeEnv. Note: we copy the dictionary so that we don't modify the
|
| 365 |
+
# instance itself.
|
| 366 |
+
env1_vars = vars(env1).copy()
|
| 367 |
+
env2_vars = vars(env2).copy()
|
| 368 |
+
|
| 369 |
+
for v in non_state_variable_names:
|
| 370 |
+
if v in env1_vars:
|
| 371 |
+
env1_vars.pop(v)
|
| 372 |
+
if v in env2_vars:
|
| 373 |
+
env2_vars.pop(v)
|
| 374 |
+
|
| 375 |
+
# Function for transforming the mismatched values into string.
|
| 376 |
+
# Needed, since dict and set entries order might not be the same every time.
|
| 377 |
+
def value_to_str(value: Any) -> str:
|
| 378 |
+
if isinstance(value, dict):
|
| 379 |
+
return (
|
| 380 |
+
"{"
|
| 381 |
+
+ ", ".join(f"{k}: {value[k]}" for k in sorted(value.keys(), key=str))
|
| 382 |
+
+ "}"
|
| 383 |
+
)
|
| 384 |
+
if isinstance(value, set):
|
| 385 |
+
return "{" + ", ".join(f"{v}" for v in sorted(value)) + "}"
|
| 386 |
+
return str(value)
|
| 387 |
+
|
| 388 |
+
# Compares env1_vars with env2_vars.
|
| 389 |
+
# Here, we allow the value of each field to be mapped, so that we appropriately
|
| 390 |
+
# compare the two values.
|
| 391 |
+
def compare_vars(
|
| 392 |
+
map_value: Callable[[str, Any], Any]
|
| 393 |
+
) -> List[Tuple[str, str, str]]:
|
| 394 |
+
env1_set, env2_set = set(env1_vars), set(env2_vars)
|
| 395 |
+
|
| 396 |
+
# First, compare the set of keys in each vars dictionary.
|
| 397 |
+
if env1_set != env2_set:
|
| 398 |
+
raise NotEqualError(
|
| 399 |
+
"field set mismatch:",
|
| 400 |
+
[
|
| 401 |
+
(
|
| 402 |
+
"found unique fields:",
|
| 403 |
+
str(sorted(env1_set - env2_set)),
|
| 404 |
+
str(sorted(env2_set - env1_set)),
|
| 405 |
+
),
|
| 406 |
+
],
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
# Then, sort the keys, and compare the mapped values of each key.
|
| 410 |
+
sorted_keys = list(env1_set)
|
| 411 |
+
sorted_keys.sort()
|
| 412 |
+
|
| 413 |
+
mapped_dict = [
|
| 414 |
+
(k, map_value(k, env1_vars[k]), map_value(k, env2_vars[k]))
|
| 415 |
+
for k in sorted_keys
|
| 416 |
+
]
|
| 417 |
+
|
| 418 |
+
# Return a list of tuples representing the fields that did not match
|
| 419 |
+
# alongside their respective mapped values.
|
| 420 |
+
return [
|
| 421 |
+
(f"{k}: values don't match.", value_to_str(val1), value_to_str(val2))
|
| 422 |
+
for k, val1, val2 in mapped_dict
|
| 423 |
+
if val1 != val2
|
| 424 |
+
]
|
| 425 |
+
|
| 426 |
+
# Accumulate the mismatching fields.
|
| 427 |
+
errors = compare_vars(map_value)
|
| 428 |
+
|
| 429 |
+
if len(errors) > 0:
|
| 430 |
+
raise NotEqualError("field values don't match:", errors)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
class NotEqualError(Exception):
|
| 434 |
+
def __init__(
|
| 435 |
+
self,
|
| 436 |
+
msg: str,
|
| 437 |
+
mismatched: List[Tuple[str, str, str]],
|
| 438 |
+
) -> None:
|
| 439 |
+
details = "\n".join(
|
| 440 |
+
[
|
| 441 |
+
"\n".join(
|
| 442 |
+
[
|
| 443 |
+
f"==> {inner_msg}",
|
| 444 |
+
f" > Left: {str1}",
|
| 445 |
+
f" > Right: {str2}",
|
| 446 |
+
]
|
| 447 |
+
)
|
| 448 |
+
for inner_msg, str1, str2 in mismatched
|
| 449 |
+
]
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
super().__init__(
|
| 453 |
+
f"""\
|
| 454 |
+
ShapeEnv not equal: {msg}
|
| 455 |
+
|
| 456 |
+
{details}
|
| 457 |
+
"""
|
| 458 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/refinement_types.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class Equality:
|
| 2 |
+
def __init__(self, lhs, rhs):
|
| 3 |
+
self.lhs = lhs
|
| 4 |
+
self.rhs = rhs
|
| 5 |
+
|
| 6 |
+
def __str__(self):
|
| 7 |
+
return f'{self.lhs} = {self.rhs}'
|
| 8 |
+
|
| 9 |
+
def __repr__(self):
|
| 10 |
+
return f'{self.lhs} = {self.rhs}'
|
| 11 |
+
|
| 12 |
+
def __eq__(self, other):
|
| 13 |
+
if isinstance(other, Equality):
|
| 14 |
+
return self.lhs == other.lhs and self.rhs == other.rhs
|
| 15 |
+
else:
|
| 16 |
+
return False
|