koichi12 commited on
Commit
76cb23d
·
verified ·
1 Parent(s): 65e568a

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/exported_program.cpython-311.pyc +0 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/pass_base.cpython-311.pyc +0 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/utils.cpython-311.pyc +0 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/verifier.cpython-311.pyc +0 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/wrappers.cpython-311.pyc +0 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/case.cpython-311.pyc +0 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-311.pyc +0 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/__init__.cpython-311.pyc +0 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-311.pyc +0 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-311.pyc +0 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-311.pyc +0 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-311.pyc +0 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py +231 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py +94 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_sym_size_ops_pass.py +18 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py +71 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__init__.py +0 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema.cpython-311.pyc +0 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/union.cpython-311.pyc +0 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/upgrade.cpython-311.pyc +0 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/schema.yaml +389 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/upgrade.py +201 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/wrappers.py +401 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__init__.py +0 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/_numeric_suite_fx.py +1025 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__init__.py +0 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/mappings.py +761 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/n_shadows_utils.py +1311 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/ns_types.py +64 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/pattern_utils.py +200 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/qconfig_multi_mapping.py +243 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_learnable_fake_quantize.cpython-311.pyc +0 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-311.pyc +0 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/observer.cpython-311.pyc +0 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/qconfig_mapping.cpython-311.pyc +0 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-311.pyc +0 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_jit.cpython-311.pyc +0 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-311.pyc +0 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-311.pyc +0 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/rewrite.py +600 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/__init__.cpython-311.pyc +0 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/_tensorboard_vis.cpython-311.pyc +0 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/config.py +6 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-311.pyc +0 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-311.pyc +0 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-311.pyc +0 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/recording.cpython-311.pyc +0 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/validator.cpython-311.pyc +0 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/recording.py +458 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/refinement_types.py +16 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/exported_program.cpython-311.pyc ADDED
Binary file (1.57 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/pass_base.cpython-311.pyc ADDED
Binary file (27.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/utils.cpython-311.pyc ADDED
Binary file (20.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/verifier.cpython-311.pyc ADDED
Binary file (23.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/wrappers.cpython-311.pyc ADDED
Binary file (7.31 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/case.cpython-311.pyc ADDED
Binary file (9.12 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-311.pyc ADDED
Binary file (1.52 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (333 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-311.pyc ADDED
Binary file (11.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-311.pyc ADDED
Binary file (12.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-311.pyc ADDED
Binary file (1.65 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-311.pyc ADDED
Binary file (7.57 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import operator
3
+ import traceback
4
+ from functools import partial
5
+ from typing import Callable, Dict, List, NamedTuple, Set
6
+
7
+ import sympy
8
+
9
+ import torch
10
+ import torch.fx
11
+ from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, ProxyValue, PassResult
12
+ from torch.utils._sympy.value_ranges import ValueRanges
13
+ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
14
+
15
+
16
+ __all__ = ["InputDim"]
17
+
18
+
19
+ class InputDim(NamedTuple):
20
+ input_name: str
21
+ dim: int
22
+
23
+
24
+ def _convert_to_int(val):
25
+ # Convert simple sympy Integers into concrete int
26
+ if val == sympy.oo:
27
+ return math.inf
28
+ if val == -sympy.oo:
29
+ return -math.inf
30
+ if isinstance(val, sympy.Integer):
31
+ return int(val)
32
+ raise RuntimeError(
33
+ "Export constraints cannot be non-integer expressions"
34
+ )
35
+
36
+
37
+ def _convert_range_to_int(range: ValueRanges):
38
+ assert isinstance(range, ValueRanges)
39
+ min_val = _convert_to_int(range.lower)
40
+ max_val = _convert_to_int(range.upper)
41
+ return min_val, max_val
42
+
43
+
44
+ class _AddRuntimeAssertionsForInlineConstraintsPass(_ExportPassBaseDeprecatedDoNotUse):
45
+ def __init__(
46
+ self,
47
+ range_constraints: Dict[sympy.Symbol, ValueRanges],
48
+ ):
49
+ super().__init__()
50
+ self.range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints
51
+ self._asserts_generated_unbacked_symbols: Set[sympy.Symbol] = set()
52
+ self.counter = 0
53
+
54
+ def _assert_range_constraint(self, proxy, lower, upper, assert_msg):
55
+ if lower > -math.inf:
56
+ self._insert_assert_async(operator.ge, proxy, lower, assert_msg)
57
+
58
+ if upper < math.inf:
59
+ self._insert_assert_async(operator.le, proxy, upper, assert_msg)
60
+
61
+ def _insert_assert_async(self, operator, lower, upper, assert_msg):
62
+ """
63
+ Inserts assert_async call_function nodes in the graph. This function is
64
+ called **during** the interpreter-based pass.
65
+ """
66
+ self.counter += 1
67
+ cmp = super().call_operator(operator, (lower, upper), {}, self._create_dummy_node_metadata())
68
+ cmp_tensor = super().call_operator(torch.ops.aten.scalar_tensor.default, (cmp,), {}, self._create_dummy_node_metadata())
69
+ super().call_operator(
70
+ torch.ops.aten._assert_async.msg,
71
+ (cmp_tensor, assert_msg),
72
+ {},
73
+ self._create_dummy_node_metadata(),
74
+ )
75
+
76
+ def call_operator(self, op, args, kwargs, meta) -> ProxyValue:
77
+ ret = super().call_operator(op, args, kwargs, meta)
78
+ if "val" not in meta:
79
+ return ret
80
+
81
+ val = meta["val"]
82
+
83
+ # In general, we may have to deal the case such as: ret[1].shape[0].
84
+ # We need first find out what symbols require assertion, then we need to follow the path
85
+ # from ret to the symbol, construct the proxies along the way and construct the messages
86
+ # piece-wise at the same time.
87
+ #
88
+ # We use post-order traversal to collect all the proxies callbacks needed, construct
89
+ # the error message callbacks, and at the top-level traversal tree we execute all the callbacks.
90
+ # We need the callbacks because, in order to call the function to create a proxy for shape[0], we
91
+ # need the proxy for shape, which further requires the proxy for ret[1], etc.
92
+ def add_assertions(val):
93
+ call_backs: List[Callable] = []
94
+ messages: List[str] = []
95
+ if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)):
96
+ symbol = val.node.expr
97
+ if symbol in self.existing_inline_assertions:
98
+ return call_backs, messages
99
+ if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols(symbol):
100
+ if symbol in self._asserts_generated_unbacked_symbols:
101
+ return call_backs, messages
102
+ # We only care about unbacked symints for these inline
103
+ # constraints, which are prefixed with 'u'
104
+ constraint = self.range_constraints[symbol]
105
+ min_val, max_val = _convert_range_to_int(constraint)
106
+ assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]."
107
+ call_backs.append(
108
+ partial(self._assert_range_constraint, lower=min_val, upper=max_val)
109
+ )
110
+ messages.append(assert_msg)
111
+ self._asserts_generated_unbacked_symbols.add(symbol)
112
+
113
+ elif isinstance(val, torch.Tensor):
114
+ for i, sym in enumerate(val.shape):
115
+ cbs, msgs = add_assertions(sym)
116
+ for cb, msg in zip(cbs, msgs):
117
+ def sym_size_cb(proxy, assert_msg, dim):
118
+ dim_proxy = super(
119
+ _AddRuntimeAssertionsForInlineConstraintsPass,
120
+ self
121
+ ).call_operator(
122
+ torch.ops.aten.sym_size.int,
123
+ (proxy, dim),
124
+ {},
125
+ self._create_dummy_node_metadata(),
126
+ )
127
+ cb(proxy=dim_proxy, assert_msg=assert_msg)
128
+ call_backs.append(partial(sym_size_cb, dim=i))
129
+ messages.append(f".shape[{i}]" + msg)
130
+ return call_backs, messages
131
+
132
+ callbacks, messages = add_assertions(val)
133
+ for cb, msg in zip(callbacks, messages):
134
+ cb(proxy=ret, assert_msg=f"{ret.node}" + msg)
135
+ return ret
136
+
137
+ def call(self, graph_module):
138
+ self.existing_inline_assertions = _get_existing_inline_assertions(
139
+ graph_module, self.range_constraints
140
+ )
141
+
142
+ # Add runtime asserts for inline constraints
143
+ val = super().call(graph_module)
144
+
145
+ # Sometimes this pass would return a wrong graph where we have mismatched
146
+ # node names in signature. Before we fix it, let's just skip it.
147
+ if self.counter == 0 and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass:
148
+ return PassResult(graph_module, False)
149
+
150
+ # Populate the stack trace with dummy vals to respect IR
151
+ for node in val.graph_module.graph.nodes:
152
+ if not node.meta.get("stack_trace", None):
153
+ node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1))
154
+
155
+ return PassResult(val.graph_module, val.modified)
156
+
157
+
158
+ def _get_existing_inline_assertions(
159
+ graph_module: torch.fx.GraphModule,
160
+ range_constraints: Dict[sympy.Symbol, ValueRanges],
161
+ ) -> Dict[sympy.Symbol, ValueRanges]:
162
+ existing_inline_assertions: Dict[sympy.Symbol, ValueRanges] = {}
163
+
164
+ for module in graph_module.modules():
165
+ if not isinstance(module, torch.fx.GraphModule):
166
+ continue
167
+
168
+ # Find all the existing inline assertions. They will look something like:
169
+ # %_local_scalar_dense = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%arg1_1,), kwargs = {})
170
+ # %ge = call_function[target=operator.ge](args = (%_local_scalar_dense, 0), kwargs = {})
171
+ # %scalar_tensor = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%ge,), kwargs = {})
172
+ # %_assert_async = call_function[target=torch.ops.aten._assert_async.msg](args = (%scalar_tensor, "..."), kwargs = {})
173
+ for node in module.graph.nodes:
174
+ if node.target != torch.ops.aten._assert_async.msg:
175
+ continue
176
+
177
+ scalar_tensor_arg = node.args[0]
178
+ if not (
179
+ scalar_tensor_arg.op == "call_function" and
180
+ scalar_tensor_arg.target == torch.ops.aten.scalar_tensor.default
181
+ ):
182
+ continue
183
+
184
+ compare_arg = scalar_tensor_arg.args[0]
185
+ if not (
186
+ compare_arg.op == "call_function" and
187
+ compare_arg.target in (operator.le, operator.ge) and
188
+ len(compare_arg.args) == 2
189
+ ):
190
+ continue
191
+
192
+ compare_op = compare_arg.target
193
+ maybe_symint_arg, compare_int = compare_arg.args
194
+
195
+ # x >= 0 will sometimes be canonicalized to -x <= 0, so in some
196
+ # cases the operation before the comparison is to multiply by -1. We
197
+ # can undo the canonicalization here
198
+ if (
199
+ maybe_symint_arg.op == "call_function" and
200
+ maybe_symint_arg.target == operator.mul and
201
+ maybe_symint_arg.args[0] == -1
202
+ ):
203
+ maybe_symint_arg = maybe_symint_arg.args[1]
204
+ compare_op = operator.ge
205
+ compare_int = -1 * compare_int
206
+
207
+ if not (
208
+ "val" in maybe_symint_arg.meta and
209
+ isinstance(maybe_symint_arg.meta["val"], torch.SymInt)
210
+ ):
211
+ continue
212
+
213
+ symint = maybe_symint_arg.meta["val"].node.expr
214
+ if not isinstance(symint, sympy.Symbol):
215
+ continue
216
+
217
+ if symint not in range_constraints:
218
+ raise RuntimeError(f"Unable to find symint {symint} in {range_constraints}")
219
+
220
+ found_range = existing_inline_assertions.get(symint, ValueRanges(-math.inf, math.inf))
221
+
222
+ if compare_arg.target == operator.le:
223
+ existing_inline_assertions[symint] = ValueRanges(
224
+ lower=found_range.lower, upper=compare_int
225
+ )
226
+ elif compare_arg.target == operator.ge:
227
+ existing_inline_assertions[symint] = ValueRanges(
228
+ lower=compare_int, upper=found_range.upper
229
+ )
230
+
231
+ return existing_inline_assertions
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Dict, Optional, Tuple, List
3
+
4
+ import torch
5
+ from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, PassResult, Argument
6
+ from torch._export.pass_infra.node_metadata import NodeMetadata
7
+ from torch._export.pass_infra.proxy_value import ProxyValue
8
+ from torch._ops import OpOverload
9
+
10
+ aten = torch.ops.aten
11
+
12
+ _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: Dict[OpOverload, OpOverload] = {
13
+ aten.sym_constrain_range.default: aten._functional_sym_constrain_range,
14
+ aten._assert_async.msg: aten._functional_assert_async.msg,
15
+ }
16
+
17
+
18
+ class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse):
19
+ """
20
+ Functionalize ops with side effect in graph module by replacing the op with
21
+ functional version of it. A new dependency token (`dep_token`) will be
22
+ created and propagated through functional ops to output.
23
+ For example:
24
+ ```
25
+ def f(x):
26
+ sym_constrain_range(x.shape[0], min=1, max=3)
27
+ return x.add(3)
28
+ ```
29
+ Will be transformed to:
30
+ ```
31
+ def f(x):
32
+ dep_token0 = _make_dep_token()
33
+ dep_token1 = _functional_sym_constrain_range(
34
+ x.shape[0], min=1, max=3, dep_token=dep_token0
35
+ )
36
+
37
+ return x.add(3), dep_token1
38
+ ```
39
+ """
40
+
41
+ def __init__(self) -> None:
42
+ super().__init__()
43
+ self._dep_token: Optional[ProxyValue] = None
44
+ self._next_dep_token_index: Optional[int] = None
45
+
46
+ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
47
+ # Early return if no non-functional assertions.
48
+ if not any(
49
+ n.target in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS
50
+ for n in graph_module.graph.nodes
51
+ ):
52
+ return PassResult(graph_module=graph_module, modified=False)
53
+
54
+ gm = copy.deepcopy(graph_module)
55
+ self._dep_token = None
56
+ self._next_dep_token_index = None
57
+ return super().call(gm)
58
+
59
+ def call_operator(
60
+ self,
61
+ op: OpOverload,
62
+ args: Tuple[Argument, ...],
63
+ kwargs: Dict[str, Argument],
64
+ meta: NodeMetadata,
65
+ ) -> ProxyValue:
66
+ if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS:
67
+ return super().call_operator(op, args, kwargs, meta)
68
+
69
+ if self._dep_token is None:
70
+ self._dep_token = super().call_operator(
71
+ aten._make_dep_token,
72
+ args=(),
73
+ kwargs={},
74
+ meta=self._create_dummy_node_metadata(),
75
+ )
76
+ self._dep_token.node.name = "dep_token0"
77
+ self._next_dep_token_index = 1
78
+
79
+ self._dep_token = super().call_operator(
80
+ _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS[op],
81
+ args=args,
82
+ kwargs={**kwargs, "dep_token": self._dep_token},
83
+ meta=meta,
84
+ )
85
+ assert self._next_dep_token_index is not None
86
+ self._dep_token.node.name = f"dep_token{self._next_dep_token_index}"
87
+ self._next_dep_token_index += 1
88
+
89
+ return self._dep_token
90
+
91
+ def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
92
+ assert self._dep_token is not None
93
+
94
+ return super().output(results=(*results, self._dep_token), meta=meta) # type: ignore[arg-type]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_sym_size_ops_pass.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import torch
4
+
5
+ replacements: Dict[torch._ops.OpOverloadPacket, torch._ops.OpOverload] = {
6
+ torch.ops.aten.sym_size: torch.ops.aten.sym_size.int,
7
+ torch.ops.aten.sym_stride: torch.ops.aten.sym_stride.int,
8
+ torch.ops.aten.sym_numel: torch.ops.aten.sym_numel.default,
9
+ }
10
+
11
+
12
+ def _replace_sym_size_ops_pass(gm: torch.fx.GraphModule):
13
+ for module in gm.modules():
14
+ if not isinstance(module, torch.fx.GraphModule):
15
+ continue
16
+ for node in module.graph.nodes:
17
+ if node.target in replacements:
18
+ node.target = replacements[node.target]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Set
2
+
3
+ import torch
4
+ from torch._ops import OpOverload, OpOverloadPacket, HigherOrderOperator
5
+ from torch._export.error import InternalError
6
+ from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
7
+
8
+
9
+ __all__ = ["ReplaceViewOpsWithViewCopyOpsPass"]
10
+
11
+
12
+ _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: Dict[OpOverload, OpOverload] = {
13
+ torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default,
14
+ }
15
+
16
+ # TODO (tmanlaibaatar) remove this after https://github.com/pytorch/pytorch/pull/100749
17
+ _BLACK_LISTED_OPS: Set[OpOverloadPacket] = {
18
+ torch.ops.aten.sym_size,
19
+ torch.ops.aten.sym_stride,
20
+ torch.ops.aten.sym_numel,
21
+ }
22
+
23
+ def is_view_op(schema: torch._C.FunctionSchema) -> bool:
24
+ if len(schema.arguments) == 0:
25
+ return False
26
+ alias_info = schema.arguments[0].alias_info
27
+ return (alias_info is not None) and (not alias_info.is_write)
28
+
29
+
30
+ def get_view_copy_of_view_op(schema: torch._C.FunctionSchema) -> Optional[OpOverload]:
31
+ if is_view_op(schema) and schema.name.startswith("aten::"):
32
+ view_op_name = schema.name.split("::")[1]
33
+ view_op_overload = (
34
+ schema.overload_name
35
+ if schema.overload_name != ""
36
+ else "default"
37
+ )
38
+ view_copy_op_name = view_op_name + "_copy"
39
+ if not hasattr(torch.ops.aten, view_copy_op_name):
40
+ raise InternalError(f"{schema.name} is missing a view_copy variant")
41
+
42
+ view_copy_op_overload_packet = getattr(torch.ops.aten, view_copy_op_name)
43
+
44
+ if not hasattr(view_copy_op_overload_packet, view_op_overload):
45
+ raise InternalError(f"{schema.name} is missing a view_copy variant")
46
+
47
+ return getattr(view_copy_op_overload_packet, view_op_overload)
48
+
49
+ return None
50
+
51
+
52
+ class ReplaceViewOpsWithViewCopyOpsPass(_ExportPassBaseDeprecatedDoNotUse):
53
+ """
54
+ Our backend expects pure functional operators. For efficiency
55
+ purposes, we keep view ops around while functionalizing the exported
56
+ program. This pass replaces view ops with view copy ops for backends that
57
+ need AOT memory planning.
58
+ """
59
+ def call_operator(self, op, args, kwargs, meta):
60
+ if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS:
61
+ return super().call_operator(
62
+ (_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op]), args, kwargs, meta
63
+ )
64
+
65
+ if op in _BLACK_LISTED_OPS or isinstance(op, HigherOrderOperator):
66
+ return super().call_operator(op, args, kwargs, meta)
67
+
68
+ if view_copy_op := get_view_copy_of_view_op(op._schema):
69
+ return super().call_operator(view_copy_op, args, kwargs, meta)
70
+
71
+ return super().call_operator(op, args, kwargs, meta)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema.cpython-311.pyc ADDED
Binary file (15.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/union.cpython-311.pyc ADDED
Binary file (5.63 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/upgrade.cpython-311.pyc ADDED
Binary file (14 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/schema.yaml ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @generated by update_schema.py
2
+ # checksum<<4c9986f3aba283b1746995fff8fe7005b370c7e288adec65c03030349a4bab60>>
3
+ Argument:
4
+ kind: union
5
+ fields:
6
+ as_none:
7
+ type: Tuple[()]
8
+ as_tensor:
9
+ type: TensorArgument
10
+ as_tensors:
11
+ type: List[TensorArgument]
12
+ as_int:
13
+ type: int
14
+ as_ints:
15
+ type: List[int]
16
+ as_float:
17
+ type: float
18
+ as_floats:
19
+ type: List[float]
20
+ as_string:
21
+ type: str
22
+ as_strings:
23
+ type: List[str]
24
+ as_sym_int:
25
+ type: SymIntArgument
26
+ as_sym_ints:
27
+ type: List[SymIntArgument]
28
+ as_scalar_type:
29
+ type: ScalarType
30
+ as_memory_format:
31
+ type: MemoryFormat
32
+ as_layout:
33
+ type: Layout
34
+ as_device:
35
+ type: Device
36
+ as_bool:
37
+ type: bool
38
+ as_bools:
39
+ type: List[bool]
40
+ as_sym_bool:
41
+ type: SymBoolArgument
42
+ as_sym_bools:
43
+ type: List[SymBoolArgument]
44
+ as_graph:
45
+ type: GraphArgument
46
+ as_optional_tensors:
47
+ type: List[OptionalTensorArgument]
48
+ as_custom_obj:
49
+ type: CustomObjArgument
50
+ as_operator:
51
+ type: str
52
+ BufferMutationSpec:
53
+ kind: struct
54
+ fields:
55
+ arg:
56
+ type: TensorArgument
57
+ buffer_name:
58
+ type: str
59
+ CustomObjArgument:
60
+ kind: struct
61
+ fields:
62
+ name:
63
+ type: str
64
+ class_fqn:
65
+ type: str
66
+ Device:
67
+ kind: struct
68
+ fields:
69
+ type:
70
+ type: str
71
+ index:
72
+ type: Optional[int]
73
+ default: None
74
+ ExportedProgram:
75
+ kind: struct
76
+ fields:
77
+ graph_module:
78
+ type: GraphModule
79
+ opset_version:
80
+ type: Dict[str, int]
81
+ range_constraints:
82
+ type: Dict[str, RangeConstraint]
83
+ schema_version:
84
+ type: SchemaVersion
85
+ dialect:
86
+ type: str
87
+ GradientToParameterSpec:
88
+ kind: struct
89
+ fields:
90
+ arg:
91
+ type: TensorArgument
92
+ parameter_name:
93
+ type: str
94
+ GradientToUserInputSpec:
95
+ kind: struct
96
+ fields:
97
+ arg:
98
+ type: TensorArgument
99
+ user_input_name:
100
+ type: str
101
+ Graph:
102
+ kind: struct
103
+ fields:
104
+ inputs:
105
+ type: List[Argument]
106
+ outputs:
107
+ type: List[Argument]
108
+ nodes:
109
+ type: List[Node]
110
+ tensor_values:
111
+ type: Dict[str, TensorMeta]
112
+ sym_int_values:
113
+ type: Dict[str, SymInt]
114
+ sym_bool_values:
115
+ type: Dict[str, SymBool]
116
+ is_single_tensor_return:
117
+ type: bool
118
+ default: 'False'
119
+ custom_obj_values:
120
+ type: Dict[str, CustomObjArgument]
121
+ default: '{}'
122
+ GraphArgument:
123
+ kind: struct
124
+ fields:
125
+ name:
126
+ type: str
127
+ graph:
128
+ type: Graph
129
+ GraphModule:
130
+ kind: struct
131
+ fields:
132
+ graph:
133
+ type: Graph
134
+ signature:
135
+ type: GraphSignature
136
+ module_call_graph:
137
+ type: List[ModuleCallEntry]
138
+ GraphSignature:
139
+ kind: struct
140
+ fields:
141
+ input_specs:
142
+ type: List[InputSpec]
143
+ output_specs:
144
+ type: List[OutputSpec]
145
+ InputSpec:
146
+ kind: union
147
+ fields:
148
+ user_input:
149
+ type: UserInputSpec
150
+ parameter:
151
+ type: InputToParameterSpec
152
+ buffer:
153
+ type: InputToBufferSpec
154
+ tensor_constant:
155
+ type: InputToTensorConstantSpec
156
+ custom_obj:
157
+ type: InputToCustomObjSpec
158
+ InputToBufferSpec:
159
+ kind: struct
160
+ fields:
161
+ arg:
162
+ type: TensorArgument
163
+ buffer_name:
164
+ type: str
165
+ persistent:
166
+ type: bool
167
+ InputToCustomObjSpec:
168
+ kind: struct
169
+ fields:
170
+ arg:
171
+ type: CustomObjArgument
172
+ custom_obj_name:
173
+ type: str
174
+ InputToParameterSpec:
175
+ kind: struct
176
+ fields:
177
+ arg:
178
+ type: TensorArgument
179
+ parameter_name:
180
+ type: str
181
+ InputToTensorConstantSpec:
182
+ kind: struct
183
+ fields:
184
+ arg:
185
+ type: TensorArgument
186
+ tensor_constant_name:
187
+ type: str
188
+ Layout:
189
+ kind: enum
190
+ fields:
191
+ Unknown: 0
192
+ SparseCoo: 1
193
+ SparseCsr: 2
194
+ SparseCsc: 3
195
+ SparseBsr: 4
196
+ SparseBsc: 5
197
+ _mkldnn: 6
198
+ Strided: 7
199
+ LossOutputSpec:
200
+ kind: struct
201
+ fields:
202
+ arg:
203
+ type: TensorArgument
204
+ MemoryFormat:
205
+ kind: enum
206
+ fields:
207
+ Unknown: 0
208
+ ContiguousFormat: 1
209
+ ChannelsLast: 2
210
+ ChannelsLast3d: 3
211
+ PreserveFormat: 4
212
+ ModuleCallEntry:
213
+ kind: struct
214
+ fields:
215
+ fqn:
216
+ type: str
217
+ signature:
218
+ type: Optional[ModuleCallSignature]
219
+ default: None
220
+ ModuleCallSignature:
221
+ kind: struct
222
+ fields:
223
+ inputs:
224
+ type: List[Argument]
225
+ outputs:
226
+ type: List[Argument]
227
+ in_spec:
228
+ type: str
229
+ out_spec:
230
+ type: str
231
+ NamedArgument:
232
+ kind: struct
233
+ fields:
234
+ name:
235
+ type: str
236
+ arg:
237
+ type: Argument
238
+ Node:
239
+ kind: struct
240
+ fields:
241
+ target:
242
+ type: str
243
+ inputs:
244
+ type: List[NamedArgument]
245
+ outputs:
246
+ type: List[Argument]
247
+ metadata:
248
+ type: Dict[str, str]
249
+ OptionalTensorArgument:
250
+ kind: union
251
+ fields:
252
+ as_tensor:
253
+ type: str
254
+ as_none:
255
+ type: Tuple[()]
256
+ OutputSpec:
257
+ kind: union
258
+ fields:
259
+ user_output:
260
+ type: UserOutputSpec
261
+ loss_output:
262
+ type: LossOutputSpec
263
+ buffer_mutation:
264
+ type: BufferMutationSpec
265
+ gradient_to_parameter:
266
+ type: GradientToParameterSpec
267
+ gradient_to_user_input:
268
+ type: GradientToUserInputSpec
269
+ user_input_mutation:
270
+ type: UserInputMutationSpec
271
+ RangeConstraint:
272
+ kind: struct
273
+ fields:
274
+ min_val:
275
+ type: int
276
+ max_val:
277
+ type: int
278
+ ScalarType:
279
+ kind: enum
280
+ fields:
281
+ UNKNOWN: 0
282
+ BYTE: 1
283
+ CHAR: 2
284
+ SHORT: 3
285
+ INT: 4
286
+ LONG: 5
287
+ HALF: 6
288
+ FLOAT: 7
289
+ DOUBLE: 8
290
+ COMPLEXHALF: 9
291
+ COMPLEXFLOAT: 10
292
+ COMPLEXDOUBLE: 11
293
+ BOOL: 12
294
+ BFLOAT16: 13
295
+ SchemaVersion:
296
+ kind: struct
297
+ fields:
298
+ major:
299
+ type: int
300
+ minor:
301
+ type: int
302
+ SymBool:
303
+ kind: union
304
+ fields:
305
+ as_expr:
306
+ type: SymExpr
307
+ as_bool:
308
+ type: bool
309
+ SymBoolArgument:
310
+ kind: union
311
+ fields:
312
+ as_name:
313
+ type: str
314
+ as_bool:
315
+ type: bool
316
+ SymExpr:
317
+ kind: struct
318
+ fields:
319
+ expr_str:
320
+ type: str
321
+ hint:
322
+ type: Optional[SymExprHint]
323
+ default: None
324
+ SymExprHint:
325
+ kind: union
326
+ fields:
327
+ as_int:
328
+ type: int
329
+ as_float:
330
+ type: float
331
+ as_bool:
332
+ type: bool
333
+ SymInt:
334
+ kind: union
335
+ fields:
336
+ as_expr:
337
+ type: SymExpr
338
+ as_int:
339
+ type: int
340
+ SymIntArgument:
341
+ kind: union
342
+ fields:
343
+ as_name:
344
+ type: str
345
+ as_int:
346
+ type: int
347
+ TensorArgument:
348
+ kind: struct
349
+ fields:
350
+ name:
351
+ type: str
352
+ TensorMeta:
353
+ kind: struct
354
+ fields:
355
+ dtype:
356
+ type: ScalarType
357
+ sizes:
358
+ type: List[SymInt]
359
+ requires_grad:
360
+ type: bool
361
+ device:
362
+ type: Device
363
+ strides:
364
+ type: List[SymInt]
365
+ storage_offset:
366
+ type: SymInt
367
+ layout:
368
+ type: Layout
369
+ UserInputMutationSpec:
370
+ kind: struct
371
+ fields:
372
+ arg:
373
+ type: TensorArgument
374
+ user_input_name:
375
+ type: str
376
+ UserInputSpec:
377
+ kind: struct
378
+ fields:
379
+ arg:
380
+ type: Argument
381
+ UserOutputSpec:
382
+ kind: struct
383
+ fields:
384
+ arg:
385
+ type: Argument
386
+ SCHEMA_VERSION:
387
+ - 5
388
+ - 1
389
+ TREESPEC_VERSION: 1
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/upgrade.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections import defaultdict
3
+ from typing import Tuple, Dict, Optional, List
4
+
5
+ import torch
6
+ from torch.export import export
7
+ from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
8
+ from torch._export.pass_infra.node_metadata import NodeMetadata
9
+ from torch._export.pass_infra.proxy_value import ProxyValue
10
+ from torch._subclasses import FakeTensor
11
+ from torch.fx.node import Target, Argument
12
+ from torch.library import Library
13
+ from torch.utils._pytree import tree_unflatten
14
+ import torch._export.exported_program as ep
15
+ import re
16
+
17
+ lib = Library("aten", "FRAGMENT")
18
+ impl_lib = Library("aten", "IMPL")
19
+
20
+ log = logging.getLogger(__name__)
21
+
22
+
23
+ def get_target_version(versioned_upgrader_name: str) -> int:
24
+ """div_Scalar_0_3 is the name of the upgrader, meaning it applies to div.Scalar of version 0 to 3 and is
25
+ upgrading to version 4."""
26
+ if not re.match("^.*_[0-9]+_[0-9]+$", versioned_upgrader_name):
27
+ raise RuntimeError(f"Upgrader name {versioned_upgrader_name} is invalid")
28
+
29
+ return int(versioned_upgrader_name.split('_')[-1]) + 1
30
+
31
+
32
+ def get_upgraders() -> Dict[str, Tuple[str, str]]:
33
+ """Getting upgraders entry map and operator version map and merge them into one dict."""
34
+ upgraders = torch._C._get_upgraders_entry_map()
35
+ op_version_map = torch._C._get_operator_version_map()
36
+ output: Dict[str, Tuple[str, str]] = defaultdict(tuple) # type: ignore[arg-type]
37
+ for opname, entry_list in op_version_map.items():
38
+ if not entry_list:
39
+ raise RuntimeError(f"Op version map has an empty entry for opname {opname}")
40
+ entry = entry_list[0]
41
+ old_schema = entry.old_schema
42
+ upgrader_name = entry.upgrader_name
43
+ upgrader_str = upgraders.get(upgrader_name, None)
44
+ if not upgrader_str:
45
+ raise RuntimeError(f"Can't find upgrader for op {opname} and upgrader name {upgrader_name}")
46
+ output[upgrader_name] = (old_schema, upgrader_str)
47
+ return output
48
+
49
+
50
+ class GraphModuleOpUpgrader:
51
+ """This upgrader is able to upgrade the old version of ops in a given GraphModule, if all upgraders are available.
52
+ To use it, retrieve upgraders from somewhere (TorchScript API or new API) and pass it into this upgrader. In
53
+ __init__() it does the following:
54
+ 1. parse the upgrader list and reorder for upgrading purpose.
55
+ 2. register old versions of operators as custom ops.
56
+ 3. prepare upgrader passes.
57
+
58
+ In `upgrade()` API run these upgrader passes.
59
+
60
+ An example of op_upgraders input:
61
+ {
62
+ "aten::div__Scalar_0_3": ( # versioned op name
63
+ "div._Scalar(self: Tensor, other: Scalar)", # old schema
64
+ '''
65
+ def div__Scalar_0_3(self: torch.Tensor, other) -> torch.Tensor: # upgrader in literal string
66
+ if (self.is_floating_point() or isinstance(other, float)):
67
+ return self.true_divide_(other)
68
+ return self.divide_(other, rounding_mode='trunc')
69
+ ''',
70
+ ),
71
+ },
72
+
73
+ Note that we require the upgrader function to be runnable in Python (which is a stricter requirement than the
74
+ original TorchScript upgrader).
75
+ """
76
+
77
+ class UpgraderPass(_ExportPassBaseDeprecatedDoNotUse):
78
+ def __init__(self, old_target: Target, new_target: Target):
79
+ super().__init__()
80
+ self.old_target = old_target
81
+ self.new_target = new_target
82
+
83
+ def call_operator(
84
+ self,
85
+ op,
86
+ args: Tuple[Argument, ...],
87
+ kwargs: Dict[str, Argument],
88
+ meta: NodeMetadata,
89
+ ) -> ProxyValue:
90
+ if op == self.old_target:
91
+ return super().call_operator(self.new_target, args, kwargs, meta)
92
+ return super().call_operator(op, args, kwargs, meta)
93
+
94
+ def __init__(
95
+ self,
96
+ compiler_opset_version: Optional[Dict[str, int]] = None,
97
+ model_opset_version: Optional[Dict[str, int]] = None,
98
+ op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None,
99
+ ):
100
+ self.op_upgraders: Dict[str, Tuple[str, str]] = get_upgraders() if not op_upgraders else op_upgraders
101
+ self.compiler_opset_version = compiler_opset_version if compiler_opset_version else {}
102
+ self.model_opset_version = model_opset_version if model_opset_version else {}
103
+ self.upgrader_passes: List[GraphModuleOpUpgrader.UpgraderPass] = GraphModuleOpUpgrader._populate_passes(
104
+ self._parse_upgraders(self.op_upgraders))
105
+
106
+ def _parse_upgraders(self, op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None) -> List[Tuple[str, str]]:
107
+ """Reorder op_upgraders by version number, return an ordered list of tuples, containing old op schema as well
108
+ as the upgrader function string literal."""
109
+ # TODO(larryliu0820): Add support for custom ops
110
+ op_namespace = "aten"
111
+ if not op_upgraders or op_namespace not in self.model_opset_version or op_namespace not in self.compiler_opset_version:
112
+ return []
113
+ model_ver = self.model_opset_version[op_namespace]
114
+ curr_ver = self.compiler_opset_version[op_namespace]
115
+
116
+ # key is the target version. div__Scalar_0_3 should have a key of 4.
117
+ versioned_upgraders: Dict[int, Tuple[str, str]] = {get_target_version(name): v for name, v in
118
+ op_upgraders.items()}
119
+ target_upgraders: List[Tuple[str, str]] = []
120
+ # we need all upgraders from model_ver + 1 to curr_ver, inclusively
121
+ for ver in range(model_ver + 1, curr_ver + 1):
122
+ if ver in versioned_upgraders:
123
+ target_upgraders.append(versioned_upgraders[ver])
124
+ else:
125
+ # we may be able to get away with missing upgraders, if that operator is missing from given graph
126
+ # module.
127
+ log.warning("Missing an upgrader to upgrade to version {ver}.", extra={"ver": ver})
128
+
129
+ return target_upgraders
130
+
131
+ @staticmethod
132
+ def _populate_passes(upgraders: List[Tuple[str, str]]) -> List[UpgraderPass]:
133
+ """Given a list of upgraders, loop through it from lower version to higher version and create passes for all
134
+ upgraders. se torch.Library API to register old ops. Op name will be
135
+ <name>_<valid_from_ver>_<valid_till_ver>. Register upgraders as CompositeImplicitAutograd kernels. For example:
136
+
137
+ lib = Library("aten", "FRAGMENT")
138
+ lib.define(old_schema)
139
+
140
+ impl_lib = Library("aten", "IMPL")
141
+ impl_lib.impl("div__Scalar_0_3", div__Scalar_0_3, "CompositeImplicitAutograd")
142
+
143
+ @:var upgraders: a list of tuples. The first element of the tuple is the old schema and the second is the
144
+ upgrader function literal text.
145
+ @:return upgrader passes, order matters
146
+ """
147
+
148
+ upgrader_passes = []
149
+
150
+ def register_old_op(name: str, schema: str, impl_str: str):
151
+ """Registers an old version operator using impl_name as old op name."""
152
+ lib.define(schema)
153
+ try:
154
+ exec(impl_str)
155
+ except Exception as e:
156
+ raise RuntimeError(f"Invalid upgrader string: {impl_str}") from e
157
+ impl_lib.impl(name, locals()[name], "CompositeImplicitAutograd")
158
+
159
+ for (schema, upgrader_str) in upgraders:
160
+ upgrader_name = upgrader_str.split('(')[0].split(' ')[-1]
161
+ op_name = schema.split('(')[0].split("::")[-1]
162
+ schema = schema.replace(op_name, upgrader_name)
163
+ try:
164
+ register_old_op(name=upgrader_name, schema=schema, impl_str=upgrader_str)
165
+ except RuntimeError as e:
166
+ if "with the same name and overload name multiple times" in str(e):
167
+ print(f"Registering {upgrader_name} multiple times")
168
+ else:
169
+ raise RuntimeError from e
170
+ old_op_target = getattr(torch.ops.aten, upgrader_name).default
171
+ # for example, the operator instance of "aten::div" is torch.op.aten.div.default. We need to append the
172
+ # "default" at the end.
173
+ op_name, overload_name = (op_name, "default") if "." not in op_name else tuple(op_name.split(".")[:2])
174
+ new_op_target = getattr(getattr(torch.ops.aten, op_name), overload_name)
175
+ # Note that the graph will have op names in the graph, but actually they are of old versions.
176
+ upgrader_passes.append(
177
+ GraphModuleOpUpgrader.UpgraderPass(old_target=new_op_target, new_target=old_op_target))
178
+
179
+ return upgrader_passes
180
+
181
+ def upgrade(self, exported_program: ep.ExportedProgram) -> ep.ExportedProgram:
182
+ """Run each upgrader pass and then retrace to decompose it. Each upgrader pass replaces the old version of
183
+ operators with a custom operator. The custom operator contains a CompositeImplicitAutograd kernel (the
184
+ upgrading function itself). After retrace, this custom operator will be decomposed into the ops used in the
185
+ upgrader. After all passes are applied, the exported program will be upgraded to the target version."""
186
+ if not self.upgrader_passes:
187
+ return exported_program
188
+
189
+ args = [n.meta.get("val", None) for n in exported_program.graph.nodes if n.op == "placeholder"]
190
+ args_real_tensors = [torch.ones(tuple(arg.size()), dtype=arg.dtype) if isinstance(arg, FakeTensor) else arg for
191
+ arg in args]
192
+ assert exported_program.call_spec.in_spec is not None
193
+ args, kwargs = tree_unflatten(args_real_tensors, exported_program.call_spec.in_spec)
194
+ assert kwargs == {}
195
+
196
+ for _pass in self.upgrader_passes:
197
+ upgraded_program = exported_program._transform_do_not_use(_pass)
198
+ # NB: we have to retrace the graph_module instead of ep because of some failure.
199
+ exported_program = export(upgraded_program.module(), args, kwargs)
200
+
201
+ return exported_program
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/wrappers.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import warnings
3
+ from functools import wraps
4
+ from itertools import chain
5
+
6
+ from typing import Callable, NamedTuple, Optional, overload, Sequence, Tuple
7
+
8
+ import torch
9
+ import torch._prims_common as utils
10
+ from torch._prims_common import (
11
+ CustomOutParamAnnotation,
12
+ ELEMENTWISE_TYPE_PROMOTION_KIND,
13
+ Number,
14
+ NumberType,
15
+ ShapeType,
16
+ TensorLike,
17
+ TensorLikeType,
18
+ )
19
+ from torch.utils import _pytree as pytree
20
+ from torch.utils._pytree import tree_flatten, tree_unflatten
21
+
22
+
23
+ @overload
24
+ def _maybe_convert_to_dtype(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
25
+ pass
26
+
27
+
28
+ @overload
29
+ def _maybe_convert_to_dtype(a: NumberType, dtype: torch.dtype) -> NumberType:
30
+ pass
31
+
32
+
33
+ @overload
34
+ def _maybe_convert_to_dtype(a: Sequence, dtype: torch.dtype) -> Sequence:
35
+ pass
36
+
37
+
38
+ @overload
39
+ def _maybe_convert_to_dtype(a: None, dtype: torch.dtype) -> None:
40
+ pass
41
+
42
+
43
+ # TODO: implement ref.cast with an option to enforce safe casting
44
+ def _maybe_convert_to_dtype(a, dtype):
45
+ if isinstance(a, TensorLike):
46
+ if a.dtype != dtype:
47
+ return a.to(dtype)
48
+ return a
49
+ if isinstance(a, Number):
50
+ return utils.dtype_to_type_ctor(dtype)(a) # type: ignore[arg-type]
51
+ if isinstance(a, Sequence):
52
+ return tuple(_maybe_convert_to_dtype(x, dtype) for x in a)
53
+ # Passthrough None because some functions wrapped with type promotion
54
+ # wrapper might have optional args
55
+ if a is None:
56
+ return None
57
+
58
+ raise ValueError(f"Received type {type(a)} that is neither a tensor or a number!")
59
+
60
+
61
+ def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType:
62
+ if not isinstance(a, Number):
63
+ msg = f"Found unknown type {type(a)} when trying to convert scalars!"
64
+ raise ValueError(msg)
65
+ if not utils.is_weakly_lesser_type(type(a), typ):
66
+ msg = f"Scalar {a} of type {type(a)} cannot be safely cast to type {typ}!"
67
+ raise ValueError(msg)
68
+
69
+ return typ(a)
70
+
71
+
72
+ def _annotation_has_type(*, typ, annotation):
73
+ if hasattr(annotation, "__args__"):
74
+ for a in annotation.__args__:
75
+ if _annotation_has_type(typ=typ, annotation=a):
76
+ return True
77
+ return False
78
+
79
+ return typ is annotation
80
+
81
+
82
+ class elementwise_type_promotion_wrapper:
83
+ """
84
+ Adds elementwise type promotion to a Python reference implementation.
85
+
86
+ Takes two kwargs, type_promoting_args and type_promotion_kind.
87
+
88
+ type_promoting_args must be a string Sequence specifiying the argument names of all
89
+ arguments that participate in type promotion (and should be type promoted). If the
90
+ arg specifies a Sequence-type then every element of the Sequence will participate in
91
+ type promotion.
92
+
93
+ type_promotion_kind must be one of the kinds specified by ELEMENTWISE_TYPE_PROMOTION_KIND.
94
+ See its documentation for details.
95
+
96
+ The return_dtype will be coerced to the wrapped function's dtype arg if it is available and
97
+ not None.
98
+
99
+ Other type promotion behavior, like validating the Python type of scalar arguments, must
100
+ be handled separately.
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ *,
106
+ type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
107
+ type_promoting_args: Optional[Sequence[str]] = None,
108
+ ):
109
+ self.type_promoting_arg_names = type_promoting_args
110
+ self.type_promotion_kind = type_promotion_kind
111
+
112
+ def __call__(self, fn: Callable) -> Callable:
113
+ sig = inspect.signature(fn)
114
+
115
+ @wraps(fn)
116
+ def _fn(*args, **kwargs):
117
+ bound = sig.bind(*args, **kwargs)
118
+ type_promoting_args = tuple(
119
+ bound.arguments[x]
120
+ for x in self.type_promoting_arg_names # type: ignore[union-attr]
121
+ if x in bound.arguments.keys()
122
+ )
123
+
124
+ flattened_type_promoting_args = pytree.arg_tree_leaves(*type_promoting_args)
125
+ compute_dtype, result_dtype = utils.elementwise_dtypes(
126
+ *flattened_type_promoting_args,
127
+ type_promotion_kind=self.type_promotion_kind,
128
+ )
129
+
130
+ promoted_args = {
131
+ x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype)
132
+ for x in self.type_promoting_arg_names # type: ignore[union-attr]
133
+ if x in bound.arguments.keys()
134
+ }
135
+ bound.arguments.update(promoted_args)
136
+
137
+ result = fn(**bound.arguments)
138
+
139
+ # Override the return_dtype if a dtype arg is present and not None
140
+ if "dtype" in bound.arguments:
141
+ maybe_dtype = bound.arguments["dtype"]
142
+ if maybe_dtype: # dtype cannot be None
143
+ result_dtype = maybe_dtype
144
+
145
+ if isinstance(result, TensorLike):
146
+ return _maybe_convert_to_dtype(result, result_dtype)
147
+ if isinstance(result, Sequence):
148
+ return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result)
149
+ raise AssertionError(f"Unhandled result type: {type(result)}")
150
+
151
+ _fn.__signature__ = sig # type: ignore[attr-defined]
152
+ return _fn
153
+
154
+
155
+ # Returns True if resize is necessary
156
+ def _resize_output_check(out: TensorLikeType, shape: ShapeType):
157
+ # If the shapes are correct there's nothing to do
158
+ if utils.same_shape(out.shape, shape):
159
+ return False
160
+ if out.numel() != 0:
161
+ msg = (
162
+ f"An output with one or more elements was resized since it had shape {str(out.shape)} "
163
+ "which does not match the required output shape {str(shape)}. "
164
+ "This behavior is deprecated, and in a future PyTorch release outputs will not "
165
+ "be resized unless they have zero elements. "
166
+ "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)."
167
+ )
168
+ warnings.warn(msg)
169
+ return True
170
+
171
+
172
+ # TODO: handle tuples of tensors
173
+ def _maybe_resize_out(out: TensorLikeType, shape: ShapeType):
174
+ if _resize_output_check(out, shape):
175
+ return out.resize_(shape)
176
+ else:
177
+ return out
178
+
179
+
180
+ def _safe_copy_out(
181
+ *, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False
182
+ ):
183
+ # Checks same device
184
+ if copy_from.device != copy_to.device:
185
+ msg = "Attempting to copy from device {} to device {}, but cross-device copies are not allowed!".format(
186
+ copy_from.device, copy_to.device
187
+ )
188
+ raise RuntimeError(msg)
189
+
190
+ # Checks safe cast
191
+ if exact_dtype:
192
+ torch._check(
193
+ copy_from.dtype == copy_to.dtype,
194
+ lambda: f"Expected out tensor to have dtype {copy_from.dtype} "
195
+ f"but got {copy_to.dtype} instead",
196
+ )
197
+ else:
198
+ torch._check(
199
+ utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype),
200
+ lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, "
201
+ "but this can't be cast because it is not safe!",
202
+ )
203
+
204
+ return copy_to.copy_(copy_from)
205
+
206
+
207
+ def out_wrapper(*out_names: str, exact_dtype: bool = False, pass_is_out: bool = False):
208
+ # The wrapped function needs to convert the output parameters to ensure
209
+ # compatibility between the Python API (which always uses "out" as the
210
+ # parameter name and may be a tuple) and the Aten API (which may have
211
+ # multiple output parameters and use different parameter names such as
212
+ # "grad_input", "indices" or "values".)
213
+
214
+ default_out_names = ("out",)
215
+ if len(out_names) == 0:
216
+ # Use default in out name
217
+ out_names = default_out_names
218
+
219
+ is_tensor = len(out_names) == 1
220
+
221
+ def _out_wrapper(fn: Callable) -> Callable:
222
+ """
223
+ Adds the out parameter to a Python reference.
224
+ """
225
+ out_type = (
226
+ TensorLikeType
227
+ if is_tensor
228
+ else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))]
229
+ )
230
+ return_type = (
231
+ TensorLikeType
232
+ if is_tensor
233
+ else NamedTuple(
234
+ f"return_types_{fn.__name__}", [(o, TensorLikeType) for o in out_names]
235
+ )
236
+ )
237
+
238
+ sig = inspect.signature(fn)
239
+ factory_kwargs = ("device", "dtype")
240
+ is_factory_fn = all(p in sig.parameters for p in factory_kwargs)
241
+
242
+ @wraps(fn)
243
+ def _fn(*args, out=None, **kwargs):
244
+ if is_factory_fn and out is not None:
245
+ for k in factory_kwargs:
246
+ out_attr = getattr(out, k)
247
+ if k not in kwargs:
248
+ kwargs[k] = out_attr
249
+ if pass_is_out:
250
+ result = fn(*args, is_out=(out is not None), **kwargs)
251
+ else:
252
+ result = fn(*args, **kwargs)
253
+ assert (
254
+ isinstance(result, TensorLike)
255
+ and is_tensor
256
+ or isinstance(result, Tuple) # type: ignore[arg-type]
257
+ and len(result) == len(out_names)
258
+ )
259
+ if out is not None:
260
+ # Naively you might expect this assert to be true, but
261
+ # it's not:
262
+ #
263
+ # assert type(out) == type(result)
264
+ #
265
+ # The reason is that functions under this wrapper can
266
+ # get registered to the Meta dispatch key, and that
267
+ # means they can be executed in a context where tensor
268
+ # subclasses are disabled (with no_dispatch), which is a
269
+ # handy way for an is-a tensor subclass (e.g.,
270
+ # FakeTensor) to have the normal meta backend create a
271
+ # meta tensor, to be wrapped once it gets returned.
272
+ # In this situation, you will get a FakeTensor as
273
+ # the output tensor, but not the result--which will
274
+ # be a normal meta tensor, but this is perfectly
275
+ # harmless.
276
+ if is_tensor:
277
+ assert isinstance(out, TensorLike)
278
+ # These two operations are done in-place
279
+ _maybe_resize_out(out, result.shape)
280
+ _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
281
+ else:
282
+ assert isinstance(out, Tuple) # type: ignore[arg-type]
283
+ torch._check_type(
284
+ len(out) == len(result),
285
+ lambda: f"expected tuple of {len(result)} elements but got {len(out)}",
286
+ )
287
+ for r, o in zip(result, out):
288
+ # These two operations are done in-place
289
+ _maybe_resize_out(o, r.shape)
290
+ _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type]
291
+ else:
292
+ out = result
293
+ # mypy does not see through the definition of out_type given that it's in a different scope
294
+ return out if is_tensor else return_type(*out) # type: ignore[operator]
295
+
296
+ out_param = inspect.Parameter(
297
+ "out",
298
+ kind=inspect.Parameter.KEYWORD_ONLY,
299
+ default=None,
300
+ annotation=out_type,
301
+ )
302
+ # Mark that the function now returns a tuple
303
+ assert isinstance(sig.return_annotation, str) or sig.return_annotation in (
304
+ sig.empty,
305
+ out_type,
306
+ )
307
+ params = chain(sig.parameters.values(), (out_param,))
308
+ _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
309
+ parameters=params, return_annotation=return_type # type: ignore[arg-type]
310
+ )
311
+
312
+ _fn.__annotations__ = fn.__annotations__
313
+ _fn.__annotations__["out"] = out_type
314
+ _fn.__annotations__["return"] = return_type
315
+
316
+ # In the special case of having a single tensor out parameter with a
317
+ # name other than out, add a special annotation to name the parameter
318
+ if is_tensor and out_names != default_out_names:
319
+ _fn.__annotations__[CustomOutParamAnnotation] = out_names[0]
320
+
321
+ # Add an indicator attribute that can be used in special cases
322
+ # where having a function wrapped by `out_wrapper` is not desirable e.g.
323
+ # jit
324
+ _fn._torch_decompositions_out_wrapper = f"This function is wrapped by {out_wrapper.__module__}.out_wrapper" # type: ignore[attr-defined]
325
+
326
+ return _fn
327
+
328
+ return _out_wrapper
329
+
330
+
331
+ def _maybe_remove_out_wrapper(fn: Callable):
332
+ return inspect.unwrap(
333
+ fn,
334
+ stop=lambda f: not hasattr(f, "_torch_decompositions_out_wrapper"),
335
+ )
336
+
337
+
338
+ def backwards_not_supported(prim):
339
+ def redispatch_prim(args, kwargs):
340
+ with torch._C._AutoDispatchBelowAutograd():
341
+ old = torch._C._dispatch_tls_is_dispatch_key_excluded(
342
+ torch._C.DispatchKey.ADInplaceOrView
343
+ )
344
+ return prim(*args, **kwargs)
345
+
346
+ class BackwardsNotSupported(torch.autograd.Function):
347
+ @staticmethod
348
+ def forward(ctx, args_spec, *flat_args):
349
+ args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type]
350
+ return redispatch_prim(args, kwargs)
351
+
352
+ @staticmethod
353
+ def backward(ctx, *args):
354
+ raise RuntimeError("backwards not supported on prim")
355
+
356
+ @wraps(prim)
357
+ def _autograd_impl(*args, **kwargs):
358
+ flat_args, args_spec = tree_flatten((args, kwargs))
359
+ if torch.is_grad_enabled() and any(
360
+ a.requires_grad for a in flat_args if isinstance(a, torch.Tensor)
361
+ ):
362
+ # TODO: There is a subtle bug here: prims like copy_to
363
+ # return their input argument after mutating it; and custom
364
+ # autograd function will incorrectly turn the result into
365
+ # a view which will fail test_python_ref_executor tests.
366
+ # At the moment, we sidestep this by observing that the
367
+ # unit tests don't ever try to run the executor with
368
+ # autograd, so we don't exercise the buggy case, but if
369
+ # you ever want to feed autograd through this, be aware
370
+ # of it! We need a way of properly implementing autograd
371
+ # for mutating operations in Python to do this.
372
+ return BackwardsNotSupported.apply(args_spec, *flat_args)
373
+ else:
374
+ return redispatch_prim(args, kwargs)
375
+
376
+ return _autograd_impl
377
+
378
+
379
+ # TODO: when tracing this will add torch tensors and not TensorMeta objects
380
+ # to the trace -- we should fix this by adding a tracing context and NumberMeta classes
381
+ # TODO: this wrapper is currently untested
382
+ def elementwise_unary_scalar_wrapper(fn: Callable) -> Callable:
383
+ """
384
+ Allows unary operators that accept tensors to work with Python numbers.
385
+ """
386
+ sig = inspect.signature(fn)
387
+
388
+ @wraps(fn)
389
+ def _fn(*args, **kwargs):
390
+ if len(args) > 0 and isinstance(args[0], Number):
391
+ dtype = utils.type_to_dtype(type(args[0]))
392
+ args_ = list(args)
393
+ args_[0] = torch.tensor(args[0], dtype=dtype)
394
+ result = fn(*args_, **kwargs)
395
+ assert isinstance(result, torch.Tensor)
396
+ return result.item()
397
+
398
+ return fn(*args, **kwargs)
399
+
400
+ _fn.__signature__ = sig # type: ignore[attr-defined]
401
+ return _fn
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/_numeric_suite_fx.py ADDED
@@ -0,0 +1,1025 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module contains tooling to compare weights and activations
3
+ across models. Example usage::
4
+
5
+ import copy
6
+ import torch
7
+ import torch.ao.quantization.quantize_fx as quantize_fx
8
+ import torch.ao.ns._numeric_suite_fx as ns
9
+
10
+ m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)).eval()
11
+ mp = quantize_fx.prepare_fx(m, {'': torch.ao.quantization.default_qconfig})
12
+ # We convert a copy because we need the original prepared model
13
+ # to be available for comparisons, and `quantize_fx.convert_fx` is inplace.
14
+ mq = quantize_fx.convert_fx(copy.deepcopy(mp))
15
+
16
+ #
17
+ # Comparing weights
18
+ #
19
+
20
+ # extract weight pairs
21
+ weight_comparison = ns.extract_weights('a', mp, 'b', mq)
22
+
23
+ # add SQNR for each comparison, inplace
24
+ ns.extend_logger_results_with_comparison(
25
+ weight_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
26
+ 'sqnr')
27
+
28
+ # weight_comparison contains the weights from `mp` and `mq` stored
29
+ # in pairs, and can be used for further analysis.
30
+
31
+
32
+ #
33
+ # Comparing activations, with error propagation
34
+ #
35
+
36
+ # add loggers
37
+ mp_ns, mq_ns = ns.add_loggers(
38
+ 'a', copy.deepcopy(mp),
39
+ 'b', copy.deepcopy(mq),
40
+ ns.OutputLogger)
41
+
42
+ # send an example datum to capture intermediate activations
43
+ datum = torch.randn(1, 1, 1, 1)
44
+ mp_ns(datum)
45
+ mq_ns(datum)
46
+
47
+ # extract intermediate activations
48
+ act_comparison = ns.extract_logger_info(
49
+ mp_ns, mq_ns, ns.OutputLogger, 'b')
50
+
51
+ # add SQNR for each comparison, inplace
52
+ ns.extend_logger_results_with_comparison(
53
+ act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
54
+ 'sqnr')
55
+
56
+ # act_comparison contains the activations from `mp_ns` and `mq_ns` stored
57
+ # in pairs, and can be used for further analysis.
58
+
59
+ #
60
+ # Comparing activations, without error propagation
61
+ #
62
+
63
+ # create shadow model
64
+ mp_shadows_mq = ns.add_shadow_loggers(
65
+ 'a', copy.deepcopy(mp),
66
+ 'b', copy.deepcopy(mq),
67
+ ns.OutputLogger)
68
+
69
+ # send an example datum to capture intermediate activations
70
+ datum = torch.randn(1, 1, 1, 1)
71
+ mp_shadows_mq(datum)
72
+
73
+ # extract intermediate activations
74
+ shadow_act_comparison = ns.extract_shadow_logger_info(
75
+ mp_shadows_mq, ns.OutputLogger, 'b')
76
+
77
+ # add SQNR for each comparison, inplace
78
+ ns.extend_logger_results_with_comparison(
79
+ shadow_act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
80
+ 'sqnr')
81
+
82
+ # shadow_act_comparison contains the activations from `mp_ns` and `mq_ns` stored
83
+ # in pairs, and can be used for further analysis.
84
+
85
+ """
86
+
87
+ import collections
88
+
89
+ import torch
90
+ import torch.nn as nn
91
+ import torch.ao.quantization.quantize_fx as quantize_fx
92
+ from torch.fx import GraphModule
93
+ from torch.fx.graph import Node
94
+ from torch.ao.ns.fx.mappings import (
95
+ get_base_name_to_sets_of_related_ops,
96
+ )
97
+ from torch.ao.ns.fx.graph_matcher import (
98
+ get_matching_subgraph_pairs,
99
+ get_type_a_related_to_b,
100
+ )
101
+
102
+ from .fx.weight_utils import (
103
+ extract_weight_from_node,
104
+ )
105
+
106
+ from .fx.graph_passes import (
107
+ add_loggers_to_model,
108
+ create_a_shadows_b,
109
+ )
110
+
111
+ from .fx.utils import (
112
+ rekey_logger_info_on_node_name_of_model,
113
+ maybe_add_missing_fqns,
114
+ get_target_type_str,
115
+ )
116
+
117
+ from .fx.ns_types import (
118
+ NSSingleResultValuesType,
119
+ NSResultsType,
120
+ NSNodeTargetType,
121
+ )
122
+ from torch.ao.quantization.backend_config.utils import get_fusion_pattern_to_root_node_getter
123
+ from torch.ao.quantization.backend_config import BackendConfig
124
+ from torch.ao.quantization.fx.match_utils import _find_matches
125
+ from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr
126
+ from torch.ao.quantization.fx.qconfig_mapping_utils import _generate_node_name_to_qconfig
127
+ from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
128
+ from torch.ao.quantization.qconfig import QConfigAny
129
+ from torch.ao.quantization import QConfigMapping
130
+ from torch.ao.ns.fx.n_shadows_utils import (
131
+ OutputProp,
132
+ _get_dedup_subgraphs,
133
+ SHADOW_WRAPPER_NODE_NAME_PREFIX,
134
+ group_results_by_subgraph,
135
+ create_results_comparison,
136
+ print_n_shadows_summary,
137
+ create_n_transformed_and_logged_copies_of_subgraph,
138
+ create_add_loggers_graph,
139
+ extract_weight_comparison,
140
+ )
141
+ from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping
142
+
143
+ from typing import Dict, Tuple, Callable, List, Optional, Set, Any, Type
144
+
145
+ RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
146
+
147
+ class OutputLogger(nn.Module):
148
+ """
149
+ Base class for capturing intermediate values.
150
+ """
151
+ stats: List[torch.Tensor]
152
+ stats_rnn: List[RNNReturnType]
153
+
154
+ # Mark as impure so that calls to it will not be removed during DCE.
155
+ _is_impure = True
156
+
157
+ def __init__(
158
+ self,
159
+ ref_node_name: str,
160
+ prev_node_name: str,
161
+ model_name: str,
162
+ ref_name: str,
163
+ prev_node_target_type: str,
164
+ ref_node_target_type: str,
165
+ results_type: str,
166
+ index_within_arg: int,
167
+ index_of_arg: int,
168
+ fqn: Optional[str],
169
+ qconfig_str: Optional[str] = '',
170
+ ):
171
+ super().__init__()
172
+ self.stats: List[torch.Tensor] = []
173
+ self.stats_rnn: List[RNNReturnType] = []
174
+
175
+ # name of the node which was responsible for adding this logger
176
+ # Note:
177
+ # - if we are logging node outputs, this is the same as prev_node_name
178
+ # - if we are logging node inputs, this is the name of the node
179
+ # whose input this logger is logging.
180
+ #
181
+ # example, where logger1 is logging input of op1 and logger2 is logging
182
+ # the output of op1:
183
+ #
184
+ # x1 -> logger1 -> op1 -> logger2 -> x2
185
+ #
186
+ # in this example,
187
+ # - logger1's prev_node_name is x1 and ref_node_name is op1
188
+ # - logger2's prev_node_name is op1 and ref_node_name is op1
189
+ self.ref_node_name = ref_node_name
190
+ # name of the node whose output this Logger is capturing
191
+ self.prev_node_name = prev_node_name
192
+
193
+ # name of the model from which the node originated from
194
+ self.model_name = model_name
195
+ # reference name, used to match loggers from separate models
196
+ # to each other
197
+ self.ref_name = ref_name
198
+ # type of the target of the node whose output this logger is logging
199
+ self.prev_node_target_type = prev_node_target_type
200
+ # type of the target of the node which was responsible for adding this
201
+ # logger
202
+ self.ref_node_target_type = ref_node_target_type
203
+ # what kind of values are inside of stats
204
+ self.results_type = results_type
205
+ # index of this node within the arg of the input/output node
206
+ # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
207
+ self.index_within_arg = index_within_arg
208
+ # index of this node within the args of the input/output node
209
+ # for example, in add(x1, x2), x2 would have index_of_arg == 1
210
+ self.index_of_arg = index_of_arg
211
+ # fully qualified name
212
+ self.fqn = fqn
213
+ # if loggers are added before prepare_fx, but we do not want
214
+ # collect results of calibration, only results after convert_fx
215
+ # so, we add a flag to control whether this logger collects data
216
+ self.enabled = True
217
+ # string representation of qconfig
218
+ self.qconfig_str = qconfig_str
219
+ # this can be turned off to reduce memory usage during calibration
220
+ self.save_activations = True
221
+
222
+ # Note: cannot annotate the type of x because TorchScript does not support
223
+ # the Union type.
224
+ def forward(self, x):
225
+ """
226
+ """ # blank docblock to make autodoc happy
227
+ # TODO(future PR): consider designing this better, as the difference
228
+ # between these two flags is subtle and not obvious.
229
+ if not self.enabled:
230
+ return x
231
+ if not self.save_activations:
232
+ return x
233
+ # TODO(future PR): consider refactoring this to better reuse the parent
234
+ # class
235
+ if isinstance(x, torch.Tensor):
236
+ self.stats.append(x.detach())
237
+ elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
238
+ new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
239
+ self.stats_rnn.append(new_res)
240
+ return x
241
+
242
+ def __repr__(self):
243
+ clean_dict = {
244
+ k: v
245
+ for k, v in self.__dict__.items()
246
+ # skip nn.Module keys
247
+ if (k != 'training') and not k.startswith('_')
248
+ }
249
+ return f"OutputLogger({clean_dict})"
250
+
251
+
252
+ class OutputComparisonLogger(OutputLogger):
253
+ """
254
+ Same as OutputLogger, but also requires the original activation
255
+ in order to calculate the comparison at calibration time
256
+ """
257
+
258
+ def __init__(self, *args, **kwargs):
259
+ super().__init__(*args, **kwargs)
260
+ # TODO(future PR): make the comparison function configurable
261
+ self.comparison_fn = torch.ao.ns.fx.utils.compute_sqnr
262
+ self.comparison_fn_name = 'sqnr'
263
+ # precalculated comparisons of logger output versus reference
264
+ self.comparisons = []
265
+ # precalculated comparisons function
266
+
267
+ def forward(self, x, x_ref):
268
+ """
269
+ """ # blank docblock to make autodoc happy
270
+ if not self.enabled:
271
+ return x
272
+ assert isinstance(x, torch.Tensor), 'non-tensor inputs not yet supported'
273
+ if self.save_activations:
274
+ # save the activation, for debugging
275
+ self.stats.append(x.detach())
276
+ # save the comparison
277
+ self.comparisons.append(self.comparison_fn(x, x_ref))
278
+ return x
279
+
280
+ def __repr__(self):
281
+ clean_dict = {
282
+ k: v
283
+ for k, v in self.__dict__.items()
284
+ # skip nn.Module keys
285
+ if (k != 'training') and not k.startswith('_')
286
+ }
287
+ return f"OutputComparisonLogger({clean_dict})"
288
+
289
+
290
+ class NSTracer(quantize_fx.QuantizationTracer):
291
+ """
292
+ Just like a regular FX quantization tracer, but treats observers and fake_quantize
293
+ modules as leaf modules.
294
+ """
295
+ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
296
+ """
297
+ """ # blank docblock to make autodoc happy
298
+ if isinstance(m, torch.ao.quantization.ObserverBase):
299
+ return True
300
+ elif isinstance(m, torch.ao.quantization.FakeQuantizeBase):
301
+ return True
302
+ return super().is_leaf_module(m, module_qualified_name)
303
+
304
+
305
+ def _extract_weights_one_model(
306
+ model_name: str,
307
+ model: GraphModule,
308
+ nodes_and_names_to_instrument: List[Tuple[Node, str]],
309
+ results: NSResultsType,
310
+ op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
311
+ ) -> None:
312
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model")
313
+ for node, ref_name in nodes_and_names_to_instrument:
314
+ res_type = NSSingleResultValuesType.WEIGHT.value
315
+ extracted_weight = extract_weight_from_node(
316
+ node, model, op_to_type_to_weight_extraction_fn)
317
+ if extracted_weight:
318
+ if ref_name not in results:
319
+ results[ref_name] = {res_type: {}}
320
+ results[ref_name][res_type][model_name] = [extracted_weight]
321
+
322
+
323
+ def _extract_weights_impl(
324
+ model_name_a: str,
325
+ gm_a: GraphModule,
326
+ model_name_b: str,
327
+ gm_b: GraphModule,
328
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
329
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
330
+ op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
331
+ ) -> NSResultsType:
332
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl")
333
+ matched_subgraph_pairs = get_matching_subgraph_pairs(
334
+ gm_a, gm_b, base_name_to_sets_of_related_ops,
335
+ unmatchable_types_map)
336
+
337
+ # split the subgraph pairs into one data structure for each model
338
+ nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
339
+ nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
340
+ for match_name, match in matched_subgraph_pairs.items():
341
+ subgraph_a, subgraph_b = match
342
+ nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
343
+ nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
344
+
345
+ # populate the results, one model at a time
346
+ results: NSResultsType = {}
347
+ _extract_weights_one_model(
348
+ model_name_a, gm_a, nodes_and_names_to_instrument_a, results,
349
+ op_to_type_to_weight_extraction_fn)
350
+ _extract_weights_one_model(
351
+ model_name_b, gm_b, nodes_and_names_to_instrument_b, results,
352
+ op_to_type_to_weight_extraction_fn)
353
+
354
+ # fill in missing fqn entries
355
+ maybe_add_missing_fqns(results)
356
+
357
+ # rekey on names of nodes in gm_b
358
+ results = rekey_logger_info_on_node_name_of_model(results, model_name_b)
359
+
360
+ return results
361
+
362
+
363
+ def extract_weights(
364
+ model_name_a: str,
365
+ model_a: nn.Module,
366
+ model_name_b: str,
367
+ model_b: nn.Module,
368
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
369
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
370
+ op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
371
+ ) -> NSResultsType:
372
+ """
373
+ Extract weights from model A and model B, and return a comparison.
374
+
375
+ Args:
376
+ model_name_a: string name of model A to use in results
377
+ model_a: model A
378
+ model_name_b: string name of model B to use in results
379
+ model_b: model B
380
+ base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
381
+ unmatchable_types_map: optional override of unmatchable types, subject to change
382
+ op_to_type_to_weight_extraction_fn: optional override of function which extracts weight
383
+ from a type, subject to change
384
+
385
+ Return:
386
+ NSResultsType, containing the weight comparisons
387
+ """
388
+
389
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
390
+ if base_name_to_sets_of_related_ops is None:
391
+ base_name_to_sets_of_related_ops = \
392
+ get_base_name_to_sets_of_related_ops()
393
+ type_a_related_to_b = \
394
+ get_type_a_related_to_b(base_name_to_sets_of_related_ops)
395
+
396
+ # TODO(future PR): expose these
397
+ skipped_module_names: List[str] = []
398
+ skipped_module_classes: List[Callable] = []
399
+ tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
400
+ tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
401
+ gm_a = GraphModule(model_a, tracer_a.trace(model_a))
402
+ maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(model_a, 'node_name_to_scope')
403
+ if maybe_model_a_node_name_to_scope is not None:
404
+ gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
405
+ gm_b = GraphModule(model_b, tracer_b.trace(model_b))
406
+ maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(model_b, 'node_name_to_scope')
407
+ if maybe_model_b_node_name_to_scope is not None:
408
+ gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
409
+ return _extract_weights_impl(
410
+ model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops,
411
+ unmatchable_types_map, op_to_type_to_weight_extraction_fn)
412
+
413
+
414
+ def _add_loggers_one_model(
415
+ model_name: str,
416
+ model: GraphModule,
417
+ nodes_and_names_to_instrument_inputs: List[Tuple[Node, str, str]],
418
+ nodes_and_names_to_instrument_outputs: List[Tuple[Node, str, str]],
419
+ logger_cls: Callable,
420
+ ) -> nn.Module:
421
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_one_model")
422
+
423
+ # TODO(future PR): do not observe nodes we do not care
424
+ # about (both fp32, denylist, etc)
425
+ node_to_instrument_inputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
426
+ node_to_instrument_outputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
427
+ for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs:
428
+ node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type)
429
+ for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs:
430
+ node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type)
431
+
432
+ model = add_loggers_to_model(
433
+ model, node_to_instrument_inputs_to_ref_name,
434
+ node_to_instrument_outputs_to_ref_name, logger_cls, model_name)
435
+ return model
436
+
437
+
438
+ def _add_loggers_impl(
439
+ name_a: str,
440
+ gm_a: GraphModule,
441
+ name_b: str,
442
+ gm_b: GraphModule,
443
+ logger_cls: Callable,
444
+ should_log_inputs: bool,
445
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
446
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
447
+ ) -> Tuple[nn.Module, nn.Module]:
448
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl")
449
+ matched_subgraph_pairs = get_matching_subgraph_pairs(
450
+ gm_a, gm_b,
451
+ base_name_to_sets_of_related_ops, unmatchable_types_map)
452
+ nodes_and_names_to_instrument_inputs_a = []
453
+ nodes_and_names_to_instrument_inputs_b = []
454
+ nodes_and_names_to_instrument_outputs_a = []
455
+ nodes_and_names_to_instrument_outputs_b = []
456
+ for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
457
+ ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
458
+ ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
459
+ # Note: for matching inputs we use start_node, such as observing
460
+ # the input of linear in linear-relu
461
+ if should_log_inputs:
462
+ nodes_and_names_to_instrument_inputs_a.append(
463
+ (subgraph_a.start_node, match_name, ref_node_type_a))
464
+ nodes_and_names_to_instrument_inputs_b.append(
465
+ (subgraph_b.start_node, match_name, ref_node_type_b))
466
+ # Note: for matching activations we always use end_node,
467
+ # such as observing the output of relu in linear-relu
468
+ nodes_and_names_to_instrument_outputs_a.append(
469
+ (subgraph_a.end_node, match_name, ref_node_type_a))
470
+ nodes_and_names_to_instrument_outputs_b.append(
471
+ (subgraph_b.end_node, match_name, ref_node_type_b))
472
+
473
+ new_model_a = _add_loggers_one_model(
474
+ name_a, gm_a, nodes_and_names_to_instrument_inputs_a,
475
+ nodes_and_names_to_instrument_outputs_a, logger_cls)
476
+ new_model_b = _add_loggers_one_model(
477
+ name_b, gm_b, nodes_and_names_to_instrument_inputs_b,
478
+ nodes_and_names_to_instrument_outputs_b, logger_cls)
479
+ return (new_model_a, new_model_b)
480
+
481
+
482
+ def add_loggers(
483
+ name_a: str,
484
+ model_a: nn.Module,
485
+ name_b: str,
486
+ model_b: nn.Module,
487
+ logger_cls: Callable,
488
+ should_log_inputs : bool = False,
489
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
490
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
491
+ ) -> Tuple[nn.Module, nn.Module]:
492
+ """
493
+ Instrument model A and model B with loggers.
494
+
495
+ Args:
496
+ name_a: string name of model A to use in results
497
+ model_a: model A
498
+ name_b: string name of model B to use in results
499
+ model_b: model B
500
+ logger_cls: class of Logger to use
501
+ base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
502
+ unmatchable_types_map: optional override of unmatchable types, subject to change
503
+
504
+ Return:
505
+ Returns a tuple of (model_a_with_loggers, model_b_with_loggers). Modifies both models inplace.
506
+ """
507
+
508
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers")
509
+ # TODO(future PR): expose these
510
+ skipped_module_names: List[str] = []
511
+ skipped_module_classes: List[Callable] = []
512
+ tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
513
+ tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
514
+ gm_a = GraphModule(model_a, tracer_a.trace(model_a))
515
+ maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(model_a, 'node_name_to_scope')
516
+ if maybe_model_a_node_name_to_scope is not None:
517
+ gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
518
+ gm_b = GraphModule(model_b, tracer_b.trace(model_b))
519
+ maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(model_b, 'node_name_to_scope')
520
+ if maybe_model_b_node_name_to_scope is not None:
521
+ gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
522
+ return _add_loggers_impl(
523
+ name_a, gm_a, name_b, gm_b, logger_cls,
524
+ should_log_inputs=should_log_inputs,
525
+ base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
526
+ unmatchable_types_map=unmatchable_types_map)
527
+
528
+
529
+ def _extract_logger_info_one_model(
530
+ model: nn.Module,
531
+ results: NSResultsType,
532
+ logger_cls: Callable,
533
+ ) -> None:
534
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_logger_info_one_model")
535
+ for gm_name, mod in model.named_modules():
536
+ # TODO(future PR): better check when scripted
537
+ is_logger = (
538
+ isinstance(mod, logger_cls) # type: ignore[arg-type]
539
+ or (
540
+ isinstance(mod, torch.jit.RecursiveScriptModule)
541
+ and mod.original_name == 'OutputLogger'
542
+ )
543
+ )
544
+ if is_logger:
545
+ key = mod.ref_name
546
+ if key not in results:
547
+ results[key] = {}
548
+ assert mod.model_name not in results[key], \
549
+ f"{mod.model_name} is already present in results"
550
+ if mod.results_type not in results[key]:
551
+ results[key][mod.results_type] = {}
552
+ if mod.model_name not in results[key][mod.results_type]:
553
+ results[key][mod.results_type][mod.model_name] = []
554
+ stats_to_use = mod.stats
555
+ if len(mod.stats_rnn) > 0:
556
+ stats_to_use = mod.stats_rnn
557
+ data = {
558
+ 'type': mod.results_type,
559
+ 'values': stats_to_use,
560
+ 'ref_node_name': mod.ref_node_name,
561
+ 'ref_node_target_type': mod.ref_node_target_type,
562
+ 'prev_node_name': mod.prev_node_name,
563
+ 'prev_node_target_type': mod.prev_node_target_type,
564
+ 'index_within_arg': mod.index_within_arg,
565
+ 'index_of_arg': mod.index_of_arg,
566
+ 'fqn': mod.fqn,
567
+ 'qconfig_str': mod.qconfig_str,
568
+ }
569
+ if hasattr(mod, 'comparisons'):
570
+ data['comparisons'] = mod.comparisons
571
+ data['comparison_fn_name'] = mod.comparison_fn_name
572
+ else:
573
+ data['comparisons'] = []
574
+ data['comparison_fn_name'] = ''
575
+ results[key][mod.results_type][mod.model_name].append(data)
576
+ # ensure the list stays sorted
577
+ results[key][mod.results_type][mod.model_name].sort(
578
+ key=lambda res:
579
+ f"{res['index_of_arg']}:{res['index_within_arg']}"
580
+ )
581
+
582
+
583
+ # TODO(future PR): align on naming
584
+ # this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
585
+ def extract_logger_info(
586
+ model_a: nn.Module,
587
+ model_b: nn.Module,
588
+ logger_cls: Callable,
589
+ model_name_to_use_for_layer_names: str,
590
+ ) -> NSResultsType:
591
+ """
592
+ Traverse all loggers in `model_a` and `model_b`, and extract the logged
593
+ information.
594
+
595
+ Args:
596
+ model_a: model A
597
+ model_b: model B
598
+ logger_cls: class of Logger to use
599
+ model_name_to_use_for_layer_names: string name of model to use for
600
+ layer names in the output
601
+
602
+ Return:
603
+ NSResultsType, containing the logged comparisons
604
+ """
605
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_logger_info")
606
+ results: NSResultsType = {}
607
+ for model in (model_a, model_b):
608
+ _extract_logger_info_one_model(model, results, logger_cls)
609
+ # fill in missing fqn entries
610
+ maybe_add_missing_fqns(results)
611
+ # rekey on the name of model b
612
+ results = rekey_logger_info_on_node_name_of_model(
613
+ results, model_name_to_use_for_layer_names)
614
+ return results
615
+
616
+
617
+ def _add_shadow_loggers_impl(
618
+ name_a: str,
619
+ gm_a: GraphModule,
620
+ name_b: str,
621
+ gm_b: GraphModule,
622
+ logger_cls: Callable,
623
+ should_log_inputs: bool,
624
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
625
+ node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
626
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
627
+ ) -> nn.Module:
628
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_shadow_loggers_impl")
629
+ matched_subgraph_pairs = get_matching_subgraph_pairs(
630
+ gm_a, gm_b, base_name_to_sets_of_related_ops,
631
+ unmatchable_types_map)
632
+ gm_a_shadows_b = create_a_shadows_b(
633
+ name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls,
634
+ should_log_inputs=should_log_inputs,
635
+ node_type_to_io_type_map=node_type_to_io_type_map)
636
+ return gm_a_shadows_b
637
+
638
+
639
+ def add_shadow_loggers(
640
+ name_a: str,
641
+ model_a: nn.Module,
642
+ name_b: str,
643
+ model_b: nn.Module,
644
+ logger_cls: Callable,
645
+ should_log_inputs: bool = False,
646
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
647
+ node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
648
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
649
+ ) -> nn.Module:
650
+ """
651
+ Instrument model A and model B with shadow loggers.
652
+
653
+ Args:
654
+ name_a: string name of model A to use in results
655
+ model_a: model A
656
+ name_b: string name of model B to use in results
657
+ model_b: model B
658
+ logger_cls: class of Logger to use
659
+ should_log_inputs: whether to log inputs
660
+ base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
661
+ unmatchable_types_map: optional override of unmatchable types, subject to change
662
+ """
663
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_shadow_loggers")
664
+ # TODO(future PR): expose these
665
+ skipped_module_names: List[str] = []
666
+ skipped_module_classes: List[Callable] = []
667
+ tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
668
+ tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
669
+ gm_a = GraphModule(model_a, tracer_a.trace(model_a))
670
+ maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(model_a, 'node_name_to_scope')
671
+ if maybe_model_a_node_name_to_scope is not None:
672
+ gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
673
+ gm_b = GraphModule(model_b, tracer_b.trace(model_b))
674
+ maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(model_b, 'node_name_to_scope')
675
+ if maybe_model_b_node_name_to_scope is not None:
676
+ gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
677
+ return _add_shadow_loggers_impl(
678
+ name_a, gm_a, name_b, gm_b, logger_cls,
679
+ should_log_inputs=should_log_inputs,
680
+ base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
681
+ node_type_to_io_type_map=node_type_to_io_type_map,
682
+ unmatchable_types_map=unmatchable_types_map)
683
+
684
+
685
+ def extract_shadow_logger_info(
686
+ model_a_shadows_b: nn.Module,
687
+ logger_cls: Callable,
688
+ model_name_to_use_for_layer_names: str,
689
+ ) -> NSResultsType:
690
+ """
691
+ Traverse all loggers in a shadow model, and extract the logged
692
+ information.
693
+
694
+ Args:
695
+ model_a_shadows_b: shadow model
696
+ logger_cls: class of Logger to use
697
+ model_name_to_use_for_layer_names: string name of model to use for
698
+ layer names in the output
699
+
700
+ Return:
701
+ NSResultsType, containing the logged comparisons
702
+ """
703
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_shadow_logger_info")
704
+ results: NSResultsType = collections.defaultdict(dict)
705
+ _extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
706
+ # fill in missing fqn entries
707
+ maybe_add_missing_fqns(results)
708
+ # rekey on the name of model b
709
+ results = rekey_logger_info_on_node_name_of_model(
710
+ results, model_name_to_use_for_layer_names)
711
+ return dict(results)
712
+
713
+
714
+ def extend_logger_results_with_comparison(
715
+ results: NSResultsType,
716
+ model_name_1: str,
717
+ model_name_2: str,
718
+ comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
719
+ comparison_name: str,
720
+ ) -> None:
721
+ """
722
+ Compares the logged values from `model_name_2` against the corresponding
723
+ values in `model_name_1`, using `comparison_fn`. Records the result
724
+ in `model_name_2`'s results under `comparison_name`. Modifies `results` inplace.
725
+
726
+ Args:
727
+ results: the result data structure from `extract_logger_info` or
728
+ `extract_shadow_logger_info`.
729
+ model_name_1: string name of model 1
730
+ model_name_2: string name of model 2
731
+ comparison_fn: function to compare two Tensors
732
+ comparison_name: string name of model to use for
733
+ layer names in the output
734
+ """
735
+ for results_type_to_results in results.values():
736
+ for model_name_to_results in results_type_to_results.values():
737
+ assert model_name_1 in model_name_to_results, \
738
+ f"{model_name_1} not found in results"
739
+ assert model_name_2 in model_name_to_results, \
740
+ f"{model_name_2} not found in results"
741
+
742
+ results_1 = model_name_to_results[model_name_1]
743
+ results_2 = model_name_to_results[model_name_2]
744
+
745
+ for result_2 in results_2:
746
+ index_within_arg_2 = result_2['index_within_arg']
747
+ index_of_arg_2 = result_2['index_of_arg']
748
+ # find corresponding result_1
749
+ result_1 = None
750
+ for cur_result_1 in results_1:
751
+ index_within_arg_1 = cur_result_1['index_within_arg']
752
+ index_of_arg_1 = cur_result_1['index_of_arg']
753
+ if (
754
+ (index_within_arg_1 == index_within_arg_2) and
755
+ (index_of_arg_1 == index_of_arg_2)
756
+ ):
757
+ result_1 = cur_result_1
758
+ break
759
+ assert result_1 is not None
760
+
761
+ values_1 = result_1['values']
762
+ values_2 = result_2['values']
763
+ result_2[comparison_name] = []
764
+ for value_1, value_2 in zip(values_1, values_2):
765
+ comparison_result = comparison_fn(value_1, value_2)
766
+ result_2[comparison_name].append(comparison_result)
767
+
768
+ def prepare_n_shadows_model(
769
+ model: torch.nn.Module,
770
+ example_inputs: Any,
771
+ qconfig_multi_mapping: QConfigMultiMapping,
772
+ backend_config: BackendConfig,
773
+ custom_prepare_fn: Optional[Callable] = None,
774
+ custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
775
+ custom_tracer: Any = None,
776
+ ) -> GraphModule:
777
+ """
778
+ Given a model with a graph with M ops such as
779
+
780
+
781
+ args_kwargs_m -> op_m -> output_m
782
+
783
+
784
+ And a set of N qconfigs for each op, creates a new model, with
785
+ each of the subgraph of `op_m` transformed into
786
+
787
+ .. code::
788
+
789
+ |---------> op_m_n -> log_m_n
790
+ | /
791
+ args_kwargs_m ---------> op_m -> log_m_0
792
+
793
+ Where op_m_n is op_m wrapped in a submodule and transformed with
794
+ qconfig_n, and its inner graph looks like
795
+
796
+ .. code::
797
+
798
+ args_m -------- op_m_prepared_with_qconfig_n -> out_m_n
799
+ /
800
+ kwargs_m ---
801
+
802
+ This is useful for testing different quantization of multiple layers in
803
+ a single pass through the model.
804
+
805
+ High level TODOs for future PRs:
806
+ * figure out a better way to name the output structure
807
+ * return a results data structure instead of printing it out
808
+ * add examples to docblocks
809
+ """
810
+
811
+ if custom_tracer is None:
812
+ tracer = quantize_fx.QuantizationTracer([], [])
813
+ else:
814
+ tracer = custom_tracer
815
+ mt = torch.fx.GraphModule(model, tracer.trace(model))
816
+ # this is necessary to ensure logger FQNs get populated
817
+ mt._node_name_to_scope = tracer.node_name_to_scope
818
+
819
+ # run example input propagation, we need this to call prepare_fx on
820
+ # individual subgraphs
821
+ output_prop = OutputProp(mt)
822
+ output_prop.propagate(*example_inputs)
823
+
824
+ # Find the set of subgraphs in the original graph which we need to
825
+ # consider.
826
+ modules = dict(mt.named_modules(remove_duplicate=False))
827
+ patterns = _get_pattern_to_quantize_handlers(backend_config)
828
+ root_node_getter_mapping = \
829
+ get_fusion_pattern_to_root_node_getter(backend_config)
830
+ standalone_module_names: List[str] = []
831
+ standalone_module_classes: List[Type] = []
832
+ custom_module_classes: List[Type] = []
833
+ matches = _find_matches(
834
+ mt.graph, modules, patterns, root_node_getter_mapping,
835
+ standalone_module_names, standalone_module_classes, custom_module_classes)
836
+ subgraphs_dedup: Dict[str, List[Node]] = \
837
+ _get_dedup_subgraphs(matches)
838
+
839
+ # generate node to qconfig for each subgraph
840
+ # TODO(future PR): deduplicate repeating entries
841
+ list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]] = []
842
+ for qconfig_mapping in qconfig_multi_mapping.qconfig_mappings_list:
843
+ node_name_to_qconfig = _generate_node_name_to_qconfig(
844
+ mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope)
845
+ list_of_node_name_to_qconfig.append(node_name_to_qconfig)
846
+
847
+ # For each region in the model, do the following:
848
+ # For each qconfig for that region, do the following:
849
+ # 1. create a copy of the region wrapped in a module
850
+ # 2. pass original args, original kwargs, and expected output to module
851
+ # 3. add an output comparison logger and hook it up to compare
852
+ # actual output to expected output
853
+ # 4. run `prepare_fx` on the module
854
+ for (subgraph_idx, (match_name, nodes_in_this_subgraph)) in \
855
+ enumerate(subgraphs_dedup.items()):
856
+ create_n_transformed_and_logged_copies_of_subgraph(
857
+ mt, subgraph_idx, match_name, nodes_in_this_subgraph,
858
+ qconfig_multi_mapping.qconfig_mappings_list, list_of_node_name_to_qconfig,
859
+ custom_prepare_fn, custom_prepare_kwargs # type: ignore[arg-type]
860
+ )
861
+
862
+ return mt
863
+
864
+ # TODO(future PR): we should rethink the names of all the PNP APIs
865
+ def _prepare_n_shadows_add_loggers_model(
866
+ model: torch.nn.Module,
867
+ example_inputs: Any,
868
+ qconfig_mapping: QConfigMapping,
869
+ backend_config: BackendConfig,
870
+ ) -> torch.nn.Module:
871
+ r"""
872
+ Note: this API is not recommended for wide usage, it is only
873
+ provided for customers who need to migrate from the `add_loggers`
874
+ API.
875
+
876
+ This creates a model which provides logging for the following
877
+ problem: if we quantize `model` with `qconfig_mapping` and feed
878
+ the same input through both models, log the comparisons of
879
+ corresponding intermediate layers.
880
+
881
+ The problem is solved with a single model. Specifically, we
882
+ partition `model` into N subgraphs, create a copy of each relevant
883
+ subgraph, wrap it in a module, apply the quantization API to that
884
+ module, and hook up loggers to measure the comparisons.
885
+
886
+ Example starting graph:
887
+
888
+ x0 -> op0 -> x1 -> op1 -> x2
889
+
890
+ Example config: quantize op0 to int8, do nothing to op1.
891
+ The following graph will be created:
892
+
893
+ .. code::
894
+
895
+ x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log
896
+ \ \ \ # noqa: W605
897
+ ---> op0_1 -> x1_1 ----> clog -> op1_0 -> x2_1 ----> clog
898
+
899
+ Where op0_0 is op0, op0_1 is op0 wrapped in a submodule and quantized
900
+ to int8, op1_0 is op1 (appearing in the graph twice), log is a logger,
901
+ and clog is a comparison logger.
902
+ """
903
+
904
+ tracer = quantize_fx.QuantizationTracer([], [])
905
+ mt = torch.fx.GraphModule(model, tracer.trace(model))
906
+ # this is necessary to ensure logger FQNs get populated
907
+ mt._node_name_to_scope = tracer.node_name_to_scope
908
+
909
+ # run example input propagation, we need this to call prepare_fx on
910
+ # individual subgraphs
911
+ output_prop = OutputProp(mt)
912
+ output_prop.propagate(*example_inputs)
913
+
914
+ # Find the set of subgraphs in the original graph which we need to
915
+ # consider.
916
+ modules = dict(mt.named_modules(remove_duplicate=False))
917
+ patterns = _get_pattern_to_quantize_handlers(backend_config)
918
+ root_node_getter_mapping = \
919
+ get_fusion_pattern_to_root_node_getter(backend_config)
920
+ standalone_module_names: List[str] = []
921
+ standalone_module_classes: List[Type] = []
922
+ custom_module_classes: List[Type] = []
923
+ matches = _find_matches(
924
+ mt.graph, modules, patterns, root_node_getter_mapping,
925
+ standalone_module_names, standalone_module_classes, custom_module_classes)
926
+ subgraphs_dedup: Dict[str, List[Node]] = \
927
+ _get_dedup_subgraphs(matches)
928
+
929
+ # generate node to qconfig for each subgraph
930
+ node_name_to_qconfig = _generate_node_name_to_qconfig(
931
+ mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope)
932
+
933
+ # Now, mutate the graph to be the add_loggers graph with propagation
934
+ # error.
935
+ create_add_loggers_graph(
936
+ mt, subgraphs_dedup, qconfig_mapping, node_name_to_qconfig)
937
+
938
+ return mt
939
+
940
+ # TODO(future PR): we should rethink the names of all the PNP APIs
941
+ def _n_shadows_compare_weights(
942
+ model: torch.nn.Module,
943
+ example_inputs: Any,
944
+ qconfig_mapping: QConfigMapping,
945
+ backend_config: BackendConfig,
946
+ ) -> NSResultsType:
947
+ """
948
+ Note: this API is not recommended for wide usage, it is only
949
+ provided for customers who need to migrate from the `add_loggers`
950
+ API.
951
+ """
952
+ qconfig_multi_mapping = \
953
+ QConfigMultiMapping.from_list_qconfig_mapping([qconfig_mapping])
954
+ mp = prepare_n_shadows_model(
955
+ model, example_inputs, qconfig_multi_mapping, backend_config)
956
+ # passing inputs through the model is necessary to populate
957
+ # observers which observe weights with real values
958
+ mp(*example_inputs)
959
+ mq = convert_n_shadows_model(mp)
960
+ weight_comparison = extract_weight_comparison(mq)
961
+ return weight_comparison
962
+
963
+ # TODO(future PR): consider aligning API signature with other similar quantization
964
+ # functions (enable_fake_quant, etc)
965
+ def loggers_set_enabled(model: torch.nn.Module, enabled: bool) -> None:
966
+ """
967
+ Sets the `enabled` setting on a `model`'s loggers
968
+ """
969
+ for name, child in model.named_modules():
970
+ if isinstance(child, OutputLogger):
971
+ child.enabled = enabled
972
+
973
+ # TODO(future PR): consider aligning API signature with other similar quantization
974
+ # functions (enable_fake_quant, etc)
975
+ def loggers_set_save_activations(
976
+ model: torch.nn.Module,
977
+ save_activations: bool,
978
+ ) -> None:
979
+ """
980
+ Sets the `save_activations` setting on a `model`'s loggers
981
+ """
982
+ for name, child in model.named_modules():
983
+ if isinstance(child, OutputLogger):
984
+ child.save_activations = save_activations
985
+
986
+ def convert_n_shadows_model(
987
+ model: GraphModule,
988
+ custom_convert_fn: Optional[Callable] = None,
989
+ custom_convert_kwargs: Optional[Dict[str, Any]] = None
990
+ ) -> GraphModule:
991
+ """
992
+ Given a model from `prepare_n_shadows_model`, runs `convert_fx`
993
+ on each shadow submodule.
994
+ """
995
+ for node in model.graph.nodes:
996
+ # TODO(future PR): consider matching in a safer way than
997
+ # node name string match
998
+ if node.name.startswith(SHADOW_WRAPPER_NODE_NAME_PREFIX):
999
+ orig_mod = getattr(model, node.name)
1000
+ if custom_convert_fn is None:
1001
+ converted_mod = torch.ao.quantization.quantize_fx.convert_fx(
1002
+ orig_mod)
1003
+ else:
1004
+ if custom_convert_kwargs is None:
1005
+ custom_convert_kwargs = {}
1006
+ converted_mod = custom_convert_fn(orig_mod, **custom_convert_kwargs)
1007
+ setattr(model, node.name, converted_mod)
1008
+
1009
+ return model
1010
+
1011
+ def extract_results_n_shadows_model(model: torch.nn.Module) -> NSResultsType:
1012
+ """
1013
+ Extracts logger results from `model`.
1014
+ """
1015
+ results: NSResultsType = {}
1016
+ _extract_logger_info_one_model(model, results, OutputLogger)
1017
+ return results
1018
+
1019
+ def print_comparisons_n_shadows_model(results: NSResultsType) -> None:
1020
+ """
1021
+ Prints a summary of extracted `results`.
1022
+ """
1023
+ results_grouped = group_results_by_subgraph(results)
1024
+ results_comparison = create_results_comparison(results_grouped)
1025
+ print_n_shadows_summary(results_comparison)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/mappings.py ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ toq = torch.ops.quantized
7
+
8
+ import torch.ao.nn.quantized as nnq
9
+ import torch.ao.nn.quantized.dynamic as nnqd
10
+ import torch.ao.nn.intrinsic.quantized as nniq
11
+ import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
12
+ import torch.ao.nn.intrinsic.qat as nniqat
13
+ import torch.ao.nn.intrinsic as nni
14
+ import torch.ao.nn.qat as nnqat
15
+ import torch.ao.nn.qat.dynamic as nnqatd
16
+ from torch.ao.quantization.backend_config import get_native_backend_config
17
+ import torch.ao.quantization.fx._lower_to_native_backend as \
18
+ _lower_to_native_backend
19
+ import torch.ao.quantization.quantization_mappings as quantization_mappings
20
+
21
+ from .ns_types import NSNodeTargetType
22
+
23
+ from typing import Callable, Dict, List, Optional, Set, Tuple
24
+
25
+
26
+ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
27
+ # note: this set is modified below by items from backend_config
28
+ sets_of_related_ops: List[Set[NSNodeTargetType]] = [
29
+ # conv modules
30
+ {
31
+ nn.Conv1d,
32
+ },
33
+ {
34
+ nn.Conv2d,
35
+ },
36
+ {
37
+ nn.Conv3d,
38
+ },
39
+ # conv functionals
40
+ {
41
+ F.conv1d,
42
+ },
43
+ {
44
+ F.conv2d,
45
+ },
46
+ {
47
+ F.conv3d,
48
+ },
49
+ # linear modules
50
+ {
51
+ nn.Linear,
52
+ },
53
+ # linear functionals
54
+ {
55
+ F.linear,
56
+ },
57
+ # average pool
58
+ {
59
+ nn.AvgPool1d,
60
+ torch.avg_pool1d,
61
+ },
62
+ {
63
+ nn.AvgPool2d,
64
+ torch._C._nn.avg_pool2d,
65
+ },
66
+ {
67
+ nn.AvgPool3d,
68
+ torch._C._nn.avg_pool3d,
69
+ },
70
+ # adaptive average pool
71
+ {
72
+ nn.AdaptiveAvgPool1d,
73
+ F.adaptive_avg_pool1d,
74
+ },
75
+ {
76
+ nn.AdaptiveAvgPool2d,
77
+ F.adaptive_avg_pool2d,
78
+ },
79
+ {
80
+ nn.AdaptiveAvgPool3d,
81
+ F.adaptive_avg_pool3d,
82
+ },
83
+ # LSTM
84
+ {
85
+ nn.LSTM,
86
+ },
87
+ # add
88
+ {
89
+ torch.add,
90
+ operator.add, # x + y
91
+ },
92
+ # cat
93
+ {
94
+ torch.cat,
95
+ },
96
+ # mul
97
+ {
98
+ torch.mul,
99
+ operator.mul,
100
+ },
101
+ # relu
102
+ {
103
+ F.relu,
104
+ nn.ReLU,
105
+ 'relu',
106
+ 'relu_',
107
+ torch.relu,
108
+ },
109
+ # maxpool
110
+ {
111
+ nn.MaxPool1d,
112
+ F.max_pool1d,
113
+ },
114
+ {
115
+ nn.MaxPool2d,
116
+ F.max_pool2d,
117
+ },
118
+ {
119
+ nn.MaxPool3d,
120
+ F.max_pool3d,
121
+ },
122
+ # sigmoid
123
+ {
124
+ torch.sigmoid,
125
+ 'sigmoid',
126
+ 'sigmoid_',
127
+ nn.Sigmoid,
128
+ F.sigmoid,
129
+ },
130
+ # BatchNorm
131
+ {
132
+ nn.BatchNorm2d,
133
+ },
134
+ {
135
+ nn.BatchNorm3d,
136
+ },
137
+ # ConvTranspose
138
+ {
139
+ nn.ConvTranspose1d,
140
+ },
141
+ {
142
+ nn.ConvTranspose2d,
143
+ },
144
+ {
145
+ nn.ConvTranspose3d,
146
+ },
147
+ # functional transposed conv
148
+ {
149
+ F.conv_transpose1d,
150
+ },
151
+ {
152
+ F.conv_transpose2d,
153
+ },
154
+ {
155
+ F.conv_transpose3d,
156
+ },
157
+ # ELU
158
+ {
159
+ nn.ELU,
160
+ },
161
+ # Embedding
162
+ {
163
+ nn.Embedding,
164
+ },
165
+ # EmbeddingBag
166
+ {
167
+ nn.EmbeddingBag,
168
+ },
169
+ # GroupNorm
170
+ {
171
+ nn.GroupNorm,
172
+ },
173
+ # Hardswish
174
+ {
175
+ nn.Hardswish,
176
+ },
177
+ # InstanceNorm
178
+ {
179
+ nn.InstanceNorm1d,
180
+ },
181
+ {
182
+ nn.InstanceNorm2d,
183
+ },
184
+ {
185
+ nn.InstanceNorm3d,
186
+ },
187
+ # LayerNorm
188
+ {
189
+ nn.LayerNorm,
190
+ },
191
+ # LeakyReLU
192
+ {
193
+ nn.LeakyReLU,
194
+ },
195
+ # ReLU6
196
+ {
197
+ nn.ReLU6,
198
+ F.relu6,
199
+ },
200
+ # F.elu
201
+ {
202
+ F.elu,
203
+ },
204
+ # F.hardswish
205
+ {
206
+ F.hardswish,
207
+ },
208
+ # F.group_norm
209
+ {
210
+ F.group_norm,
211
+ },
212
+ # F.instance_norm
213
+ {
214
+ F.instance_norm,
215
+ },
216
+ # F.layer_norm
217
+ {
218
+ F.layer_norm,
219
+ },
220
+ # F.leaky_relu
221
+ {
222
+ F.leaky_relu,
223
+ },
224
+ # F.silu
225
+ {
226
+ nn.SiLU,
227
+ F.silu,
228
+ },
229
+ # F.mish
230
+ {
231
+ nn.Mish,
232
+ F.mish,
233
+ },
234
+ # F.tanh
235
+ {
236
+ nn.Tanh,
237
+ F.tanh,
238
+ torch.tanh,
239
+ 'tanh_',
240
+ 'tanh',
241
+ },
242
+ # F.hardsigmoid
243
+ {
244
+ 'hardsigmoid_',
245
+ 'hardsigmoid',
246
+ F.hardsigmoid,
247
+ nn.Hardsigmoid,
248
+ },
249
+ # F.hardtanh
250
+ {
251
+ nn.Hardtanh,
252
+ F.hardtanh,
253
+ F.hardtanh_,
254
+ },
255
+ # floordiv
256
+ {
257
+ operator.floordiv,
258
+ },
259
+ # unsqueeze
260
+ {
261
+ torch.unsqueeze,
262
+ },
263
+ # stack
264
+ {
265
+ torch.stack,
266
+ },
267
+ # squeeze
268
+ {
269
+ torch.squeeze,
270
+ },
271
+ # sort
272
+ {
273
+ torch.sort,
274
+ },
275
+ # repeat_interleave
276
+ {
277
+ torch.repeat_interleave,
278
+ },
279
+ # min
280
+ {
281
+ torch.min,
282
+ },
283
+ # mean
284
+ {
285
+ torch.mean,
286
+ },
287
+ # max
288
+ {
289
+ torch.max,
290
+ },
291
+ # transpose
292
+ {
293
+ torch.transpose,
294
+ },
295
+ # flatten
296
+ {
297
+ torch.flatten,
298
+ },
299
+ # clamp
300
+ {
301
+ torch.clamp,
302
+ },
303
+ # chunk
304
+ {
305
+ torch.chunk,
306
+ },
307
+ # interpolate
308
+ {
309
+ torch.nn.functional.interpolate,
310
+ },
311
+ # dropout
312
+ {
313
+ nn.Dropout,
314
+ },
315
+ # F.dropout
316
+ {
317
+ F.dropout,
318
+ },
319
+ # matmul
320
+ {
321
+ torch.matmul,
322
+ },
323
+ # Softmax
324
+ {
325
+ nn.Softmax,
326
+ },
327
+ # PReLU
328
+ {
329
+ nn.PReLU,
330
+ nnq.PReLU,
331
+ },
332
+ # F.prelu
333
+ {
334
+ F.prelu,
335
+ toq.prelu,
336
+ },
337
+ # pixel shuffle
338
+ {
339
+ nn.PixelShuffle,
340
+ },
341
+ {
342
+ F.pixel_shuffle,
343
+ },
344
+ # pixel unshuffle
345
+ {
346
+ nn.PixelUnshuffle,
347
+ },
348
+ {
349
+ F.pixel_unshuffle,
350
+ },
351
+ # narrow
352
+ {
353
+ torch.narrow,
354
+ },
355
+ ]
356
+
357
+ # for each floating point op, add versions of the op added by
358
+ # backend_config
359
+ backend_config = get_native_backend_config()
360
+
361
+ new_connections: List[Tuple[Callable, Callable]] = [
362
+ # technical debt edge case
363
+ (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear),
364
+ ]
365
+
366
+ for pattern, config in backend_config._pattern_complex_format_to_config.items():
367
+
368
+ # pattern format: (c, (b, a))
369
+ first_element = pattern
370
+ # look from the end, because pattern is in reverse order
371
+ while isinstance(first_element, (list, tuple)):
372
+ first_element = first_element[-1]
373
+
374
+ if config.fused_module is not None:
375
+ # case 1: pattern fuses a pattern of ops into an op
376
+ # example: nn.Conv1d, nn.ReLU fused into nni.ConvReLU1d
377
+ new_connections.append((first_element, config.fused_module))
378
+
379
+ if config.qat_module is not None:
380
+ # case 2: pattern swaps a module into a QAT module
381
+ # example: nni.ConvReLU1d swapped into nniqat.ConvReLU1d
382
+ new_connections.append((first_element, config.qat_module))
383
+
384
+ if config.reference_quantized_module is not None:
385
+ # case 3: reference version of floating point module, such as
386
+ # nn.Conv2d and nnqr.Conv2d
387
+ new_connections.append((first_element, config.reference_quantized_module))
388
+
389
+ #
390
+ # Add reference module swaps from default lowering path
391
+ #
392
+
393
+ for source_to_target in (
394
+ _lower_to_native_backend.STATIC_LOWER_MODULE_MAP,
395
+ _lower_to_native_backend.DYNAMIC_LOWER_MODULE_MAP,
396
+ _lower_to_native_backend.WEIGHT_ONLY_LOWER_MODULE_MAP,
397
+ _lower_to_native_backend.SPECIAL_PATTERN_LOWER_MODULE_MAP,
398
+ ):
399
+ for source, target in source_to_target.items(): # type: ignore[attr-defined]
400
+ new_connections.append((source, target))
401
+
402
+ for source_to_double_target in (
403
+ _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_MAP,
404
+ _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP,
405
+ _lower_to_native_backend.DYNAMIC_LOWER_FUSED_MODULE_MAP,
406
+ ):
407
+ for source, (target1, target2) in source_to_double_target.items(): # type: ignore[attr-defined]
408
+ new_connections.append((source, target1))
409
+ new_connections.append((source, target2))
410
+
411
+ #
412
+ # Add function swaps from default lowering path
413
+ #
414
+
415
+ for source, (target1, target2) in \
416
+ _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
417
+ new_connections.append((source, target1))
418
+ new_connections.append((source, target2))
419
+
420
+ for source_to_target in (
421
+ _lower_to_native_backend.QBIN_OP_MAPPING,
422
+ _lower_to_native_backend.QBIN_RELU_OP_MAPPING,
423
+ quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
424
+ ):
425
+ for source, target in source_to_target.items():
426
+ new_connections.append((source, target))
427
+
428
+ #
429
+ # Add other swaps, ideally in the future this could be removed
430
+ # after the lowering code stops using these.
431
+ #
432
+ for source_to_target in (
433
+ quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS,
434
+ ):
435
+ for source, target in source_to_target.items():
436
+ new_connections.append((source, target))
437
+
438
+
439
+ # add the new connections from backend_config
440
+ for item1, item2 in new_connections:
441
+ for set_of_related_ops in sets_of_related_ops:
442
+ if item1 in set_of_related_ops or item2 in set_of_related_ops:
443
+ set_of_related_ops.add(item1)
444
+ set_of_related_ops.add(item2)
445
+ break
446
+
447
+ base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]] = {}
448
+
449
+ counter = 0
450
+ for set_of_related_ops in sets_of_related_ops:
451
+ base_name = str(counter)
452
+ counter += 1
453
+ base_name_to_sets_of_related_ops[base_name] = set_of_related_ops
454
+
455
+ return base_name_to_sets_of_related_ops
456
+
457
+
458
+ def get_base_name_for_op(
459
+ base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
460
+ op: NSNodeTargetType,
461
+ ) -> Optional[str]:
462
+ for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items():
463
+ if op in set_of_related_ops:
464
+ return base_name
465
+ return None
466
+
467
+
468
+ def add_op_to_sets_of_related_ops(
469
+ base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
470
+ op: NSNodeTargetType,
471
+ related_op: Optional[NSNodeTargetType],
472
+ ) -> None:
473
+ if related_op is not None:
474
+ for set_of_related_ops in base_name_to_sets_of_related_ops.values():
475
+ if related_op in set_of_related_ops:
476
+ set_of_related_ops.add(op)
477
+ return
478
+ # if we got here, related_op was not found
479
+ raise AssertionError(f"{related_op} was not found")
480
+ else:
481
+ counter = 0
482
+ while str(counter) in base_name_to_sets_of_related_ops:
483
+ counter += 1
484
+ base_name_to_sets_of_related_ops[str(counter)] = {op}
485
+
486
+
487
+ # TODO(future PR): clean this up
488
+ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
489
+ FUNS_IO_TYPE_FP32: Set[NSNodeTargetType] = {
490
+ F.linear,
491
+ F.conv1d,
492
+ F.conv2d,
493
+ F.conv3d,
494
+ torch.cat,
495
+ F.elu,
496
+ F.hardswish,
497
+ F.instance_norm,
498
+ F.layer_norm,
499
+ F.leaky_relu,
500
+ F.dropout,
501
+ F.silu,
502
+ F.mish,
503
+ operator.add,
504
+ torch.add,
505
+ operator.mul,
506
+ torch.mul,
507
+ torch.sum,
508
+ F.prelu,
509
+ }
510
+
511
+ FUNS_IO_TYPE_FP16: Set[NSNodeTargetType] = set()
512
+
513
+ FUNS_IO_TYPE_INT8: Set[NSNodeTargetType] = {
514
+ toq.linear,
515
+ toq.linear_relu,
516
+ toq.conv1d,
517
+ toq.conv1d_relu,
518
+ toq.conv2d,
519
+ toq.conv2d_relu,
520
+ toq.conv3d,
521
+ toq.conv3d_relu,
522
+ toq.cat,
523
+ toq.elu,
524
+ toq.hardswish,
525
+ toq.instance_norm,
526
+ toq.layer_norm,
527
+ toq.leaky_relu,
528
+ toq.dropout,
529
+ toq.prelu,
530
+ # TODO(future PR): implement shadowing for binary ops and
531
+ # uncomment below
532
+ # toq.add,
533
+ # toq.mul,
534
+ }
535
+
536
+ FUNS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
537
+ F.relu,
538
+ F.tanh,
539
+ torch.tanh,
540
+ F.sigmoid,
541
+ torch.sigmoid,
542
+ F.hardsigmoid,
543
+ operator.floordiv,
544
+ torch.adaptive_avg_pool1d,
545
+ F.adaptive_avg_pool2d,
546
+ F.adaptive_avg_pool3d,
547
+ F.dropout,
548
+ F.hardtanh,
549
+ F.hardtanh_,
550
+ F.interpolate,
551
+ F.max_pool1d,
552
+ F.max_pool2d,
553
+ F.max_pool3d,
554
+ F.relu6,
555
+ F.pixel_shuffle,
556
+ F.pixel_unshuffle,
557
+ torch.avg_pool1d,
558
+ torch._C._nn.avg_pool2d,
559
+ torch._C._nn.avg_pool3d,
560
+ torch.cat,
561
+ torch.chunk,
562
+ torch.clamp,
563
+ torch.flatten,
564
+ torch.transpose,
565
+ torch.max,
566
+ torch.mean,
567
+ torch.min,
568
+ torch.narrow,
569
+ torch.repeat_interleave,
570
+ torch.sort,
571
+ torch.squeeze,
572
+ torch.stack,
573
+ torch.unsqueeze,
574
+ operator.add,
575
+ }
576
+
577
+ MODS_IO_TYPE_FP32: Set[NSNodeTargetType] = {
578
+ nn.Linear,
579
+ nnqat.Linear,
580
+ nnqatd.Linear,
581
+ nnqd.Linear,
582
+ torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
583
+ nn.Conv1d,
584
+ nn.Conv2d,
585
+ nn.Conv3d,
586
+ nnqat.Conv1d,
587
+ nnqat.Conv2d,
588
+ nnqat.Conv3d,
589
+ nnqat.Embedding,
590
+ nnqat.EmbeddingBag,
591
+ nn.LSTM,
592
+ # note: nnqd.Linear is an instance of nnq.Linear, so this
593
+ # check has to happen before the int8 module check
594
+ nnqd.LSTM,
595
+ nn.BatchNorm2d,
596
+ nn.BatchNorm3d,
597
+ nn.Dropout,
598
+ nn.ConvTranspose1d,
599
+ nn.ConvTranspose2d,
600
+ nn.ConvTranspose3d,
601
+ nn.ELU,
602
+ nn.GroupNorm,
603
+ nn.InstanceNorm1d,
604
+ nn.InstanceNorm2d,
605
+ nn.InstanceNorm3d,
606
+ nn.LayerNorm,
607
+ nn.Hardswish,
608
+ nn.LeakyReLU,
609
+ nn.ReLU6,
610
+ nn.SiLU,
611
+ nn.Mish,
612
+ nn.Softmax,
613
+ nn.PReLU,
614
+ nni.BNReLU2d,
615
+ nni.BNReLU3d,
616
+ nni.ConvReLU1d,
617
+ nni.ConvReLU2d,
618
+ nni.ConvReLU3d,
619
+ nni.LinearReLU,
620
+ nni.LinearBn1d,
621
+ nni.ConvBn1d,
622
+ nni.ConvBn2d,
623
+ nni.ConvBn3d,
624
+ nniqat.ConvBn1d,
625
+ nniqat.ConvBn2d,
626
+ nniqat.ConvBn3d,
627
+ nniqat.ConvBnReLU1d,
628
+ nniqat.ConvBnReLU2d,
629
+ nniqat.ConvBnReLU3d,
630
+ nniqat.ConvReLU1d,
631
+ nniqat.ConvReLU2d,
632
+ nniqat.ConvReLU3d,
633
+ nniqat.LinearReLU,
634
+ nniqat.LinearBn1d,
635
+ nniqd.LinearReLU,
636
+ nni.LinearLeakyReLU,
637
+ nni.LinearTanh,
638
+ nni.ConvAdd2d,
639
+ nni.ConvAddReLU2d,
640
+ }
641
+
642
+ MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = {
643
+ nnq.Linear,
644
+ nnq.Conv1d,
645
+ nnq.Conv2d,
646
+ nnq.Conv3d,
647
+ nnq.BatchNorm2d,
648
+ nnq.BatchNorm3d,
649
+ nnq.Dropout,
650
+ nnq.ConvTranspose1d,
651
+ nnq.ConvTranspose2d,
652
+ nnq.ELU,
653
+ nnq.InstanceNorm1d,
654
+ nnq.InstanceNorm2d,
655
+ nnq.InstanceNorm3d,
656
+ nnq.LayerNorm,
657
+ nnq.Hardswish,
658
+ nnq.LeakyReLU,
659
+ nnq.Embedding,
660
+ nnq.EmbeddingBag,
661
+ nnq.Dropout,
662
+ nnq.Softmax,
663
+ nnq.PReLU,
664
+ nniq.BNReLU2d,
665
+ nniq.BNReLU3d,
666
+ nniq.ConvReLU1d,
667
+ nniq.ConvReLU2d,
668
+ nniq.ConvReLU3d,
669
+ nniq.LinearReLU,
670
+ nniq.LinearLeakyReLU,
671
+ nniq.LinearTanh,
672
+ nniq.ConvAdd2d,
673
+ nniq.ConvAddReLU2d,
674
+ }
675
+
676
+ MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
677
+ nn.ReLU,
678
+ nn.Tanh,
679
+ nn.Sigmoid,
680
+ nn.Hardsigmoid,
681
+ nn.AdaptiveAvgPool1d,
682
+ nn.AdaptiveAvgPool2d,
683
+ nn.AdaptiveAvgPool3d,
684
+ nn.AvgPool1d,
685
+ nn.AvgPool2d,
686
+ nn.AvgPool3d,
687
+ nn.Dropout,
688
+ nn.Hardtanh,
689
+ nn.Identity,
690
+ nn.MaxPool1d,
691
+ nn.MaxPool2d,
692
+ nn.MaxPool3d,
693
+ nn.PixelShuffle,
694
+ nn.PixelUnshuffle,
695
+ nn.ReLU6,
696
+ }
697
+
698
+ METHS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
699
+ 'sigmoid_',
700
+ 'sigmoid',
701
+ 'tanh_',
702
+ 'tanh',
703
+ 'hardsigmoid_',
704
+ 'hardsigmoid',
705
+ 'relu_',
706
+ 'relu',
707
+ }
708
+
709
+ return {
710
+ 'funs_io_type_fp32': FUNS_IO_TYPE_FP32,
711
+ 'funs_io_type_fp16': FUNS_IO_TYPE_FP16,
712
+ 'funs_io_type_int8': FUNS_IO_TYPE_INT8,
713
+ 'funs_io_type_fp32_or_int8': FUNS_IO_TYPE_FP32_OR_INT8,
714
+ 'mods_io_type_fp32': MODS_IO_TYPE_FP32,
715
+ 'mods_io_type_int8': MODS_IO_TYPE_INT8,
716
+ 'mods_io_type_fp32_or_int8': MODS_IO_TYPE_FP32_OR_INT8,
717
+ 'meths_io_type_fp32_or_int8': METHS_IO_TYPE_FP32_OR_INT8,
718
+ }
719
+
720
+
721
+ def get_unmatchable_types_map() -> Dict[str, Set[NSNodeTargetType]]:
722
+
723
+ FUNS_UNMATCHABLE: Set[NSNodeTargetType] = {
724
+ torch.quantize_per_tensor,
725
+ operator.getitem,
726
+ }
727
+
728
+ MODS_UNMATCHABLE: Set[NSNodeTargetType] = {
729
+ nn.Identity,
730
+ }
731
+
732
+ METHS_UNMATCHABLE: Set[NSNodeTargetType] = {
733
+ 'to',
734
+ 'dequantize',
735
+ 'reshape',
736
+ 'view',
737
+ 'unsqueeze_',
738
+ 'unsqueeze',
739
+ 'transpose',
740
+ 'squeeze_',
741
+ 'squeeze',
742
+ 'size',
743
+ 'shape',
744
+ 'resize_',
745
+ 'repeat_interleave',
746
+ 'repeat',
747
+ 'permute',
748
+ 'numel',
749
+ 'mean',
750
+ 'detach_',
751
+ 'detach',
752
+ 'contiguous',
753
+ 'clamp',
754
+ 'chunk',
755
+ }
756
+
757
+ return {
758
+ 'funs_unmatchable': FUNS_UNMATCHABLE,
759
+ 'mods_unmatchable': MODS_UNMATCHABLE,
760
+ 'meths_unmatchable': METHS_UNMATCHABLE,
761
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/n_shadows_utils.py ADDED
@@ -0,0 +1,1311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.fx
3
+ from torch.fx import (
4
+ Node,
5
+ GraphModule,
6
+ Graph,
7
+ )
8
+
9
+ from torch.ao.ns.fx.utils import (
10
+ # TODO(future PR): make this work correctly for methods
11
+ get_target_type_str,
12
+ get_normalized_nth_input,
13
+ )
14
+ from torch.ao.ns.fx.ns_types import (
15
+ NSSingleResultValuesType,
16
+ NSResultsType,
17
+ )
18
+ from torch.ao.ns.fx.graph_passes import _maybe_get_fqn
19
+ from torch.ao.quantization import QConfigMapping
20
+ from torch.ao.quantization.qconfig import QConfigAny
21
+ from torch.ao.quantization.utils import getattr_from_fqn
22
+ from torch.ao.quantization.fx.match_utils import _MatchResult
23
+ from torch.utils._pytree import tree_map
24
+
25
+ import collections
26
+ import copy
27
+ from typing import List, Dict, Set, Tuple, Callable, Any, Optional
28
+ import operator
29
+
30
+ SHADOW_NODE_NAME_PREFIX = 'shadow'
31
+ SHADOW_WRAPPER_NODE_NAME_PREFIX = 'shadow_wrapper'
32
+
33
+ # TODO(future PR): reuse existing mapping instead of creating a new one
34
+ BINARY_FUNCTIONS = {
35
+ torch.add,
36
+ torch.Tensor.add,
37
+ operator.add,
38
+ torch.mul,
39
+ torch.Tensor.mul,
40
+ operator.mul,
41
+ }
42
+
43
+ def _get_attr_name(subgraph_idx, subgraph_candidate_idx):
44
+ return f"{SHADOW_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}"
45
+
46
+ def _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx):
47
+ return f"{SHADOW_WRAPPER_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}"
48
+
49
+
50
+ class OutputProp:
51
+ """
52
+ Output propagation (modeled from shape propagation).
53
+
54
+ Given a GraphModule and an example input, saves the output flowing
55
+ through each node on `node.traced_result`.
56
+
57
+ Code based on the example from
58
+ https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern
59
+ """
60
+ def __init__(self, mod):
61
+ self.mod = mod
62
+ self.graph = mod.graph
63
+ self.modules = dict(self.mod.named_modules())
64
+
65
+ def propagate(self, *args):
66
+ args_iter = iter(args)
67
+ env : Dict[str, Node] = {}
68
+
69
+ def load_arg(a):
70
+ return torch.fx.graph.map_arg(a, lambda n: env[n.name])
71
+
72
+ def fetch_attr(target : str):
73
+ target_atoms = target.split('.')
74
+ attr_itr = self.mod
75
+ for i, atom in enumerate(target_atoms):
76
+ if not hasattr(attr_itr, atom):
77
+ raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
78
+ attr_itr = getattr(attr_itr, atom)
79
+ return attr_itr
80
+
81
+ for node in self.graph.nodes:
82
+ if node.op == 'placeholder':
83
+ result = next(args_iter)
84
+ elif node.op == 'get_attr':
85
+ result = fetch_attr(node.target)
86
+ elif node.op == 'call_function':
87
+ result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
88
+ elif node.op == 'call_method':
89
+ self_obj, *args = load_arg(node.args)
90
+ kwargs = load_arg(node.kwargs)
91
+ result = getattr(self_obj, node.target)(*args, **kwargs)
92
+ elif node.op == 'call_module':
93
+ result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
94
+
95
+ if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined]
96
+ node.traced_result = result
97
+
98
+ env[node.name] = result
99
+
100
+ return None
101
+
102
+ def _get_dedup_subgraphs(
103
+ matches: Dict[str, _MatchResult]
104
+ ) -> Dict[str, List[Node]]:
105
+ # the original matches variable is unique by node, make it unique by subgraph
106
+ # instead
107
+ seen_nodes = set()
108
+ subgraphs_dedup = {}
109
+
110
+ # Dict items are not reversible until Python 3.8, so we hack it
111
+ # to be compatible with previous Python versions
112
+ # TODO(future PR): try reversed(list(matches.items()))
113
+ matches_items_reversed: List[Tuple[str, _MatchResult]] = []
114
+ for name, cur_match in matches.items():
115
+ matches_items_reversed.insert(0, (name, cur_match))
116
+
117
+ # Note: the order is important. `matches` currently provides the matches
118
+ # in reverse order. We would like to process the matches in non-reverse
119
+ # order, so that we can create an intuitive naming scheme, such as
120
+ # naming the first op's submodules `shadow_0_0` through `shadow_0_(n-1)`
121
+ for name, cur_match in matches_items_reversed: # type: ignore[call-overload]
122
+ was_seen = False
123
+ for node_or_tuple in cur_match[1]:
124
+
125
+ # Cur_match[1] has an unusual type. It says that it's a `List[Node]`,
126
+ # but it is really not. Furthermore, the contents of this field
127
+ # can change from match results of multiple nodes of the same pattern
128
+ #
129
+ # For example, for conv -> bn -> relu, we see
130
+ # match_results = {
131
+ # 'conv': (relu, [(bn, conv), relu], ...),
132
+ # 'bn': (relu, [(bn, conv), relu], ...),
133
+ # 'relu': (relu, [(bn, conv), relu], ...),
134
+ # }
135
+ #
136
+ # Ideally we should clean up the `find_matches` function to make
137
+ # this more intuitive. For the purposes of this prototype, we hack
138
+ # around it.
139
+
140
+ if isinstance(node_or_tuple, Node):
141
+ if node_or_tuple in seen_nodes:
142
+ was_seen = True
143
+ seen_nodes.add(node_or_tuple)
144
+
145
+ else:
146
+ assert isinstance(node_or_tuple, tuple)
147
+ for node in node_or_tuple:
148
+ assert isinstance(node, Node)
149
+ if node in seen_nodes:
150
+ was_seen = True
151
+ seen_nodes.add(node)
152
+
153
+ if was_seen:
154
+ continue
155
+
156
+ # Start with the unusual type, convert it to [op_0, ..., op_n]
157
+ list_of_nodes = []
158
+
159
+ if len(cur_match[1]) == 1:
160
+ list_of_nodes = cur_match[1]
161
+ else:
162
+ assert len(cur_match[1]) == 2
163
+ # either (a, b), or ((a, b), c) or (c, (a, b))
164
+ # cannot make any assumptions on order, not clear what the
165
+ # _find_matches function is doing to populate this
166
+ # TODO(future PR): make this code less confusing, see discussion
167
+ # in https://github.com/pytorch/pytorch/pull/80521/files#r975918836
168
+
169
+ def _order_nodes(node_a, node_b, node_c) -> List[Node]:
170
+ nodes = [node_a, node_b, node_c]
171
+ first_node = None
172
+ mid_node = None
173
+ last_node = None
174
+ for n in nodes:
175
+ prev_n = n.args[0]
176
+ next_n = next(iter(n.users))
177
+ if prev_n not in nodes:
178
+ first_node = n
179
+ elif next_n not in nodes:
180
+ last_node = n
181
+ else:
182
+ mid_node = n
183
+ assert first_node is not None and mid_node is not None and \
184
+ last_node is not None
185
+ assert mid_node.args[0] is first_node
186
+ assert last_node.args[0] is mid_node
187
+ return [last_node, mid_node, first_node]
188
+
189
+ if isinstance(cur_match[1][0], Node) and isinstance(cur_match[1][1], Node):
190
+ # (a, b)
191
+ list_of_nodes = cur_match[1]
192
+ elif isinstance(cur_match[1][0], tuple):
193
+ # ((a, b), c)
194
+ node_a, node_b = cur_match[1][0]
195
+ node_c = cur_match[1][1]
196
+ list_of_nodes = _order_nodes(node_a, node_b, node_c)
197
+ elif isinstance(cur_match[1][1], tuple):
198
+ # (a, (b, c))
199
+ node_a, node_b = cur_match[1][1]
200
+ node_c = cur_match[1][0]
201
+ list_of_nodes = _order_nodes(node_a, node_b, node_c)
202
+
203
+ # [node_n, ..., node_0], note that the order is reversed
204
+ # to make it chronological for simple subgraphs
205
+ list_of_nodes.reverse()
206
+ subgraphs_dedup[name] = list_of_nodes
207
+
208
+ return subgraphs_dedup
209
+
210
+ def _get_logger_for_subgraph(
211
+ model: GraphModule,
212
+ first_node: Node,
213
+ last_node: Node,
214
+ subgraph_idx: int,
215
+ subgraph_candidate_idx: int,
216
+ qconfig_str: str,
217
+ logger_cls: Callable,
218
+ fqn: Optional[str],
219
+ ) -> torch.nn.Module:
220
+ """
221
+ Given a model and a linear subgraph starting from `first_node` and
222
+ ending with `last_node`, creates a logger for the end of this
223
+ subgraph.
224
+ """
225
+ if fqn is None:
226
+ fqn = ''
227
+ logger_mod_orig = logger_cls(
228
+ first_node.name, # ref_node_name
229
+ last_node.name, # prev_node_name
230
+ f'subgraph_{subgraph_idx}_{subgraph_candidate_idx}', # model_name
231
+ 'model', # ref_name
232
+ get_target_type_str(last_node, model), # prev_node_target_type
233
+ get_target_type_str(first_node, model), # ref_node_target_type
234
+ NSSingleResultValuesType.NODE_OUTPUT.value, # results_type
235
+ 0, # index_within_arg
236
+ 0, # index_of_arg
237
+ fqn, # fqn
238
+ qconfig_str,
239
+ )
240
+ # Usually we expect the user to add loggers, then calibrate, then convert,
241
+ # and then populate loggers. This is why the loggers start disabled.
242
+ # TODO(future PR): reconsider the design to make this more intuitive.
243
+ logger_mod_orig.enabled = False
244
+ return logger_mod_orig
245
+
246
+ def create_submodule_from_subgraph(
247
+ model: torch.nn.Module,
248
+ first_node: Node,
249
+ last_node: Node,
250
+ ) -> GraphModule:
251
+ """
252
+ Input: a model, and a linear subgraph within the model from first_node to
253
+ last_node.
254
+
255
+ Output: a new submodule containing a copy of the subgraph, with the inputs
256
+ to the first node becoming the inputs to the submodule, and all other
257
+ nodes in the subgraph being copied.
258
+
259
+ Example inputs:
260
+
261
+ `model`: a module with graph
262
+
263
+ x0 -> op1 -> x1 -> op2 -> x2
264
+ |
265
+ arg1
266
+
267
+ `first_node`: op1
268
+ `last_node`: op2
269
+
270
+ Example output: a new module with graph
271
+
272
+ input1 -> op1_copy -> x1 -> op2_copy -> output1
273
+ |
274
+ arg1
275
+ """
276
+
277
+ #
278
+ # create a blank GraphModule with an empty graph
279
+ #
280
+
281
+ class M(torch.nn.Module):
282
+ def forward(self, x):
283
+ pass
284
+
285
+ m = M()
286
+ gm = torch.fx.symbolic_trace(m)
287
+ g = gm.graph
288
+ for node in reversed(gm.graph.nodes):
289
+ g.erase_node(node)
290
+
291
+ #
292
+ # modify the graph to have a copy of our subgraph
293
+ #
294
+
295
+ cur_node_orig = first_node
296
+ cur_args_orig = cur_node_orig.args
297
+ cur_kwargs_orig = cur_node_orig.kwargs
298
+
299
+ cur_name_idx = 0
300
+
301
+ iteration_limit = 100
302
+ cur_iteration = 0
303
+
304
+ while True:
305
+ if cur_node_orig is first_node:
306
+ # we are at the first node, we need to set up graph inputs
307
+ # TODO(future): some graphs could have placeholders which are unrelated
308
+ # to the first node, need to handle this
309
+ cur_args_copy = []
310
+ cur_kwargs_copy = {}
311
+ seen_names: Set[str] = set()
312
+ old_name_to_new_node: Dict[str, Node] = {}
313
+
314
+ def _add_placeholder(
315
+ g: Graph, node: Node, seen_names, old_name_to_new_node
316
+ ):
317
+ # note: for graphs starting with patterns such as `y = x + x`, we
318
+ # need to ensure we do not add multiple placeholders with the
319
+ # same name
320
+ counter = 0
321
+ while node.name + '_' + str(counter) in seen_names:
322
+ counter += 1
323
+ cur_name = node.name + '_' + str(counter)
324
+ seen_names.add(cur_name)
325
+ placeholder = g.placeholder(cur_name)
326
+ old_name_to_new_node[node.name] = placeholder
327
+ return placeholder
328
+
329
+ for arg in cur_node_orig.args:
330
+ if isinstance(arg, Node):
331
+ p = _add_placeholder(
332
+ g, arg, seen_names, old_name_to_new_node)
333
+ cur_args_copy.append(p)
334
+ elif isinstance(arg, (list, tuple)):
335
+ new_arg = []
336
+ for inner_arg in arg:
337
+ if isinstance(inner_arg, Node):
338
+ new_arg.append(_add_placeholder(
339
+ g, inner_arg, seen_names, old_name_to_new_node))
340
+ else:
341
+ new_arg.append(inner_arg)
342
+ cur_args_copy.append(new_arg)
343
+ else:
344
+ cur_args_copy.append(arg)
345
+
346
+ # TODO(future PR): handle non-normalized kwargs
347
+ for kwarg_name, kwarg in cur_node_orig.kwargs.items():
348
+ if isinstance(kwarg, Node):
349
+ cur_kwargs_copy[kwarg_name] = _add_placeholder(
350
+ g, kwarg, seen_names, old_name_to_new_node)
351
+ elif isinstance(kwarg, (list, tuple)):
352
+ new_kwarg = []
353
+ for inner_kwarg in kwarg:
354
+ p = _add_placeholder(
355
+ g, inner_kwarg, seen_names, old_name_to_new_node)
356
+ new_kwarg.append(p)
357
+ cur_kwargs_copy[kwarg_name] = new_kwarg
358
+ else:
359
+ cur_kwargs_copy[kwarg_name] = kwarg
360
+
361
+ cur_args_copy = tuple(cur_args_copy) # type: ignore[assignment]
362
+ else:
363
+ # we are not at first node, first arg is from the previous node,
364
+ # and all other args are copied
365
+
366
+ # the current implementation is simplistic and cannot handle
367
+ # ops with two or more arguments which need to be passed from
368
+ # the previous op, so we assert them out
369
+ assert cur_node_orig.target not in BINARY_FUNCTIONS
370
+
371
+ # at this point in the code, cur_node_copy is pointing to the copy
372
+ # of the previous node
373
+ # TODO(future PR): this is not handling complicated graphs correctly, need to
374
+ # look at actual relationships instead of assuming sequential graph
375
+ # TODO(future PR): this is ignoring kwargs, will need to support kwargs
376
+ # for any fusion pattern which has them for a node that is not the
377
+ # first node.
378
+ cur_args_copy = [cur_node_copy] # type: ignore[has-type, possibly-undefined] # noqa: F821
379
+
380
+ if len(cur_node_orig.args) > 1:
381
+ for arg in cur_node_orig.args[1:]:
382
+ if isinstance(arg, torch.nn.Parameter):
383
+ new_arg = arg.clone().detach() # type: ignore[assignment]
384
+ mod_name = f"mod_{cur_name_idx}"
385
+ cur_name_idx += 1
386
+ setattr(gm, mod_name, new_arg)
387
+ new_arg_placeholder = gm.placeholder(mod_name)
388
+ cur_args_copy.append(new_arg_placeholder)
389
+ elif isinstance(arg, (float, int, torch.dtype)):
390
+ cur_args_copy.append(arg)
391
+ else:
392
+ raise AssertionError(f'arg of type {type(arg)} not handled yet')
393
+ cur_args_copy = tuple(cur_args_copy) # type: ignore[assignment]
394
+
395
+ # copy the node
396
+ if cur_node_orig.op == 'call_module':
397
+ orig_mod = getattr_from_fqn(model, cur_node_orig.target) # type: ignore[arg-type]
398
+ orig_mod_copy = copy.deepcopy(orig_mod)
399
+ mod_name = f"mod_{cur_name_idx}"
400
+ setattr(gm, mod_name, orig_mod_copy)
401
+ cur_name_idx += 1
402
+ cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined]
403
+
404
+ elif cur_node_orig.op == 'call_function':
405
+ cur_node_copy = g.call_function(
406
+ cur_node_orig.target, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined]
407
+
408
+ elif cur_node_orig.op == 'call_method':
409
+ cur_node_copy = g.call_method(
410
+ cur_node_orig.target, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined]
411
+
412
+ else:
413
+ raise AssertionError(f'{cur_node_orig.op} not supported yet')
414
+
415
+ if cur_node_orig is last_node:
416
+ break
417
+
418
+ # go to next node
419
+ assert len(cur_node_orig.users.keys()) == 1, \
420
+ f'{cur_node_orig} has more than 1 users, not supported yet'
421
+ cur_node_orig = next(iter(cur_node_orig.users.keys()))
422
+ cur_args_orig = cur_node_orig.args
423
+ cur_kwargs_orig = cur_node_orig.kwargs
424
+
425
+ cur_iteration += 1
426
+ if cur_iteration > iteration_limit:
427
+ raise AssertionError('iteration limit exceeded')
428
+
429
+ # set up outputs
430
+ g.output(cur_node_copy)
431
+
432
+ gm.recompile()
433
+ return gm
434
+
435
+ def create_one_transformed_and_logged_copy_of_subgraph(
436
+ mt: GraphModule,
437
+ subgraph_idx: int,
438
+ subgraph_candidate_idx: int,
439
+ first_node: Node,
440
+ last_node: Node,
441
+ fqn: Optional[str],
442
+ list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]],
443
+ example_inputs: Any,
444
+ last_added_shadow_node_list: List[Optional[Node]],
445
+ custom_prepare_fn: Optional[Callable] = None,
446
+ custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
447
+ ) -> None:
448
+ """
449
+ Given a subgraph in `mt` and a subgraph candidate idx, inserts the
450
+ subgraph candidate copy and instruments it with loggers.
451
+
452
+ If subgraph_candidate_idx is 0, this is the baseline fp32 subgraph and we just
453
+ add a logger to the end.
454
+
455
+ If subgraph_candidate_idx is not 0, we create a copy of the subgraph and
456
+ prepare it with `prepare_fx`.
457
+ """
458
+
459
+ # TODO(future PR): move logger classes to utils to remove circular dependency
460
+ from torch.ao.ns._numeric_suite_fx import OutputLogger, OutputComparisonLogger
461
+
462
+ if subgraph_candidate_idx == 0:
463
+ # idx = 0 is the floating point (original) version of the subgraph
464
+ # We keep the subgraph as is, and add a logger at the end
465
+
466
+ qconfig_str = ''
467
+ logger_mod_orig = _get_logger_for_subgraph(
468
+ mt, first_node, last_node, subgraph_idx, subgraph_candidate_idx,
469
+ qconfig_str, OutputLogger, fqn)
470
+
471
+ attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
472
+ assert not hasattr(mt, attr_name)
473
+ setattr(mt, attr_name, logger_mod_orig)
474
+ with mt.graph.inserting_after(last_node):
475
+ new_node = mt.graph.call_module(attr_name, args=(last_node,), kwargs={})
476
+ last_added_shadow_node_list[0] = new_node
477
+
478
+ else:
479
+ # idx > 0 means we have a candidate qconfig to try, so we need
480
+ # to make a copy of the subgraph, feed it with the right inputs,
481
+ # and add a logger at the end
482
+
483
+ # get the qconfig
484
+ # subtract one because the first candidate is the floating point
485
+ # version of the subgraph
486
+ node_name_to_qconfig = \
487
+ list_of_node_name_to_qconfig[subgraph_candidate_idx - 1]
488
+ qconfig = node_name_to_qconfig[first_node.name]
489
+
490
+ # if no quantization is requested, skip
491
+ # TODO(future PR): deduplicate equivalent qconfigs that come from
492
+ # different qconfig mapping objects
493
+ if qconfig is None:
494
+ return
495
+
496
+ qconfig_mapping = QConfigMapping().set_global(qconfig)
497
+
498
+ # create a copy of the submodule, wrapped in a separate module
499
+ orig_mod_copy_wrapped = create_submodule_from_subgraph(
500
+ mt, first_node, last_node)
501
+
502
+ # add a call to prepare_fx on the wrapper module
503
+ if custom_prepare_fn is None:
504
+ orig_mod_copy_wrapped = torch.ao.quantization.quantize_fx.prepare_fx(
505
+ orig_mod_copy_wrapped, qconfig_mapping, example_inputs=example_inputs)
506
+ else:
507
+ if custom_prepare_kwargs is None:
508
+ custom_prepare_kwargs = {}
509
+ for kwarg_name in ["example_inputs", "prepare_custom_config", "qconfig_mapping"]:
510
+ assert kwarg_name not in custom_prepare_kwargs, f"cannot specify {kwarg_name} in custom_prepare_kwargs"
511
+ prepare_kwargs: Dict[str, Any] = {
512
+ "example_inputs": example_inputs,
513
+ "qconfig_mapping": qconfig_mapping
514
+ }
515
+ prepare_kwargs.update(custom_prepare_kwargs)
516
+ orig_mod_copy_wrapped = custom_prepare_fn(
517
+ orig_mod_copy_wrapped,
518
+ **prepare_kwargs)
519
+
520
+ # attach the wrapper to the model
521
+ attr_name = _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx)
522
+ assert not hasattr(mt, attr_name)
523
+ setattr(mt, attr_name, orig_mod_copy_wrapped)
524
+
525
+ # add a call to the wrapper module from the parent graph
526
+ insert_after_node = last_added_shadow_node_list[0]
527
+ with mt.graph.inserting_after(insert_after_node):
528
+ # TODO(future PR): handle fusion patterns where non-first nodes
529
+ # need inputs
530
+
531
+ # pass in all node args and kwargs
532
+
533
+ new_args = []
534
+ for arg in first_node.args:
535
+ if isinstance(arg, Node):
536
+ new_args.append(arg)
537
+ elif isinstance(arg, (list, tuple)) and len(arg) and isinstance(arg[0], Node):
538
+ for inner_arg in arg:
539
+ if isinstance(inner_arg, Node):
540
+ new_args.append(inner_arg)
541
+
542
+ new_kwargs = {}
543
+ for name, old_kwarg in first_node.kwargs.items():
544
+ if isinstance(old_kwarg, Node):
545
+ new_kwargs[name] = old_kwarg
546
+ elif isinstance(old_kwarg, (list, tuple)) and len(old_kwarg):
547
+ # TODO(future PR): clarify why we are adding kwargs to args
548
+ new_args.extend(old_kwarg)
549
+
550
+ new_args = tuple(new_args) # type: ignore[assignment]
551
+
552
+ new_node = mt.graph.call_module(
553
+ attr_name, args=new_args, kwargs=new_kwargs)
554
+
555
+ # add a logger to parent graph to observe the shadow wrapper
556
+ logger_mod_orig = _get_logger_for_subgraph(
557
+ mt, first_node, last_node, subgraph_idx, subgraph_candidate_idx,
558
+ str(qconfig), OutputComparisonLogger, fqn)
559
+
560
+ attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
561
+ assert not hasattr(mt, attr_name)
562
+ setattr(mt, attr_name, logger_mod_orig)
563
+ with mt.graph.inserting_after(new_node):
564
+ logger = mt.graph.call_module(attr_name, args=(new_node, last_node), kwargs={})
565
+ last_added_shadow_node_list[0] = logger
566
+
567
+ mt.recompile()
568
+
569
+ def create_n_transformed_and_logged_copies_of_subgraph(
570
+ mt: GraphModule,
571
+ subgraph_idx: int,
572
+ match_name: str,
573
+ nodes_in_this_subgraph: List[Any],
574
+ qconfig_mappings: List[QConfigMapping],
575
+ list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]],
576
+ custom_prepare_fn: Optional[Callable] = None,
577
+ custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
578
+ ) -> None:
579
+ """
580
+ Given a model `mt` and a subgraph_idx, creates the needed copies
581
+ of the subgraph for all qconfigs, and instruments them with loggers.
582
+ """
583
+ # for now, assume that
584
+ # 1. the first node has one input
585
+ # 2. the last node has one output
586
+
587
+ # for now, ignore all subgraphs that contain non-nodes (tuples, etc)
588
+ # TODO(future PR): implement this
589
+ if any(
590
+ not isinstance(node, Node)
591
+ for node in nodes_in_this_subgraph
592
+ ):
593
+ return
594
+
595
+ first_node = nodes_in_this_subgraph[0]
596
+ last_node = nodes_in_this_subgraph[-1]
597
+ # We used output propagation to populate example values on each
598
+ # node. Use the example values from the previous node as the input
599
+ # to the current node.
600
+ prev_node = get_normalized_nth_input(first_node, mt, 0)
601
+ if isinstance(prev_node, list):
602
+ example_inputs = [x.traced_result for x in prev_node]
603
+ elif isinstance(prev_node, tuple):
604
+ example_inputs = (x.traced_result for x in prev_node) # type: ignore[assignment]
605
+ else:
606
+ # currently some customer models do not have a traced_result in
607
+ # every node, so we have to guard for this case since we cannot
608
+ # quantize without an example input
609
+ # TODO(future PR): add a test case for this once we have an easy
610
+ # repro, see https://github.com/pytorch/pytorch/pull/80521/files#r975940489
611
+ # for additional context
612
+ if hasattr(prev_node, 'traced_result'):
613
+ example_inputs = (prev_node.traced_result,) # type: ignore[attr-defined, assignment]
614
+ else:
615
+ print(
616
+ 'unable to get example input for node ' +
617
+ f'{first_node.format_node()}, skipping')
618
+ return
619
+
620
+ # If there are no quantization configs for this subgraph, skip adding
621
+ # loggers. This reduces memory usage for models where not all layers are
622
+ # quantized.
623
+ # TODO(future): consider making this configurable
624
+ found_at_least_one_qconfig = False
625
+ for subgraph_candidate_idx in range(len(qconfig_mappings) + 1):
626
+
627
+ if subgraph_candidate_idx == 0:
628
+ # fp32 baseline does not need a qconfig
629
+ continue
630
+
631
+ # a. we have N shadows, so len(qconfig_mappings) is N
632
+ # b. we will have the fp32 layer + N shadows, so overall number of
633
+ # (original_op) + (*shadows) will be N+1
634
+ # c. since `subgraph_candidate_idx` represents (b), we need
635
+ # to subtract 1 to query from (a)
636
+ node_name_to_qconfig = \
637
+ list_of_node_name_to_qconfig[subgraph_candidate_idx - 1]
638
+ qconfig = node_name_to_qconfig[first_node.name]
639
+ if qconfig is not None:
640
+ found_at_least_one_qconfig = True
641
+ break
642
+ if not found_at_least_one_qconfig:
643
+ print('unable to find at least one qconfig for node ' +
644
+ f'{first_node.format_node()}, skipping')
645
+ return
646
+
647
+ fqn = _maybe_get_fqn(first_node, mt)
648
+
649
+ # We want the results to contain the subgraphs in natural order,
650
+ # and the graph to also contain shadow wrappers and shadow loggers
651
+ # in natural order.
652
+ # If we just iterate in reverse, the graph will be in natural
653
+ # order but the eventual results will be in reverse order.
654
+ # So, we keep track of the last shadow logger we added and
655
+ # always insert after it.
656
+ last_added_shadow_node_list: List[Optional[Node]] = [None]
657
+ for subgraph_candidate_idx in range(len(qconfig_mappings) + 1):
658
+
659
+ create_one_transformed_and_logged_copy_of_subgraph(
660
+ mt, subgraph_idx, subgraph_candidate_idx, first_node,
661
+ last_node, fqn, list_of_node_name_to_qconfig,
662
+ example_inputs, last_added_shadow_node_list, custom_prepare_fn,
663
+ custom_prepare_kwargs)
664
+
665
+ def create_add_loggers_graph(
666
+ model: GraphModule,
667
+ subgraphs_dedup: Dict[str, List[Node]],
668
+ qconfig_mapping: QConfigMapping,
669
+ node_name_to_qconfig: Dict[str, QConfigAny],
670
+ ) -> None:
671
+ r"""
672
+ Given a model, a model graph partition (currently a set of matched
673
+ subgraphs) and instructions how to transform each subgraph
674
+ (currently quantizing it according to qconfig_mapping), modifies
675
+ the model graph to create an alternate path through the original graph,
676
+ with each of the subgraphs quantized. This is useful to compare
677
+ propagation error of a transformation such as quantization.
678
+
679
+ For example, given layer op0 and op1, there are four cases when handling op1:
680
+ 1. op0 and op1 quantized
681
+ 2. op0 and op1 unquantized
682
+ 3. op0 quantized, op1 unquantized
683
+ 4. op0 unquantized, op1 quantized
684
+
685
+ Example input, case 1:
686
+
687
+ .. code::
688
+
689
+ x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log
690
+ \ \ \ \ # noqa: W605
691
+ ---> op0_1 -> x1_1 ----> clog op1_1 -> x2_1 ----> clog
692
+
693
+ Example output, case 1:
694
+
695
+ .. code::
696
+
697
+ x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log
698
+ \ \ \ # noqa: W605
699
+ ---> op0_1 -> x1_1 ----> clog -> op1_1 -> x2_1 ----> clog
700
+
701
+ """
702
+ # TODO(future PR): move logger classes to utils to remove circular dependency
703
+ from torch.ao.ns._numeric_suite_fx import OutputLogger, OutputComparisonLogger
704
+
705
+ def _get_subgraph_containing_node(node, subgraphs_dedup):
706
+ for subgraph in subgraphs_dedup.values():
707
+ if node in subgraph:
708
+ return subgraph
709
+ return None
710
+
711
+ # First, we need to create shadow branches, going from
712
+ #
713
+ # x0 -> op0 -> x1 -> ...
714
+ #
715
+ #
716
+ # to
717
+ #
718
+ # x0 -> op0_0 -> x1_0 -> log -> ...
719
+ # \ \
720
+ # -> op0_1 -> x1_1 -> clog
721
+ #
722
+ # Later, the outputs of each shadow will be rerouted to calculate
723
+ # propagation error.
724
+
725
+ # Note: we cannot iterate over matched subgraphs because some nodes
726
+ # may not be matched. So, we iterate over nodes in the graph, and
727
+ # associate them to matched subgraphs if possible.
728
+
729
+ nodes_to_skip = set()
730
+ # for each subgraph, save a mapping from first node of subgraph
731
+ # to first and last node of the shadow of this subgraph
732
+ orig_first_node_to_shadow_in_node = {}
733
+ orig_first_node_to_shadow_out_node = {}
734
+ # need to record original list because we will mutate the graph as we go
735
+ orig_nodes = list(model.graph.nodes) # type: ignore[union-attr, arg-type]
736
+ cur_subgraph_idx = 0
737
+ for n in orig_nodes:
738
+ if n.op in ('placeholder', 'get_attr', 'output') or n in nodes_to_skip:
739
+ continue
740
+
741
+ maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup)
742
+ insert_submodule_copy = False
743
+ if maybe_subgraph is not None:
744
+ first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
745
+ for node_to_skip in maybe_subgraph:
746
+ nodes_to_skip.add(node_to_skip)
747
+ qconfig = node_name_to_qconfig[first_node.name]
748
+ if qconfig is not None:
749
+ insert_submodule_copy = True
750
+ else:
751
+ first_node, last_node = n, n
752
+
753
+ if insert_submodule_copy:
754
+ match_name = first_node.name
755
+ create_n_transformed_and_logged_copies_of_subgraph(
756
+ model, cur_subgraph_idx, match_name, maybe_subgraph,
757
+ [qconfig_mapping], [node_name_to_qconfig],
758
+ None, None # type: ignore[arg-type]
759
+ )
760
+ # find the created shadow module and record it so we
761
+ # can find it easily in step 2
762
+ expected_shadow_target = f"shadow_wrapper_{cur_subgraph_idx}_1"
763
+ new_shadow_mod = None
764
+ for maybe_shadow_mod in model.graph.nodes:
765
+ if maybe_shadow_mod.op == 'call_module' and \
766
+ maybe_shadow_mod.target == expected_shadow_target:
767
+ new_shadow_mod = maybe_shadow_mod
768
+ break
769
+ assert new_shadow_mod is not None
770
+ orig_first_node_to_shadow_in_node[first_node] = new_shadow_mod
771
+ orig_first_node_to_shadow_out_node[first_node] = new_shadow_mod
772
+
773
+ else:
774
+ # create a copy of the subgraph by only copying FX nodes
775
+ # but not copying any parameters, to minimize memory usage
776
+ subgraph_to_use = maybe_subgraph if maybe_subgraph is not None \
777
+ else [first_node]
778
+
779
+ # add a regular logger after last_node
780
+ qconfig_str = ''
781
+ subgraph_candidate_idx = 0
782
+ fqn = _maybe_get_fqn(first_node, model)
783
+ logger_mod_orig = _get_logger_for_subgraph(
784
+ model, first_node, last_node, cur_subgraph_idx, subgraph_candidate_idx,
785
+ qconfig_str, OutputLogger, fqn)
786
+ attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
787
+ assert not hasattr(model, attr_name)
788
+ setattr(model, attr_name, logger_mod_orig)
789
+ insertion_point = last_node
790
+ with model.graph.inserting_after(insertion_point):
791
+ logger = model.graph.call_module(
792
+ attr_name, args=(last_node,), kwargs={})
793
+ insertion_point = logger
794
+
795
+ # create a copy of the subgraph
796
+ cur_node_orig = first_node
797
+ cur_node_copy = None
798
+ first_node_copy = None
799
+ while cur_node_orig in subgraph_to_use:
800
+ # TODO(future PR): make this support all possible args/kwargs
801
+ if cur_node_orig is first_node:
802
+ new_args = cur_node_orig.args
803
+ new_kwargs = cur_node_orig.kwargs
804
+ else:
805
+ first_arg_for_copy = cur_node_copy
806
+ new_args = tuple([first_arg_for_copy, *cur_node_orig.args[1:]]) # noqa: C409
807
+ new_kwargs = cur_node_orig.kwargs
808
+ # make a copy of cur_node_orig
809
+ with model.graph.inserting_after(insertion_point):
810
+ cur_node_copy = model.graph.create_node(
811
+ cur_node_orig.op,
812
+ cur_node_orig.target,
813
+ new_args,
814
+ new_kwargs,
815
+ # cur_node_orig.name, # TODO(future PR): set name explicitly
816
+ )
817
+ if first_node_copy is None:
818
+ first_node_copy = cur_node_copy
819
+ # since now only linear subgraphs are supported, all nodes
820
+ # except the last one must have only one user
821
+ if cur_node_orig != last_node:
822
+ assert len(cur_node_orig.users.keys()) == 1
823
+ cur_node_orig = next(iter(cur_node_orig.users.keys()))
824
+ assert not cur_node_orig.name.startswith(SHADOW_NODE_NAME_PREFIX)
825
+ insertion_point = cur_node_copy
826
+
827
+ # add a comparison logger after last_node's copy
828
+ subgraph_candidate_idx = 1
829
+ logger_mod_orig = _get_logger_for_subgraph(
830
+ model, first_node, last_node, cur_subgraph_idx, subgraph_candidate_idx,
831
+ qconfig_str, OutputComparisonLogger, fqn)
832
+ attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
833
+ assert not hasattr(model, attr_name)
834
+ setattr(model, attr_name, logger_mod_orig)
835
+ with model.graph.inserting_after(insertion_point):
836
+ logger = model.graph.call_module(
837
+ attr_name, args=(cur_node_copy, last_node), kwargs={})
838
+
839
+ # save the final node so we can use it in step 2
840
+ orig_first_node_to_shadow_in_node[first_node] = first_node_copy
841
+ orig_first_node_to_shadow_out_node[first_node] = cur_node_copy
842
+
843
+ cur_subgraph_idx += 1
844
+
845
+ model.recompile()
846
+
847
+ # Now, we go from
848
+ #
849
+ # x0 -> op0_0 -> x1_0 -> log -> x1 -> op1_0 -> ...
850
+ # \ \ \
851
+ # -> op0_1 -> x1_1 -> clog -> op1_1 -> ...
852
+ #
853
+ # to
854
+ #
855
+ # x0 -> op0_0 -> x1_0 -> log --> x1_0 -> op1_0 -> ...
856
+ # \ \
857
+ # -> op0_1 -> x1_1 -> clog -> x1_1 -> op1_1 -> ...
858
+ #
859
+ # sample values of key internal variables for the example above:
860
+ #
861
+ # orig_first_node_to_shadow_in_node = {op0_0: op0_1, op1_0: op1_1}
862
+ # orig_first_node_to_shadow_out_node = {op0_0: op0_1, op1_0: op1_1}
863
+ #
864
+ # note: for subgraphs with more than one node, in_node will be different
865
+ # compared to out_node
866
+
867
+
868
+ nodes_to_skip = set()
869
+ for n in orig_nodes:
870
+ if n.op in ('placeholder', 'get_attr', 'output') or n in nodes_to_skip:
871
+ continue
872
+
873
+ maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup)
874
+ if maybe_subgraph is not None:
875
+ first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
876
+ for node_to_skip in maybe_subgraph:
877
+ nodes_to_skip.add(node_to_skip)
878
+ else:
879
+ first_node, last_node = n, n
880
+
881
+ def maybe_remap_node_to_shadow(node):
882
+ """
883
+ If unshadowed `node` has a shadow version, return that. If not,
884
+ return `node`.
885
+ """
886
+ if not isinstance(node, Node):
887
+ # handle scalars
888
+ return node
889
+
890
+ if node.op in ('placeholder', 'get_attr'):
891
+ return node
892
+
893
+ # Find the shadowed version of this arg from the previous
894
+ # subgraph. For this, we need to:
895
+ # 1. navigate to the first node of the previous subgraph
896
+ # 2. get the output of the shadow wrapper which has (1) as an input
897
+
898
+ # For now, assume the arg is in matched subgraphs. In the
899
+ # future we may have to handle the case where this is not true.
900
+ prev_subgraph = _get_subgraph_containing_node(
901
+ node, subgraphs_dedup)
902
+ if prev_subgraph is None:
903
+ prev_subgraph = [node]
904
+ prev_first_node = prev_subgraph[0]
905
+ prev_shadow_output = \
906
+ orig_first_node_to_shadow_out_node[prev_first_node]
907
+ return prev_shadow_output
908
+
909
+ cur_shadow_input = \
910
+ orig_first_node_to_shadow_in_node[first_node]
911
+ assert cur_shadow_input is not None
912
+ cur_shadow_input.args = tree_map(
913
+ maybe_remap_node_to_shadow, cur_shadow_input.args)
914
+ cur_shadow_input.kwargs = tree_map(
915
+ maybe_remap_node_to_shadow, cur_shadow_input.kwargs)
916
+
917
+ model.recompile()
918
+
919
+ def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module):
920
+ # input: shadow wrapper module
921
+ # output if shadow wrapper module has a weighted op:
922
+ # (quantize_fn, (quantize_fn_args))
923
+ # output if shadow wrapper module doesn't have a weighted op:
924
+ # None
925
+
926
+ # For now, assume that the weight is the second input
927
+ # to the shadow module. If that changes, we can fix it later.
928
+ placeholders_seen = 0
929
+ for shadow_n in shadow_wrapper.graph.nodes: # type: ignore[union-attr]
930
+ if shadow_n.op != 'placeholder':
931
+ continue
932
+
933
+ placeholders_seen += 1
934
+ if placeholders_seen != 2:
935
+ continue
936
+
937
+ # the subgraph looks like
938
+ #
939
+ # _input_scale_1 = self._input_scale_1
940
+ # _input_zero_point_1 = self._input_zero_point_1
941
+ # quantize_per_channel = torch.quantize_per_channel(
942
+ # w2_0, _input_scale_1, _input_zero_point_1,
943
+ # 0, torch.qint8)
944
+ #
945
+ # we have `w2_0`, and are navigating this subgraph
946
+ # to get `_input_scale_1` and `_input_zero_point_1`
947
+
948
+ assert len(shadow_n.users) == 1
949
+ quant_node = next(iter(shadow_n.users.keys()))
950
+ new_args: Any = None
951
+ if quant_node.target == torch.quantize_per_channel:
952
+ _weight, scale_node, zp_node, axis, dtype = quant_node.args
953
+ scale_val = getattr_from_fqn(
954
+ shadow_wrapper, scale_node.target)
955
+ zp_val = getattr_from_fqn(
956
+ shadow_wrapper, zp_node.target)
957
+ new_args = (scale_val, zp_val, axis, dtype)
958
+ else:
959
+ assert quant_node.target == torch.quantize_per_tensor
960
+ _weight, scale_node, zp_node, dtype = quant_node.args
961
+ scale_val = getattr_from_fqn(
962
+ shadow_wrapper, scale_node.target)
963
+ zp_val = getattr_from_fqn(
964
+ shadow_wrapper, zp_node.target)
965
+ new_args = (scale_val, zp_val, dtype)
966
+ return (quant_node.target, new_args)
967
+
968
+ return None
969
+
970
+
971
+ def extract_weight_comparison(m: GraphModule) -> NSResultsType:
972
+
973
+ # example graph:
974
+ #
975
+ # w1 = self.w1
976
+ # b1 = self.b1
977
+ # linear = torch._C._nn.linear(x, w1, b1)
978
+ # shadow_0_0 = self.shadow_0_0(linear)
979
+ # shadow_wrapper_0_1 = self.shadow_wrapper_0_1(x, w1, b1)
980
+ # shadow_0_1 = self.shadow_0_1(shadow_wrapper_0_1, linear)
981
+ #
982
+ # algorithm:
983
+ # 1. for each call_function node matching our allowlist:
984
+ # 2. if corresponding shadow wrapper exists, extract the weight pair
985
+ #
986
+ # Note: this is not super robust, but that's ok because this is
987
+ # just for legacy customers who depend on the previous two-model version
988
+ # of this API. TBD if we need to make this robust.
989
+ # Note: modules are not supported, since existing customers only
990
+ # use functions.
991
+
992
+ # TODO(future PR): move this to config
993
+ weighted_ops = {
994
+ torch.nn.functional.linear,
995
+ }
996
+
997
+ results: NSResultsType = {
998
+ 'model': {NSSingleResultValuesType.WEIGHT.value: {}}
999
+ }
1000
+
1001
+ for n in m.graph.nodes: # type: ignore[union-attr]
1002
+ if not (n.op == 'call_function' and n.target in weighted_ops):
1003
+ continue
1004
+
1005
+ # Check if we have a corresponding shadow wrapper
1006
+ # TODO(future PR, if needed): support kwargs
1007
+ # TODO(future PR, if needed): support multiple shadow users
1008
+ first_arg = n.args[0]
1009
+ shadow_wrapper_node = None
1010
+ for user in first_arg.users:
1011
+ # TODO(before land): fix string match
1012
+ if user.op == 'call_module' and \
1013
+ user.target.startswith('shadow_wrapper'):
1014
+ shadow_wrapper_node = user
1015
+ break
1016
+
1017
+ if shadow_wrapper_node is None:
1018
+ continue
1019
+
1020
+ shadow_wrapper = getattr_from_fqn(
1021
+ m, shadow_wrapper_node.target) # type: ignore[arg-type]
1022
+ weight_info = _get_weight_info_from_shadow_wrapper(
1023
+ shadow_wrapper)
1024
+ if weight_info is None:
1025
+ continue
1026
+
1027
+ # get weight
1028
+ w_node = n.args[1]
1029
+ w_obj = getattr_from_fqn(m, w_node.target).detach()
1030
+
1031
+ # get a quantized version of weight
1032
+ quant_fn, quant_fn_args_except_first = weight_info
1033
+ new_args = (w_obj, *quant_fn_args_except_first)
1034
+ w_obj_q = quant_fn(*new_args)
1035
+
1036
+ # add a comparison
1037
+ ref_node_name = n.name
1038
+ prev_node_name = n.name
1039
+ ref_node_type = get_target_type_str(n, m)
1040
+ prev_node_type = ref_node_type
1041
+ fqn = None
1042
+ if hasattr(m, '_node_name_to_scope'):
1043
+ fqn = m._node_name_to_scope[n.name][0] # type: ignore[index]
1044
+ comparison = torch.ao.ns.fx.utils.compute_sqnr(w_obj, w_obj_q)
1045
+ result_fp32 = {
1046
+ 'res_type': NSSingleResultValuesType.WEIGHT.value,
1047
+ 'values': [w_obj],
1048
+ 'prev_node_name': prev_node_name,
1049
+ 'prev_node_target_type': prev_node_type,
1050
+ 'ref_node_name': ref_node_name,
1051
+ 'ref_node_target_type': ref_node_type,
1052
+ 'index_within_arg': 0,
1053
+ 'index_of_arg': 0,
1054
+ 'fqn': fqn,
1055
+ 'qconfig_str': '',
1056
+ 'comparisons': [comparison],
1057
+ 'comparison_fn_name': 'sqnr',
1058
+ }
1059
+ result_q = {
1060
+ 'res_type': NSSingleResultValuesType.WEIGHT.value,
1061
+ 'values': [w_obj_q],
1062
+ 'prev_node_name': prev_node_name,
1063
+ 'prev_node_target_type': prev_node_type,
1064
+ 'ref_node_name': ref_node_name,
1065
+ 'ref_node_target_type': ref_node_type,
1066
+ 'index_within_arg': 0,
1067
+ 'index_of_arg': 0,
1068
+ 'fqn': fqn,
1069
+ 'qconfig_str': '',
1070
+ 'comparisons': [comparison],
1071
+ 'comparison_fn_name': 'sqnr',
1072
+ }
1073
+
1074
+ # go from subgraph_n_1 to subgraph_n_0
1075
+ _1, _2, node_idx, _3 = shadow_wrapper_node.target.split('_')
1076
+ name_fp32 = f"subgraph_{node_idx}_0"
1077
+ name_q = f"subgraph_{node_idx}_1"
1078
+
1079
+ results['model'][NSSingleResultValuesType.WEIGHT.value][name_fp32] = \
1080
+ [result_fp32]
1081
+ results['model'][NSSingleResultValuesType.WEIGHT.value][name_q] = \
1082
+ [result_q]
1083
+
1084
+ return results
1085
+
1086
+ # TODO(future PR): redesign this to make it easier to consume outputs
1087
+ def group_results_by_subgraph(results: NSResultsType) -> Any:
1088
+ """
1089
+ Creates a comparison of results
1090
+
1091
+ Input:
1092
+
1093
+ {
1094
+ 'model': {
1095
+ 'node_output': {
1096
+ 'subgraph_0_0': [
1097
+ 'values': [torch.tensor(...), ...], ...
1098
+ 'ref_node_name': ...,
1099
+ 'ref_node_target_type': ...,
1100
+ 'qconfig_str': ...,
1101
+ 'comparisons': [], ...
1102
+ 'comparison_fn_name': '',
1103
+ 'fqn': '...',
1104
+ ],
1105
+ 'subgraph_0_1': [
1106
+ 'values': [torch.tensor(...), ...], ...
1107
+ 'ref_node_name': ...,
1108
+ 'ref_node_target_type': ...,
1109
+ 'qconfig_str': ...,
1110
+ 'comparisons': [torch.tensor(...), ...], ...
1111
+ 'comparison_fn_name': '...',
1112
+ 'fqn': '...',
1113
+ ],
1114
+ ...
1115
+ },
1116
+ },
1117
+ }
1118
+
1119
+ Output:
1120
+ {
1121
+ 'subgraph_0': {
1122
+ '0': {
1123
+ 'ref_node_name': '...',
1124
+ 'ref_node_target_type': ...,
1125
+ 'values': [torch.tensor(...), ...],
1126
+ 'qconfig_str': None,
1127
+ 'comparisons': [torch.tensor(...), ...], ...
1128
+ 'comparison_fn_name': '...',
1129
+ 'fqn': '...',
1130
+ },
1131
+ '1': {
1132
+ 'ref_node_name': '...',
1133
+ 'ref_node_target_type': ...,
1134
+ 'values': [torch.tensor(...), ...],
1135
+ 'qconfig_str': '...',
1136
+ 'comparisons': [torch.tensor(...), ...], ...
1137
+ 'comparison_fn_name': '...',
1138
+ 'fqn': '...',
1139
+ },
1140
+ },
1141
+ }
1142
+
1143
+ """
1144
+ subgraph_name_to_subgraph_results: Any = collections.defaultdict(dict)
1145
+
1146
+ # node_output or weight
1147
+ key_to_use = next(iter(results['model'].keys()))
1148
+
1149
+ for subgraph_name_with_idx, subgraph_candidate_results in \
1150
+ results['model'][key_to_use].items():
1151
+
1152
+ # convert from `subgraph_m_n` to `subgraph_m` and `n`
1153
+ subgraph_str, subgraph_idx, subgraph_candidate_idx = \
1154
+ subgraph_name_with_idx.split('_')
1155
+ subgraph_name = f'{subgraph_str}_{subgraph_idx}'
1156
+
1157
+ subgraph_results = {
1158
+ 'ref_node_name': subgraph_candidate_results[0]['ref_node_name'],
1159
+ 'ref_node_target_type': subgraph_candidate_results[0]['ref_node_target_type'],
1160
+ 'fqn': subgraph_candidate_results[0]['fqn'],
1161
+ 'values': subgraph_candidate_results[0]['values'],
1162
+ 'qconfig_str': subgraph_candidate_results[0]['qconfig_str'],
1163
+ 'comparisons': subgraph_candidate_results[0]['comparisons'],
1164
+ 'comparison_fn_name': subgraph_candidate_results[0]['comparison_fn_name'],
1165
+ }
1166
+
1167
+ subgraph_name_to_subgraph_results[subgraph_name][subgraph_candidate_idx] = \
1168
+ subgraph_results
1169
+
1170
+ return dict(subgraph_name_to_subgraph_results)
1171
+
1172
+ # TODO(future PR): redesign this to make it easier to consume outputs
1173
+ def create_results_comparison(
1174
+ results_grouped,
1175
+ ) -> Any:
1176
+ """
1177
+ Input:
1178
+
1179
+ {
1180
+ 'subgraph_0': {
1181
+ '0': {
1182
+ 'ref_node_name': '...',
1183
+ 'ref_node_target_type': ...,
1184
+ 'values': [torch.tensor(...), ...],
1185
+ 'qconfig_str': '',
1186
+ 'comparisons': [],
1187
+ 'comparison_fn_name': '',
1188
+ 'fqn': '...',
1189
+ },
1190
+ '1': {
1191
+ 'ref_node_name': '...',
1192
+ 'ref_node_target_type': ...,
1193
+ 'values': [torch.tensor(...), ...],
1194
+ 'qconfig_str': '...',
1195
+ 'comparisons': [torch.tensor(...), ...],
1196
+ 'comparison_fn_name': 'sqnr',
1197
+ 'fqn': '...',
1198
+ },
1199
+ },
1200
+ }
1201
+
1202
+ Output:
1203
+ {
1204
+ 'subgraph_0': {
1205
+ 'ref_node_name': '...',
1206
+ 'ref_node_target_type': '...',
1207
+ 'fqn': '...',
1208
+ 'candidates': {
1209
+ '1': {
1210
+ 'qconfig_str': ...,
1211
+ 'comparison_fn_name': 'sqnr',
1212
+ 'cmp_raw': [..., ...],
1213
+ 'cmp_mean': ...,
1214
+ },
1215
+ ...,
1216
+ },
1217
+ },
1218
+ }
1219
+ """
1220
+
1221
+ results_comparison = {}
1222
+
1223
+ for subgraph_name, subgraph_results in results_grouped.items():
1224
+
1225
+ candidates = {}
1226
+ for subgraph_inner_name, subgraph_inner_result in subgraph_results.items():
1227
+ # skip comparing baseline to baseline
1228
+ if subgraph_inner_name == '0':
1229
+ continue
1230
+
1231
+ # we expect the comparisons to be precalculated from
1232
+ # calibration, so we just fetch them here
1233
+ cmp_raw = subgraph_inner_result['comparisons']
1234
+ cmp_raw_tensor = torch.stack(cmp_raw)
1235
+
1236
+ candidates[subgraph_inner_name] = {
1237
+ 'qconfig_str': subgraph_inner_result['qconfig_str'],
1238
+ 'comparison_fn_name': subgraph_inner_result['comparison_fn_name'],
1239
+ 'cmp_raw': cmp_raw_tensor,
1240
+ 'cmp_mean': torch.mean(cmp_raw_tensor),
1241
+ }
1242
+
1243
+ results_comparison[subgraph_name] = {
1244
+ 'ref_node_name': subgraph_results['0']['ref_node_name'],
1245
+ 'ref_node_target_type': subgraph_results['0']['ref_node_target_type'],
1246
+ 'fqn': subgraph_results['0']['fqn'],
1247
+ 'candidates': candidates,
1248
+ }
1249
+
1250
+ return results_comparison
1251
+
1252
+ # TODO(future PR): redesign this to make it easier to consume outputs
1253
+ def print_n_shadows_summary(
1254
+ results_comparison,
1255
+ ) -> None:
1256
+ """
1257
+ Input:
1258
+
1259
+ {
1260
+ 'subgraph_0': {
1261
+ 'ref_node_name': 'linear1',
1262
+ 'ref_node_target_type': '...',
1263
+ 'fqn': '...',
1264
+ 'candidates': {
1265
+ '1': {
1266
+ 'qconfig_str': ...,
1267
+ 'comparison_fn_name': ...,
1268
+ 'cmp_raw': [45.0, 55.0],
1269
+ 'cmp_mean': 50.0,
1270
+ },
1271
+ ...,
1272
+ },
1273
+ },
1274
+ }
1275
+
1276
+ Prints:
1277
+
1278
+ node_name | node_type | fqn | 0 | 1 | ...
1279
+ linear1 | ... | ... | 45.0 | 50.0 | ...
1280
+ """
1281
+
1282
+ try:
1283
+ from tabulate import tabulate
1284
+ except ImportError:
1285
+ print("`print_tabular` relies on the library `tabulate`, "
1286
+ "which could not be found on this machine. Run `pip "
1287
+ "install tabulate` to install the library.")
1288
+ return
1289
+
1290
+ results = []
1291
+ for subgraph_data in results_comparison.values():
1292
+ mean_all_candidates = [
1293
+ candidate['cmp_mean']
1294
+ for candidate_name, candidate in subgraph_data['candidates'].items()
1295
+ ]
1296
+
1297
+ data_row = [
1298
+ subgraph_data['ref_node_name'],
1299
+ subgraph_data['ref_node_target_type'],
1300
+ subgraph_data['fqn'],
1301
+ *mean_all_candidates,
1302
+ ]
1303
+ results.append(data_row)
1304
+
1305
+ max_candidate_idx_len = -1
1306
+ for data_row in results:
1307
+ max_candidate_idx_len = max(max_candidate_idx_len, len(data_row[1]))
1308
+ candidate_idx_headers = [str(x) for x in range(max_candidate_idx_len)]
1309
+
1310
+ headers = ['node_name', 'node_type', 'fqn', *candidate_idx_headers]
1311
+ print(tabulate(results, headers=headers))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/ns_types.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ from typing import NamedTuple
3
+
4
+ from torch.fx.graph import Node
5
+
6
+ from typing import Dict, Any, List, Union, Callable
7
+
8
+ class NSSingleResultValuesType(str, enum.Enum):
9
+ WEIGHT = 'weight'
10
+ NODE_OUTPUT = 'node_output'
11
+ NODE_INPUT = 'node_input'
12
+
13
+ class NSSubgraph(NamedTuple):
14
+ start_node: Node
15
+ end_node: Node
16
+ base_op_node: Node
17
+
18
+ # TODO(future PR): see if we can use typing_extensions's TypedDict instead
19
+ # to properly type the various keys
20
+ # {
21
+ # # one of NSSingleResultValuesType
22
+ # 'type': 'weight',
23
+ # # the values of type specified above
24
+ # 'values': [torch.tensor(...), ...],
25
+ # # name of the node directly before the logger
26
+ # 'prev_node_name': 'linear1',
27
+ # # type of the underlying function or module
28
+ # 'prev_node_target_type': torch.nn.functional.linear # or torch.nn.Linear, etc
29
+ # # name of the node responsible for adding this logger
30
+ # # Note: this may differ from prev_node_name if we are logging inputs
31
+ # 'ref_node_name': 'linear1',
32
+ # # index of this node within the arg of the input/output node
33
+ # # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
34
+ # 'index_within_arg': 0,
35
+ # # index of this node within the args of the input/output node
36
+ # # for example, in add(x1, x2), x2 would have index_of_arg == 1
37
+ # 'index_of_arg': 0,
38
+ # # precomputed comparisons of logger values to reference values
39
+ # 'comparisons': [torch.tensor(...), ...]
40
+ # # name of function used for precomputed comparisons
41
+ # 'comparison_fn_name': 'sqnr',
42
+ # # string representation of qconfig responsible for creating this logger
43
+ # 'qconfig_str': 'QConfig(...)',
44
+ # }
45
+ NSSingleResultType = Dict[str, Any]
46
+
47
+ # {
48
+ # 'layer_name_1': { # subgraph name
49
+ # 'node_output': { # results type (node_output, node_input, weight)
50
+ # 'model_name_a': # model name
51
+ # [NSSingleResultType, ...], # results, ordered by index_within_arg
52
+ # 'model_name_b':
53
+ # [NSSingleResultType, ...],
54
+ # },
55
+ # },
56
+ # }
57
+ #
58
+ NSResultsType = Dict[str, Dict[str, Dict[str, List[NSSingleResultType]]]]
59
+
60
+ # Defines the underlying target type of a node, for example:
61
+ # `F.conv1d` for a `call_function` conv node
62
+ # `nn.Conv1d` for a `call_module` node calling the forward of a `nn.Conv1d` module
63
+ # `'sigmoid'` for a `call_method` node calling `x.sigmoid()`
64
+ NSNodeTargetType = Union[Callable, str]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/pattern_utils.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ toq = torch.ops.quantized
5
+
6
+ from torch.fx import GraphModule
7
+ from torch.fx.graph import Node
8
+
9
+ from torch.ao.quantization.backend_config import get_native_backend_config
10
+ from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
11
+ from torch.ao.quantization.utils import getattr_from_fqn
12
+ from .ns_types import NSNodeTargetType
13
+ from torch.ao.quantization import (
14
+ ObserverBase,
15
+ FakeQuantizeBase,
16
+ )
17
+
18
+ from typing import Dict, Tuple, Set, Callable, Any, Union, List
19
+
20
+
21
+ def get_type_a_related_to_b(
22
+ base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
23
+ ) -> Set[Tuple[NSNodeTargetType, NSNodeTargetType]]:
24
+ # TODO(future PR): allow customizations
25
+ # TODO(future PR): reuse existing quantization mappings
26
+ # TODO(future PR): add the rest of modules and ops here
27
+ type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]] = set()
28
+
29
+ for s in base_name_to_sets_of_related_ops.values():
30
+ s_list = list(s)
31
+ # add every bidirectional pair
32
+ for idx_0 in range(0, len(s_list)):
33
+ for idx_1 in range(idx_0, len(s_list)):
34
+ type_a_related_to_b.add((s_list[idx_0], s_list[idx_1]))
35
+ type_a_related_to_b.add((s_list[idx_1], s_list[idx_0]))
36
+
37
+ return type_a_related_to_b
38
+
39
+
40
+ NSFusionElType = Union[
41
+ Callable, # call_function or call_module type, example: F.linear or nn.Conv2d
42
+ str, # call_method name, example: "dequantize"
43
+ Tuple[str, Any], # call_method name and first argument, example: ("to", torch.float16)
44
+ ]
45
+ NSFusionType = Union[
46
+ Tuple[NSFusionElType, NSFusionElType],
47
+ Tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType],
48
+ ]
49
+
50
+ def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]:
51
+ """
52
+ Set of potential fusions, in reverse order. The order is reversed
53
+ to match how fusion patterns are defined in quantization code.
54
+
55
+ Fusion format:
56
+ ((fusion_op_0, fusion_op_1), base_op_idx)
57
+
58
+ Where base_op_idx is the idx of the op we should use to match other related
59
+ ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx
60
+ of 0 represents the first op in regular (non-reverse) order, 1 represents the
61
+ second op, etc.
62
+ """
63
+ results: List[Tuple[NSFusionType, int]] = []
64
+
65
+ # Possible syntaxes:
66
+ # * single op: torch.nn.Conv2d
67
+ # * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d)
68
+ # For fusions, we only care about patterns composed of multiple ops.
69
+ # TODO(future PR): allow customizations from default patterns.
70
+ all_quant_patterns = _get_pattern_to_quantize_handlers(get_native_backend_config())
71
+
72
+ default_base_op_idx = 0
73
+ for quant_pattern in all_quant_patterns.keys():
74
+ # TODO: this is a temporary hack to flatten the patterns from quantization so
75
+ # that it works with the ns matcher function, maybe we should use `_is_match`
76
+ # in torch.ao.quantization.fx.match_utils to match the patterns
77
+ if isinstance(quant_pattern, tuple) and len(quant_pattern) == 2 and \
78
+ isinstance(quant_pattern[1], tuple) and len(quant_pattern[1]) == 2:
79
+ # flatten the pattern with form (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))
80
+ quant_pattern = (quant_pattern[0], quant_pattern[1][0], quant_pattern[1][1])
81
+
82
+ # Only patterns of multiple ops are fusions, ignore
83
+ # patterns which contain a single ops (they get matched
84
+ # without caring about fusions).
85
+ if isinstance(quant_pattern, tuple):
86
+ results.append((quant_pattern, default_base_op_idx)) # type: ignore[arg-type]
87
+
88
+ # For each pattern, add additional patterns with observers and
89
+ # fake quants at the end.
90
+ # TODO(future PR): if needed, implement matching for a node
91
+ # having multiple output observers.
92
+ for cls in (ObserverBase, FakeQuantizeBase):
93
+ if isinstance(quant_pattern, tuple):
94
+ new_pattern = (cls, *quant_pattern)
95
+ else:
96
+ new_pattern = (cls, quant_pattern)
97
+ results.append((new_pattern, default_base_op_idx)) # type: ignore[arg-type]
98
+
99
+
100
+ # After this point, results contains values such as
101
+ # [..., ((torch.nn.Relu, torch.nn.Conv2d), 0), ...]
102
+
103
+ # Patterns for matching fp16 emulation are not specified in the quantization
104
+ # fusion mappings. For now, define them here.
105
+ fp16_em_base_op_idx = 1
106
+ patterns_to_add = [
107
+ # linear-relu fp16 emulation:
108
+ # fp16_to_fp32 -> linear -> relu -> fp32_to_fp16
109
+ ((("to", torch.float16), F.relu, F.linear, "dequantize"), fp16_em_base_op_idx,),
110
+ # Conv-BN fusion (this happens outside of quantization patterns,
111
+ # which is why it is defined separately here).
112
+ ((nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
113
+ ((nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
114
+ ((nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
115
+ ((nn.ReLU, nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
116
+ ((nn.ReLU, nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
117
+ ((nn.ReLU, nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
118
+ ]
119
+ for p in patterns_to_add:
120
+ results.append(p) # type: ignore[arg-type]
121
+ results.append(((ObserverBase, *p[0]), p[1])) # type: ignore[arg-type]
122
+ results.append(((FakeQuantizeBase, *p[0]), p[1])) # type: ignore[arg-type]
123
+
124
+ return results
125
+
126
+
127
+ def end_node_matches_reversed_fusion(
128
+ end_node: Node,
129
+ reversed_fusion: NSFusionType,
130
+ gm: GraphModule,
131
+ seen_nodes: Set[Node],
132
+ ) -> bool:
133
+ """
134
+ Returns true if a pattern ending with `end_node` matches
135
+ the fusion pattern.
136
+ """
137
+ cur_node = end_node
138
+ for fusion_idx in range(len(reversed_fusion)):
139
+ # each node can only belong to one matched pattern
140
+ if cur_node in seen_nodes:
141
+ return False
142
+
143
+ cur_fusion_el = reversed_fusion[fusion_idx]
144
+
145
+ if cur_node.op == 'call_function':
146
+ fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and \
147
+ (not isinstance(cur_fusion_el, type))
148
+ if fusion_el_is_fun:
149
+ if cur_node.target != cur_fusion_el:
150
+ return False
151
+ if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
152
+ cur_node = cur_node.args[0]
153
+ else:
154
+ return False
155
+ else:
156
+ return False
157
+
158
+ elif cur_node.op == 'call_module':
159
+ fusion_el_is_mod = isinstance(cur_fusion_el, type)
160
+ if fusion_el_is_mod:
161
+ assert isinstance(cur_node.target, str)
162
+ target_mod = getattr_from_fqn(gm, cur_node.target)
163
+ if not isinstance(cur_fusion_el, type):
164
+ return False
165
+ if not isinstance(target_mod, cur_fusion_el):
166
+ return False
167
+ if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
168
+ cur_node = cur_node.args[0]
169
+ else:
170
+ return False
171
+ else:
172
+ return False
173
+
174
+ elif cur_node.op == 'call_method':
175
+ fusion_el_is_meth_with_second_arg = \
176
+ isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2
177
+ fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str)
178
+ if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg:
179
+ if fusion_el_is_meth_without_args:
180
+ if cur_node.target != cur_fusion_el:
181
+ return False
182
+ else:
183
+ assert isinstance(cur_fusion_el, tuple)
184
+ if cur_node.target != cur_fusion_el[0]:
185
+ return False
186
+ elif len(cur_node.args) < 2:
187
+ return False
188
+ elif cur_node.args[1] != cur_fusion_el[1]:
189
+ return False
190
+
191
+ if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
192
+ cur_node = cur_node.args[0]
193
+ else:
194
+ return False
195
+ else:
196
+ return False
197
+ else:
198
+ return False
199
+
200
+ return True
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/qconfig_multi_mapping.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from typing import Any, Callable, Dict, List, Union
5
+
6
+ import torch
7
+ from torch.ao.quantization import QConfigMapping
8
+ from torch.ao.quantization.qconfig_mapping import _QCONFIG_STYLE_ORDER
9
+ from torch.ao.quantization.qconfig import QConfigAny
10
+
11
+ __all__ = ["QConfigMultiMapping"]
12
+
13
+ _QCONFIG_STYLE_TO_METHOD: Dict[str, str] = {
14
+ "global_qconfig": "set_global",
15
+ "object_type_qconfigs": "set_object_type",
16
+ "module_name_regex_qconfigs": "set_module_name_regex",
17
+ "module_name_qconfigs": "set_module_name",
18
+ "module_name_object_type_order_qconfigs": "set_module_name_object_type_order",
19
+ }
20
+
21
+ def _remove_duplicates_and_none(qconfig_list: List[QConfigAny]) -> None:
22
+ to_remove = []
23
+ for index, cur_qconfig in enumerate(qconfig_list):
24
+ if cur_qconfig is None:
25
+ to_remove.append(index)
26
+ break
27
+ for checked_qconfig in qconfig_list[:index]:
28
+ if torch.ao.quantization.qconfig_equals(cur_qconfig, checked_qconfig):
29
+ to_remove.append(index)
30
+ break
31
+ for index in to_remove[::-1]:
32
+ qconfig_list.pop(index)
33
+
34
+ class QConfigMultiMapping:
35
+ """
36
+ This class, used with the prepare_n_shadows_model API, stores a list of :class:`torch.ao.quantization.QConfigMapping`s
37
+ so that multiple QConfigs can be specified for each QConfig matching style.
38
+
39
+ The user can specify QConfigs using the following methods (in increasing match priority):
40
+
41
+ ``set_global`` : sets the global (default) QConfigs
42
+
43
+ ``set_object_type`` : sets the QConfigs for a given module type, function, or method name
44
+
45
+ ``set_module_name_regex`` : sets the QConfigs for modules matching the given regex string
46
+
47
+ ``set_module_name`` : sets the QConfigs for modules matching the given module name
48
+
49
+ ``set_module_name_object_type_order`` : sets the QConfigs for modules matching a combination
50
+ of the given module name, object type, and the index at which the module appears
51
+
52
+ Note: Usage of set methods is the same as in QConfigMapping except with a passed in list of QConfigs rather than a
53
+ single QConfig.
54
+
55
+ Example usage::
56
+
57
+ qconfig_mapping = QConfigMultiMapping()
58
+ .set_global([qconfig1, qconfig2])
59
+ .set_object_type(torch.nn.Linear, [qconfig2, qconfig3])
60
+ .set_object_type(torch.nn.ReLU, [qconfig1])
61
+ .set_module_name_regex("foo.*bar.*conv[0-9]+", [qconfig2])
62
+ .set_module_name_regex("foo.*", [qconfig1, qconfig2, qconfig3])
63
+ .set_module_name("module1", [None])
64
+ .set_module_name("module2", [qconfig2])
65
+ .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, [qconfig3])
66
+
67
+ """
68
+
69
+ def __init__(self):
70
+ # initialize this with 1 QConfigMapping to avoid corner cases
71
+ self.qconfig_mappings_list: List[QConfigMapping] = [QConfigMapping()]
72
+
73
+ def _handle_list_size_mismatch(
74
+ self, qconfig_list: List[QConfigAny], style: str
75
+ ) -> None:
76
+ # this method handles cases where the size of qconfig_list does not match
77
+ # the size of qconfig_mappings_list.
78
+ # Issue: Consider a user inserting global_qconfig A and B first, then inserting
79
+ # qconfig C as an object_type_qconfig for conv ops. If we internally store
80
+ # 1 QConfigMapping with A and C and another with just B, then the
81
+ # second QConfigMapping will match B to conv ops (which is not wanted), since B is global.
82
+
83
+ # we avoid this by maintaining the invariant that if any QConfigMapping
84
+ # has a qconfig style+key with a qconfig in it, all QConfigMappings must
85
+ # have either a qconfig or None for that same style+key. In the above
86
+ # example, a None qconfig would prevent the unwanted match in the
87
+ # second QConfigMapping
88
+
89
+ if len(qconfig_list) > len(self.qconfig_mappings_list):
90
+ # Case: we have more qconfigs (in qconfig_list) than QConfigMappings
91
+
92
+ # Add new QConfigMappings (initialized so we maintain the `invariant`)
93
+
94
+ new_qconfig_mapping = QConfigMapping()
95
+ # searches other QConfigMappings for qconfig style+keys
96
+ # that need to be inserted as `None` into the new QConfigMapping
97
+ for qconfig_mapping in self.qconfig_mappings_list:
98
+
99
+ # global_qconfig has None by default
100
+ for check_style in _QCONFIG_STYLE_ORDER[1:]:
101
+ qconfigs_dict = getattr(qconfig_mapping, check_style)
102
+ target_qconfigs_dict = getattr(new_qconfig_mapping, check_style)
103
+ for key in qconfigs_dict:
104
+ target_qconfigs_dict[key] = None
105
+ break
106
+
107
+ # insert copies of this new QConfigMapping until all entires
108
+ # in qconfig_list can fit among the QConfigMappings
109
+ while len(qconfig_list) > len(self.qconfig_mappings_list):
110
+ self.qconfig_mappings_list.append(copy.deepcopy(new_qconfig_mapping))
111
+ else:
112
+ # Case: we have fewer qconfigs in qconfig_list than QConfigMappings
113
+
114
+ # pad qconfig_list with `None` until length is same
115
+ while len(qconfig_list) < len(self.qconfig_mappings_list):
116
+ qconfig_list.append(None)
117
+
118
+ # this function applies the insertion method across each QConfigMapping
119
+ def _insert_qconfig_list(
120
+ self,
121
+ style: str,
122
+ args: List[Union[str, int, Callable]],
123
+ qconfig_list: List[QConfigAny],
124
+ ) -> None:
125
+
126
+ # we remove duplicates and None to make the ordering of qconfigs
127
+ # deterministic upon insertion.
128
+ _remove_duplicates_and_none(qconfig_list)
129
+
130
+ self._handle_list_size_mismatch(qconfig_list, style)
131
+ method_name = _QCONFIG_STYLE_TO_METHOD[style]
132
+ for qconfig_mapping, qconfig in zip(self.qconfig_mappings_list, qconfig_list):
133
+ # uses QConfigMapping set method to insert qconfig
134
+ set_method = getattr(qconfig_mapping, method_name)
135
+ set_method(*args, qconfig)
136
+
137
+ def set_global(self, global_qconfig_list: List[QConfigAny]) -> QConfigMultiMapping:
138
+ """
139
+ Set global QConfigs
140
+ see :func:`~torch.ao.quantization.QConfigMapping.set_global()` for more info
141
+ """
142
+ self._insert_qconfig_list("global_qconfig", [], global_qconfig_list)
143
+ return self
144
+
145
+ def set_object_type(
146
+ self, object_type: Union[Callable, str], qconfig_list: List[QConfigAny]
147
+ ) -> QConfigMultiMapping:
148
+ """
149
+ Set object type QConfigs
150
+ see :func:`~torch.ao.quantization.QConfigMapping.set_object_type()` for more info
151
+ """
152
+ self._insert_qconfig_list("object_type_qconfigs", [object_type], qconfig_list)
153
+ return self
154
+
155
+ def set_module_name_regex(
156
+ self, module_name_regex: str, qconfig_list: List[QConfigAny]
157
+ ) -> QConfigMultiMapping:
158
+ """
159
+ Set module_name_regex QConfigs
160
+ see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_regex()` for more info
161
+ """
162
+ self._insert_qconfig_list(
163
+ "module_name_regex_qconfigs", [module_name_regex], qconfig_list
164
+ )
165
+ return self
166
+
167
+ def set_module_name(
168
+ self, module_name: str, qconfig_list: List[QConfigAny]
169
+ ) -> QConfigMultiMapping:
170
+ """
171
+ Set module_name QConfigs
172
+ see :func:`~torch.ao.quantization.QConfigMapping.set_module_name()` for more info
173
+ """
174
+ self._insert_qconfig_list("module_name_qconfigs", [module_name], qconfig_list)
175
+ return self
176
+
177
+ def set_module_name_object_type_order(
178
+ self,
179
+ module_name: str,
180
+ object_type: Callable,
181
+ index: int,
182
+ qconfig_list: List[QConfigAny],
183
+ ) -> QConfigMultiMapping:
184
+ """
185
+ Set module_name QConfigs
186
+ see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_object_type_order()` for more info
187
+ """
188
+ self._insert_qconfig_list(
189
+ "module_name_object_type_order_qconfigs",
190
+ [module_name, object_type, index],
191
+ qconfig_list,
192
+ )
193
+ return self
194
+
195
+ def __repr__(self):
196
+ return (
197
+ self.__class__.__name__ +
198
+ " [" +
199
+ "".join(f"\n{qconfig_mapping.__repr__()}," for qconfig_mapping in self.qconfig_mappings_list) +
200
+ "\n]"
201
+ )
202
+
203
+ @classmethod
204
+ def from_list_qconfig_mapping(
205
+ cls, qconfig_mapping_list: List[QConfigMapping]
206
+ ) -> QConfigMultiMapping:
207
+ """
208
+ Creates a QConfigMultiMapping from a list of QConfigMappings
209
+ """
210
+ new_qconfig_multi_mapping = cls()
211
+
212
+ new_qconfig_multi_mapping.qconfig_mappings_list = copy.deepcopy(
213
+ qconfig_mapping_list
214
+ )
215
+
216
+ # we need to avoid the issue described in _handle_list_size_mismatch,
217
+ # so we reinsert all the qconfigs using the QConfigMultiMapping
218
+ # set methods
219
+
220
+ # go through all qconfig styles
221
+ # note: global can be ignored since it is None by default
222
+ for style in _QCONFIG_STYLE_ORDER[1:]:
223
+
224
+ # gather all key+qconfigs for current style
225
+ # into qconfig_dict_list
226
+ qconfig_dict_list: Dict[Any, List[QConfigAny]] = {}
227
+ for qconfig_mapping in qconfig_mapping_list:
228
+ qconfig_dict = getattr(qconfig_mapping, style)
229
+ for key, qconfig in qconfig_dict.items():
230
+ if key not in qconfig_dict_list:
231
+ qconfig_dict_list[key] = []
232
+ qconfig_dict_list[key].append(qconfig)
233
+
234
+ # reinsert all gathered key+qconfigs
235
+ set_method_name = _QCONFIG_STYLE_TO_METHOD[style]
236
+ set_method = getattr(new_qconfig_multi_mapping, set_method_name)
237
+ for key, qconfig_list in qconfig_dict_list.items():
238
+ if isinstance(key, tuple):
239
+ set_method(*key, qconfig_list)
240
+ else:
241
+ set_method(key, qconfig_list)
242
+
243
+ return new_qconfig_multi_mapping
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_learnable_fake_quantize.cpython-311.pyc ADDED
Binary file (11.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-311.pyc ADDED
Binary file (7.83 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/observer.cpython-311.pyc ADDED
Binary file (75 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/qconfig_mapping.cpython-311.pyc ADDED
Binary file (16 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-311.pyc ADDED
Binary file (31.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_jit.cpython-311.pyc ADDED
Binary file (18.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (227 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (350 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/rewrite.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.fx import GraphModule
3
+ from ..export_utils import _WrapperModule
4
+ from ..utils import (
5
+ get_aten_graph_module,
6
+ remove_tensor_overload_for_qdq_ops,
7
+ _replace_literals_with_new_placeholders,
8
+ _replace_literals_with_existing_placeholders,
9
+ )
10
+ from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
11
+ from torch.fx.subgraph_rewriter import replace_pattern
12
+ from torch._higher_order_ops.out_dtype import out_dtype
13
+ from typing import Optional, Callable, Tuple, Any
14
+ from dataclasses import dataclass
15
+
16
+ from functools import partial
17
+
18
+ __all__ = [
19
+ "reference_representation_rewrite",
20
+ ]
21
+
22
+
23
+ _QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
24
+ torch.randint(-128, 127, (2, 5), dtype=torch.int8),
25
+ torch.randn(1, dtype=torch.float),
26
+ torch.zeros(1, dtype=torch.int),
27
+ torch.tensor([-128], dtype=torch.int),
28
+ torch.tensor([127], dtype=torch.int),
29
+ torch.randint(-128, 127, (5, 5), dtype=torch.int8),
30
+ torch.randn(1, dtype=torch.float),
31
+ torch.zeros(1, dtype=torch.int),
32
+ torch.tensor([-127], dtype=torch.int),
33
+ torch.tensor([127], dtype=torch.int),
34
+ torch.randn(1, dtype=torch.float),
35
+ torch.randn(1, dtype=torch.float),
36
+ torch.zeros(1, dtype=torch.int),
37
+ torch.tensor([-128], dtype=torch.int),
38
+ torch.tensor([127], dtype=torch.int),
39
+ )
40
+
41
+ def _qdq_quantized_linear(
42
+ x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
43
+ weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
44
+ bias_fp32,
45
+ out_scale, out_zero_point, out_quant_min, out_quant_max
46
+ ):
47
+ x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
48
+ x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
49
+ weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
50
+ weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
51
+ out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
52
+ out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
53
+ out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
54
+ return out_i8
55
+
56
+ def _reference_quantized_linear(
57
+ x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
58
+ weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
59
+ bias_fp32,
60
+ out_scale, out_zero_point, out_quant_min, out_quant_max
61
+ ):
62
+ # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
63
+ # This results in failure to match the pattern.
64
+ # Therefore, we call a torch.ops.aten.clamp here
65
+ x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max)
66
+ weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
67
+
68
+ x_i16 = x_i8.to(torch.int16)
69
+ weight_i16 = weight_i8.to(torch.int16)
70
+ # always set bias to None so that the same representation can work for the case
71
+ # no matter if bias_scale == x_scale * weight_scale or not
72
+ acc_i32 = out_dtype(
73
+ torch.ops.aten.linear.default,
74
+ torch.int32,
75
+ x_i16 - x_zero_point,
76
+ weight_i16 - weight_zero_point,
77
+ None)
78
+ # TODO: change to mul.Scalar
79
+ # Note: we are quantizing bias with these scales without signal from user, but it might be OK
80
+ bias_scale = x_scale * weight_scale
81
+ bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
82
+ acc_i32 = acc_i32 + bias_i32
83
+ # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values
84
+ acc_i32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, acc_i32, x_scale * weight_scale / out_scale) + out_zero_point
85
+ out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8)
86
+ return out_i8
87
+
88
+
89
+ _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
90
+ torch.randn((2, 5), dtype=torch.float),
91
+ -128,
92
+ 127,
93
+ torch.finfo(torch.float32).eps,
94
+ torch.randint(-128, 127, (5, 5), dtype=torch.int8),
95
+ torch.randn(1, dtype=torch.float),
96
+ torch.zeros(1, dtype=torch.int),
97
+ torch.tensor([-127], dtype=torch.int),
98
+ torch.tensor([127], dtype=torch.int),
99
+ torch.randn(1, dtype=torch.float),
100
+ )
101
+
102
+
103
+ def _qdq_dynamic_quantized_linear(
104
+ x_fp32, x_quant_min, x_quant_max, x_eps,
105
+ weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
106
+ bias_fp32,
107
+ ):
108
+ x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8)
109
+ x_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
110
+ x_fp32, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
111
+ x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
112
+ x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
113
+ weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
114
+ weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
115
+ out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
116
+ return out_fp32
117
+
118
+ def _reference_dynamic_quantized_linear(
119
+ x_fp32, x_quant_min, x_quant_max, x_eps,
120
+ weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
121
+ bias_fp32,
122
+ ):
123
+ x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8)
124
+ # decomposed representation for quantize_per_tensor
125
+ # TODO: use out_dtype(mul, ...) here when the op is ready
126
+ x_fp32 = x_fp32 / x_scale # fp32
127
+ # round modes might be different here
128
+ # pytorch is rounding to even, which is also common for most of the backends
129
+ x_fp32 = torch.round(x_fp32) # fp32
130
+ x_i32 = x_fp32.to(dtype=torch.int32) # int32
131
+ x_i32 = x_i32 + x_zero_point # int32
132
+ # clamp works for fp32, int32 and int8 dtypes
133
+ x_i32 = torch.clamp(x_i32, x_quant_min, x_quant_max) # int32
134
+ x_i8 = x_i32.to(dtype=torch.int8)
135
+
136
+ weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
137
+
138
+ x_i16 = x_i8.to(torch.int16)
139
+ weight_i16 = weight_i8.to(torch.int16)
140
+ # always set bias to None so that the same representation can work for the case
141
+ # no matter if bias_scale == x_scale * weight_scale or not
142
+ acc_i32 = out_dtype(
143
+ torch.ops.aten.linear.default,
144
+ torch.int32,
145
+ x_i16 - x_zero_point,
146
+ weight_i16 - weight_zero_point,
147
+ None)
148
+ bias_scale = x_scale * weight_scale
149
+ bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
150
+ acc_i32 = acc_i32 + bias_i32
151
+ out_fp32 = acc_i32 * (x_scale * weight_scale)
152
+ return out_fp32
153
+
154
+
155
+ _QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
156
+ torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
157
+ torch.randn(1, dtype=torch.float),
158
+ torch.zeros(1, dtype=torch.int),
159
+ torch.tensor([-128], dtype=torch.int),
160
+ torch.tensor([127], dtype=torch.int),
161
+ torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
162
+ torch.randn(1, dtype=torch.float),
163
+ torch.zeros(1, dtype=torch.int),
164
+ torch.tensor([-127], dtype=torch.int),
165
+ torch.tensor([127], dtype=torch.int),
166
+ torch.randn(1, dtype=torch.float),
167
+ torch.randn(1, dtype=torch.float),
168
+ torch.zeros(1, dtype=torch.int),
169
+ torch.tensor([-128], dtype=torch.int),
170
+ torch.tensor([127], dtype=torch.int),
171
+ )
172
+
173
+ def _qdq_quantized_conv2d(
174
+ x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
175
+ weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
176
+ bias_fp32,
177
+ out_scale, out_zero_point, out_quant_min, out_quant_max
178
+ ):
179
+ stride = [1, 1]
180
+ padding = [0, 0]
181
+ dilation = [1, 1]
182
+ transposed = False
183
+ output_padding = [0, 0]
184
+ groups = 1
185
+ x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
186
+ x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
187
+ weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
188
+ weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
189
+ out_fp32 = torch.ops.aten.convolution.default(
190
+ x_fp32, weight_fp32, bias_fp32, stride, padding, dilation, transposed, output_padding, groups)
191
+ out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
192
+ out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
193
+ return out_i8
194
+
195
+ def _reference_quantized_conv2d(
196
+ x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
197
+ weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
198
+ bias_fp32,
199
+ out_scale, out_zero_point, out_quant_min, out_quant_max
200
+ ):
201
+ stride = [1, 1]
202
+ padding = [0, 0]
203
+ dilation = [1, 1]
204
+ transposed = False
205
+ output_padding = [0, 0]
206
+ groups = 1
207
+ # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
208
+ # This results in failure to match the pattern.
209
+ # Therefore, we call a torch.ops.aten.clamp here
210
+ x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max)
211
+ weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
212
+
213
+ x_i16 = x_i8.to(torch.int16)
214
+ weight_i16 = weight_i8.to(torch.int16)
215
+ # always set bias to None so that the same representation can work for the case
216
+ # no matter if bias_scale == x_scale * weight_scale or not
217
+ acc_i32 = out_dtype(
218
+ torch.ops.aten.convolution.default,
219
+ torch.int32,
220
+ x_i16 - x_zero_point,
221
+ weight_i16 - weight_zero_point,
222
+ None, stride, padding, dilation, transposed, output_padding, groups)
223
+ # Note: we are quantizing bias with these scales without signal from user, but it might be OK
224
+ bias_scale = x_scale * weight_scale
225
+ # bias quantization to int32 uses bias_scale = x_scale * weight_scale due to:
226
+ # Take linear calculation for example
227
+ # Out_(i, j)_fp32 = Sum_(over k)[X_(i, k)_fp32 * W_(i, k)_fp32] + bias_(i)_fp32
228
+ # Represent X, W fp32 as their dequant transforms
229
+ # A_fp32 = (A_q - A_zero_point)/A_scale
230
+ # Out_(i, j)_fp32 = Sum_(over k)[(X_(i, k)_fp32 - X_zp) * X_scale * (W_(i, k)_fp32 - W_zp) * W_scale] + bias_(i)_fp32
231
+ # Factor out X_scale and W_scale
232
+ # Out_(i, j)_fp32 = ((X_scale * W_scale) * Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)]) + bias_(i)_fp32
233
+ # In order to addition of bias_(i)_fp32 inside, we must do
234
+ # Out_(i, j)_fp32 = (X_scale * W_scale) * (Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)] + (1 / (X_scale * W_scale)) * bias_(i)_fp32)W_scale # noqa: B950
235
+ # Note we had to multiply bias_fp32 qith X_scale * W_scale = bias_scale
236
+ # Thus bias quantization to int32 must be with X_scale * W_scale
237
+
238
+ bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
239
+ # Unsqueeze to match broadcast dims
240
+ # Unfortnuately I cannot do bias_i32.unsqueeze(0) due to literal matching nightmare
241
+ # in graph pattern replacement
242
+ bias_i32 = bias_i32.unsqueeze(-1)
243
+ bias_i32 = bias_i32.unsqueeze(-1)
244
+ acc_i32 = acc_i32 + bias_i32
245
+ # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values
246
+ acc_i32 = out_dtype(
247
+ torch.ops.aten.mul.Tensor, torch.int32, acc_i32, x_scale * weight_scale / out_scale) + out_zero_point
248
+ out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8)
249
+ return out_i8
250
+
251
+
252
+ _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = (
253
+ torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
254
+ torch.randn(1, dtype=torch.float),
255
+ torch.zeros(1, dtype=torch.int),
256
+ torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
257
+ torch.randn(1, dtype=torch.float),
258
+ torch.zeros(1, dtype=torch.int),
259
+ torch.randn(1, dtype=torch.float),
260
+ torch.zeros(1, dtype=torch.int),
261
+ torch.tensor([-128], dtype=torch.int),
262
+ torch.tensor([127], dtype=torch.int),
263
+ )
264
+
265
+ def _qdq_quantized_add_relu(
266
+ x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point,
267
+ out_scale, out_zero_point, quant_min, quant_max
268
+ ):
269
+ x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8)
270
+ y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8)
271
+ out_fp32 = x_fp32 + y_fp32
272
+ out_fp32 = torch.ops.aten.relu(out_fp32)
273
+ out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
274
+ out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8
275
+ )
276
+ return out_i8
277
+
278
+ def _reference_quantized_add_relu(
279
+ x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point,
280
+ out_scale, out_zero_point, quant_min, quant_max
281
+ ):
282
+ """
283
+ See comments for `_reference_quantized_add` for more information on
284
+ how to derive the formula for out_i8 based on x_i8 and y_i8
285
+ """
286
+ x_i32 = x_i8.to(torch.int32)
287
+ y_i32 = y_i8.to(torch.int32)
288
+ # TODO: change this to mul.Scalar?
289
+ x_i32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, (x_i32 - x_zero_point), (x_scale / out_scale))
290
+ y_i32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, (y_i32 - y_zero_point), (y_scale / out_scale))
291
+ out_i32 = x_i32 + y_i32 + out_zero_point
292
+ # out_i32 = torch.ops.aten.clamp(out_i32, out_zero_point)
293
+ out_i8 = torch.ops.aten.clamp(out_i32, out_zero_point, quant_max).to(torch.int8)
294
+ return out_i8
295
+
296
+ def _qdq_quantized_add(x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point, out_scale, out_zero_point, quant_min, quant_max):
297
+ x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8)
298
+ y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8)
299
+ out_fp32 = x_fp32 + y_fp32
300
+ out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
301
+ out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8
302
+ )
303
+ return out_i8
304
+
305
+ def _reference_quantized_add(
306
+ x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point,
307
+ out_scale, out_zero_point, quant_min, quant_max
308
+ ):
309
+ """
310
+ # How to Derive the formula for out_i8 based on x_i8 and y_i8
311
+ # (since quantized add takes x_i8, y_i8 and their quantization parameters, and produce an out_i8)
312
+
313
+ # out_i8 is quantized output, we can write down the formula for it first:
314
+ out_i8 = out_f32 / out_scale + out_zero_point (1)
315
+
316
+ # then out_fp32 is computed from x_f32 + y_f32, and the x_fp32 and y_fp32 are the dequantized x_i8 and y_i8
317
+ out_f32 = x_f32 + y_f32 (2)
318
+ x_fp32 = (x_i8 - x_zero_point) * x_scale (3)
319
+ y_fp32 = (y_i8 - y_zero_point) * y_scale (4)
320
+
321
+ # applying the above fomula to the out_i8 equation we can get the following:
322
+ out_i8 = out_fp32 / out_scale + out_zero_point # (1)
323
+ = (x_f32 + y_f32) / out_scale + out_zero_point # applying (2) to substitute out_fp32 with x_fp32 + y_fp32
324
+ = ((x_i8 - x_zero_point) * x_scale + (y_i8 - y_zero_point) * y_scale) / out_scale + out_zero_point # apply (3) and (4)
325
+ """
326
+ x_i32 = x_i8.to(torch.int32)
327
+ y_i32 = y_i8.to(torch.int32)
328
+ # TODO: use out_dtype op
329
+ x_i32 = torch.round((x_scale / out_scale) * (x_i32 - x_zero_point)).to(torch.int32)
330
+ y_i32 = torch.round((y_scale / out_scale) * (y_i32 - y_zero_point)).to(torch.int32)
331
+ out_i32 = x_i32 + y_i32 + out_zero_point
332
+ quant_min = -128
333
+ quant_max = 127
334
+ out_i8 = torch.ops.aten.clamp(out_i32, quant_min, quant_max).to(torch.int8)
335
+ return out_i8
336
+
337
+ _QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = (
338
+ torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
339
+ torch.randn(1, dtype=torch.float),
340
+ torch.zeros(1, dtype=torch.int),
341
+ torch.tensor([-128], dtype=torch.int),
342
+ torch.tensor([127], dtype=torch.int),
343
+ torch.randn(1, dtype=torch.float),
344
+ torch.zeros(1, dtype=torch.int),
345
+ torch.tensor([-128], dtype=torch.int),
346
+ torch.tensor([127], dtype=torch.int),
347
+ )
348
+
349
+ def _qdq_quantized_max_pool2d(
350
+ x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, out_scale, out_zero_point, out_quant_min, out_quant_max):
351
+ kernel_size = 1
352
+ stride = 1
353
+ padding = 0
354
+ dilation = 1
355
+ ceil_mode = False
356
+ x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
357
+ out_fp32, _ = torch.ops.aten.max_pool2d_with_indices.default(x_fp32, kernel_size, stride, padding, dilation, ceil_mode)
358
+ out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
359
+ out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
360
+ return out_i8
361
+
362
+ def _reference_quantized_max_pool2d(
363
+ x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, out_scale, out_zero_point, out_quant_min, out_quant_max):
364
+ kernel_size = 1
365
+ stride = 1
366
+ padding = 0
367
+ dilation = 1
368
+ ceil_mode = False
369
+ # to preserve x_quant_min, x_quant_max in the graph for pattern matching
370
+ x_i8 = torch.clamp(x_i8, x_quant_min, x_quant_max)
371
+ x_i32 = x_i8.to(torch.int32)
372
+ out_i32, _ = torch.ops.aten.max_pool2d_with_indices.default(
373
+ x_i32 - x_zero_point,
374
+ kernel_size,
375
+ stride,
376
+ padding,
377
+ dilation,
378
+ ceil_mode
379
+ )
380
+ out_fp32 = out_i32 * (x_scale / out_scale) + out_zero_point
381
+ out_fp32 = torch.clamp(out_fp32, out_quant_min, out_quant_max)
382
+ out_i8 = out_fp32.to(torch.int8)
383
+ return out_i8
384
+
385
+ _QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
386
+ torch.randn(1, 3, 3, 3, dtype=torch.float),
387
+ torch.randn(1, dtype=torch.float),
388
+ torch.zeros(1, dtype=torch.int),
389
+ torch.tensor([-128], dtype=torch.int),
390
+ torch.tensor([127], dtype=torch.int),
391
+ )
392
+
393
+ def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max):
394
+ x = torch.ops.quantized_decomposed.quantize_per_tensor(x_fp32, scale, zero_point, quant_min, quant_max, torch.int8)
395
+ return x
396
+
397
+ def _reference_quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max):
398
+ # TODO: use out_dtype(mul, ...) here when the op is ready
399
+ x = x_fp32 / scale # fp32
400
+ # round modes might be different here
401
+ # pytorch is rounding to even, which is also common for most of the backends
402
+ x = torch.round(x) # fp32
403
+ x = x.to(dtype=torch.int32) # int32
404
+ x = x + zero_point # int32
405
+ # clamp works for fp32, int32 and int8 dtypes
406
+ x = torch.clamp(x, quant_min, quant_max) # int32
407
+ x = x.to(dtype=torch.int8)
408
+ return x
409
+
410
+ _DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
411
+ torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
412
+ torch.randn(1, dtype=torch.float),
413
+ torch.zeros(1, dtype=torch.int),
414
+ torch.tensor([-128], dtype=torch.int),
415
+ torch.tensor([127], dtype=torch.int),
416
+ )
417
+
418
+ def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max):
419
+ x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max, torch.int8)
420
+ return x_fp32
421
+
422
+ def _reference_dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max):
423
+ # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
424
+ # This results in failure to match the pattern.
425
+ # Therefore, we call a torch.ops.aten.clamp here
426
+ x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max)
427
+ # TODO: use out_dtype op
428
+ # note: x_i8.to(torch.int32) does not work here
429
+ # TODO: debug the implementation later when torchdynamo time out issue is resolved
430
+ return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
431
+
432
+ _QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
433
+ torch.randn(1, 3, 3, 3, dtype=torch.float),
434
+ torch.randn(3, dtype=torch.float),
435
+ torch.zeros(3, dtype=torch.int),
436
+ 1,
437
+ -128,
438
+ 127,
439
+ )
440
+
441
+ def _quantize_per_channel_int8(x_fp32, scales, zero_points, ch_axis, quant_min, quant_max):
442
+ out_i8 = torch.ops.quantized_decomposed.quantize_per_channel(
443
+ x_fp32, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8
444
+ )
445
+ return out_i8
446
+
447
+ def _reference_quantize_per_channel_int8(x_fp32, scales, zero_points, ch_axis, quant_min, quant_max):
448
+ x_fp32 = torch.transpose(x_fp32, ch_axis, -1)
449
+ out_i32 = torch.ops.aten.clamp(torch.round(x_fp32 / scales).to(torch.int32) + zero_points, quant_min, quant_max)
450
+ out_i32 = torch.transpose(out_i32, ch_axis, -1)
451
+ return out_i32.to(torch.int8)
452
+
453
+ _DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
454
+ torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
455
+ torch.randn(3, dtype=torch.float),
456
+ torch.zeros(3, dtype=torch.int),
457
+ 1,
458
+ -128,
459
+ 127,
460
+ )
461
+
462
+ def _dequantize_per_channel_int8(x_i8, scales, zero_points, ch_axis, quant_min, quant_max):
463
+ # the following will be replaced as placeholders
464
+ out_fp32 = torch.ops.quantized_decomposed.dequantize_per_channel(
465
+ x_i8, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8
466
+ )
467
+ return out_fp32
468
+
469
+ def _reference_dequantize_per_channel_int8(x_i8, scales, zero_points, ch_axis, quant_min, quant_max):
470
+ # the following will be replaced as placeholders
471
+ # in order to preserve the quant_min/quant_max args for pattern matching (e.g. matching for int4 quantized ops)
472
+ # we call a torch.ops.aten.clamp here
473
+ x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max)
474
+ x_i8 = torch.transpose(x_i8, ch_axis, -1)
475
+ x_i32 = x_i8.to(torch.int32)
476
+ out_fp32 = (x_i32 - zero_points).to(torch.float) * scales
477
+ out_fp32 = torch.transpose(out_fp32, ch_axis, -1)
478
+ return out_fp32
479
+
480
+ def _replace_ph_qdq_per_channel_replacement(gm: torch.fx.GraphModule):
481
+ return _replace_literals_with_existing_placeholders(
482
+ gm,
483
+ exclude_literals=[-1],
484
+ literal_to_ph_idx={1: 3, -128: 4, 127: 5}
485
+ )
486
+
487
+
488
+ @dataclass
489
+ class _RewriteInfo:
490
+ """Data needed for rewrite, this includes example inputs, pattern and replacement functions
491
+ and post transformation functions for the exported pattern and replacement GraphModule
492
+ """
493
+
494
+ # example inputs used for exporting the pattern into GraphModule
495
+ example_inputs: Tuple[Any, ...]
496
+ pattern: Callable
497
+ replacement: Callable
498
+ # post transformation on the exported pattern and replacement GraphModule
499
+ pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
500
+ replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
501
+
502
+ _REWRITE_INFO_LIST = [
503
+ _RewriteInfo(
504
+ _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
505
+ _WrapperModule(_qdq_dynamic_quantized_linear),
506
+ _WrapperModule(_reference_dynamic_quantized_linear),
507
+ partial(
508
+ _replace_literals_with_existing_placeholders,
509
+ literal_to_ph_idx={
510
+ -128: 1,
511
+ 127: 2,
512
+ torch.finfo(torch.float32).eps: 3
513
+ }
514
+ ),
515
+ partial(
516
+ _replace_literals_with_existing_placeholders,
517
+ literal_to_ph_idx={
518
+ -128: 1,
519
+ 127: 2,
520
+ torch.finfo(torch.float32).eps: 3
521
+ }
522
+ ),
523
+ ),
524
+ _RewriteInfo(
525
+ _QUANTIZED_LINEAR_EXAMPLE_INPUTS,
526
+ _WrapperModule(_qdq_quantized_linear),
527
+ _WrapperModule(_reference_quantized_linear),
528
+ _replace_literals_with_new_placeholders,
529
+ _replace_literals_with_new_placeholders,
530
+ ),
531
+ _RewriteInfo(
532
+ _QUANTIZED_CONV2d_EXAMPLE_INPUTS,
533
+ _WrapperModule(_qdq_quantized_conv2d),
534
+ _WrapperModule(_reference_quantized_conv2d),
535
+ partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
536
+ partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
537
+ ),
538
+ _RewriteInfo(
539
+ _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
540
+ _WrapperModule(_qdq_quantized_add_relu),
541
+ _WrapperModule(_reference_quantized_add_relu),
542
+ ),
543
+ _RewriteInfo(
544
+ _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
545
+ _WrapperModule(_qdq_quantized_add),
546
+ _WrapperModule(_reference_quantized_add),
547
+ ),
548
+ _RewriteInfo(
549
+ _QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS,
550
+ _WrapperModule(_qdq_quantized_max_pool2d),
551
+ _WrapperModule(_reference_quantized_max_pool2d),
552
+ _replace_literals_with_new_placeholders,
553
+ _replace_literals_with_new_placeholders
554
+ ),
555
+ _RewriteInfo(
556
+ _QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
557
+ _WrapperModule(_quantize_per_tensor_int8),
558
+ _WrapperModule(_reference_quantize_per_tensor_int8),
559
+ ),
560
+ _RewriteInfo(
561
+ _DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
562
+ _WrapperModule(_dequantize_per_tensor_int8),
563
+ _WrapperModule(_reference_dequantize_per_tensor_int8),
564
+ ),
565
+ _RewriteInfo(
566
+ _QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
567
+ _WrapperModule(_quantize_per_channel_int8),
568
+ _WrapperModule(_reference_quantize_per_channel_int8),
569
+ _replace_ph_qdq_per_channel_replacement,
570
+ _replace_ph_qdq_per_channel_replacement
571
+ ),
572
+ _RewriteInfo(
573
+ _DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
574
+ _WrapperModule(_dequantize_per_channel_int8),
575
+ _WrapperModule(_reference_dequantize_per_channel_int8),
576
+ _replace_ph_qdq_per_channel_replacement,
577
+ _replace_ph_qdq_per_channel_replacement
578
+ ),
579
+ ]
580
+
581
+ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
582
+ remove_tensor_overload_for_qdq_ops(model)
583
+ for rewrite_info in _REWRITE_INFO_LIST:
584
+ example_inputs = rewrite_info.example_inputs
585
+ pattern = rewrite_info.pattern
586
+ replacement = rewrite_info.replacement
587
+ pattern_post_trans = rewrite_info.pattern_post_trans
588
+ replacement_post_trans = rewrite_info.replacement_post_trans
589
+ pattern = get_aten_graph_module(pattern, example_inputs) # type: ignore[arg-type, assignment]
590
+ remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type]
591
+ replacement = get_aten_graph_module(replacement, example_inputs) # type: ignore[arg-type, assignment]
592
+ remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type]
593
+ if pattern_post_trans:
594
+ pattern = pattern_post_trans(pattern)
595
+ if replacement_post_trans:
596
+ replacement = replacement_post_trans(replacement)
597
+ pattern.recompile() # type: ignore[attr-defined]
598
+ replacement.recompile() # type: ignore[attr-defined]
599
+ matches = replace_pattern(model, pattern, replacement)
600
+ return model
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (214 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/_tensorboard_vis.cpython-311.pyc ADDED
Binary file (9.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/config.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Whether to disable showing progress on compilation passes
2
+ # Need to add a new config otherwise wil get a circular import if dynamo config is imported here
3
+ disable_progress = True
4
+
5
+ # If True this also shows the node names in each pass, for small models this is great but larger models it's quite noisy
6
+ verbose_progress = False
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-311.pyc ADDED
Binary file (1.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-311.pyc ADDED
Binary file (13.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-311.pyc ADDED
Binary file (62.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/recording.cpython-311.pyc ADDED
Binary file (16.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/validator.cpython-311.pyc ADDED
Binary file (39.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/recording.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import itertools
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.utils._pytree as pytree
8
+
9
+
10
+ __all__ = [
11
+ "ShapeEnvEvent",
12
+ "record_shapeenv_event",
13
+ "replay_shape_env_events",
14
+ "FakeTensorMeta",
15
+ "shape_env_check_state_equal",
16
+ "NotEqualError",
17
+ ]
18
+
19
+ # [Note: Recording ShapeEnv Events]
20
+ # =================================
21
+ #
22
+ # What is a ShapeEnv event?
23
+ # -------------------------
24
+ # We consider a ShapeEnv event every function call (ShapeEnv method or
25
+ # independent function) that modifies the state of the ShapeEnv instance.
26
+ # Such calls are recorded alongside their positional and keyword arguments,
27
+ # so that it may be replayed over a different ShapeEnv instance.
28
+ #
29
+ # See [Note: ShapeEnv State Equality] for what is considered the state
30
+ # of a ShapeEnv instance.
31
+ #
32
+ # What is it for?
33
+ # ---------------
34
+ # ShapeEnv events recording is used for reconstructing the ShapeEnv in an
35
+ # arbitrary state in time.
36
+ #
37
+ # Being able to arbitrarily replay events like so is useful, mainly for
38
+ # translation validation bisection. i.e. if a ValidationException has been
39
+ # raised, find the earliest point in time where the translation validation
40
+ # fails.
41
+ #
42
+ # Besides that, it also allows us to inspect the given instance and,
43
+ # for example, check the guards that would actually be issued at that point.
44
+ #
45
+ # What kind of arguments can be stored in an event?
46
+ # -------------------------------------------------
47
+ # There's no specific rule for what cannot be used as an argument.
48
+ # That said, pay special attention to the following cases:
49
+ #
50
+ # 1. Tensor inputs: there are some tests that check whether the inputs
51
+ # were garbage collected after execution. These will fail if there's
52
+ # an event that is holding a reference to those inputs.
53
+ #
54
+ # 2. ShapeEnv arguments: if there is an argument of ShapeEnv type, that
55
+ # will be automatically replaced by the new given ShapeEnv instance.
56
+ #
57
+ # 3. SymTypes arguments: they also hold references to ShapeEnv. So,
58
+ # whenever we see them, we create a new instance, replacing the
59
+ # ShapeEnv reference.
60
+ #
61
+ # 4. FX nodes: specifically, FX nodes from the FX graph for symbolic
62
+ # shapes. That argument must be replaced when replaying the event at
63
+ # ShapeEnvEvent.run, since it has to reference a node from the given
64
+ # instance, and not from the recorded instance.
65
+
66
+
67
+ # Event class for reconstructing ShapeEnv at arbitrary time.
68
+ #
69
+ # Represents a method call that mutates ShapeEnv in a way that affects the
70
+ # issued guards, when ShapeEnv.produce_guards is called.
71
+ @dataclass
72
+ class ShapeEnvEvent:
73
+ # ShapeEnv method.
74
+ f: Callable
75
+
76
+ # Arguments and keyword arguments called with.
77
+ args: Optional[List[Any]] = None
78
+ kwargs: Optional[Dict[str, Any]] = None
79
+
80
+ # List of tracked_fakes at the time the method was called.
81
+ tracked_fakes: Optional[List[Any]] = None
82
+
83
+ # Name of the captured event.
84
+ # Used for special handling of particular methods.
85
+ name: Optional[str] = None
86
+
87
+ # Replay itself, but using shape_env as self.
88
+ def run(self, shape_env=None) -> Any:
89
+ from torch.fx.experimental.symbolic_shapes import (
90
+ is_symbolic,
91
+ ShapeEnv,
92
+ SymTypes,
93
+ )
94
+
95
+ # Special handling for the constructor event.
96
+ if self.f is ShapeEnv:
97
+ assert shape_env is None and self.args is None and self.kwargs is not None
98
+ return ShapeEnv(**self.kwargs)
99
+
100
+ assert shape_env is not None
101
+ args = list(self.args or list())
102
+ kwargs = dict(self.kwargs or dict())
103
+
104
+ # Replace any argument of type ShapeEnv by the given one.
105
+ args, kwargs = pytree.tree_map_only(
106
+ ShapeEnv, lambda _: shape_env, (args, kwargs)
107
+ )
108
+
109
+ # Replace any argument of type SymTypes by a new instance,
110
+ # replacing its ShapeEnv reference.
111
+ args, kwargs = pytree.tree_map_only(
112
+ lambda x: isinstance(x, SymTypes) and is_symbolic(x),
113
+ lambda a: type(a)(a.node.with_shape_env(shape_env)),
114
+ (args, kwargs),
115
+ )
116
+
117
+ # Converts FX nodes using the mapping argument.
118
+ def maybe_convert_node(x: Any) -> Any:
119
+ if not isinstance(x, torch.fx.Node):
120
+ # Don't do anything to x if it's not an FX node.
121
+ return x
122
+
123
+ # If, at some point, we created an FX node, it means that translation validation is on.
124
+ # It also means we are building an FX graph for symbolic shapes at shape_env.graph, and
125
+ # we are tracking node names at shape_env.name_to_node.
126
+ assert hasattr(shape_env, "name_to_node")
127
+ name_to_node = shape_env.name_to_node # type: ignore[attr-defined]
128
+ assert x.name in name_to_node
129
+ return name_to_node[x.name]
130
+
131
+ # Replaces the value of an specific argument by the result of fn.
132
+ def replacearg(index: int, key: str, fn: Callable):
133
+ if index < len(args):
134
+ args[index] = fn(args[index])
135
+ if key in kwargs:
136
+ kwargs[key] = fn(kwargs[key])
137
+
138
+ if self.is_create_fx_call_function():
139
+ # ShapeEnv.create_fx_call_function:
140
+ # "args" parameter is a tuple of FX nodes from the FX graph of the old ShapeEnv.
141
+ # They must be replaced, since a "call_function" FX node with this tuple as argument
142
+ # will be added to the FX graph of the new shape_env.
143
+ replacearg(
144
+ index=2,
145
+ key="args",
146
+ fn=lambda args: tuple(maybe_convert_node(a) for a in args),
147
+ )
148
+ if self.is_evaluate_expr() or self.is_defer_runtime_assert():
149
+ # ShapeEnv.evaluate_expr and ShapeEnv.defer_runtime_assert:
150
+ # "fx_node" parameter is an (optional) FX node that represents the evaluate expression.
151
+ # They must be replaced, since it will be part of a "call_function" FX node for
152
+ # torch._assert, which will be added to the FX graph of the new shape_env.
153
+ replacearg(index=3, key="fx_node", fn=maybe_convert_node)
154
+
155
+ # Actually call the method with the converted arguments.
156
+ return self.f(*args, **kwargs)
157
+
158
+ def __str__(self) -> str:
159
+ name = self.name if self.name is not None else self.f.__name__
160
+ return f"event: {name} ({self.args}, {self.kwargs})"
161
+
162
+ def is_create_fx_call_function(self) -> bool:
163
+ return self.name == "_create_fx_call_function"
164
+
165
+ def is_evaluate_expr(self) -> bool:
166
+ return self.name == "evaluate_expr"
167
+
168
+ def is_defer_runtime_assert(self) -> bool:
169
+ return self.name == "defer_runtime_assert"
170
+
171
+
172
+ # Extracts a ShapeEnv instance inside args and kwargs.
173
+ # Specifically, it looks for:
174
+ # 1. ShapeEnv arguments
175
+ # 2. SymInt, SymFloat, or SymBool arguments
176
+ # If we find more than one object of any of the above types, we
177
+ # also check that the ShapeEnv instance is the same for all of them.
178
+ def _extract_shape_env_and_assert_equal(args, kwargs):
179
+ from torch.fx.experimental.symbolic_shapes import is_symbolic, ShapeEnv, SymTypes
180
+
181
+ def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv:
182
+ if old is not None:
183
+ assert old is new, "call with different ShapeEnv"
184
+ return new
185
+
186
+ shape_env = None
187
+ for val in itertools.chain(args, kwargs.values()):
188
+ if isinstance(val, ShapeEnv):
189
+ shape_env = assert_equal(shape_env, val)
190
+ if isinstance(val, SymTypes) and is_symbolic(val):
191
+ shape_env = assert_equal(shape_env, val.node.shape_env)
192
+
193
+ return shape_env
194
+
195
+
196
+ # Decorator for recording the given function as a replayable event.
197
+ #
198
+ # This decorator should be used at every function that mutates the state of
199
+ # ShapeEnv in some way that affects the resulting issued guards (i.e. when
200
+ # ShapeEnv.produce_guards is called).
201
+ #
202
+ # save_tracked_fakes: saves a snapshot of the TrackedFake list.
203
+ # This is used when calling ShapeEnv.produce_guards at arbitrary points in time.
204
+ #
205
+ # When to save the list of TrackedFake?
206
+ # =====================================
207
+ # We should save the list of TrackedFake whenever the translation validation
208
+ # bisection may actually stop and call the produce_guards method at the moment
209
+ # right after the recorded function was played. In other words, since the
210
+ # bisection bisects through torch._assert calls, we should save in all methods
211
+ # that adds a torch._assert call to the symbolic shapes FX graph.
212
+ #
213
+ # At the moment, there are 2 methods that save the list:
214
+ # - ShapeEnv.evaluate_expr
215
+ # - ShapeEnv.defer_runtime_assert
216
+ def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
217
+ def decorator(fn: Callable) -> Callable:
218
+ assert callable(fn)
219
+ name = fn.__name__
220
+
221
+ @functools.wraps(fn)
222
+ def wrapper(*args, **kwargs):
223
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
224
+
225
+ if isinstance(args[0], ShapeEnv) and args[0].is_recording: # type: ignore[has-type]
226
+ # If ShapeEnv is already recording an event, call the wrapped
227
+ # function directly.
228
+ #
229
+ # NB: here, we skip the check of whether all ShapeEnv instances
230
+ # are equal, in favor of a faster dispatch.
231
+ return fn(*args, **kwargs)
232
+
233
+ # Retrieve an instance of ShapeEnv.
234
+ # Assumption: the collection of args and kwargs may not reference
235
+ # different ShapeEnv instances.
236
+ self = _extract_shape_env_and_assert_equal(args, kwargs)
237
+
238
+ # If we are calling this function without any ShapeEnv instance
239
+ # alive in its arguments, we don't record and call the original.
240
+ if self is None:
241
+ return fn(*args, **kwargs)
242
+
243
+ # Otherwise, start recording and call the function.
244
+ with self._recording():
245
+ # Take a snapshot of the current tracked_fakes.
246
+ tracked_fakes = (
247
+ self._snapshot_tracked_fakes() if save_tracked_fakes else None
248
+ )
249
+ # Record the event for 'fn'.
250
+ event = ShapeEnvEvent(
251
+ fn, list(args), kwargs, tracked_fakes, name=fn.__name__
252
+ )
253
+ self.events.append(event)
254
+ # Play the event on this ShapeEnv.
255
+ return event.run(self)
256
+
257
+ return wrapper
258
+
259
+ return decorator
260
+
261
+
262
+ # Replays the ShapeEnvEvents list.
263
+ # It assumes the first event is the constructor call.
264
+ #
265
+ # fn: transforms an old FX node into one corresponding to the newly created ShapeEnv.
266
+ def replay_shape_env_events(events):
267
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
268
+
269
+ constructor_event = events[0]
270
+ assert constructor_event.f == ShapeEnv
271
+
272
+ # Constructs the new ShapeEnv.
273
+ shape_env = constructor_event.run()
274
+
275
+ for event in events[1:]:
276
+ try:
277
+ # Actually replays each event.
278
+ # We need to call create_mapping_fn every time, since the node list might
279
+ # change after each event is replayed.
280
+ event.run(shape_env)
281
+ except Exception as e:
282
+ raise RuntimeError(f"failed when running event: {event}") from e
283
+
284
+ return shape_env
285
+
286
+
287
+ # FakeTensor metadata.
288
+ # This is to be used in place of FakeTensor placeholders when calling
289
+ # ShapeEnv.produce_guards.
290
+ @dataclass
291
+ class FakeTensorMeta:
292
+ tensor_size: Tuple[Union[int, torch.SymInt], ...]
293
+ tensor_stride: Tuple[Union[int, torch.SymInt], ...]
294
+ tensor_storage_offset: Union[int, torch.SymInt]
295
+ is_nested: bool
296
+
297
+ def size(self) -> Tuple[Union[int, torch.SymInt], ...]:
298
+ return self.tensor_size
299
+
300
+ def stride(self) -> Tuple[Union[int, torch.SymInt], ...]:
301
+ return self.tensor_stride
302
+
303
+ def storage_offset(self) -> Union[int, torch.SymInt]:
304
+ return self.tensor_storage_offset
305
+
306
+ def dim(self) -> int:
307
+ return len(self.tensor_size)
308
+
309
+ @staticmethod
310
+ def from_fake(fake) -> "FakeTensorMeta":
311
+ return FakeTensorMeta(
312
+ fake.size(), fake.stride(), fake.storage_offset(), fake.is_nested
313
+ )
314
+
315
+
316
+ # [Note: ShapeEnv State Equality]
317
+ # ===============================
318
+ #
319
+ # What is considered ShapeEnv state?
320
+ # ----------------------------------
321
+ # We consider to be the state of a ShapeEnv instance everything that
322
+ # is not in the inline tuple inside remove_nonstate_variables function.
323
+ # That is: the fields within ShapeEnv that modify the flow of execution
324
+ # of the program.
325
+ #
326
+ # So, for example: the replacements field might influence on how an
327
+ # expression is simplified. That, in turn, may result in a guard being
328
+ # statically known (i.e. not added).
329
+ #
330
+ # On the other hand, var_to_stack serves only changes what is printed
331
+ # in the screen, i.e. used only for debugging purposes. Therefore, we
332
+ # should not consider it when comparing states.
333
+ #
334
+ # What to do on NotEqualError?
335
+ # ----------------------------
336
+ # Here are a few possible causes for getting a NotEqualError raised:
337
+ #
338
+ # 1. New field that does not belong in the ShapeEnv state.
339
+ # For example: log field of type ShapeEnvLoggerAdapter. Different
340
+ # ShapeEnv instances will always have different ShapeEnvLoggerAdapter
341
+ # instances, i.e. equality comparison would fail.
342
+ # Solution: add it to the inlined tuple inside remove_nonstate_variables
343
+ # function inside check_equal method.
344
+ #
345
+ # 2. New field that is not directly comparable across instances.
346
+ # For example: guards field of type List[ShapeGuard]. More specifically,
347
+ # the ShapeGuard type holds an expression and a stack information
348
+ # for debugging purposes. When replaying the even on a new ShapeEnv
349
+ # instance, the stack would be different, which would trigger this error.
350
+ # Solution: add a special case to the map_value function inside
351
+ # check_equal function.
352
+ #
353
+ # 3. Mutation of ShapeEnv on some not recorded function.
354
+ # If a mutation of the state of ShapeEnv happens inside a function
355
+ # that is not recorded (or that no caller in the stack is recorded),
356
+ # then, the replayed ShapeEnv won't catch that.
357
+ # Solution: decorate the function with record_shape_env_event.
358
+
359
+
360
+ # Checks whether the state of two ShapeEnv are equal w.r.t. the guards
361
+ # returned by ShapeEnv.produce_guards.
362
+ def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value):
363
+ # Collect and remove variables that don't necessarily represent the state
364
+ # of a ShapeEnv. Note: we copy the dictionary so that we don't modify the
365
+ # instance itself.
366
+ env1_vars = vars(env1).copy()
367
+ env2_vars = vars(env2).copy()
368
+
369
+ for v in non_state_variable_names:
370
+ if v in env1_vars:
371
+ env1_vars.pop(v)
372
+ if v in env2_vars:
373
+ env2_vars.pop(v)
374
+
375
+ # Function for transforming the mismatched values into string.
376
+ # Needed, since dict and set entries order might not be the same every time.
377
+ def value_to_str(value: Any) -> str:
378
+ if isinstance(value, dict):
379
+ return (
380
+ "{"
381
+ + ", ".join(f"{k}: {value[k]}" for k in sorted(value.keys(), key=str))
382
+ + "}"
383
+ )
384
+ if isinstance(value, set):
385
+ return "{" + ", ".join(f"{v}" for v in sorted(value)) + "}"
386
+ return str(value)
387
+
388
+ # Compares env1_vars with env2_vars.
389
+ # Here, we allow the value of each field to be mapped, so that we appropriately
390
+ # compare the two values.
391
+ def compare_vars(
392
+ map_value: Callable[[str, Any], Any]
393
+ ) -> List[Tuple[str, str, str]]:
394
+ env1_set, env2_set = set(env1_vars), set(env2_vars)
395
+
396
+ # First, compare the set of keys in each vars dictionary.
397
+ if env1_set != env2_set:
398
+ raise NotEqualError(
399
+ "field set mismatch:",
400
+ [
401
+ (
402
+ "found unique fields:",
403
+ str(sorted(env1_set - env2_set)),
404
+ str(sorted(env2_set - env1_set)),
405
+ ),
406
+ ],
407
+ )
408
+
409
+ # Then, sort the keys, and compare the mapped values of each key.
410
+ sorted_keys = list(env1_set)
411
+ sorted_keys.sort()
412
+
413
+ mapped_dict = [
414
+ (k, map_value(k, env1_vars[k]), map_value(k, env2_vars[k]))
415
+ for k in sorted_keys
416
+ ]
417
+
418
+ # Return a list of tuples representing the fields that did not match
419
+ # alongside their respective mapped values.
420
+ return [
421
+ (f"{k}: values don't match.", value_to_str(val1), value_to_str(val2))
422
+ for k, val1, val2 in mapped_dict
423
+ if val1 != val2
424
+ ]
425
+
426
+ # Accumulate the mismatching fields.
427
+ errors = compare_vars(map_value)
428
+
429
+ if len(errors) > 0:
430
+ raise NotEqualError("field values don't match:", errors)
431
+
432
+
433
+ class NotEqualError(Exception):
434
+ def __init__(
435
+ self,
436
+ msg: str,
437
+ mismatched: List[Tuple[str, str, str]],
438
+ ) -> None:
439
+ details = "\n".join(
440
+ [
441
+ "\n".join(
442
+ [
443
+ f"==> {inner_msg}",
444
+ f" > Left: {str1}",
445
+ f" > Right: {str2}",
446
+ ]
447
+ )
448
+ for inner_msg, str1, str2 in mismatched
449
+ ]
450
+ )
451
+
452
+ super().__init__(
453
+ f"""\
454
+ ShapeEnv not equal: {msg}
455
+
456
+ {details}
457
+ """
458
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/refinement_types.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Equality:
2
+ def __init__(self, lhs, rhs):
3
+ self.lhs = lhs
4
+ self.rhs = rhs
5
+
6
+ def __str__(self):
7
+ return f'{self.lhs} = {self.rhs}'
8
+
9
+ def __repr__(self):
10
+ return f'{self.lhs} = {self.rhs}'
11
+
12
+ def __eq__(self, other):
13
+ if isinstance(other, Equality):
14
+ return self.lhs == other.lhs and self.rhs == other.rhs
15
+ else:
16
+ return False