koichi12 commited on
Commit
1c399ca
·
verified ·
1 Parent(s): e7b25d3

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/gen_example.py +28 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/exported_program.py +50 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/non_strict_utils.py +258 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/verifier.py +416 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/closure.py +134 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/computation.py +26 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/debug.py +21 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/ts_backend.py +6 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/error.cpython-311.pyc +0 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/graphs.cpython-311.pyc +0 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__init__.py +11 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__pycache__/__init__.cpython-311.pyc +0 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-311.pyc +0 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-311.pyc +0 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-311.pyc +0 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-311.pyc +0 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-311.pyc +0 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-311.pyc +0 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-311.pyc +0 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py +557 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py +1279 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py +1040 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py +348 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-311.pyc +0 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-311.pyc +0 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-311.pyc +0 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-311.pyc +0 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py +125 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-311.pyc +0 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-311.pyc +0 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/annotate_getitem_nodes.py +44 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__init__.py +0 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__init__.py +0 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-311.pyc +0 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/cse_pass.py +112 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/fake_tensor_prop.py +73 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_drawer.py +421 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-311.pyc +0 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/partitioner.py +329 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/operator_support.py +217 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/param_fetch.py +66 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/shape_prop.py +195 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/split_utils.py +302 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-311.pyc +0 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/common.py +95 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/matcher_utils.py +400 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/pool.cpython-311.pyc +0 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_expanded_weights.cpython-311.pyc +0 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__init__.py +87 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/gen_example.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import torch._export.db.examples as examples
5
+
6
+ TEMPLATE = '''import torch
7
+
8
+ from torch._export.db.case import export_case
9
+
10
+
11
+ @export_case(
12
+ example_inputs=(torch.randn(3, 2),),
13
+ tags={{}},
14
+ )
15
+ def {case_name}(x):
16
+ """
17
+ """
18
+
19
+ return
20
+ '''
21
+
22
+ if __name__ == "__main__":
23
+ assert len(sys.argv) == 2
24
+ root_dir = examples.__name__.replace(".", "/")
25
+ assert os.path.exists(root_dir)
26
+ with open(os.path.join(root_dir, sys.argv[1] + ".py"), "w") as f:
27
+ print("Writing to", f.name, "...")
28
+ f.write(TEMPLATE.format(case_name=sys.argv[1]))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/exported_program.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+
4
+ import torch
5
+ import torch.fx
6
+
7
+
8
+ # TODO(ycao): This is added to avoid breaking existing code temporarily.
9
+ # Remove when migration is done.
10
+ from torch.export.graph_signature import (
11
+ ExportBackwardSignature,
12
+ ExportGraphSignature,
13
+ )
14
+
15
+ from torch.export.exported_program import (
16
+ ExportedProgram,
17
+ ModuleCallEntry,
18
+ ModuleCallSignature,
19
+ )
20
+
21
+
22
+
23
+ __all__ = [
24
+ "ExportBackwardSignature",
25
+ "ExportGraphSignature",
26
+ "ExportedProgram",
27
+ "ModuleCallEntry",
28
+ "ModuleCallSignature",
29
+ ]
30
+
31
+
32
+ def _create_graph_module_for_export(root, graph):
33
+ try:
34
+ gm = torch.fx.GraphModule(root, graph)
35
+ except SyntaxError:
36
+ # If custom objects stored in memory are being used in the graph,
37
+ # the generated python code will result in a syntax error on the custom
38
+ # object, since it is unable to parse the in-memory object. However
39
+ # we can still run the graph eagerly through torch.fx.Interpreter,
40
+ # so we will bypass this error.
41
+ warnings.warn(
42
+ "Unable to execute the generated python source code from "
43
+ "the graph. The graph module will no longer be directly callable, "
44
+ "but you can still run the ExportedProgram, and if needed, you can "
45
+ "run the graph module eagerly using torch.fx.Interpreter."
46
+ )
47
+ gm = torch.fx.GraphModule(root, torch.fx.Graph())
48
+ gm._graph = graph
49
+
50
+ return gm
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/non_strict_utils.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from collections import defaultdict
3
+ from typing import Any, Callable, Dict, List, Tuple, Union
4
+
5
+ import torch
6
+ from torch._dynamo.source import (
7
+ AttrSource,
8
+ GetItemSource,
9
+ LocalSource,
10
+ TensorProperty,
11
+ TensorPropertySource,
12
+ )
13
+ from torch._dynamo.variables.builder import TrackedFake
14
+ from torch._export.passes.add_runtime_assertions_for_constraints_pass import InputDim
15
+ from torch._guards import Source
16
+ from torch._subclasses.fake_tensor import FakeTensorMode
17
+ from torch.export import Constraint
18
+ from torch.export.graph_signature import CustomObjArgument
19
+ from torch.fx.experimental.symbolic_shapes import (
20
+ ConstraintViolationError,
21
+ DimDynamic,
22
+ EqualityConstraint,
23
+ ShapeEnv,
24
+ StatelessSymbolicContext,
25
+ )
26
+ from torch.utils._pytree import (
27
+ GetAttrKey,
28
+ KeyPath,
29
+ MappingKey,
30
+ SequenceKey,
31
+ tree_map_with_path,
32
+ )
33
+
34
+
35
+ def key_path_to_source(kp: KeyPath) -> Source:
36
+ """
37
+ Given a key path, return the source for the key path.
38
+ """
39
+ source: Source = LocalSource("args")
40
+ for k in kp:
41
+ if isinstance(k, SequenceKey):
42
+ source = GetItemSource(source, k.idx)
43
+ elif isinstance(k, MappingKey):
44
+ source = GetItemSource(source, k.key)
45
+ elif isinstance(k, GetAttrKey):
46
+ source = AttrSource(source, k.name)
47
+ else:
48
+ raise ValueError(f"Unknown KeyEntry {k}")
49
+
50
+ return source
51
+
52
+
53
+ def _is_constant_argument(t):
54
+ return t is None or isinstance(t, (int, float, bool, str))
55
+
56
+
57
+ def fakify(
58
+ mode: FakeTensorMode,
59
+ kp: KeyPath,
60
+ t: Any,
61
+ t_constraints: Dict[int, Dict[int, Constraint]],
62
+ sources: Dict[Tuple[int, int], List[Source]],
63
+ ):
64
+ source = key_path_to_source(kp)
65
+ if _is_constant_argument(t) or isinstance(t, torch.ScriptObject):
66
+ return t
67
+ if not isinstance(t, torch.Tensor):
68
+ raise ValueError(f"Unsupported input type {type(t)}")
69
+ n_dims = len(t.shape)
70
+ symbolic_context = StatelessSymbolicContext(
71
+ dynamic_sizes=[DimDynamic.STATIC] * n_dims,
72
+ constraint_sizes=[None] * n_dims,
73
+ )
74
+ t_id = id(t)
75
+ if t_id in t_constraints:
76
+ for i, constraint in t_constraints[t_id].items():
77
+ symbolic_context.constraint_sizes[i] = constraint.constraint_range
78
+ symbolic_context.dynamic_sizes[i] = DimDynamic.DYNAMIC
79
+ src = TensorPropertySource(base=source, prop=TensorProperty.SIZE, idx=i)
80
+ sources[(t_id, i)].append(src)
81
+ mode.shape_env.source_name_to_debug_name[src.name()] = constraint.debug_name
82
+ fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context)
83
+ mode.shape_env.tracked_fakes.append(TrackedFake(fake, source, symbolic_context))
84
+ return fake
85
+
86
+
87
+ def make_fake_params_buffers(
88
+ fake_mode: FakeTensorMode,
89
+ params_buffers: Dict[str, torch.Tensor],
90
+ ) -> Dict[str, Union[torch.Tensor, torch.nn.Parameter]]:
91
+ faked_params_buffers = {}
92
+ for key, value in params_buffers.items():
93
+ faked_params_buffers[key] = fake_mode.from_tensor(value, static_shapes=True)
94
+ return faked_params_buffers
95
+
96
+
97
+ def make_fake_inputs(nn_module, args, kwargs, constraints):
98
+ """
99
+ Given an nn module, example inputs, and constraints, return a new fake mode,
100
+ fake inputs created in that mode whose dynamic shape dimensions are constrained
101
+ by the given ranges, and sources for pairs of dynamic shape dimensions that are
102
+ constrained to be equal.
103
+ """
104
+ # TODO(avik): refactor Dynamo to avoid duplication of the following code
105
+ # between non-strict and strict.
106
+ # Specifically, here (non-strict) we do the following pre-tracing steps:
107
+ # - Fakify inputs.
108
+ # - Process input shape equalities.
109
+ # In strict, these steps are spread across multiple files:
110
+ # - output_graph.py fakifies inputs.
111
+ # - [post-tracing] guards.py processes input shape equalities.
112
+
113
+ t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict)
114
+ for constraint in constraints:
115
+ t_constraints[constraint.t_id][constraint.dim] = constraint
116
+ if constraint.shared is not None:
117
+ t_constraints[constraint.shared.t_id][constraint.shared.dim] = constraint
118
+
119
+ code = nn_module.forward.__code__
120
+ co_fields = {
121
+ "co_name": code.co_name,
122
+ "co_filename": code.co_filename,
123
+ "co_firstlineno": code.co_firstlineno,
124
+ }
125
+
126
+ fake_mode = FakeTensorMode(
127
+ shape_env=ShapeEnv(tracked_fakes=[], co_fields=co_fields),
128
+ allow_non_fake_inputs=True,
129
+ )
130
+ if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None:
131
+ raise ValueError(
132
+ "Detected fake_mode does not have a shape_env with tracked fakes. "
133
+ "If you constructed the module under a FakeTensorMode, "
134
+ "please initialize it like: FakeTensorMode(shape_env=ShapeEnv(tracked_fakes=[]))"
135
+ )
136
+
137
+ with fake_mode:
138
+ original_signature = inspect.signature(nn_module.forward)
139
+ sources: Dict[Tuple[int, int], List[Source]] = defaultdict(list)
140
+ fake_args, fake_kwargs = tree_map_with_path(
141
+ lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources),
142
+ (args, kwargs),
143
+ )
144
+
145
+ from sympy import Symbol
146
+
147
+ source_pairs: List[Tuple[Source, Source]] = []
148
+ derived_equalities: List[Tuple[Source, Union[Source, Symbol], Callable]] = []
149
+ phantom_symbols: Dict[str, Symbol] = {}
150
+ for constraint in constraints:
151
+ torch.export.dynamic_shapes._process_equalities(
152
+ constraint,
153
+ lambda t_id, dim: sources[(t_id, dim)],
154
+ fake_mode.shape_env,
155
+ source_pairs,
156
+ derived_equalities,
157
+ phantom_symbols,
158
+ )
159
+
160
+ equalities_inputs = EqualityConstraint(
161
+ source_pairs=source_pairs,
162
+ derived_equalities=derived_equalities,
163
+ phantom_symbols=list(phantom_symbols.values()),
164
+ warn_only=False,
165
+ )
166
+ return fake_mode, fake_args, fake_kwargs, equalities_inputs, original_signature
167
+
168
+
169
+ def make_constraints(
170
+ fake_mode,
171
+ equalities_inputs,
172
+ original_signature,
173
+ gm,
174
+ ):
175
+ """
176
+ Given a fake mode, sources pairs corresponding to equal dynamic shape dimensions,
177
+ and a graph module, produce guards on the fake mode's shape env (raising constraint
178
+ violations if any), solve (to suggest simplifications or fixes), and return the
179
+ resulting range constraints and equality constraints.
180
+ """
181
+ # TODO(avik): refactor Dynamo to avoid duplication of the following code
182
+ # between non-strict and strict.
183
+ # Specifically, here (non-strict) we do the following post-tracing steps:
184
+ # - Produce guards.
185
+ # - Solve constraints.
186
+ # - Install shape metadata in IR.
187
+ # In strict, these steps are spread across multiple files:
188
+ # - guards.py produces guards.
189
+ # - eval_frame.py solves constraints
190
+ # - _trace.py installs shape metadata in IR.
191
+
192
+ shape_env = fake_mode.shape_env
193
+ placeholders = [tf.fake for tf in shape_env.tracked_fakes]
194
+ sources = [tf.source for tf in shape_env.tracked_fakes]
195
+ input_contexts = [tf.symbolic_context for tf in shape_env.tracked_fakes]
196
+ constraint_violation_error = None
197
+ try:
198
+ shape_env.produce_guards(
199
+ placeholders,
200
+ sources,
201
+ input_contexts=input_contexts,
202
+ equalities_inputs=equalities_inputs,
203
+ ignore_static=False,
204
+ )
205
+ except ConstraintViolationError as e:
206
+ constraint_violation_error = e
207
+
208
+ shape_env.frozen = True
209
+ dim_constraints = shape_env.dim_constraints
210
+ if dim_constraints is None:
211
+ # Expected when shape_env.produce_guards throws an early constraint violation error.
212
+ # There is nothing to solve for in this case.
213
+ # TODO(avik): Maybe record the constraint violation error instead and replay later?
214
+ assert constraint_violation_error
215
+ raise constraint_violation_error
216
+ dim_constraints.solve()
217
+ dim_constraints.remove_redundant_dynamic_results()
218
+ forced_specializations = dim_constraints.forced_specializations()
219
+ msg = dim_constraints.prettify_results(
220
+ original_signature, constraint_violation_error, forced_specializations
221
+ )
222
+ if constraint_violation_error:
223
+ constraint_violation_error.args = (constraint_violation_error.args[0] + msg,)
224
+ elif forced_specializations:
225
+ constraint_violation_error = ConstraintViolationError(msg)
226
+ if constraint_violation_error:
227
+ raise constraint_violation_error
228
+
229
+ range_constraints = {}
230
+ input_dims = defaultdict(list)
231
+ free_symbols = set()
232
+ for node in gm.graph.nodes:
233
+ if node.op != "placeholder":
234
+ continue
235
+ if _is_constant_argument(node.meta["val"]) or isinstance(
236
+ node.meta["val"], CustomObjArgument
237
+ ):
238
+ continue
239
+ for i, d in enumerate(node.meta["val"].shape):
240
+ if isinstance(d, torch.SymInt):
241
+ # Look up the range constraint for the symbol corresponding to this shape dimension
242
+ # and store it indexed by the symbolic expression corresponding to it.
243
+ # NOTE(avik): Use node._expr instead of node.expr for the lookup here because
244
+ # we want the symbol, not its replacement, which could be an expression. Maybe
245
+ # there's a better way to do this, e.g., by (re)computing value ranges for expressions?
246
+ range_constraints[d.node.expr] = shape_env.var_to_range[d.node._expr]
247
+ input_dims[d.node.expr].append(InputDim(input_name=node.name, dim=i))
248
+ free_symbols.update(d.node.expr.free_symbols)
249
+
250
+ for symbol in free_symbols:
251
+ if symbol not in range_constraints:
252
+ # Placeholders can have symbolic shapes that are derived expressions.
253
+ # The above code will record direct range constraints for them
254
+ # so that we can do runtime assertions. In addition, for serde checks
255
+ # we want to record range constraints for their root symbols.
256
+ range_constraints[symbol] = shape_env.var_to_range[symbol]
257
+
258
+ return range_constraints
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/verifier.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ import operator
4
+ from collections.abc import Iterable
5
+ from typing import Any, Dict, final, List, Optional, Tuple, Type
6
+
7
+ import torch
8
+ from torch._ops import HigherOrderOperator, OpOverload
9
+ from torch._subclasses.fake_tensor import FakeTensor
10
+ from torch.export.exported_program import ExportedProgram
11
+ from torch.export.graph_signature import (
12
+ CustomObjArgument,
13
+ InputKind,
14
+ SymIntArgument,
15
+ TensorArgument,
16
+ )
17
+ from torch.fx import GraphModule
18
+ from torch.fx.experimental.symbolic_shapes import SymBool, SymFloat, SymInt
19
+
20
+
21
+ class SpecViolationError(Exception):
22
+ pass
23
+
24
+
25
+ def is_functional(op: OpOverload) -> bool:
26
+ return not op._schema.is_mutable
27
+
28
+
29
+ def _check_has_fake_tensor(node: torch.fx.Node) -> None:
30
+ # TODO(angelayi): remove this in favor of _check_val
31
+ return _check_val(node)
32
+
33
+
34
+ def _check_val(node: torch.fx.Node) -> None:
35
+ def _check_correct_val(val):
36
+ if val is None:
37
+ return True
38
+ elif isinstance(val, (int, bool, str, float)):
39
+ return True
40
+ elif isinstance(val, (torch.memory_format, torch.dtype, torch.device, torch.layout)):
41
+ return True
42
+ elif isinstance(val, (FakeTensor, torch.Tensor)): # TODO(zhxchen17) Remove Tensor.
43
+ return True
44
+ elif isinstance(val, (SymInt, SymFloat, SymBool)):
45
+ return True
46
+ elif isinstance(val, CustomObjArgument):
47
+ return True
48
+ elif isinstance(val, Iterable):
49
+ return all(_check_correct_val(x) for x in val)
50
+ return False
51
+
52
+ def _no_returns(op):
53
+ if not isinstance(op, OpOverload):
54
+ return False
55
+ return len(op._schema.returns) == 0
56
+
57
+ if "val" not in node.meta:
58
+ if node.op == "call_function" and _no_returns(node.target):
59
+ return
60
+ raise SpecViolationError(f"Node.meta {node.name} is missing val field.")
61
+
62
+ val = node.meta["val"]
63
+ if not _check_correct_val(val):
64
+ raise SpecViolationError(f"Node.meta {node.name} has invalid val field {val}")
65
+
66
+
67
+ class _VerifierMeta(type):
68
+ _registry: Dict[str, Type['Verifier']] = {}
69
+
70
+ def __new__(metacls, name, bases, attrs):
71
+ if bases:
72
+ if "check" in attrs or "_check_graph_module" in attrs:
73
+ raise SyntaxError("Overriding method check is not allowed.")
74
+ assert "dialect" in attrs and attrs["dialect"] != "ATEN"
75
+ else:
76
+ assert "check" in attrs
77
+ assert "_check_graph_module" in attrs
78
+ assert attrs["dialect"] == "ATEN"
79
+
80
+ assert isinstance(attrs["dialect"], str)
81
+ ret = type.__new__(metacls, name, bases, attrs)
82
+ metacls._registry[attrs["dialect"]] = ret # type: ignore[assignment]
83
+ return ret
84
+
85
+ def getattr_recursive(obj: Any, target: str) -> Any:
86
+ target_atoms = target.split('.')
87
+ attr_itr = obj
88
+ for i, atom in enumerate(target_atoms):
89
+ if not hasattr(attr_itr, atom):
90
+ raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
91
+ attr_itr = getattr(attr_itr, atom)
92
+ return attr_itr
93
+
94
+
95
+ class Verifier(metaclass=_VerifierMeta):
96
+ dialect = "ATEN"
97
+
98
+ def allowed_builtin_ops(self) -> List:
99
+ return [
100
+ operator.getitem,
101
+ operator.add,
102
+ operator.mul,
103
+ operator.sub,
104
+ operator.truediv,
105
+ operator.ge,
106
+ operator.le,
107
+ operator.gt,
108
+ operator.lt,
109
+ operator.eq,
110
+ operator.ne,
111
+ operator.floordiv,
112
+ operator.mod,
113
+ operator.and_,
114
+ operator.or_,
115
+ operator.not_,
116
+ operator.pow,
117
+ operator.neg,
118
+ operator.abs,
119
+ math.ceil,
120
+ math.floor,
121
+ ]
122
+
123
+ def allowed_op_types(self) -> Tuple[Type[Any], ...]:
124
+ return (OpOverload, HigherOrderOperator)
125
+
126
+ def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
127
+ return (torch.fx.GraphModule,)
128
+
129
+ def check_valid_op(self, op):
130
+ pass
131
+
132
+ def check_additional(self, gm: GraphModule) -> None:
133
+ """
134
+ Additional checks that are specific to some dialects.
135
+ """
136
+ pass
137
+
138
+ @final
139
+ def check(self, ep: ExportedProgram) -> None:
140
+ self._check_graph_module(ep.graph_module)
141
+ _verify_exported_program_signature(ep)
142
+
143
+ @final
144
+ def _check_graph_module(self, gm: torch.fx.GraphModule) -> None:
145
+ def _allowed_getattr_types() -> Tuple[Type[Any], ...]:
146
+ ret = self.allowed_getattr_types()
147
+ assert not any(t is object for t in ret)
148
+ return ret
149
+
150
+ def _check_valid_op(op) -> None:
151
+ def _allowed_builtin_ops() -> List:
152
+ ret = self.allowed_builtin_ops()
153
+ assert all(inspect.isbuiltin(op) for op in ret)
154
+ return ret
155
+
156
+ def _allowed_op_types() -> Tuple[Type[Any], ...]:
157
+ ret = self.allowed_op_types()
158
+ assert not any(t is object for t in ret)
159
+ return ret
160
+
161
+ # TODO Remove this allowlist.
162
+ _allowed_torch_functions = (
163
+ torch.autograd.grad_mode.set_grad_enabled,
164
+ torch.sym_int,
165
+ torch.sym_ite,
166
+ torch.sym_max,
167
+ torch.sym_min,
168
+ torch.sym_not,
169
+ torch.sym_sqrt,
170
+ # TODO (tmanlaibaatar)
171
+ # Predispatch export is able to contain autograd ops.
172
+ # These will be modeled as HOO later
173
+ torch._C._set_grad_enabled
174
+
175
+ )
176
+
177
+ if not isinstance(op, _allowed_op_types()):
178
+ if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions:
179
+ raise SpecViolationError(
180
+ f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n"
181
+ f"Valid builtin ops: {_allowed_builtin_ops()}"
182
+ f"Valid torch functions: {_allowed_torch_functions}"
183
+ )
184
+
185
+ if isinstance(op, OpOverload):
186
+ # All ops functional
187
+ if not is_functional(op):
188
+ raise SpecViolationError(
189
+ f"operator '{op}' is not functional"
190
+ )
191
+ self.check_valid_op(op)
192
+
193
+ for mod in gm.modules():
194
+ if not isinstance(mod, torch.fx.GraphModule):
195
+ continue
196
+
197
+ mod.graph.lint()
198
+ for node in mod.graph.nodes:
199
+ # TODO(T140410192): should have fake tensor for all dialects
200
+ if node.op in {"call_module", "call_method"}:
201
+ raise SpecViolationError(
202
+ f"call_module is not valid: got a class '{node.target}' ",
203
+ )
204
+
205
+ elif node.op == "call_function":
206
+ _check_val(node)
207
+
208
+ _check_valid_op(node.target)
209
+
210
+ elif node.op == "get_attr":
211
+ if not isinstance(node.target, str):
212
+ raise SpecViolationError(
213
+ f"Expected get_attr target to be string, but got {type(node.target)}"
214
+ )
215
+
216
+ attr = getattr_recursive(mod, node.target)
217
+ if isinstance(attr, torch.nn.Module):
218
+ def _is_type(name, ty):
219
+ return isinstance(getattr(attr, name, None), ty)
220
+ if type(attr).__name__ == "LoweredBackendModule":
221
+ if _is_type("backend_id", str) \
222
+ and _is_type("processed_bytes", bytes) \
223
+ and _is_type("compile_specs", list) \
224
+ and hasattr(attr, "original_module"):
225
+ continue
226
+ else:
227
+ backend_id = getattr(attr, "backend_id", None)
228
+ processed_bytes = getattr(attr, "processed_bytes", None)
229
+ compile_specs = getattr(attr, "compile_specs", None)
230
+ raise SpecViolationError(
231
+ f"Invalid get_attr type {type(attr)}. \n"
232
+ f"LoweredBackendModule fields: "
233
+ f"backend_id(str) : {type(backend_id)}, "
234
+ f"processed_bytes(bytes) : {type(processed_bytes)}, "
235
+ f"compile_specs(list) : {type(compile_specs)}"
236
+ )
237
+
238
+ if not isinstance(attr, _allowed_getattr_types()):
239
+ raise SpecViolationError(
240
+ f"Invalid get_attr type {type(attr)}. \n"
241
+ f"Valid get_attr types: {_allowed_getattr_types()}"
242
+ )
243
+
244
+
245
+ elif node.op == "placeholder":
246
+ _check_val(node)
247
+ # TODO(zhxchen17)
248
+ # elif node.op == "output":
249
+ # _check_flattened_outputs()
250
+
251
+ self.check_additional(gm)
252
+
253
+
254
+ def _verify_exported_program_signature(exported_program) -> None:
255
+ # Check ExportedProgram signature matches
256
+ gs = exported_program.graph_signature
257
+
258
+ # Check every node in the signature exists in the graph
259
+ input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"]
260
+
261
+ if len(input_node_names) != len(gs.input_specs):
262
+ raise SpecViolationError(
263
+ f"Number of graph inputs ({len(input_node_names)}) "
264
+ f"does not match number of inputs in the graph signature ({len(gs.user_inputs)})"
265
+ )
266
+
267
+ for input_spec, node in zip(gs.input_specs, input_node_names):
268
+ if isinstance(input_spec.arg, (TensorArgument, SymIntArgument)):
269
+ if input_spec.arg.name != node:
270
+ raise SpecViolationError(
271
+ f"Input spec name {input_spec.arg.name} does not match node name {node}"
272
+ )
273
+
274
+ if input_spec.kind == InputKind.USER_INPUT:
275
+ continue
276
+
277
+ elif input_spec.kind == InputKind.PARAMETER:
278
+ if not isinstance(input_spec.arg, TensorArgument):
279
+ raise SpecViolationError(
280
+ f"Parameter {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
281
+ )
282
+ if input_spec.target is None:
283
+ raise SpecViolationError(
284
+ f"InputSpec for {input_spec.name} has no target."
285
+ )
286
+
287
+ param = input_spec.target
288
+ if param not in exported_program.state_dict:
289
+ raise SpecViolationError(
290
+ f"Parameter {param} is not in the state dict."
291
+ )
292
+
293
+ if not isinstance(exported_program.state_dict[param], torch.nn.Parameter):
294
+ raise SpecViolationError(
295
+ f"State dict entry for parameter {param} is not an instance of torch.nn.Parameter."
296
+ )
297
+
298
+ elif input_spec.kind == InputKind.BUFFER:
299
+ if not isinstance(input_spec.arg, TensorArgument):
300
+ raise SpecViolationError(
301
+ f"Buffer {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
302
+ )
303
+ if input_spec.target is None:
304
+ raise SpecViolationError(
305
+ f"InputSpec for {input_spec.name} has no target."
306
+ )
307
+
308
+ buffer = input_spec.target
309
+ if input_spec.persistent is None:
310
+ raise SpecViolationError(
311
+ f"Buffer {buffer} is missing a persistence flag"
312
+ )
313
+
314
+ if input_spec.persistent is True and buffer not in exported_program.state_dict:
315
+ raise SpecViolationError(
316
+ f"Buffer {buffer} is not in the state dict."
317
+ )
318
+
319
+ if input_spec.persistent is False and buffer in exported_program.state_dict:
320
+ raise SpecViolationError(
321
+ f"Non-persistent buffer {buffer} is in the state dict, it should not be."
322
+ )
323
+ elif input_spec.kind == InputKind.CONSTANT_TENSOR:
324
+ if not isinstance(input_spec.arg, TensorArgument):
325
+ raise SpecViolationError(
326
+ f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
327
+ )
328
+ if input_spec.target is None:
329
+ raise SpecViolationError(
330
+ f"InputSpec for {input_spec.name} has no target."
331
+ )
332
+
333
+ tensor_const = input_spec.target
334
+ if tensor_const not in exported_program.constants:
335
+ raise SpecViolationError(
336
+ f"Constant tensor {tensor_const} is not in the constants dictionary."
337
+ )
338
+ elif input_spec.kind == InputKind.CUSTOM_OBJ:
339
+ if not isinstance(input_spec.arg, CustomObjArgument):
340
+ raise SpecViolationError(
341
+ f"Custom object {input_spec.name} is not a custom object argument. Found {input_spec.arg} instead."
342
+ )
343
+ if input_spec.target is None:
344
+ raise SpecViolationError(
345
+ f"InputSpec for {input_spec.name} has no target."
346
+ )
347
+
348
+ custom_obj = input_spec.target
349
+ if custom_obj not in exported_program.constants:
350
+ raise SpecViolationError(
351
+ f"Custom object {custom_obj} is not in the constants dictionary."
352
+ )
353
+ elif input_spec.kind == InputKind.TOKEN:
354
+ if not isinstance(input_spec.arg, TensorArgument):
355
+ raise SpecViolationError(
356
+ f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
357
+ )
358
+ else:
359
+ raise SpecViolationError(
360
+ f"Unknown InputKind {input_spec.kind}."
361
+ )
362
+
363
+ # Check outputs
364
+ output_node = list(exported_program.graph.nodes)[-1]
365
+ assert output_node.op == "output"
366
+ output_nodes = [
367
+ arg.name if isinstance(arg, torch.fx.Node) else arg
368
+ for arg in output_node.args[0]
369
+ ]
370
+
371
+ if len(output_nodes) != len(gs.output_specs):
372
+ raise SpecViolationError(
373
+ f"Number of output nodes {len(output_nodes)} is different "
374
+ "Than the number of outputs specified by the graph signature: \n"
375
+ f"Number of mutated buffers: {len(gs.buffers_to_mutate)}. \n"
376
+ f"Number of user outputs: {len(gs.user_outputs)}. \n"
377
+ )
378
+
379
+ num_tokens = len(gs.output_tokens)
380
+ end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens
381
+ mutate_nodes: List[str] = output_nodes[num_tokens:end]
382
+ user_output_nodes = output_nodes[end:end + len(gs.user_outputs)]
383
+
384
+ for mutation_node in mutate_nodes:
385
+ if mutation_node in gs.buffers_to_mutate:
386
+ if gs.buffers_to_mutate[mutation_node] not in gs.buffers:
387
+ raise SpecViolationError(
388
+ f"Buffer output {mutation_node} does not point to a buffer that exists. \n"
389
+ f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n"
390
+ f"Buffer nodes available: {gs.buffers} \n"
391
+ )
392
+ elif mutation_node in gs.user_inputs_to_mutate:
393
+ if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs:
394
+ raise SpecViolationError(
395
+ f"User input output {mutation_node} does not point to a user input that exists. \n"
396
+ f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n"
397
+ f"User input nodes available: {gs.user_inputs} \n")
398
+ else:
399
+ raise SpecViolationError(
400
+ f"Mutation node {mutation_node} is neither a buffer nor a user input. "
401
+ f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}"
402
+ )
403
+
404
+ for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs):
405
+ if user_output_node != user_output_name:
406
+ raise SpecViolationError(
407
+ f"User output {user_output_node} is not in the correct "
408
+ "order or is not found in the "
409
+ f"exported program's user_output list: {gs.user_outputs}. "
410
+ )
411
+
412
+
413
+ def load_verifier(dialect: str) -> Optional[Type[Verifier]]:
414
+ if dialect == "ATEN":
415
+ return _VerifierMeta._registry.get(dialect)
416
+ return _VerifierMeta._registry[dialect]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/closure.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import threading
3
+ from queue import Empty as EmptyQueue, Queue
4
+
5
+ from torch._lazy.device_context import get_device_context
6
+
7
+
8
+ class ClosureHandler:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def run(self, closure):
13
+ """Run closure function
14
+
15
+ Args:
16
+ closure: callable function to run
17
+ """
18
+ closure()
19
+
20
+ def __call__(self, closures):
21
+ for closure in closures:
22
+ self.run(closure)
23
+
24
+
25
+ class AsyncClosureHandler(ClosureHandler):
26
+ """Handler for Asynchronous Step Closures
27
+ Args:
28
+ max_queue_size: The maximum length of the closure queue after which
29
+ the training loop will block until closures are evaluated.
30
+ By default, a reasonable limit of a maximum of 100 on the queue.
31
+ This value can be set using the `XLA_MAX_ASYNC_QUEUE` environment
32
+ variable.
33
+ """
34
+
35
+ def __init__(self, max_queue_size=100):
36
+ super().__init__()
37
+ self._closure_queue: Queue = Queue(
38
+ int(os.environ.get("LTC_MAX_ASYNC_QUEUE", max_queue_size))
39
+ )
40
+ self._closure_exception: Queue = Queue()
41
+ self._closure_lock = threading.Lock()
42
+ self._closure_event_loop_finished = threading.Event()
43
+ self._closure_event_loop = None
44
+
45
+ def start_event_loop(self):
46
+ """Start closure event loop if not started"""
47
+ if self._closure_event_loop is None:
48
+
49
+ def event_loop():
50
+ # Run loop until closure event is set and closure queue is empty
51
+ while True:
52
+ try:
53
+ closure = self._closure_queue.get(block=True, timeout=3)
54
+ closure()
55
+ self._closure_queue.task_done()
56
+ except EmptyQueue:
57
+ with self._closure_lock:
58
+ if self._closure_queue.empty():
59
+ self._closure_event_loop_finished.set()
60
+ return
61
+ except Exception as e:
62
+ self._closure_exception.put(e)
63
+ return
64
+
65
+ self._closure_event_loop = threading.Thread(target=event_loop)
66
+ self._closure_event_loop.start()
67
+
68
+ def run(self, closure):
69
+ with self._closure_lock:
70
+ self._closure_queue.put(closure, block=True)
71
+ if (
72
+ self._closure_event_loop is None
73
+ or not self._closure_event_loop.is_alive()
74
+ ):
75
+ try:
76
+ e = self._closure_exception.get(block=False)
77
+ raise RuntimeError(
78
+ "Cannot run asynchronous closure due to previously raised exception"
79
+ ) from e
80
+ except EmptyQueue:
81
+ self._closure_event_loop = None
82
+ self.start_event_loop()
83
+
84
+
85
+ def add_step_closure(closure, args=(), run_async=False):
86
+ """Adds a closure to the list of the ones to be run at the end of the step.
87
+ Many times during model training there is the need to print/report (print to
88
+ console, post to tensorboard, etc...) information which require the content of
89
+ intermediary tensors to be inspected.
90
+ Inspecting different tensors content in different points of the model code
91
+ requires many executions and typically causes performance issues.
92
+ Adding a step closure will ensure that it will be run after the barrier, when
93
+ all the live tensors will be already materialized to device data.
94
+ Live tensors which will include the ones captured by the closure arguments.
95
+ So using `add_step_closure()` will ensure a single execution will be
96
+ performed, even when multiple closures are queued, requiring multiple tensors
97
+ to be inspected.
98
+ Step closures will be run sequentially in the order they have been queued.
99
+ Note that even though using this API the execution will be optimized, it is
100
+ advised to throttle the printing/reporting events once every N steps.
101
+ Args:
102
+ closure (callable): The function to be called.
103
+ args (tuple): The arguments to be passed to the closure.
104
+ run_async: If True, run the closure asynchronously.
105
+ """
106
+ devctx = get_device_context()
107
+ closures_type = "async_step_closures" if run_async else "step_closures"
108
+ step_closures = getattr(devctx, closures_type, None)
109
+ if step_closures is None:
110
+ step_closures = []
111
+ setattr(devctx, closures_type, step_closures)
112
+ step_closures.append(lambda a=args: closure(*a))
113
+
114
+
115
+ def run_step_closures():
116
+ devctx = get_device_context()
117
+ async_step_closures = getattr(devctx, "async_step_closures", None)
118
+ if async_step_closures is not None:
119
+ devctx.async_step_closures = []
120
+ async_closure_handler = getattr(devctx, "async_closure_handler", None)
121
+ if async_closure_handler is None:
122
+ async_closure_handler = AsyncClosureHandler()
123
+ devctx.async_closure_handler = async_closure_handler
124
+ async_closure_handler(async_step_closures)
125
+
126
+ step_closures = getattr(devctx, "step_closures", None)
127
+ if step_closures is not None:
128
+ devctx.step_closures = []
129
+ closure_handler = getattr(devctx, "closure_handler", None)
130
+ if closure_handler is None:
131
+ closure_handler = ClosureHandler()
132
+ devctx.closure_handler = closure_handler
133
+ closure_handler(step_closures)
134
+ return devctx
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/computation.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch._C._lazy
2
+ import torch._C._lazy_ts_backend
3
+
4
+
5
+ def get_tensors_ts_device_data_node(tensors):
6
+ """Return tensor ids and eager tensors for DeviceData nodes in the
7
+ IR for the passed in lazy tensors.
8
+
9
+ TODO: This API is currently ts backend specific. We are working on
10
+ generalizing it to all backends including XLA.
11
+ """
12
+ return torch._C._lazy_ts_backend._get_tensors_ts_device_data_node(tensors)
13
+
14
+
15
+ def get_graph_hash(tensors):
16
+ """Return the graph hash for the passed in lazy tensors"""
17
+ return torch._C._lazy._get_graph_hash(tensors)
18
+
19
+
20
+ def run_cached_graph(hash_str, graph_inputs):
21
+ """Running the cached computation graph with the given inputs
22
+
23
+ TODO: This API is currently ts backend specific. We are working on
24
+ generalizing it to all backends including XLA.
25
+ """
26
+ return torch._C._lazy_ts_backend._run_cached_graph(hash_str, graph_inputs)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/debug.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch._C._lazy
2
+
3
+
4
+ def render_ir_graph(tensors):
5
+ """Return a text dump of the LTC IR graph in dot format for the tensors.
6
+ The text can be processed by tools like dot to be rendered in pdf,png etc."""
7
+ return torch._C._lazy._get_tensors_dot(tensors)
8
+
9
+
10
+ def dump_ir(tensors, ir_format):
11
+ """Return a dump of the tensors in the specified format.
12
+ Valid format are
13
+ - text: for LTC IR
14
+ - backend: for the activate backend IR
15
+ """
16
+ if ir_format == "text":
17
+ return torch._C._lazy._get_tensors_text(tensors)
18
+ elif ir_format == "backend":
19
+ return torch._C._lazy._get_tensors_backend(tensors)
20
+ else:
21
+ raise RuntimeError(f"Unrecognized IR format: {ir_format}")
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/ts_backend.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import torch._C._lazy_ts_backend
2
+
3
+
4
+ def init():
5
+ """Initializes the lazy Torchscript backend"""
6
+ torch._C._lazy_ts_backend._init()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/error.cpython-311.pyc ADDED
Binary file (208 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/graphs.cpython-311.pyc ADDED
Binary file (29 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .autocast_mode import autocast, custom_bwd, custom_fwd
2
+ from .common import amp_definitely_not_available
3
+ from .grad_scaler import GradScaler
4
+
5
+ __all__ = [
6
+ "amp_definitely_not_available",
7
+ "autocast",
8
+ "custom_bwd",
9
+ "custom_fwd",
10
+ "GradScaler",
11
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (525 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (244 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-311.pyc ADDED
Binary file (28.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-311.pyc ADDED
Binary file (72.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-311.pyc ADDED
Binary file (52.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-311.pyc ADDED
Binary file (549 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-311.pyc ADDED
Binary file (16.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-311.pyc ADDED
Binary file (2.46 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \
2
+ op_mod, op_gt, op_lt, op_neq, op_eq
3
+ from torch.fx.tensor_type import TensorType, Dyn
4
+
5
+
6
+ class Constraint:
7
+ pass
8
+
9
+
10
+ class Conj(Constraint):
11
+ def __init__(self, conjuncts):
12
+ """
13
+ :param conjuncts: Conjunction of constraints
14
+ """
15
+ self.conjucts = conjuncts
16
+
17
+ def __eq__(self, other):
18
+ if isinstance(other, Conj):
19
+ return self.conjucts == other.conjucts and self.conjucts == other.conjucts
20
+ else:
21
+ return False
22
+
23
+ def __repr__(self):
24
+ return f'And({self.conjucts})'
25
+
26
+
27
+ class Disj(Constraint):
28
+ def __init__(self, disjuncts):
29
+ """
30
+ :param disjuncts: Disjunction of constraints
31
+ """
32
+ self.disjuncts = disjuncts
33
+
34
+ def __eq__(self, other):
35
+ if isinstance(other, Disj):
36
+ return self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts
37
+ else:
38
+ return False
39
+
40
+ def __repr__(self):
41
+ return f'Or({self.disjuncts})'
42
+
43
+
44
+ class Prod(Constraint):
45
+ def __init__(self, products):
46
+ """
47
+ :param products: lists of dimensions to multiply
48
+ """
49
+ self.products = products
50
+
51
+ def __eq__(self, other):
52
+ if isinstance(other, Prod):
53
+ return self.products == other.products and self.products == other.products
54
+ else:
55
+ return False
56
+
57
+ def __repr__(self):
58
+ return f'Product({self.products})'
59
+
60
+
61
+ class T(Constraint):
62
+ """
63
+ True
64
+ """
65
+ def __init__(self):
66
+ pass
67
+
68
+ def __eq__(self, other):
69
+ return isinstance(other, T)
70
+
71
+ def __repr__(self):
72
+ return 'True'
73
+
74
+ class F(Constraint):
75
+ """
76
+ False
77
+ """
78
+ def __init__(self):
79
+ pass
80
+
81
+ def __eq__(self, other):
82
+ return isinstance(other, F)
83
+
84
+ def __repr__(self):
85
+ return 'False'
86
+
87
+
88
+ class BinaryConstraint(Constraint):
89
+ """
90
+ Represents all binary operations
91
+ """
92
+ def __init__(self, lhs, rhs, op):
93
+ """
94
+ :param lhs: lhs of the constraint
95
+ :param rhs: rhs of the constraint
96
+ :param op: string representing the operation
97
+ """
98
+ self.lhs = lhs
99
+ self.rhs = rhs
100
+ self.op = op
101
+
102
+ def __eq__(self, other):
103
+ if isinstance(other, BinaryConstraint):
104
+ return self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op
105
+ else:
106
+ return False
107
+
108
+ def __repr__(self):
109
+ return f'({self.lhs} {self.op} {self.rhs})'
110
+
111
+
112
+ class BinConstraintT(BinaryConstraint):
113
+ """
114
+ Binary constraints about tensors
115
+ """
116
+ def __init__(self, lhs, rhs, op):
117
+ assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and \
118
+ (isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn)
119
+ super().__init__(lhs, rhs, op)
120
+
121
+ def __eq__(self, other):
122
+ return super().__eq__(other)
123
+
124
+
125
+ class BinConstraintD(BinaryConstraint):
126
+ """
127
+ Binary constraints about dimensions
128
+ """
129
+ def __init__(self, lhs, rhs, op):
130
+ assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs)
131
+ assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs)
132
+
133
+ super().__init__(lhs, rhs, op)
134
+
135
+ def __eq__(self, other):
136
+ return super().__eq__(other)
137
+
138
+
139
+
140
+ class TGreatestUpperBound(Constraint):
141
+ """
142
+ Greatest Upper bound for tensors with dynamic type
143
+ """
144
+ def __init__(self, res, rhs1, rhs2):
145
+ """
146
+ :param res: tensor variable that stores the result of the outout
147
+ :param rhs1: tensor or tensor variable
148
+ :param rhs2: tensor or tensor variabke
149
+ """
150
+ self.res = res
151
+ self.rhs1 = rhs1
152
+ self.rhs2 = rhs2
153
+
154
+ def __repr__(self):
155
+ return f'{self.res} = {self.rhs1}⊔*{self.rhs2}'
156
+
157
+ def __eq__(self, other):
158
+ if isinstance(other, TGreatestUpperBound):
159
+ return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
160
+ else:
161
+ return False
162
+
163
+
164
+ class DGreatestUpperBound(Constraint):
165
+ """
166
+ Greatest Upper bound for dimensions
167
+ """
168
+ def __init__(self, res, rhs1, rhs2):
169
+ """
170
+ :param res: Dimension variable to store the result
171
+ :param rhs1: dimension variable 1
172
+ :param rhs2: dimension variable 2
173
+ """
174
+ assert is_dim(res)
175
+ assert is_dim(rhs1)
176
+ assert is_dim(rhs2)
177
+
178
+ self.res = res
179
+ self.rhs1 = rhs1
180
+ self.rhs2 = rhs2
181
+
182
+ def __repr__(self):
183
+ return f'{self.res} = {self.rhs1}⊔{self.rhs2}'
184
+
185
+ def __eq__(self, other):
186
+ if isinstance(other, DGreatestUpperBound):
187
+ return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
188
+ else:
189
+ return False
190
+
191
+
192
+ class CanReshape(Constraint):
193
+ """
194
+ can_reshape constraint
195
+ """
196
+ def __init__(self, src, target):
197
+ """
198
+ :param src: tensor variable
199
+ :param target: tensor
200
+ """
201
+ self.src = src
202
+ self.target = target
203
+
204
+ def __repr__(self):
205
+ return f'can-reshape({self.src}, {self.target})'
206
+
207
+ def __eq__(self, other):
208
+ if isinstance(other, CanReshape):
209
+ return self.src == other.src and self.target == other.target
210
+ else:
211
+ return False
212
+
213
+
214
+ class IndexSelect(Constraint):
215
+
216
+ def __init__(self, tensor_size, input_var, dim_replace, index, output):
217
+ """
218
+ Args:
219
+ input_var: input to index_select
220
+ tensor_size: tensor size we are considering
221
+ dim_replace: the dimension of the output at "index"
222
+ index: location of the dimensions to replace in the input
223
+ output: variable to store the result
224
+ """
225
+ assert isinstance(input_var, TVar)
226
+ assert isinstance(output, TVar)
227
+ assert isinstance(dim_replace, DVar) or dim_replace == Dyn
228
+ assert isinstance(index, int)
229
+
230
+ self.input_var = input_var
231
+ self.tensor_size = tensor_size
232
+ self.dim_replace = dim_replace
233
+ self.index = index
234
+ self.output = output
235
+
236
+ def __repr__(self):
237
+
238
+ return f' {self.output} = ' \
239
+ f'IndexSelect({self.input_var}, ' \
240
+ f'tensor_size: {self.tensor_size}, ' \
241
+ f'{self.dim_replace}, ' \
242
+ f'{self.index})'
243
+
244
+ def __eq__(self, other):
245
+ if isinstance(other, IndexSelect):
246
+ return self.tensor_size == other.tensor_size and \
247
+ self.dim_replace == other.dim_replace and \
248
+ self.index == other.index and \
249
+ self.output == other.output and \
250
+ self.input_var == other.input_var
251
+ else:
252
+ return False
253
+
254
+
255
+ class Transpose(Constraint):
256
+
257
+ def __init__(self, tensor_size, input_var, index1, index2, output):
258
+ """
259
+ Args:
260
+ tensor_size: current tensor size
261
+ input_var: variable to hold input
262
+ index1: dimension 1
263
+ index2: dimension 2
264
+ output: output that stores result
265
+ """
266
+ assert isinstance(input_var, TVar)
267
+ assert isinstance(output, TVar)
268
+ assert isinstance(index1, int)
269
+ assert isinstance(index2, int)
270
+
271
+ self.input_var = input_var
272
+ self.tensor_size = tensor_size
273
+ self.index1 = index1
274
+ self.index2 = index2
275
+ self.output = output
276
+
277
+ def __repr__(self):
278
+
279
+ return f' {self.output} = ' \
280
+ f'Transpose({self.input_var}, ' \
281
+ f'tensor_size: {self.tensor_size}, ' \
282
+ f'{self.index1}, ' \
283
+ f'{self.index2})'
284
+
285
+ def __eq__(self, other):
286
+ if isinstance(other, Transpose):
287
+ return self.tensor_size == other.tensor_size and \
288
+ self.index1 == other.index1 and \
289
+ self.index2 == other.index2 and \
290
+ self.output == other.output and \
291
+ self.input_var == other.input_var
292
+ else:
293
+ return False
294
+
295
+
296
+ class GetItem(Constraint):
297
+
298
+ def __init__(self, tensor_size, index, res, input_var):
299
+ """
300
+ Constraint for getting item given a tensor size
301
+ :param tensor_size: actual number
302
+ :param index: actual number representing the index
303
+ :param res: dimension variable to carry the item we get
304
+ :param input_var: a tensor variable from which we will get item
305
+ """
306
+ assert isinstance(res, DVar)
307
+
308
+ self.res = res
309
+ self.tensor_size = tensor_size
310
+ self.index = index
311
+ self.input_var = input_var
312
+
313
+ def __repr__(self):
314
+ return f' {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})'
315
+
316
+ def __eq__(self, other):
317
+ if isinstance(other, GetItem):
318
+ return self.res == other.res and \
319
+ self.tensor_size == other.tensor_size and \
320
+ self.index == other.index and \
321
+ self.input_var == other.input_var
322
+ else:
323
+ return False
324
+
325
+ class GetItemTensor(Constraint):
326
+
327
+ def __init__(self, tensor_size, index_tuple, res, input_var):
328
+ """
329
+ Constraint for getting item given a tensor size
330
+ However, when the argument is a tuple, we will
331
+ expect a tensor
332
+ :param tensor_size: actual number representing the rank
333
+ :param index_tuple: tuple for indexing
334
+ :param res: tensor variable to carry the item we get
335
+ :param input_var: a tensor variable from which we will get item
336
+ """
337
+ assert isinstance(res, TVar)
338
+
339
+ self.res = res
340
+ self.tensor_size = tensor_size
341
+ self.index_tuple = index_tuple
342
+ self.input_var = input_var
343
+
344
+ def __repr__(self):
345
+ return f' {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})'
346
+
347
+ def __eq__(self, other):
348
+ if isinstance(other, GetItemTensor):
349
+ return self.res == other.res and \
350
+ self.tensor_size == other.tensor_size and \
351
+ self.index_tuple == other.index_tuple and \
352
+ self.input_var == other.input_var
353
+ else:
354
+ return False
355
+
356
+ class CalcConv(Constraint):
357
+
358
+ def __init__(self, conv_result, input_var, c_out, kernel, padding, stride, dilation, matching_constraint_vars):
359
+ """
360
+ :param conv_result: the convolution result
361
+ :param input_var: input to convolution
362
+ :param c_out: output chanel type
363
+ :param kernel: kernel tuple
364
+ """
365
+ self.conv_result = conv_result
366
+ self.input_var = input_var
367
+ self.c_out = c_out
368
+ self.kernel = kernel
369
+ self.padding = padding
370
+ self.stride = stride
371
+ self.dilation = dilation
372
+ self.matching_constraint = matching_constraint_vars
373
+
374
+ def __repr__(self):
375
+ return f'{self.conv_result} =' \
376
+ f' calc-conv({self.input_var},' \
377
+ f' {self.c_out}, {self.kernel}, ' \
378
+ f'{self.padding}, {self.stride},' \
379
+ f' {self.dilation})'
380
+
381
+ def __eq__(self, other):
382
+ if isinstance(other, CalcConv):
383
+ return self.conv_result == other.conv_result and self.input_var == other.input_var and \
384
+ self.c_out == other.c_out and self.kernel == other.kernel and self.padding == other.padding \
385
+ and self.stride == other.stride and self.dilation == other.dilation \
386
+ and self.matching_constraint == other.matching_constraint
387
+ else:
388
+ return False
389
+
390
+
391
+ class CalcMaxPool(Constraint):
392
+
393
+ def __init__(self, maxpool_result, input_var, kernel, padding, stride, dilation, matching_constraint_vars):
394
+ """
395
+ :param maxpool_result: the result of maxpool
396
+ :param input_var: input to convolution
397
+ :param kernel: kernel tuple
398
+ """
399
+ self.maxpool_result = maxpool_result
400
+ self.input_var = input_var
401
+ self.kernel = kernel
402
+ self.padding = padding
403
+ self.stride = stride
404
+ self.dilation = dilation
405
+ self.matching_constraint = matching_constraint_vars
406
+
407
+ def __repr__(self):
408
+ return f'{self.maxpool_result} =' \
409
+ f' calc-maxpool({self.input_var},' \
410
+ f' {self.kernel}, ' \
411
+ f'{self.padding}, {self.stride},' \
412
+ f' {self.dilation})'
413
+
414
+ def __eq__(self, other):
415
+ if isinstance(other, CalcMaxPool):
416
+ return self.maxpool_result == other.maxpool_result and self.input_var == other.input_var \
417
+ and self.kernel == other.kernel and self.padding == other.padding \
418
+ and self.stride == other.stride and self.dilation == other.dilation \
419
+ and self.matching_constraint == other.matching_constraint
420
+ else:
421
+ return False
422
+
423
+
424
+ class ApplyBroadcasting(Constraint):
425
+ def __init__(self, res1, res2, input1, input2):
426
+ """
427
+ :param res1: resulting tensor 1
428
+ :param res2: resulting tensor 2
429
+ :param input1: tensor variable 1
430
+ :param input2: tensor variable 2
431
+ """
432
+ self.res1 = res1
433
+ self.res2 = res2
434
+ self.input1 = input1
435
+ self.input2 = input2
436
+
437
+ def __eq__(self, other):
438
+ if isinstance(other, ApplyBroadcasting):
439
+ return self.res1 == other.res1 \
440
+ and self.res2 == other.res2 \
441
+ and self.input1 == other.input1 \
442
+ and self.input2 == other.input2
443
+ else:
444
+ return False
445
+
446
+ def __repr__(self):
447
+ return f'{self.res1}, {self.res2} ='f' apply-broadcasting({self.input1},' f' {self.input2})'
448
+
449
+
450
+ class CalcProduct(Constraint):
451
+ """
452
+ Given correct dimensions, calculate the product for flatten accounting for Dyn
453
+ """
454
+ def __init__(self, start, end, flattened, dims_to_flatten):
455
+ """
456
+ :param start: start index
457
+ :param end: end index
458
+ :param flattened: variable to store the product
459
+ :param dims_to_flatten: the type which we will flatten
460
+ """
461
+ assert isinstance(dims_to_flatten, list)
462
+ assert isinstance(flattened, TVar)
463
+ assert isinstance(start, int)
464
+ assert isinstance(end, int)
465
+
466
+ self.start = start
467
+ self.end = end
468
+ self.dims_to_flatten = dims_to_flatten
469
+ self.flattened = flattened
470
+
471
+ def __eq__(self, other):
472
+ if isinstance(other, CalcProduct):
473
+ return self.start == other.start and self.end == other.end and \
474
+ self.dims_to_flatten == other.dims_to_flatten and self.flattened == other.flattened
475
+
476
+ else:
477
+ return False
478
+
479
+ def __repr__(self):
480
+ return f'{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})'
481
+
482
+
483
+ class TVar:
484
+ """
485
+ Tensor variable with no tensor constructor
486
+ """
487
+ def __init__(self, tvar):
488
+ """
489
+ :param tvar: tensor variable
490
+ """
491
+ self.tvar = tvar
492
+
493
+ def __repr__(self):
494
+ return f'TV({self.tvar})'
495
+
496
+ def __eq__(self, other):
497
+ if isinstance(other, TVar):
498
+ return self.tvar == other.tvar
499
+ else:
500
+ return False
501
+
502
+
503
+ class DVar:
504
+ """
505
+ Dimension variable
506
+ """
507
+ def __init__(self, c):
508
+ """
509
+ :param c: character or number
510
+ """
511
+ self.c = c
512
+
513
+ def __repr__(self):
514
+ return f'DV({self.c})'
515
+
516
+ def __eq__(self, other):
517
+ if isinstance(other, DVar):
518
+ return self.c == other.c
519
+ else:
520
+ return False
521
+
522
+
523
+ class BVar:
524
+ """
525
+ Boolean variable
526
+ """
527
+ def __init__(self, c):
528
+ """
529
+ :param c: character or number
530
+ """
531
+ self.c = c
532
+
533
+ def __repr__(self):
534
+ return f'BV({self.c})'
535
+
536
+ def __eq__(self, other):
537
+ if isinstance(other, BVar):
538
+ return self.c == other.c
539
+ else:
540
+ return False
541
+
542
+
543
+ def is_algebraic_expression(constraint):
544
+ if isinstance(constraint, BinConstraintD):
545
+ return constraint.op in [op_add, op_sub, op_div, op_mul, op_mod]
546
+ else:
547
+ return isinstance(constraint, Prod)
548
+
549
+
550
+ def is_bool_expr(constraint):
551
+ if isinstance(constraint, BinConstraintD):
552
+ return constraint.op in [op_gt, op_lt, op_neq, op_eq]
553
+ else:
554
+ return isinstance(constraint, (BVar, Conj, Disj))
555
+
556
+ def is_dim(d):
557
+ return isinstance(d, (DVar, int)) or d == Dyn
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py ADDED
@@ -0,0 +1,1279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import operator
3
+ import warnings
4
+ from typing import Callable, Dict, Iterable
5
+
6
+ from torch.fx._symbolic_trace import _assert_is_none
7
+ from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, CalcProduct, \
8
+ Disj, TGreatestUpperBound, CalcMaxPool, CalcConv, Conj, BinConstraintT, CanReshape, BinConstraintD, GetItem, T, F, \
9
+ TVar, DVar, GetItemTensor, IndexSelect, Transpose, DGreatestUpperBound
10
+ from torch.fx.experimental.migrate_gradual_types.operation import \
11
+ op_eq, op_matching, op_consistency, op_leq, op_precision, op_gt, op_div, op_sub, op_neq, op_lt, op_add, op_mul
12
+ from torch.fx.node import Target, Node
13
+ from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar, gen_tvar, \
14
+ gen_bvar
15
+
16
+ from torch.fx.tensor_type import Dyn, TensorType
17
+ from torch.nn.modules.conv import Conv2d
18
+ from torch.nn.modules.batchnorm import BatchNorm2d
19
+
20
+ _INFERENCE_RULES: Dict[Target, Callable] = {}
21
+
22
+ MAX_TENSOR_RANK = 4
23
+
24
+ def register_inference_rule(call_target):
25
+ def register(fn):
26
+ if call_target in _INFERENCE_RULES:
27
+ raise RuntimeError(f'Inference rule already registered for {call_target}!')
28
+ _INFERENCE_RULES[call_target] = fn
29
+ return fn
30
+ return register
31
+
32
+
33
+ def generate_flatten_constraints(start_dim, end_dim, input, flattened, n, counter):
34
+ d, counter = gen_tensor_dims(n, counter)
35
+ c1 = BinConstraintT(input, TensorType(d), op_eq)
36
+ start_dim = n if start_dim == -1 else abs(start_dim)
37
+ end_dim = n + end_dim + 1 if end_dim < 0 else end_dim + 1
38
+ c2 = CalcProduct(start_dim, end_dim, flattened, d)
39
+ nat_constraints = gen_nat_constraints(d)
40
+ return Conj([c1, c2, *nat_constraints]), counter
41
+
42
+
43
+ @register_inference_rule(getattr)
44
+ def get_attr_inference_rule(n: Node, symbols, constraints, counter):
45
+ """
46
+ If the attribute is "device" then the tensor shape is preserved
47
+ """
48
+ assert isinstance(n.args[0], Node)
49
+ assert isinstance(n.args[1], str)
50
+ output, counter = gen_tvar(counter)
51
+ symbols[n] = output
52
+
53
+ input = symbols[n.args[0]]
54
+ attr = n.args[1]
55
+
56
+ if attr == 'device':
57
+ return [BinConstraintT(input, output, op_eq)], counter
58
+ else:
59
+ raise NotImplementedError('Not yet implemented')
60
+
61
+ @register_inference_rule(torch.bmm)
62
+ def bmm_inference_rule(n: Node, symbols, constraints, counter):
63
+ """
64
+ Constraints that match the input to a size 3 tensor
65
+ and switch the dimensions according to the rules
66
+ of batch multiplication
67
+ """
68
+ assert isinstance(n.args[0], Node)
69
+ assert isinstance(n.args[1], Node)
70
+
71
+ bmm_output, counter = gen_tvar(counter)
72
+ symbols[n] = bmm_output
73
+
74
+ bmm_input1 = symbols[n.args[0]]
75
+ bmm_input2 = symbols[n.args[1]]
76
+
77
+ dims_input1, counter = gen_tensor_dims(3, counter)
78
+ dims_input2, counter = gen_tensor_dims(3, counter)
79
+
80
+ inputs_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq),
81
+ BinConstraintT(bmm_input2, Dyn, op_eq),
82
+ BinConstraintT(bmm_output, Dyn, op_eq)])
83
+
84
+ input1_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq),
85
+ BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
86
+ BinConstraintT(bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq)])
87
+
88
+ input2_dyn = Conj([BinConstraintT(bmm_input2, Dyn, op_eq),
89
+ BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
90
+ BinConstraintT(bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq)])
91
+
92
+ consistency_constraints = [BinConstraintD(dims_input1[0], dims_input2[0], op_consistency)]
93
+
94
+ batch_size, counter = gen_dvar(counter)
95
+
96
+ inputs_are_tensors = Conj([BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
97
+ BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
98
+ BinConstraintT(bmm_output, TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq),
99
+ *consistency_constraints, DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0])])
100
+
101
+ return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter
102
+
103
+
104
+ @register_inference_rule("index_select")
105
+ def index_select_inference_rule(n: Node, symbols, constraints, counter):
106
+ """
107
+ We constrain the second argument to a vector or Dyn.
108
+ The output replaces the input with the shape of the vector
109
+ at the position given by the index (first argument)
110
+ """
111
+ # print(n.args)
112
+ assert isinstance(n.args[0], Node)
113
+ assert isinstance(n.args[1], int)
114
+ assert isinstance(n.args[2], Node)
115
+
116
+
117
+
118
+ index_select, counter = gen_tvar(counter)
119
+ symbols[n] = index_select
120
+
121
+ dims, counter = gen_tensor_dims(1, counter)
122
+
123
+ # equality constraint
124
+ is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq)
125
+ is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq)
126
+
127
+ c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select)
128
+ for i in range(MAX_TENSOR_RANK)])])
129
+ c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select)
130
+ for i in range(MAX_TENSOR_RANK)])])
131
+
132
+ return [Disj([c2, c3])], counter
133
+
134
+
135
+ @register_inference_rule("expand")
136
+ def expand_inference_rule(n: Node, symbols, constraints, counter):
137
+ """
138
+ We generate the exact constraints as we do for tensor additions but we constraint
139
+ the rank of this expression to be equal to len(n.args[1:]) so that only
140
+ those cases get considered for the output
141
+ """
142
+ assert isinstance(n.args[0], Node)
143
+
144
+ # define the output for expand
145
+ expand, counter = gen_tvar(counter)
146
+ symbols[n] = expand
147
+
148
+ # since we do not have two nodes here, we will construct an argument variable
149
+ e1 = symbols[n.args[0]]
150
+ e2, counter = gen_tvar(counter)
151
+
152
+ e2_nat_constraints = []
153
+ for arg in n.args[1:]:
154
+ assert isinstance(arg, (Node, int))
155
+ if isinstance(arg, Node):
156
+ assert isinstance(symbols[arg], DVar)
157
+ e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq))
158
+
159
+ e2_constraint = BinConstraintT(e2, TensorType([arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]), op_eq)
160
+
161
+ constraints, counter = gen_broadcasting_constraints(e1, e2, symbols, counter, expand)
162
+
163
+ # constraint the output size
164
+ dims, counter = gen_tensor_dims(len(n.args[1:]), counter)
165
+ nat_constraints = gen_nat_constraints(dims)
166
+ c = [BinConstraintT(expand, TensorType(dims), op_eq), *nat_constraints, e2_constraint, *e2_nat_constraints]
167
+ constraints += c
168
+
169
+ return constraints, counter
170
+
171
+
172
+ @register_inference_rule(torch.nn.functional.gelu)
173
+ @register_inference_rule(torch.nn.functional.dropout)
174
+ @register_inference_rule(torch.nn.functional.softmax)
175
+ @register_inference_rule("detach")
176
+ @register_inference_rule("to")
177
+ @register_inference_rule("int")
178
+ @register_inference_rule("long")
179
+ @register_inference_rule("contiguous")
180
+ @register_inference_rule(torch.ones)
181
+ @register_inference_rule(torch.zeros)
182
+ def equality_inference_rule(n: Node, symbols, constraints, counter):
183
+ """
184
+ We generate the constraint: input = output
185
+ """
186
+ output, counter = gen_tvar(counter)
187
+ symbols[n] = output
188
+
189
+ if isinstance(n.args[0], Node):
190
+ input = symbols[n.args[0]]
191
+ if isinstance(input, TVar):
192
+ return [BinConstraintT(input, output, op_eq)], counter
193
+
194
+ # then we have dimension variables
195
+ else:
196
+ for arg in n.args:
197
+ assert isinstance(symbols[arg], DVar)
198
+ my_size = [symbols[arg] for arg in n.args]
199
+ return [BinConstraintT(output, TensorType(my_size), op_eq)], counter
200
+
201
+ elif isinstance(n.args[0], tuple):
202
+ # then the tuple is the size
203
+ assert len(n.args[0]) <= 4
204
+ my_size = [symbols[arg] for arg in n.args[0]]
205
+ return [BinConstraintT(output, TensorType(my_size), op_eq)], counter
206
+ else:
207
+ raise NotImplementedError('Method not yet implemented')
208
+
209
+
210
+ @register_inference_rule("transpose")
211
+ def transpose_inference_rule(n: Node, symbols, constraints, counter):
212
+ """
213
+ Can be considered as a sequence of two index selects, so we generate constraints accordingly
214
+ """
215
+ assert isinstance(n.args[0], Node)
216
+ assert isinstance(n.args[1], int)
217
+ assert isinstance(n.args[2], int)
218
+
219
+ output, counter = gen_tvar(counter)
220
+ symbols[n] = output
221
+
222
+ from_arg = symbols[n.args[0]]
223
+ assert isinstance(from_arg, TVar)
224
+
225
+ # input and output are dyn
226
+ is_dyn = Conj([BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)])
227
+
228
+ # or input is a tensor and we actually do the replacement
229
+ c3 = Disj([Transpose(i + 1, from_arg, n.args[1], n.args[2], output) for i in range(MAX_TENSOR_RANK)])
230
+
231
+ return [Disj([is_dyn, c3])], counter
232
+
233
+
234
+ @register_inference_rule("type_as")
235
+ def type_inference_rule(n: Node, symbols, constraints, counter):
236
+ """
237
+ We generate the constraint: input = output
238
+ """
239
+ assert isinstance(n.args[0], Node)
240
+ assert isinstance(n.args[1], Node)
241
+
242
+ output, counter = gen_tvar(counter)
243
+ symbols[n] = output
244
+
245
+ from_arg = symbols[n.args[0]]
246
+ to_arg = symbols[n.args[1]]
247
+
248
+ assert isinstance(from_arg, TVar)
249
+ assert isinstance(to_arg, TVar)
250
+
251
+ return [BinConstraintT(from_arg, to_arg, op_consistency),
252
+ BinConstraintT(output, to_arg, op_eq)], counter
253
+
254
+ @register_inference_rule("masked_fill_")
255
+ def masked_fill_inference_rule(n: Node, symbols, constraints, counter):
256
+ """
257
+ Similar to addition. For now we implement the constraints when
258
+ the argument is a boolean tensor. There is also a case for when
259
+ it is a condition. We will leave this out for now.
260
+ """
261
+
262
+ assert isinstance(n.args[0], Node)
263
+ assert isinstance(n.args[1], Node)
264
+
265
+ # We will retrieve the type variables from the symbol table
266
+ # and confirm they are tensor variables
267
+
268
+ e1 = symbols[n.args[0]]
269
+ e2 = symbols[n.args[1]]
270
+
271
+ if isinstance(e1, TVar) and isinstance(e2, TVar):
272
+ masked_fill_tensor, counter = gen_tvar(counter)
273
+ symbols[n] = masked_fill_tensor
274
+ return gen_broadcasting_constraints(e1, e2, symbols, counter, masked_fill_tensor)
275
+ else:
276
+ raise NotImplementedError('Not yet implemented')
277
+
278
+
279
+ @register_inference_rule(torch.nn.functional.embedding)
280
+ def embedding_inference_rule_functional(n: Node, symbols, constraints, counter):
281
+ assert isinstance(n.args[0], Node)
282
+
283
+ embedding_dim_weights = symbols[n.args[1]]
284
+
285
+ # will treat this as a static shape. So we will not use matching.
286
+ weight_dims, counter = gen_tensor_dims(2, counter)
287
+ equality_constraint = BinConstraintT(embedding_dim_weights, TensorType(weight_dims), op_eq)
288
+ embedding_dim = weight_dims[1]
289
+ constraints, counter = gen_embedding_rules(n, symbols, embedding_dim, counter)
290
+ return [equality_constraint] + constraints, counter
291
+
292
+
293
+ @register_inference_rule(torch.nn.modules.sparse.Embedding)
294
+ def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter):
295
+ """
296
+ The output shape differs from the input shape in the last dimension
297
+ """
298
+ assert isinstance(n.args[0], Node)
299
+ return gen_embedding_rules(n, symbols, module_instance.embedding_dim, counter)
300
+
301
+
302
+ def gen_embedding_rules(n: Node, symbols, embedding_dim, counter):
303
+
304
+ embedding_output, counter = gen_tvar(counter)
305
+ symbols[n] = embedding_output
306
+ embedding_input = symbols[n.args[0]]
307
+
308
+ input_dyn = BinConstraintT(embedding_input, Dyn, op_eq)
309
+ output_dyn = BinConstraintT(embedding_output, Dyn, op_eq)
310
+
311
+ c1 = Conj([input_dyn, output_dyn])
312
+ c2 = []
313
+
314
+ for i in range(1, MAX_TENSOR_RANK):
315
+ new_dims, counter = gen_tensor_dims(i, counter)
316
+ nat_constraints = gen_nat_constraints(new_dims)
317
+
318
+ # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases
319
+ c_tensor_i = Conj([BinConstraintT(embedding_input, TensorType(new_dims), op_eq),
320
+ BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq)] +
321
+ nat_constraints)
322
+ c2.append(c_tensor_i)
323
+
324
+ return [Disj([c1, Disj(c2)])], counter
325
+
326
+
327
+ @register_inference_rule(torch.tensor)
328
+ def tensor_inference_rule(n: Node, symbols, constraints, counter):
329
+ """
330
+ If the tensor is a scalar, we will skip it since we
331
+ do not support scalars yet. We will add support in the future
332
+ if it's needed. For our examples so far, scalars are not needed.
333
+ """
334
+ return [], counter
335
+
336
+
337
+ @register_inference_rule("reshape")
338
+ @register_inference_rule("view")
339
+ def view_inference_rule(n: Node, symbols, constraints, counter):
340
+ """
341
+ Similar to reshape but with an extra condition on the strides
342
+ """
343
+ assert isinstance(n.args[0], Node)
344
+
345
+ # generate the new variable
346
+ my_view, counter = gen_tvar(counter)
347
+ symbols[n] = my_view
348
+
349
+
350
+ src_var = symbols[n.args[0]]
351
+ t2 = [symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]] # target shape
352
+ t2_type = []
353
+ num_constraints = []
354
+
355
+ for t in t2:
356
+ if t == -1:
357
+ var, counter = gen_dvar(counter)
358
+ t2_type.append(var)
359
+ num_constraints.append(BinConstraintD(var, Dyn, op_neq))
360
+
361
+ else:
362
+ num_constraints.append(BinConstraintD(t, Dyn, op_neq))
363
+ t2_type.append(t)
364
+
365
+ t2_type = TensorType(t2_type) # type: ignore[assignment]
366
+
367
+ c1 = BinConstraintT(my_view, t2_type, op_eq)
368
+ c2 = CanReshape(src_var, t2_type)
369
+
370
+ # TODO: add the extra check mentioned here:
371
+ # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view
372
+
373
+ return [c1, c2] + num_constraints, counter # type: ignore[operator]
374
+
375
+
376
+ @register_inference_rule("size")
377
+ def size_inference_rule(n: Node, symbols, constraints, counter):
378
+ """
379
+ The constraint is just lhs = rhs.
380
+ Ex: size = input_ids.size()
381
+ """
382
+
383
+
384
+ if len(n.args) == 1:
385
+ # generate the new variable
386
+ size, counter = gen_tvar(counter)
387
+ symbols[n] = size
388
+ input = symbols[n.args[0]]
389
+ c = BinConstraintT(input, size, op_eq)
390
+ return [c], counter
391
+
392
+ elif len(n.args) == 2:
393
+ # TODO: review this rule; should input = dyn; output = dyn be included here?
394
+ if isinstance(n.args[1], int):
395
+ # generate the new variable
396
+ size_index, counter = gen_dvar(counter)
397
+ symbols[n] = size_index
398
+ input = symbols[n.args[0]]
399
+ c2 = [GetItem(i + 1, n.args[1], size_index, input) for i in range(MAX_TENSOR_RANK)]
400
+ c3 = BinConstraintD(0, size_index, op_leq)
401
+
402
+ input_dyn = BinConstraintT(input, Dyn, op_eq)
403
+ output_dyn = BinConstraintD(size_index, Dyn, op_eq)
404
+ c1 = Conj([input_dyn, output_dyn])
405
+
406
+ return [Disj([c1, Conj([Disj(c2), c3])])], counter
407
+
408
+ else:
409
+ raise NotImplementedError
410
+
411
+ else:
412
+ raise NotImplementedError
413
+
414
+
415
+ def range_check(i, n):
416
+ """
417
+ Checks if an index i is within range of a size n list
418
+ Args:
419
+ i: index
420
+ n: list size
421
+
422
+ Returns: Boolean
423
+ """
424
+ if i >= 0:
425
+ return T() if i < n else F()
426
+ else:
427
+ return T() if i >= n else F()
428
+
429
+
430
+ @register_inference_rule(torch.cumsum)
431
+ def cumsum_inference_rule(n: Node, symbols, constraints, counter):
432
+ """
433
+ Input and output shapes should be equal
434
+ We should verify that the index is valid
435
+ """
436
+ assert isinstance(n.args[0], Node)
437
+ arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"]
438
+ assert isinstance(arg_1, int)
439
+
440
+ output, counter = gen_tvar(counter)
441
+ symbols[n] = output
442
+ input = symbols[n.args[0]]
443
+
444
+ input_dyn = BinConstraintT(input, Dyn, op_eq)
445
+ output_dyn = BinConstraintT(output, Dyn, op_eq)
446
+ c1 = Conj([input_dyn, output_dyn])
447
+ c2 = []
448
+ for i in range(1, MAX_TENSOR_RANK + 1):
449
+ new_dims, counter = gen_tensor_dims(i, counter)
450
+
451
+ nat_constraints = gen_nat_constraints(new_dims)
452
+
453
+ c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims), op_eq),
454
+ BinConstraintT(output, TensorType(new_dims), op_eq)] +
455
+ [range_check(arg_1, i)] + nat_constraints)
456
+
457
+ c2.append(c_tensor_i)
458
+ dyn_or_tensor = Disj([c1, Disj(c2)])
459
+ return [dyn_or_tensor], counter
460
+
461
+
462
+ @register_inference_rule(_assert_is_none)
463
+ def assert_inference_rule(n: Node, symbols, constraints, counter):
464
+ assert len(n.users) == 0
465
+ return [], counter
466
+
467
+
468
+ @register_inference_rule(operator.getitem)
469
+ def getitem_inference_rule(n: Node, symbols, constraints, counter):
470
+ assert isinstance(n.args[0], Node)
471
+
472
+ # dimension output case
473
+ if isinstance(n.args[1], int):
474
+ # create and store the new dimension variable
475
+ get_item_output, counter = gen_dvar(counter)
476
+ symbols[n] = get_item_output
477
+
478
+ # retrieve arg variables
479
+ get_item_arg = symbols[n.args[0]]
480
+ assert isinstance(get_item_arg, TVar)
481
+
482
+
483
+ # if the input is dynamic, we accept any index and return
484
+ # a dynamic dimension as output
485
+ input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq)
486
+ output_dyn = BinConstraintD(get_item_output, Dyn, op_eq)
487
+ c1 = Conj([input_dyn, output_dyn])
488
+
489
+ # if the input is a tensor,
490
+ # generate a getItem constraint which will be expanded based on the
491
+ # tensor dimension.
492
+
493
+ c2 = [GetItem(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK)]
494
+
495
+
496
+ # since the output is a dimension, we make sure it's a natural number
497
+ # added as a conjunction to the disjunction of c2
498
+ c3 = BinConstraintD(0, get_item_output, op_leq)
499
+ return [Disj([c1, Conj([Disj(c2), c3])])], counter
500
+
501
+ # tensor output case
502
+ elif isinstance(n.args[1], tuple):
503
+ # create and store the new tensor variable
504
+ get_item_output, counter = gen_tvar(counter)
505
+ symbols[n] = get_item_output
506
+
507
+ # retrieve arg variables
508
+ if n.args[0] in symbols:
509
+ get_item_arg = symbols[n.args[0]]
510
+ assert isinstance(get_item_arg, TVar)
511
+
512
+ input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq)
513
+ output_dyn = BinConstraintT(get_item_output, Dyn, op_eq) # type: ignore[assignment]
514
+ c1 = Conj([input_dyn, output_dyn])
515
+
516
+ c2 = [GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) # type: ignore[misc]
517
+ for i in range(MAX_TENSOR_RANK)]
518
+ else:
519
+ # TODO: we should figure out why there is a key-error here.
520
+ return [], counter
521
+
522
+ return [Disj([c1, *c2])], counter
523
+
524
+ else:
525
+ raise RuntimeError('Method not yet implemented')
526
+
527
+
528
+ @register_inference_rule(operator.gt)
529
+ def gt_inference_rule(n: Node, symbols, constraints, counter):
530
+ assert isinstance(n.args[0], (Node, int))
531
+ assert isinstance(n.args[1], (Node, int))
532
+
533
+ # We make sure this node will not be used again. We do not
534
+ # generate a constraint about that node. Only about the operands.
535
+
536
+ e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
537
+ e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
538
+
539
+ if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
540
+ if isinstance(e1, TVar) and isinstance(e2, TVar):
541
+ gt_tensor, counter = gen_tvar(counter)
542
+ symbols[n] = gt_tensor
543
+ return gen_broadcasting_constraints(e1, e2, symbols, counter, gt_tensor)
544
+
545
+ elif isinstance(e1, DVar) and isinstance(e2, DVar):
546
+ # This is meant to be used for flow analysis only
547
+ gt_constraint = BinConstraintD(e1, e2, op_gt)
548
+
549
+ my_gt, counter = gen_bvar(counter)
550
+ equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
551
+ return [equality_constraint], counter
552
+
553
+ else:
554
+ raise RuntimeError('Sort Mismatch')
555
+
556
+ elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
557
+ if isinstance(e1, DVar):
558
+ # This is meant to be used for flow analysis only
559
+ gt_constraint = BinConstraintD(e1, e2, op_gt)
560
+
561
+ my_gt, counter = gen_bvar(counter)
562
+ equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
563
+ return [equality_constraint], counter
564
+
565
+ elif isinstance(e1, TVar) and isinstance(e2, int):
566
+ # then we made the wrong assumption about the argument being a tensor
567
+ # so we should fix the assumption
568
+ warnings.warn(f'Made the wrong assumption for node {n}. Correctness not guaranteed.')
569
+
570
+ new_e1, counter = gen_dvar(counter)
571
+ symbols[n.args[0]] = new_e1
572
+ symbols[n.args[0]]
573
+
574
+ gt_constraint = BinConstraintD(new_e1, e2, op_gt)
575
+
576
+ my_gt, counter = gen_bvar(counter)
577
+ equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
578
+ return [equality_constraint], counter
579
+
580
+ else:
581
+ raise NotImplementedError('Method not yet implemented')
582
+
583
+ else:
584
+ raise NotImplementedError('Method not yet implemented')
585
+
586
+
587
+ @register_inference_rule(operator.eq)
588
+ def eq_inference_rule(n: Node, symbols, constraints, counter):
589
+ assert isinstance(n.args[0], (Node, int))
590
+ assert isinstance(n.args[1], (Node, int))
591
+
592
+ e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
593
+ e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
594
+
595
+ if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
596
+ if isinstance(e1, TVar) and isinstance(e2, TVar):
597
+ eq_tensor, counter = gen_tvar(counter)
598
+ symbols[n] = eq_tensor
599
+ return gen_broadcasting_constraints(e1, e2, symbols, counter, eq_tensor)
600
+
601
+ elif isinstance(e1, DVar) and isinstance(e2, DVar):
602
+ # This is meant to be used for flow analysis only
603
+ eq_constraint = BinConstraintD(e1, e2, op_eq)
604
+
605
+ my_eq, counter = gen_bvar(counter)
606
+ equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq)
607
+ return [equality_constraint], counter
608
+
609
+ else:
610
+ raise RuntimeError('Sort Mismatch')
611
+
612
+ elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
613
+ if isinstance(e1, DVar):
614
+ # This is meant to be used for flow analysis only
615
+ eq_constraint = BinConstraintD(e1, e2, op_eq)
616
+
617
+ my_eq, counter = gen_bvar(counter)
618
+ equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq)
619
+ return [equality_constraint], counter
620
+ else:
621
+ raise NotImplementedError('Method not yet implemented')
622
+ else:
623
+ raise NotImplementedError('Method not yet implemented')
624
+
625
+ @register_inference_rule(operator.ne)
626
+ def neq_inference_rule(n: Node, symbols, constraints, counter):
627
+ """
628
+ Translates to inconsistent in gradual types.
629
+ To prove inequality, we should prove that
630
+ tensors are either different sizes or
631
+ disagree on at least one dimension
632
+
633
+ This is a WIP (works when the condition
634
+ is false. We are working on making this operation work
635
+ when the condition is true as well)
636
+ """
637
+ assert isinstance(n.args[0], Node)
638
+ assert isinstance(n.args[1], tuple)
639
+
640
+ # implementing for size 3 and 4
641
+ if len(n.args[1]) == 3:
642
+
643
+ assert isinstance(n.args[1][0], (Node, int))
644
+ assert isinstance(n.args[1][1], (Node, int))
645
+ assert isinstance(n.args[1][2], (Node, int))
646
+
647
+ lhs = symbols[n.args[0]]
648
+
649
+ b, counter = gen_tensor_dims(4, counter)
650
+ input_is_size3 = BinConstraintT(lhs, TensorType([b[0], b[1], b[2]]), op_eq)
651
+
652
+ d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]]
653
+ d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]]
654
+ d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]]
655
+
656
+ # dimensions not equal
657
+ my_ne, counter = gen_bvar(counter)
658
+ neq_1 = BinConstraintD(d1, b[0], op_neq)
659
+ neq_2 = BinConstraintD(d2, b[1], op_neq)
660
+ neq_3 = BinConstraintD(d3, b[2], op_neq)
661
+
662
+ # dimensions inconsistent
663
+ dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1])
664
+ dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2])
665
+ dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3])
666
+
667
+ dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3])
668
+
669
+ # we are covering size 3 and 4 only for now
670
+ ne_constraint = Conj([input_is_size3, dims_inconsistent])
671
+
672
+ my_ne, counter = gen_bvar(counter)
673
+ equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq)
674
+
675
+ elif len(n.args[1]) == 4:
676
+
677
+ assert isinstance(n.args[1][0], (Node, int))
678
+ assert isinstance(n.args[1][1], (Node, int))
679
+ assert isinstance(n.args[1][2], (Node, int))
680
+ assert isinstance(n.args[1][3], (Node, int))
681
+
682
+ lhs = symbols[n.args[0]]
683
+
684
+ b1, counter = gen_dvar(counter)
685
+ b2, counter = gen_dvar(counter)
686
+ b3, counter = gen_dvar(counter)
687
+ b4, counter = gen_dvar(counter)
688
+
689
+ input_is_size4 = BinConstraintT(lhs, TensorType([b1, b2, b3, b4]), op_eq)
690
+
691
+ d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]]
692
+ d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]]
693
+ d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]]
694
+ d4 = n.args[1][3] if isinstance(n.args[1][3], int) else symbols[n.args[1][3]]
695
+
696
+ # dimensions not equal
697
+ my_ne, counter = gen_bvar(counter)
698
+ neq_1 = BinConstraintD(d1, b1, op_neq)
699
+ neq_2 = BinConstraintD(d2, b2, op_neq)
700
+ neq_3 = BinConstraintD(d3, b3, op_neq)
701
+ neq_4 = BinConstraintD(d4, b4, op_neq)
702
+
703
+ # dimensions to inconsistent
704
+ dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1])
705
+ dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2])
706
+ dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3])
707
+ dims_inconsistent4 = Conj([BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4])
708
+
709
+ dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3, dims_inconsistent4])
710
+
711
+ ne_constraint = Conj([input_is_size4, dims_inconsistent])
712
+
713
+ my_ne, counter = gen_bvar(counter)
714
+
715
+ equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq)
716
+
717
+ else:
718
+ raise NotImplementedError('Method not yet implemented')
719
+
720
+ return [equality_constraint], counter
721
+
722
+
723
+ @register_inference_rule(operator.lt)
724
+ def lt_inference_rule(n: Node, symbols, constraints, counter):
725
+ assert isinstance(n.args[0], (Node, int))
726
+ assert isinstance(n.args[1], (Node, int))
727
+
728
+ # We make sure this node will not be used again. We do not
729
+ # generate a constraint about that node. Only about the operands.
730
+
731
+ e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
732
+ e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
733
+
734
+ if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
735
+ if isinstance(e1, TVar) and isinstance(e2, TVar):
736
+ lt_tensor, counter = gen_tvar(counter)
737
+ symbols[n] = lt_tensor
738
+ return gen_broadcasting_constraints(e1, e2, symbols, counter, lt_tensor)
739
+
740
+ elif isinstance(e1, DVar) and isinstance(e2, DVar):
741
+ # This is meant to be used for flow analysis only
742
+ lt_constraint = BinConstraintD(e1, e2, op_lt)
743
+
744
+ my_lt, counter = gen_bvar(counter)
745
+ equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq)
746
+ return [equality_constraint], counter
747
+
748
+ else:
749
+ raise RuntimeError('Sort Mismatch')
750
+
751
+ elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
752
+ if isinstance(e1, DVar):
753
+ # This is meant to be used for flow analysis only
754
+ lt_constraint = BinConstraintD(e1, e2, op_lt)
755
+
756
+ my_lt, counter = gen_bvar(counter)
757
+ equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq)
758
+ return [equality_constraint], counter
759
+ else:
760
+ raise NotImplementedError('Method not yet implemented')
761
+
762
+ else:
763
+ raise NotImplementedError('Method not yet implemented')
764
+
765
+
766
+ @register_inference_rule(torch.full)
767
+ def full_inference_rule(n: Node, symbols, constraints, counter):
768
+ full, counter = gen_tvar(counter)
769
+ symbols[n] = full
770
+ res = []
771
+
772
+ assert isinstance(n.args[0], Iterable)
773
+ for arg in n.args[0]:
774
+ dim = arg if isinstance(arg, int) else symbols[arg]
775
+ res.append(dim)
776
+ c = BinConstraintT(full, TensorType(list(res)), op_eq) # type: ignore[arg-type]
777
+ return [c], counter
778
+
779
+
780
+ # TODO normalize index
781
+ @register_inference_rule(torch.arange)
782
+ def arange_inference_rule(n: Node, symbols, constraints, counter):
783
+ start = 0
784
+ step = 1
785
+
786
+ if len(n.args) == 1:
787
+ end = symbols[n.args[0]]
788
+ else:
789
+ raise NotImplementedError('Not yet implemented')
790
+
791
+ # int((end - start) / step)
792
+ d1, counter = gen_dvar(counter)
793
+ size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq)
794
+ arange, counter = gen_tvar(counter)
795
+ symbols[n] = arange
796
+
797
+ # either the a parameter is a number or it is Dyn
798
+ c1 = Disj([BinConstraintD(end, Dyn, op_eq),
799
+ BinConstraintD(start, Dyn, op_eq),
800
+ BinConstraintD(step, Dyn, op_eq)])
801
+ c2 = BinConstraintD(d1, Dyn, op_eq)
802
+ both_dyn = Conj([c1, c2])
803
+
804
+ c11 = Conj([BinConstraintD(end, Dyn, op_neq),
805
+ BinConstraintD(start, Dyn, op_neq),
806
+ BinConstraintD(step, Dyn, op_neq)])
807
+ c22 = BinConstraintD(d1, Dyn, op_neq)
808
+ both_numbers = Conj([c11, c22, size_constraint])
809
+
810
+ return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter
811
+
812
+ def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var):
813
+ # additional vars that don't correspond to expressions
814
+ e11, counter = gen_tvar(counter)
815
+ e22, counter = gen_tvar(counter)
816
+
817
+ # generate constraints
818
+ c1 = TGreatestUpperBound(output_var, e11, e22)
819
+ c2 = ApplyBroadcasting(e11, e22, e1, e2)
820
+ c3 = BinConstraintT(e11, e22, op_consistency)
821
+ return [c1, c2, c3], counter
822
+
823
+
824
+ @register_inference_rule(operator.mul)
825
+ @register_inference_rule(torch.ne)
826
+ @register_inference_rule("ne")
827
+ @register_inference_rule(torch.add)
828
+ @register_inference_rule(operator.add)
829
+ def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
830
+
831
+ op_code = None
832
+ if n.target == operator.add or n.target == torch.add:
833
+ op_code = op_add
834
+ elif n.target == operator.mul:
835
+ op_code = op_mul
836
+
837
+ if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
838
+ if isinstance(symbols[n.args[0]], TVar) and isinstance(symbols[n.args[1]], TVar):
839
+ my_output, counter = gen_tvar(counter)
840
+ symbols[n] = my_output
841
+ e1 = symbols[n.args[0]]
842
+ e2 = symbols[n.args[1]]
843
+
844
+ return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output)
845
+ else:
846
+ raise NotImplementedError('Method not yet implemented')
847
+
848
+ elif isinstance(n.args[0], Node) and isinstance(n.args[1], (int, float)):
849
+ if isinstance(symbols[n.args[0]], TVar):
850
+ my_output, counter = gen_tvar(counter)
851
+ symbols[n] = my_output
852
+ e1 = symbols[n.args[0]]
853
+ return [BinConstraintT(my_output, e1, op_eq)], counter
854
+ elif isinstance(symbols[n.args[0]], DVar):
855
+ my_output, counter = gen_dvar(counter)
856
+ symbols[n] = my_output
857
+ e1 = symbols[n.args[0]]
858
+
859
+ # we will propagate the runtime value here since this is regular addition
860
+ c = Conj([BinConstraintD(my_output, BinConstraintD(e1, n.args[1], op_code), op_eq),
861
+ BinConstraintD(0, my_output, op_leq)])
862
+ return [c], counter
863
+
864
+ elif isinstance(n.args[1], Node) and isinstance(n.args[0], (int, float)):
865
+ if isinstance(symbols[n.args[1]], TVar):
866
+ my_output, counter = gen_tvar(counter)
867
+ symbols[n] = my_output
868
+ e2 = symbols[n.args[1]]
869
+ return [BinConstraintT(my_output, e2, op_eq)], counter
870
+ elif isinstance(symbols[n.args[1]], DVar):
871
+ my_output, counter = gen_dvar(counter)
872
+ symbols[n] = my_output
873
+ e2 = symbols[n.args[1]]
874
+
875
+ # we will propagate the runtime value here since this is regular addition
876
+ c = Conj([BinConstraintD(my_output, BinConstraintD(e2, n.args[0], op_code), op_eq),
877
+ BinConstraintD(0, my_output, op_leq)])
878
+ return [c], counter
879
+
880
+ else:
881
+ raise NotImplementedError('Method not yet implemented')
882
+
883
+ else:
884
+ # TODO generate add constraints for scalar addition
885
+ raise NotImplementedError('Addition not yet implemented')
886
+
887
+
888
+ @register_inference_rule(torch.flatten)
889
+ def flatten_inference_rule(n: Node, symbols, constraints, counter):
890
+ assert isinstance(n.args[0], Node)
891
+
892
+ # generate the new variable
893
+ flattened, counter = gen_tvar(counter)
894
+ symbols[n] = flattened
895
+
896
+ input = symbols[n.args[0]]
897
+
898
+ # set the default start and end dims
899
+ start_dim = 1
900
+ end_dim = -1
901
+
902
+ if len(n.args) > 1:
903
+ assert isinstance(n.args[1], int)
904
+ start_dim = n.args[1]
905
+
906
+ if len(n.args) > 2:
907
+ assert isinstance(n.args[2], int)
908
+ end_dim = n.args[2]
909
+
910
+ c1 = BinConstraintT(input, Dyn, op_eq)
911
+ c2 = BinConstraintT(flattened, Dyn, op_eq)
912
+ both_dyn = Conj([c1, c2])
913
+
914
+ const = []
915
+ for i in range(1, MAX_TENSOR_RANK + 1):
916
+ c, counter = generate_flatten_constraints(start_dim, end_dim, input, flattened, i, counter)
917
+ const.append(c)
918
+
919
+ return [Disj([both_dyn, *const])], counter
920
+
921
+
922
+ @register_inference_rule(torch.nn.functional.layer_norm)
923
+ def layer_norm_functional(n: Node, symbols, constraints, counter):
924
+ """
925
+ We generate the constraint: input = output
926
+ """
927
+ assert isinstance(n.args[0], Node)
928
+ return gen_layer_norm_constraints(n, n.args[1], symbols, counter)
929
+
930
+
931
+ @register_inference_rule(torch.nn.LayerNorm)
932
+ def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter):
933
+ """
934
+ Input and output shapes should be equal.
935
+ Input should be consistent with the normalized_shape
936
+ """
937
+ assert isinstance(n.args[0], Node)
938
+ return gen_layer_norm_constraints(n, module_instance.normalized_shape, symbols, counter)
939
+
940
+
941
+ def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter):
942
+ output, counter = gen_tvar(counter)
943
+ symbols[n] = output
944
+ input = symbols[n.args[0]]
945
+
946
+ input_dyn = BinConstraintT(input, Dyn, op_eq)
947
+ output_dyn = BinConstraintT(output, Dyn, op_eq)
948
+
949
+ c1 = Conj([input_dyn, output_dyn])
950
+
951
+ c2 = []
952
+ for i in range(1, MAX_TENSOR_RANK + 1):
953
+ new_dims_rhs, counter = gen_tensor_dims(i, counter)
954
+ nat_constraints = gen_nat_constraints(new_dims_rhs)
955
+
956
+ c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq),
957
+ BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] +
958
+ add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) +
959
+ nat_constraints)
960
+ c2.append(c_tensor_i)
961
+ return [Disj([c1, Disj(c2)])], counter
962
+
963
+ @register_inference_rule(torch.nn.Dropout)
964
+ @register_inference_rule(torch.nn.ReLU)
965
+ def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter):
966
+ """
967
+ Input and output shapes should be equal.
968
+ """
969
+ assert isinstance(n.args[0], Node)
970
+ output, counter = gen_tvar(counter)
971
+ symbols[n] = output
972
+ input = symbols[n.args[0]]
973
+ assert isinstance(input, TVar)
974
+ return [BinConstraintT(input, output, op_eq)], counter
975
+
976
+
977
+ @register_inference_rule(torch.nn.Linear)
978
+ def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter):
979
+ """
980
+ Input and output sizes should be the same except for the last dimension
981
+ If the input is Dyn, then so should the output
982
+ """
983
+ assert isinstance(n.args[0], Node)
984
+ return linear_constraints(n, module_instance.in_features, module_instance.out_features, symbols, counter)
985
+
986
+
987
+ @register_inference_rule("dim") # type: ignore[attr-defined]
988
+ def torch_dim_inference_rule(n: Node, symbols, constraints, counter):
989
+ assert isinstance(n.args[0], Node)
990
+ my_dim, counter = gen_dvar(counter)
991
+ symbols[n] = my_dim
992
+ input = symbols[n.args[0]]
993
+
994
+ input_dyn = BinConstraintT(input, Dyn, op_eq)
995
+ output_dyn = BinConstraintD(my_dim, Dyn, op_eq)
996
+
997
+ c1 = []
998
+
999
+ for i in range(1, MAX_TENSOR_RANK + 1):
1000
+ new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
1001
+
1002
+ c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq),
1003
+ BinConstraintD(my_dim, i, op_eq)])
1004
+ c1.append(c_tensor_i)
1005
+
1006
+ return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter
1007
+
1008
+
1009
+ @register_inference_rule(torch._C._nn.linear) # type: ignore[attr-defined]
1010
+ def torch_linear_inference_rule(n: Node, symbols, constraints, counter):
1011
+ assert isinstance(n.args[0], Node)
1012
+ weight_dims, counter = gen_tensor_dims(2, counter)
1013
+ equality_constraint = BinConstraintT(symbols[n.args[1]], TensorType(weight_dims), op_eq)
1014
+ constraints, counter = linear_constraints(n, weight_dims[1], weight_dims[0], symbols, counter)
1015
+ return [equality_constraint] + constraints, counter
1016
+
1017
+
1018
+ def linear_constraints(n: Node, in_features, out_features, symbols, counter):
1019
+ linear_output, counter = gen_tvar(counter)
1020
+ symbols[n] = linear_output
1021
+ linear_input = symbols[n.args[0]]
1022
+
1023
+ input_dyn = BinConstraintT(linear_input, Dyn, op_eq)
1024
+ output_dyn = BinConstraintT(linear_output, Dyn, op_eq)
1025
+
1026
+ c1 = Conj([input_dyn, output_dyn])
1027
+
1028
+ c2 = []
1029
+ for i in range(1, MAX_TENSOR_RANK + 1):
1030
+ new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
1031
+ new_dims_rhs_2, counter = gen_tensor_dims(i, counter)
1032
+
1033
+ nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)
1034
+
1035
+ c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq),
1036
+ BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] +
1037
+ add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, in_features, out_features) +
1038
+ nat_constraints)
1039
+ c2.append(c_tensor_i)
1040
+ return [Disj([c1, Disj(c2)])], counter
1041
+
1042
+ def add_layer_norm_constraints(input_dim, normalized_dim):
1043
+ """
1044
+ The constraints say that the type has te form: [*, 1024, 1024]
1045
+ while the normalized_dim have the form [1024, 1024]
1046
+ Args:
1047
+ input_dim: Input shape of layer norm
1048
+ normalized_dim: normalized_dim parameter of the module instance
1049
+
1050
+ """
1051
+
1052
+ # in this case we return false since there's a pattern mismatch
1053
+ if len(normalized_dim) > len(input_dim):
1054
+ return [F()]
1055
+
1056
+ else:
1057
+ constraints = []
1058
+ for i, n in zip(reversed(input_dim), reversed(normalized_dim)):
1059
+ constraints.append(BinConstraintD(i, n, op_consistency))
1060
+ return constraints
1061
+
1062
+
1063
+ def add_linear_constraints(dims1, dims2, in_features, out_features):
1064
+ assert len(dims1) == len(dims2)
1065
+ constraints = []
1066
+ for i in range(len(dims1)):
1067
+ if i == len(dims1) - 1:
1068
+ constraints.append(BinConstraintD(dims1[i], in_features, op_consistency))
1069
+ constraints.append(BinConstraintD(dims2[i], out_features, op_eq))
1070
+ else:
1071
+ constraints.append(BinConstraintD(dims1[i], dims2[i], op_eq))
1072
+
1073
+ return constraints
1074
+
1075
+
1076
+ @register_inference_rule(torch.reshape)
1077
+ def reshape_inference_rule(n: Node, symbols, constraints, counter):
1078
+ assert isinstance(n.args[0], Node)
1079
+
1080
+ # generate the new variable
1081
+ my_reshape, counter = gen_tvar(counter)
1082
+ symbols[n] = my_reshape
1083
+
1084
+ src_var = symbols[n.args[0]]
1085
+ t2 = n.args[1]
1086
+ t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) # type: ignore[union-attr]
1087
+ c1 = BinConstraintT(my_reshape, t2_type, op_eq) # type: ignore[union-attr]
1088
+ c2 = CanReshape(src_var, t2_type)
1089
+
1090
+ return [c1, c2], counter
1091
+
1092
+
1093
+ @register_inference_rule(BatchNorm2d)
1094
+ def batchnorm_inference_rule(n: Node, module_instance, symbols, constraints, counter):
1095
+ assert isinstance(n.args[0], Node)
1096
+
1097
+ # generate the new variable
1098
+ batchnorm_output, counter = gen_tvar(counter)
1099
+ symbols[n] = batchnorm_output
1100
+ batchnorm_input = symbols[n.args[0]]
1101
+
1102
+ # dim vars
1103
+ d1, counter = gen_dvar(counter)
1104
+ d2, counter = gen_dvar(counter)
1105
+ d3, counter = gen_dvar(counter)
1106
+ d4, counter = gen_dvar(counter)
1107
+
1108
+ nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
1109
+
1110
+ c1 = BinConstraintT(batchnorm_input, TensorType([d1, d2, d3, d4]), op_matching)
1111
+ c2 = BinConstraintT(batchnorm_input, batchnorm_output, op_eq)
1112
+ return [c1, c2, *nat_constraints], counter
1113
+
1114
+
1115
+ @register_inference_rule(torch.nn.AdaptiveAvgPool2d)
1116
+ def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, counter):
1117
+ assert isinstance(n.args[0], Node)
1118
+
1119
+ avg_pool, counter = gen_tvar(counter)
1120
+
1121
+ symbols[n] = avg_pool
1122
+ input_var = symbols[n.args[0]]
1123
+
1124
+ # dim vars
1125
+ d1, counter = gen_dvar(counter)
1126
+ d2, counter = gen_dvar(counter)
1127
+ d3, counter = gen_dvar(counter)
1128
+ d4, counter = gen_dvar(counter)
1129
+ nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
1130
+ c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
1131
+ c2 = BinConstraintT(avg_pool, TensorType([d1, d2, module_instance.output_size[0], module_instance.output_size[1]]), op_eq)
1132
+
1133
+ return [c1, c2, *nat_constraints], counter
1134
+
1135
+
1136
+ @register_inference_rule(Conv2d)
1137
+ def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counter):
1138
+ assert isinstance(n.args[0], Node)
1139
+
1140
+ my_conv, counter = gen_tvar(counter)
1141
+ symbols[n] = my_conv
1142
+ input_var = symbols[n.args[0]]
1143
+
1144
+ # dim vars
1145
+ [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter)
1146
+
1147
+ # c1 = Matching(input_var, TensorType([d1, d2, d3, d4]))
1148
+ c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
1149
+
1150
+ # c2 = DConsistency(module_instance.in_channels, d2)
1151
+ c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency)
1152
+
1153
+ c3 = CalcConv(my_conv, input_var,
1154
+ module_instance.out_channels,
1155
+ module_instance.kernel_size,
1156
+ module_instance.padding,
1157
+ module_instance.stride,
1158
+ module_instance.dilation, [d1, d2, d3, d4])
1159
+
1160
+ nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
1161
+
1162
+ return [c1, c2, c3, *nat_constraints], counter
1163
+
1164
+
1165
+ @register_inference_rule(torch.nn.MaxPool2d)
1166
+ def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, counter):
1167
+ assert isinstance(n.args[0], Node)
1168
+ maxpool, counter = gen_tvar(counter)
1169
+ symbols[n] = maxpool
1170
+ input_var = symbols[n.args[0]]
1171
+
1172
+ # dim vars
1173
+ [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter)
1174
+
1175
+ c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
1176
+
1177
+ c2 = CalcMaxPool(maxpool, input_var, module_instance.kernel_size, module_instance.padding,
1178
+ module_instance.stride, module_instance.dilation, [d1, d2, d3, d4])
1179
+
1180
+ nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
1181
+
1182
+ return [c1, c2, *nat_constraints], counter
1183
+
1184
+
1185
+ class ConstraintGenerator:
1186
+ def __init__(self, traced, graph=None):
1187
+ self.traced = traced # traced or tracer.root
1188
+ self.traced_params = dict(self.traced.named_parameters())
1189
+ self.constraints = []
1190
+ self.symbol_dict = {}
1191
+ self.graph = traced.graph if hasattr(traced, 'graph') else graph
1192
+
1193
+
1194
+ def generate_constraints(self, counter=0):
1195
+ """
1196
+ Iterate through every node and generate constraints
1197
+ Effect: self.constraints will be populated with the final constraints
1198
+ """
1199
+ graph = self.graph
1200
+
1201
+ all_constraints = []
1202
+
1203
+ for n in graph.nodes:
1204
+ (constraints, counter) = self.generate_constraints_node(n, counter)
1205
+ all_constraints += constraints
1206
+
1207
+ return Conj(all_constraints), counter
1208
+
1209
+ def generate_constraints_node(self, n: Node, counter):
1210
+ """
1211
+ Generate constraints the given node:
1212
+ Currently supported operations:
1213
+ - Reshape
1214
+ - Add
1215
+ - conv2d
1216
+ """
1217
+
1218
+ if n.op == 'placeholder':
1219
+ x, counter = gen_tvar(counter)
1220
+ self.symbol_dict[n] = x
1221
+
1222
+ my_type = n.type
1223
+
1224
+ if n.type != Dyn and (not isinstance(n.type, TensorType)):
1225
+ if n.type == torch.nn.parameter.Parameter:
1226
+ # since we have a parameter, the shape must be static
1227
+ assert 'example_value' in n.meta
1228
+ my_type = TensorType(n.meta['example_value'].size())
1229
+ else:
1230
+ my_type = Dyn
1231
+
1232
+ c1 = BinConstraintT(my_type, x, op_precision)
1233
+ c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq)
1234
+ return [c1, c2], counter
1235
+
1236
+ elif n.op == 'call_function':
1237
+ if n.target in _INFERENCE_RULES:
1238
+ return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter)
1239
+ else:
1240
+ raise RuntimeError(f'No inference rule registered for target {n.target}!')
1241
+
1242
+ elif n.op == 'call_module':
1243
+
1244
+ module_instance = self.traced.get_submodule(n.target)
1245
+ if type(module_instance) in _INFERENCE_RULES:
1246
+ return _INFERENCE_RULES[type(module_instance)](n,
1247
+ module_instance,
1248
+ self.symbol_dict,
1249
+ self.constraints, counter)
1250
+ else:
1251
+ raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!')
1252
+
1253
+ elif n.op == 'call_method':
1254
+ if n.target in _INFERENCE_RULES:
1255
+ return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter)
1256
+ else:
1257
+ raise RuntimeError(f'No inference rule registered for target {n.target}!')
1258
+
1259
+ elif n.op == 'get_attr':
1260
+ t = self.traced_params.get(n.target, None)
1261
+
1262
+ if isinstance(t, torch.Tensor):
1263
+ if len(t.shape) > 0:
1264
+ res = list(t.shape)
1265
+ attr_type = TensorType(res)
1266
+ output, counter = gen_tvar(counter)
1267
+ self.symbol_dict[n] = output
1268
+ return [BinConstraintT(output, attr_type, op_eq)], counter
1269
+ else:
1270
+ # scalar?
1271
+ return [], counter
1272
+ else:
1273
+ return [], counter
1274
+
1275
+ elif n.op == 'output':
1276
+ return [], counter
1277
+
1278
+ else:
1279
+ raise NotImplementedError(f"Method {n.op} not yet implemented")
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py ADDED
@@ -0,0 +1,1040 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+ import copy
3
+ import itertools
4
+ from torch.fx.experimental.migrate_gradual_types.constraint_generator import BinConstraintT, MAX_TENSOR_RANK
5
+ from torch.fx.experimental.migrate_gradual_types.constraint import T, BinConstraintD, Conj, Constraint, DVar, TVar, \
6
+ Transpose
7
+ from torch.fx.experimental.migrate_gradual_types.constraint import Disj, TGreatestUpperBound
8
+ from torch.fx.experimental.migrate_gradual_types.constraint import DGreatestUpperBound
9
+ from torch.fx.experimental.migrate_gradual_types.constraint import CalcConv, CalcMaxPool
10
+ from torch.fx.experimental.migrate_gradual_types.constraint import CalcProduct, CanReshape
11
+ from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, Prod, F, GetItem, GetItemTensor, IndexSelect
12
+ from torch.fx.experimental.migrate_gradual_types.operation import op_eq, op_precision, op_leq, op_matching
13
+ from torch.fx.experimental.migrate_gradual_types.operation import op_consistency, op_neq
14
+ from torch.fx.experimental.migrate_gradual_types.operation import op_mul, op_add, op_sub, op_div, op_mod
15
+ from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar
16
+ from torch.fx.tensor_type import TensorType, Dyn
17
+ from typing import Callable, Dict, List
18
+
19
+ _TRANSFORMATION_RULES: Dict[Constraint, Callable] = {}
20
+
21
+
22
+ def register_transformation_rule(call_target):
23
+ def register(fn):
24
+ if call_target in _TRANSFORMATION_RULES:
25
+ raise RuntimeError(f'Transformation rule already registered for {call_target}!')
26
+ _TRANSFORMATION_RULES[call_target] = fn
27
+ return fn
28
+ return register
29
+
30
+
31
+ def valid_index(index, dims):
32
+ """
33
+ Given a list of dimensions, checks if an index is valid in the list
34
+ """
35
+ try:
36
+ dims[index]
37
+ return T()
38
+ except IndexError:
39
+ return F()
40
+
41
+
42
+ @register_transformation_rule(Transpose)
43
+ def transform_transpose(constraint, counter):
44
+ """
45
+ Similar to a sequence of two index-selects
46
+ """
47
+ dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
48
+ is_valid_index1 = valid_index(constraint.index1, dims)
49
+ is_valid_index2 = valid_index(constraint.index2, dims)
50
+ new_dims = copy.deepcopy(dims)
51
+ nat_constraints = gen_nat_constraints(dims)
52
+
53
+ if is_valid_index1 == T() and is_valid_index2 == T():
54
+ new_dims[constraint.index1] = dims[constraint.index2]
55
+ new_dims[constraint.index2] = dims[constraint.index1]
56
+
57
+ transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
58
+ *nat_constraints,
59
+ is_valid_index1, is_valid_index2,
60
+ BinConstraintT(constraint.output, TensorType(new_dims), op_eq)])
61
+ return transformed_constraint, counter
62
+
63
+
64
+ @register_transformation_rule(IndexSelect)
65
+ def transform_index_select(constraint, counter):
66
+ """
67
+ The constraints consider the given tensor size, checks if the index is valid
68
+ and if so, generates a constraint for replacing the input dimension
69
+ with the required dimension
70
+ """
71
+ dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
72
+ is_valid_index = valid_index(constraint.index, dims)
73
+ nat_constraints = gen_nat_constraints(dims)
74
+
75
+ # if the index is valid then replace the input dimension with the new dimension
76
+ # otherwise the dimension will not be replaced and the clause will contain False
77
+ if is_valid_index == T():
78
+ new_dims = copy.deepcopy(dims)
79
+ new_dims[constraint.index] = constraint.dim_replace
80
+
81
+ transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
82
+ *nat_constraints,
83
+ is_valid_index,
84
+ BinConstraintT(constraint.output, TensorType(new_dims), op_eq)])
85
+
86
+ # print(constraints)
87
+ return transformed_constraint, counter
88
+
89
+
90
+ @register_transformation_rule(GetItem)
91
+ def transform_get_item(constraint, counter):
92
+ """
93
+ generate an equality of the form:
94
+ t = [a1, ..., an]
95
+ then generate constraints that check if the given index is valid
96
+ given this particular tensor size.
97
+ If the index is valid, generate a constraint to get the item
98
+ Note that we already handled the Dyn input case in the previous
99
+ step.
100
+ Args:
101
+ constraint: GetItem which assumes we are getting an item from a tensor (not Dyn)
102
+ counter: variable tracking
103
+ Returns: simplified constraints for GetItem
104
+
105
+ """
106
+ dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
107
+ nat_constraints = gen_nat_constraints(dims)
108
+
109
+
110
+ is_valid_index = valid_index(constraint.index, dims)
111
+
112
+ all_constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
113
+ *nat_constraints,
114
+ is_valid_index]
115
+
116
+ # if the index is valid, we generate a constraint for getting an item
117
+ # otherwise this clause will have been UNSAT due to the wrong index
118
+ if is_valid_index == T():
119
+ all_constraints.append(BinConstraintD(constraint.res, dims[constraint.index], op_eq))
120
+
121
+ return Conj(all_constraints), counter
122
+
123
+ def valid_index_tensor(index, dims):
124
+ """
125
+ if the slice instances exceed the length of the dimensions
126
+ then this is a type error so we return False
127
+ """
128
+ slice_count = 0
129
+ for s in index:
130
+ if isinstance(s, slice):
131
+ slice_count += 1
132
+ if slice_count > len(dims):
133
+ return F()
134
+ else:
135
+ return T()
136
+
137
+ @register_transformation_rule(GetItemTensor)
138
+ def transform_get_item_tensor(constraint, counter):
139
+ """
140
+ When the index is a tuple, then the output will be a tensor
141
+ TODO: we have to check if this is the case for all HF models
142
+
143
+ The cases we are covering here are a tuple with one of:
144
+ - slice with default argument
145
+ - None
146
+
147
+ None appends 1 to the input tensor dimensions
148
+ so each occurrence of 'None' increases the rank by 1
149
+
150
+ slice with default arguments does not change the rank
151
+ """
152
+ assert isinstance(constraint.index_tuple, tuple)
153
+
154
+
155
+ # generate a result tensor of the expected size
156
+ dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
157
+ nat_constraints = gen_nat_constraints(dims)
158
+
159
+ # generate a place-holder list of the right rank
160
+ # where "slice" does not contribute to the rank and "None" does
161
+ none_c = constraint.index_tuple.count(None)
162
+ resulting_tensor_dims = (none_c + len(dims)) * [None]
163
+
164
+ dim_index = 0
165
+ for i in range(len(constraint.index_tuple)):
166
+
167
+ # append 1 to the right location of the resulting tensor
168
+ if constraint.index_tuple[i] is None:
169
+ resulting_tensor_dims[i] = 1
170
+
171
+ elif constraint.index_tuple[i] == slice(None, None, None):
172
+ pass
173
+
174
+ else:
175
+ raise NotImplementedError('Method not yet implemented')
176
+
177
+ # append the remaining dimensions to the right location
178
+ dim_index = 0
179
+ for i in range(len(resulting_tensor_dims)):
180
+ if resulting_tensor_dims[i] is None:
181
+ resulting_tensor_dims[i] = dims[dim_index]
182
+ dim_index += 1
183
+
184
+ # check if the index is valid
185
+ is_valid_index = valid_index_tensor(constraint.index_tuple, dims)
186
+
187
+ # check if the resulting tensor is within bounds
188
+ if len(resulting_tensor_dims) > 4:
189
+ return F(), counter
190
+
191
+ else:
192
+ constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
193
+ BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq),
194
+ *nat_constraints,
195
+ is_valid_index]
196
+ return Conj(constraints), counter
197
+
198
+
199
+ @register_transformation_rule(BinConstraintT)
200
+ def generate_binconstraint_t(constraint, counter):
201
+ """
202
+ Transform binary constraints for tensors
203
+ """
204
+
205
+ # precision constraints
206
+ if constraint.op == op_precision:
207
+ if constraint.lhs == Dyn:
208
+ return T(), counter
209
+ elif isinstance(constraint.lhs, TensorType):
210
+ is_fully_static = all(d != Dyn for d in constraint.lhs.__args__)
211
+ if is_fully_static:
212
+ return BinConstraintT(constraint.lhs, constraint.rhs, op_eq), counter
213
+ else:
214
+ new_dims = []
215
+
216
+ for _ in range(len(constraint.lhs.__args__)):
217
+ dim, counter = gen_dvar(counter)
218
+ new_dims.append(dim)
219
+
220
+ new_dim_constraints = [BinConstraintD(old_dim, new_dim, op_precision) for
221
+ new_dim, old_dim in zip(new_dims, constraint.lhs.__args__)] + \
222
+ [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + \
223
+ [BinConstraintD(1, new_dim, op_leq) for
224
+ new_dim in new_dims]
225
+ return Conj(new_dim_constraints), counter
226
+
227
+ # matching
228
+ elif constraint.op == op_matching:
229
+ assert isinstance(constraint.rhs, TensorType)
230
+ d1 = constraint.rhs.__args__[0]
231
+ d2 = constraint.rhs.__args__[1]
232
+ d3 = constraint.rhs.__args__[2]
233
+ d4 = constraint.rhs.__args__[3]
234
+
235
+ conj = [BinConstraintT(constraint.lhs, Dyn, op_eq),
236
+ BinConstraintD(d1, Dyn, op_eq),
237
+ BinConstraintD(d2, Dyn, op_eq),
238
+ BinConstraintD(d3, Dyn, op_eq),
239
+ BinConstraintD(d4, Dyn, op_eq)]
240
+ return Disj([Conj(conj),
241
+ BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq)]), counter
242
+
243
+ elif constraint.op == op_consistency:
244
+ c_dyn = Disj([BinConstraintT(constraint.lhs, Dyn, op_eq), BinConstraintT(constraint.rhs, Dyn, op_eq)])
245
+ [c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4], counter = gen_consistency_constraints(constraint, counter)
246
+
247
+ return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4]), counter
248
+
249
+ elif constraint.op == op_leq:
250
+ assert isinstance(constraint.rhs, int)
251
+ disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)]
252
+ for i in range(1, constraint.rhs + 1):
253
+ dims = []
254
+ for j in range(1, i + 1):
255
+ dim_var, counter = gen_dvar(counter)
256
+ dims.append(dim_var)
257
+ disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq))
258
+ return Disj(disj), counter
259
+ else:
260
+ return constraint, counter
261
+
262
+
263
+ @register_transformation_rule(BinConstraintD)
264
+ def generate_binconstraint_d(constraint, counter):
265
+ """
266
+ Transform binary constraints for dimensions
267
+ """
268
+ if constraint.op == op_precision:
269
+ if isinstance(constraint.lhs, int):
270
+ return BinConstraintD(constraint.lhs, constraint.rhs, op_eq), counter
271
+ elif constraint.lhs == Dyn:
272
+ return T(), counter
273
+
274
+ elif constraint.op == op_consistency:
275
+ return Disj([BinConstraintD(constraint.lhs, constraint.rhs, op_eq),
276
+ BinConstraintD(constraint.rhs, Dyn, op_eq), BinConstraintD(constraint.lhs, Dyn, op_eq)]), counter
277
+
278
+ else:
279
+ return constraint, counter
280
+
281
+
282
+ @register_transformation_rule(Conj)
283
+ def generate_conj(constraint, counter):
284
+ """
285
+ Transform conjunctions
286
+ """
287
+ new = []
288
+ for c in constraint.conjucts:
289
+ new_c, counter = transform_constraint(c, counter)
290
+ new.append(new_c)
291
+ return Conj(new), counter
292
+
293
+
294
+ @register_transformation_rule(Disj)
295
+ def generate_disj(constraint, counter):
296
+ """
297
+ Transform disjunctions
298
+ """
299
+ new = []
300
+ for c in constraint.disjuncts:
301
+ new_c, counter = transform_constraint(c, counter)
302
+ new.append(new_c)
303
+ return Disj(new), counter
304
+
305
+
306
+ @register_transformation_rule(TGreatestUpperBound)
307
+ def generate_gub(constraint, counter):
308
+ """
309
+ Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound
310
+ on dimensions
311
+ """
312
+ c1 = Conj([Disj([BinConstraintT(constraint.rhs1, Dyn, op_eq),
313
+ BinConstraintT(constraint.rhs2, Dyn, op_eq)]), BinConstraintT(constraint.res, Dyn, op_eq)])
314
+
315
+ [c2, c3, c4, c5], counter = gen_greatest_upper_bound(constraint, counter)
316
+
317
+ return Disj([c1, c2, c3, c4, c5]), counter
318
+
319
+
320
+ @register_transformation_rule(DGreatestUpperBound)
321
+ def generate_d_gub(constraint, counter):
322
+ """
323
+ Transform greatest upper bound for dimensions into equality constraints
324
+ """
325
+ c1 = Conj([BinConstraintD(constraint.rhs1, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs2, op_eq)])
326
+ c2 = Conj([BinConstraintD(constraint.rhs2, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)])
327
+ c3 = Conj([BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)])
328
+ return Disj([c1, c2, c3]), counter
329
+
330
+
331
+ @register_transformation_rule(CalcConv)
332
+ def generate_calc_conv(constraint, counter):
333
+ d, counter = gen_tensor_dims(4, counter)
334
+ conv_result = TensorType([d[0], d[1], d[2], d[3]])
335
+
336
+ # the convolution result is a tensor of size 4
337
+ c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq)
338
+
339
+ # the second dimension of the output is equal to the output channels
340
+ c2 = Conj([BinConstraintD(d[1], constraint.c_out, op_eq), BinConstraintD(d[1], Dyn, op_neq)])
341
+
342
+ # the input corresponds to the output in the first dimension of the convolution
343
+ c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq)
344
+
345
+ c4, c5 = calc_last_two_dims(constraint, d)
346
+
347
+ leq_constraints = Conj([BinConstraintD(0, d[0], op_leq),
348
+ BinConstraintD(0, d[1], op_leq),
349
+ BinConstraintD(0, d[2], op_leq),
350
+ BinConstraintD(0, d[3], op_leq)])
351
+
352
+ return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
353
+
354
+
355
+ @register_transformation_rule(CalcMaxPool)
356
+ def generate_calc_maxpool(constraint, counter):
357
+ """
358
+ Transform maxpool constraints
359
+ """
360
+ d, counter = gen_tensor_dims(4, counter)
361
+ maxpool_result = TensorType([d[0], d[1], d[2], d[3]])
362
+
363
+ # the maxpool result is a tensor of size 4
364
+ c1 = BinConstraintT(constraint.maxpool_result, maxpool_result, op_eq)
365
+
366
+ # the input corresponds to the output in the first and second dimension of maxpool
367
+ c2 = BinConstraintD(constraint.matching_constraint[1], d[1], op_eq)
368
+ c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq)
369
+ c4, c5 = calc_last_two_dims(constraint, d)
370
+
371
+ leq_constraints = Conj([BinConstraintD(0, d[0], op_leq),
372
+ BinConstraintD(0, d[1], op_leq),
373
+ BinConstraintD(0, d[2], op_leq),
374
+ BinConstraintD(0, d[3], op_leq)])
375
+
376
+ return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
377
+
378
+
379
+ @register_transformation_rule(CalcProduct)
380
+ def generate_calc_product(constraint, counter):
381
+ """
382
+ Transform flatten constraints
383
+ """
384
+ start = constraint.start
385
+ end = constraint.end
386
+ dims = constraint.dims_to_flatten
387
+ flattened = constraint.flattened
388
+ n = len(constraint.dims_to_flatten)
389
+
390
+ # this will be evaluated right here
391
+ boundary_check = (0 <= start and start < end and end <= n)
392
+
393
+ c_boundary = T() if boundary_check else F()
394
+
395
+ lhs = dims[0:start]
396
+ rhs = dims[end:]
397
+ mid = dims[start:end]
398
+
399
+ all_possibilities = generate_all_int_dyn_dim_possibilities(mid)
400
+
401
+ all_constraints = []
402
+
403
+ for p in all_possibilities:
404
+ p = list(p)
405
+ # this tells us there is a dynamic variable
406
+ contains_dyn = not all(constraint.op == op_neq for constraint in p)
407
+ if contains_dyn:
408
+ mid_var = [Dyn]
409
+ total_constraints = lhs + mid_var + rhs
410
+ if len(total_constraints) > 4:
411
+ all_constraints.append(F())
412
+ else:
413
+ all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq)] + p))
414
+ else:
415
+ new_var, counter = gen_dvar(counter)
416
+ mid_eq_prod = Conj([BinConstraintD(new_var, Prod(mid), op_eq), BinConstraintD(new_var, Dyn, op_neq)])
417
+ mid_var = [new_var]
418
+ total_constraints = lhs + mid_var + rhs
419
+ if len(total_constraints) > 4:
420
+ all_constraints.append(F())
421
+ else:
422
+ all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq), mid_eq_prod] + p))
423
+
424
+ return Conj([Disj(all_constraints), c_boundary]), counter
425
+
426
+
427
+ @register_transformation_rule(CanReshape)
428
+ def generate_reshape(constraint, counter):
429
+ """
430
+ Transform reshape constraints
431
+ """
432
+ d, counter = gen_tensor_dims(4, counter)
433
+
434
+ d1 = d[0]
435
+ d2 = d[1]
436
+ d3 = d[2]
437
+ d4 = d[3]
438
+
439
+ target = constraint.target.__args__
440
+
441
+ is_fully_static = all(d != Dyn for d in target)
442
+
443
+ # dynamic tensor
444
+ c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq)
445
+ c2_tensor1 = BinConstraintT(constraint.src, TensorType([d1]), op_eq)
446
+ c2_tensor2 = BinConstraintT(constraint.src, TensorType([d1, d2]), op_eq)
447
+ c2_tensor3 = BinConstraintT(constraint.src, TensorType([d1, d2, d3]), op_eq)
448
+ c2_tensor4 = BinConstraintT(constraint.src, TensorType([d1, d2, d3, d4]), op_eq)
449
+
450
+ d1_eq_dyn = BinConstraintD(d1, Dyn, op_eq)
451
+ d1_neq_dyn = BinConstraintD(d1, Dyn, op_neq)
452
+
453
+ d2_eq_dyn = BinConstraintD(d2, Dyn, op_eq)
454
+ d2_neq_dyn = BinConstraintD(d2, Dyn, op_neq)
455
+
456
+ d3_eq_dyn = BinConstraintD(d3, Dyn, op_eq)
457
+ d3_neq_dyn = BinConstraintD(d3, Dyn, op_neq)
458
+
459
+ d4_eq_dyn = BinConstraintD(d3, Dyn, op_eq)
460
+ d4_neq_dyn = BinConstraintD(d3, Dyn, op_neq)
461
+
462
+ nat_d1 = BinConstraintD(0, d1, op_leq)
463
+ nat_d2 = BinConstraintD(0, d2, op_leq)
464
+ nat_d3 = BinConstraintD(0, d3, op_leq)
465
+ nat_d4 = BinConstraintD(0, d4, op_leq)
466
+
467
+ if is_fully_static:
468
+ # size 1 tensor
469
+ c3_tensor1 = Disj([d1_eq_dyn,
470
+ (Conj([d1_neq_dyn,
471
+ BinConstraintD(d1, Prod(target), op_eq)]))])
472
+ all_tensor_1 = Conj([c2_tensor1, c3_tensor1])
473
+
474
+ # size 2 tensor
475
+ all_tensor_2 = Conj([c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)])
476
+
477
+ # size 3 tensor
478
+ all_tensor_3 = Conj([c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)])
479
+
480
+ # size 4 tensor
481
+ all_tensor_4 = Conj([c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)])
482
+
483
+ return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]),
484
+ nat_d1, nat_d2, nat_d3, nat_d4]), counter
485
+
486
+ # then there must be exactly one occurrence of dyn
487
+ else:
488
+ new_target = []
489
+
490
+ for n in target:
491
+ if n != Dyn:
492
+ new_target.append(n)
493
+
494
+ # tensor 1
495
+ c3_tensor1 = Disj([d1_eq_dyn,
496
+ (Conj([d1_neq_dyn,
497
+ is_dim_div_by_target(new_target, d1)]))])
498
+ all_tensor_1 = Conj([c2_tensor1, c3_tensor1])
499
+
500
+ # tensor 2
501
+ c21 = Disj([d1_eq_dyn, d2_eq_dyn])
502
+ c22 = Conj([d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))])
503
+ all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])])
504
+
505
+ # tensor 3
506
+ c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn])
507
+ c32 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3]))])
508
+ all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])])
509
+
510
+ # tensor 4
511
+ c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn])
512
+ c42 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, d4_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4]))])
513
+ all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])])
514
+
515
+ return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]),
516
+ nat_d1, nat_d2, nat_d3, nat_d4]), counter
517
+
518
+
519
+ @register_transformation_rule(ApplyBroadcasting)
520
+ def generate_broadcasting(constraint, counter):
521
+ """
522
+ Transform broadcasting constraints
523
+ """
524
+ e11, e12 = constraint.res1, constraint.res2
525
+ e1, e2 = constraint.input1, constraint.input2
526
+
527
+ e1_dyn = BinConstraintT(e1, Dyn, op_eq)
528
+ e2_dyn = BinConstraintT(e2, Dyn, op_eq)
529
+
530
+ # Introduce dimensions
531
+ e1_equal_e11 = BinConstraintT(e1, e11, op_eq)
532
+ e2_equal_e12 = BinConstraintT(e2, e12, op_eq)
533
+
534
+ # dyn possibility
535
+ e1_dyn_constraint = Conj([e1_dyn, e1_equal_e11, e2_equal_e12])
536
+ e2_dyn_constraint = Conj([e2_dyn, e1_equal_e11, e2_equal_e12])
537
+
538
+ # tensor possibility
539
+ # generate dimensions to create tensors of size 1
540
+ final_tensor_1_constraint, _, _, nat_dims_1, counter = \
541
+ gen_broadcasting_constraints(e1, e2, e11, e12, 1, counter)
542
+
543
+ # generate dimensions to create tensors of size 2
544
+ final_tensor_2_constraint_no_padding, final_tensor_2_constraint_padding_arg1, \
545
+ final_tensor_2_constraint_padding_arg2, nat_dims_2, counter = \
546
+ gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter)
547
+
548
+ # generate dimensions to create tensors of size 3
549
+ final_tensor_3_constraint_no_padding, final_tensor_3_constraint_padding_arg1, \
550
+ final_tensor_3_constraint_padding_arg2, nat_dims_3, counter = \
551
+ gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter)
552
+
553
+ # generate dimensions to create tensors of size 4
554
+ final_tensor_4_constraint_no_padding, final_tensor_4_constraint_padding_arg1, \
555
+ final_tensor_4_constraint_padding_arg2, nat_dims_4, counter = \
556
+ gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter)
557
+
558
+ final_result = Disj([
559
+ e1_dyn_constraint,
560
+ e2_dyn_constraint,
561
+ final_tensor_1_constraint,
562
+ final_tensor_2_constraint_no_padding,
563
+ final_tensor_2_constraint_padding_arg1,
564
+ final_tensor_2_constraint_padding_arg2,
565
+ final_tensor_3_constraint_no_padding,
566
+ final_tensor_3_constraint_padding_arg1,
567
+ final_tensor_3_constraint_padding_arg2,
568
+ final_tensor_4_constraint_no_padding,
569
+ final_tensor_4_constraint_padding_arg1,
570
+ final_tensor_4_constraint_padding_arg2
571
+ ])
572
+
573
+ return Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]), counter
574
+
575
+
576
+ def transform_constraint(constraint: Constraint, counter: int):
577
+ """
578
+ Transforms a constraint into a simpler constraint.
579
+ Ex: precision and consistency are transformed to equality
580
+ Args:
581
+ constraint: constraint to be transformed
582
+ counter: for variable tracking
583
+
584
+ Returns: Constraint
585
+
586
+ """
587
+ if type(constraint) in _TRANSFORMATION_RULES:
588
+ return _TRANSFORMATION_RULES[type(constraint)](constraint, counter)
589
+
590
+ else:
591
+ return constraint, counter
592
+
593
+
594
+
595
+
596
+ def calc_last_two_dims(constraint, d: List[DVar]):
597
+ """
598
+ Generates constraints for the last two dimensions of a convolution or a maxpool output
599
+ Args:
600
+ constraint: CalcConv or CalcMaxPool
601
+ d: The list of output dimensions
602
+
603
+ Returns: Constraints for calculating the last two dimensions of the output
604
+
605
+ """
606
+
607
+ assert isinstance(constraint, (CalcConv, CalcMaxPool))
608
+
609
+ b3 = constraint.matching_constraint[2]
610
+ b4 = constraint.matching_constraint[3]
611
+
612
+ b3_dyn = Conj([BinConstraintD(d[2], Dyn, op_eq), BinConstraintD(b3, Dyn, op_eq)])
613
+ b4_dyn = Conj([BinConstraintD(d[3], Dyn, op_eq), BinConstraintD(b4, Dyn, op_eq)])
614
+
615
+ d3_not_dyn = Conj([BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)])
616
+ d4_not_dyn = Conj([BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)])
617
+
618
+ # transform parameters into tuples incase they are not already
619
+ padding = (constraint.padding, constraint.padding) \
620
+ if isinstance(constraint.padding, int) else constraint.padding
621
+ kernel = (constraint.kernel, constraint.kernel) \
622
+ if isinstance(constraint.kernel, int) else constraint.kernel
623
+ stride = (constraint.stride, constraint.stride) \
624
+ if isinstance(constraint.stride, int) else constraint.stride
625
+ dilation = (constraint.dilation, constraint.dilation) \
626
+ if isinstance(constraint.dilation, int) else constraint.dilation
627
+
628
+ f1 = BinConstraintD(b3, BinConstraintD(2, padding[0], op_mul), op_add)
629
+ f2 = BinConstraintD(dilation[0], BinConstraintD(kernel[0], 1, op_sub), op_mul)
630
+ f3 = BinConstraintD(BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div)
631
+ f4 = BinConstraintD(f3, 1, op_add)
632
+
633
+ c4 = Disj([b3_dyn, Conj([d3_not_dyn, BinConstraintD(d[2], f4, op_eq)])])
634
+
635
+ f11 = BinConstraintD(b4, BinConstraintD(2, padding[1], op_mul), op_add)
636
+ f22 = BinConstraintD(dilation[1], BinConstraintD(kernel[1], 1, op_sub), op_mul)
637
+ f33 = BinConstraintD(BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div)
638
+ f44 = BinConstraintD(f33, 1, op_add)
639
+
640
+ c5 = Disj([b4_dyn, Conj([d4_not_dyn, BinConstraintD(d[3], f44, op_eq)])])
641
+
642
+ return c4, c5
643
+
644
+
645
+ def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]):
646
+ """
647
+ Generate all possibilities of being equal or not equal to dyn for my_list
648
+ Args:
649
+ my_list: List of tensor dimensions
650
+
651
+ Returns: A list of a list of constraints. Each list of constraints corresponds to
652
+ one possibility about the values of the dimension variables
653
+ """
654
+ # generate all possibilities of being equal or not equal to dyn for my_list
655
+ eq_possibilities = [BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list))]
656
+ neq_possibilities = [BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))]
657
+ d_possibilities = []
658
+
659
+ for i in zip(eq_possibilities, neq_possibilities):
660
+ d_possibilities.append(list(i))
661
+ all_possibilities = list(itertools.product(*d_possibilities))
662
+ return all_possibilities
663
+
664
+
665
+ def is_target_div_by_dim(target: List[int], dim: List[DVar]):
666
+ """
667
+ Generate constraints to check if the target dimensions are divisible by the input dimensions
668
+ Args:
669
+ target: Target dimensions
670
+ dim: Input dimensions
671
+
672
+ Returns: Constraints to check divisibility
673
+
674
+ """
675
+ return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq)
676
+
677
+
678
+ def is_dim_div_by_target(target: List[int], dim: List[DVar]):
679
+ """
680
+ Generate constraints to check if the input dimensions is divisible by the target dimensions
681
+ Args:
682
+ target: Target dimensions
683
+ dim: Input dimensions
684
+
685
+ Returns: Constraints to check divisibility
686
+
687
+ """
688
+ return BinConstraintD(BinConstraintD(dim, Prod(target), op_mod), 0, op_eq)
689
+
690
+
691
+ def gen_all_reshape_possibilities(list_of_dims, target):
692
+ """
693
+ Consider all possibilities what the input dimensions could be (number or dynamic)
694
+ Then generate the appropriate constraints using multiplication or mod depending on the possibility
695
+ The possibilities we consider here are the cross product of being equal to dyn or not equal to dyn
696
+ for the input. Target is fixed because at most one dimension could be dyn.
697
+ We have different cases for this.
698
+
699
+ Args:
700
+ list_of_dims: The input list of dimensions
701
+ target: The tensor we want to reshape to
702
+
703
+ Returns: A disjunction of transformed reshape constraints
704
+
705
+ """
706
+ all_possibilities = generate_all_int_dyn_dim_possibilities(list_of_dims)
707
+
708
+ all_constraints = []
709
+
710
+ for p in all_possibilities:
711
+ to_multiply = []
712
+
713
+ p = list(p)
714
+
715
+ for constraint in p:
716
+ assert isinstance(constraint, BinConstraintD)
717
+ if constraint.op == op_neq:
718
+ to_multiply.append(constraint.lhs)
719
+
720
+ if not to_multiply:
721
+ all_constraints.append(Conj(p))
722
+
723
+ elif len(to_multiply) < len(list_of_dims):
724
+ all_constraints.append(Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))]))
725
+ else:
726
+ all_constraints.append(Conj(p + [BinConstraintD(Prod(list_of_dims),
727
+ Prod(target), op_eq)]))
728
+
729
+ return Disj(all_constraints)
730
+
731
+
732
+ def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False):
733
+ """
734
+ Apply broadcasting to the 'index' dimension of tensor_input1.
735
+ Args:
736
+ tensor_input1: should represent [d1, ..., d_index, ...] where d_index = 1
737
+ tensor_input2: represents the second input
738
+ res1: broadcasted result 1
739
+ res2: broadcasted result 2
740
+ index: the index to broadcast
741
+ padding: If padding was used, then tensor_input1[index] does not exist
742
+
743
+ Returns:
744
+
745
+ """
746
+ if tensor_input1[index] is None:
747
+ assert padding
748
+
749
+
750
+ if not padding:
751
+ # then the inputs are the same length so they all have dimensions at "index"
752
+ return Conj([BinConstraintD(tensor_input1[index], 1, op_eq),
753
+ BinConstraintD(res1[index], res2[index], op_eq),
754
+ BinConstraintD(res2[index], tensor_input2[index], op_eq)])
755
+
756
+ else:
757
+ # we don't set the input dimension to 1, since it doesn't exist.
758
+ return Conj([BinConstraintD(res1[index], res2[index], op_eq),
759
+ BinConstraintD(res2[index], tensor_input2[index], op_eq)])
760
+
761
+
762
+ def apply_padding(e1_var: TVar,
763
+ e11: BinConstraintT,
764
+ e2: BinConstraintT,
765
+ e12: BinConstraintT,
766
+ d2: List[DVar],
767
+ d11: List[DVar],
768
+ d12: List[DVar],
769
+ counter: int):
770
+ """
771
+ We are considering the possibility where one input has less dimensions than
772
+ another input, so we apply padding to the broadcasted results
773
+
774
+ Args:
775
+ e1_var: Variable representing the first input where padding will be
776
+ e11: constraint of the form e11 = Tensortype[d1, ..., dn]
777
+ e2: constraint of the form e2 = Tensortype[d1, ..., dn]
778
+ e12: constraint of the form e11 = Tensortype[d1, ..., dn]
779
+ d2: Tensor variables for the second input
780
+ d11: Tensor variables for the broadcasted first input
781
+ d12: Tensor variables for the broadcasted second input
782
+ counter: variable tracking
783
+
784
+ Returns: A new constraint whose goal is to apply padding to the broadcasted result
785
+
786
+ """
787
+
788
+ res = []
789
+
790
+ # pad the shorter input with None so we can pass it to the broadcasting helper function
791
+ for i in range(1, len(d2)):
792
+
793
+ d1, counter = gen_tensor_dims(i, counter)
794
+
795
+ nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12)
796
+
797
+ e1 = BinConstraintT(e1_var, TensorType(d1), op_eq)
798
+
799
+ simulate_padding = [None] * (len(d2) - i)
800
+
801
+ assert len(simulate_padding + d1) == len(d2)
802
+
803
+ broadcast_padding = []
804
+
805
+ # for every padding size, we also consider broadcasting
806
+ for j in range(len(d2) - i):
807
+ broadcast_padding.append(broadcast_dim(simulate_padding, d2, d11, d12, j, True))
808
+
809
+ # we consider the possibilities for broadcasting for every dimension. Since we already
810
+ # padded d1, we do not consider it while broadcasting
811
+ all_broadcasting_possibilities = generate_all_broadcasting_possibilities_no_padding(d1,
812
+ d2[(len(d2) - i):],
813
+ d11[(len(d2) - i):],
814
+ d12[(len(d2) - i):])
815
+ # combine all constraints into a conjunction
816
+ c = Conj([e1, e11, e2, e12,
817
+ *broadcast_padding,
818
+ all_broadcasting_possibilities,
819
+ *nat_constraints
820
+ ])
821
+ res.append(c)
822
+
823
+ return Disj(res), counter
824
+
825
+
826
+ def no_broadcast_dim_with_index(d1: List[DVar],
827
+ d2: List[DVar],
828
+ d3: List[DVar],
829
+ d4: List[DVar],
830
+ i: int):
831
+ """
832
+ Args:
833
+ d1: input 1
834
+ d2: input 2
835
+ d3: simulated broadcasting for input 1
836
+ d4: simulated broadcasting for input 2
837
+ i: the rank of the resulting tensor addition
838
+
839
+ Returns: Constraints for when no broadcasting occurs
840
+ """
841
+ return Conj([
842
+ Disj([
843
+ Conj([BinConstraintD(d1[i], 1, op_eq),
844
+ BinConstraintD(d2[i], 1, op_eq)]),
845
+
846
+ Conj([BinConstraintD(d1[i], 1, op_neq),
847
+ BinConstraintD(d2[i], 1, op_neq)])]),
848
+
849
+ BinConstraintD(d1[i], d3[i], op_eq),
850
+ BinConstraintD(d2[i], d4[i], op_eq)])
851
+
852
+
853
+
854
+ def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int):
855
+ """
856
+ Generate lists of DVar to represent tensor dimensions
857
+ Args:
858
+ num_tensors: the required number of tensors
859
+ dim_size: the number of dimensions for each tensor
860
+ counter: variable tracking
861
+
862
+ Returns: A list of a list of tensor dimensions
863
+
864
+ """
865
+ res = []
866
+
867
+ for _ in range(num_tensors):
868
+ dims, counter = gen_tensor_dims(dim_size, counter)
869
+ res.append(dims)
870
+
871
+ return res, counter
872
+
873
+
874
+ def create_equality_constraints_for_broadcasting(e1: TVar,
875
+ e2: TVar,
876
+ e11: TVar,
877
+ e12: TVar,
878
+ d1: List[DVar],
879
+ d2: List[DVar],
880
+ d11: List[DVar],
881
+ d12: List[DVar]):
882
+ """
883
+ Create equality constraints for when no broadcasting occurs
884
+ Args:
885
+ e1: Input 1
886
+ e2: Input 2
887
+ e11: Broadcasted input 1
888
+ e12: Broadcasted input 2
889
+ d1: Variables that store dimensions for e1
890
+ d2: Variables that store dimensions for e2
891
+ d11: Variables that store dimensions for e11
892
+ d12: Variables that store dimensions for e22
893
+
894
+ Returns: Four equality constraints
895
+
896
+ """
897
+
898
+ e1_tensor = BinConstraintT(e1, TensorType(d1), op_eq)
899
+ e11_tensor = BinConstraintT(e11, TensorType(d11), op_eq)
900
+ e2_tensor = BinConstraintT(e2, TensorType(d2), op_eq)
901
+ e12_tensor = BinConstraintT(e12, TensorType(d12), op_eq)
902
+ return [e1_tensor, e11_tensor, e2_tensor, e12_tensor]
903
+
904
+
905
+ def gen_consistency_constraints(constraint: Constraint, counter: int):
906
+ """
907
+ Args:
908
+ constraint: Consistency constraint on tensors
909
+ counter: for variable tracking
910
+
911
+ Returns: Equality and consistency constraints on dimensions
912
+
913
+ """
914
+
915
+ all_constraints = []
916
+
917
+ for i in range(1, MAX_TENSOR_RANK + 1):
918
+ new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
919
+ new_dims_rhs_2, counter = gen_tensor_dims(i, counter)
920
+
921
+ nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)
922
+
923
+ c_tensor_i = Conj([BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq),
924
+ BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq)] +
925
+ [BinConstraintD(d1, d2, op_consistency) for
926
+ d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2)] + nat_constraints)
927
+
928
+ all_constraints.append(c_tensor_i)
929
+
930
+ return all_constraints, counter
931
+
932
+
933
+ def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int):
934
+ """
935
+ Args:
936
+ constraint: Greatest upper bound on tensors
937
+ counter: variable tracking
938
+
939
+ Returns: A set of equality constraints and DGreatestUpperBound constraints
940
+
941
+ """
942
+
943
+ all_constraints = []
944
+
945
+ for i in range(1, MAX_TENSOR_RANK + 1):
946
+ c = []
947
+ dims1, counter = gen_tensor_dims(i, counter)
948
+ c1tensor = TensorType(dims1)
949
+
950
+ dims2, counter = gen_tensor_dims(i, counter)
951
+ c2tensor = TensorType(dims2)
952
+
953
+ dims3, counter = gen_tensor_dims(i, counter)
954
+ c3tensor = TensorType(dims3)
955
+
956
+ c += [BinConstraintT(constraint.rhs1, c1tensor, op_eq),
957
+ BinConstraintT(constraint.rhs2, c2tensor, op_eq),
958
+ BinConstraintT(constraint.res, c3tensor, op_eq)] + \
959
+ gen_nat_constraints(dims1 + dims2 + dims3)
960
+
961
+ assert len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__)
962
+ for i in range(len(c3tensor.__args__)):
963
+ c.append(DGreatestUpperBound(c3tensor.__args__[i],
964
+ c1tensor.__args__[i],
965
+ c2tensor.__args__[i]))
966
+
967
+ all_constraints.append(Conj(c))
968
+ return all_constraints, counter
969
+
970
+
971
+ def generate_all_broadcasting_possibilities_no_padding(d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar]):
972
+ """
973
+ Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension.
974
+ We look at all combinations for all dimensions in d1 and d2
975
+ Args:
976
+ d1: input1 dimensions
977
+ d2: input2 dimensions
978
+ d11: broadcasted input1 dimensions
979
+ d12: broadcasted input2 dimensions
980
+
981
+ Returns: broadcasting constraints relating the input dimensions to the broadcasted dimensions
982
+
983
+ """
984
+
985
+ size = len(d1)
986
+
987
+ res2 = []
988
+
989
+ for i in range(size):
990
+ t1 = broadcast_dim(d1, d2, d11, d12, i)
991
+ t2 = broadcast_dim(d2, d1, d12, d11, i)
992
+ t3 = no_broadcast_dim_with_index(d1, d2, d11, d12, i)
993
+
994
+ res2.append(Disj([t1, t2, t3]))
995
+
996
+ return Conj(res2)
997
+
998
+
999
+ def gen_broadcasting_constraints(e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int):
1000
+ """
1001
+ Simulates broadcasting on e1 and e2 and returns the results
1002
+ respectively in e11 and e12. Because of gradual types,
1003
+ e1 and e2 may not be equal. Similarly, e11 and e12 may not
1004
+ be equal. e11 and e12 should be guaranteed to be consistent
1005
+ as they represent the shapes of the tensors to be added after
1006
+ broadcasting.
1007
+ Args:
1008
+ e1: TVar representing the type of input 1
1009
+ e2: TVar representing the type of input 2
1010
+ e11: TVar representing the representing broadcasted input 1
1011
+ e12: TVar representing the representing broadcasted input 2
1012
+ i: The rank of the resulting type of addition
1013
+ counter: for variable tracking
1014
+
1015
+ Returns: Simplified broadcasting constraints
1016
+
1017
+ """
1018
+ dims, counter = gen_lists_of_dims(4, i, counter)
1019
+ [d1, d2, d3, d4] = dims
1020
+ nat_dims_i = gen_nat_constraints(list(itertools.chain.from_iterable(dims)))
1021
+
1022
+ initialize_tensors_constraints = create_equality_constraints_for_broadcasting(e1, e2, e11, e12,
1023
+ d1, d2, d3, d4)
1024
+
1025
+ [e1_tensor, e11_tensor, e2_tensor, e12_tensor] = initialize_tensors_constraints
1026
+
1027
+ # without padding, broadcast all possibilities for tensors of size i
1028
+ final_tensor_constraint_no_padding = Conj([*initialize_tensors_constraints,
1029
+ generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4)])
1030
+
1031
+ # with padding, broadcast all possibilities for tensors of size i
1032
+ final_tensor_constraint_padding_arg1, counter = \
1033
+ apply_padding(e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter)
1034
+
1035
+ final_tensor_constraint_padding_arg2, counter = \
1036
+ apply_padding(e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter)
1037
+
1038
+ return final_tensor_constraint_no_padding, \
1039
+ final_tensor_constraint_padding_arg1, \
1040
+ final_tensor_constraint_padding_arg2, nat_dims_i, counter
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.fx.experimental.migrate_gradual_types.constraint import Conj, Disj, T, F, BinConstraintT, BVar, is_bool_expr
2
+ from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraintD, TVar, DVar
3
+ from torch.fx.experimental.migrate_gradual_types.constraint import Prod, is_algebraic_expression, is_dim
4
+ from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator
5
+ from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint
6
+ from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_eq, op_neq, op_gt, op_lt
7
+ from torch.fx.experimental.migrate_gradual_types.operation import op_leq, op_sub, op_div, op_mul, op_mod
8
+ from torch.fx.tensor_type import TensorType, Dyn
9
+
10
+ try:
11
+ import z3 # type: ignore[import]
12
+ from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, z3_dyn, D
13
+ HAS_Z3 = True
14
+
15
+ def transform_to_z3(constraint, counter, dimension_dict):
16
+ if isinstance(constraint, Conj):
17
+ conjuncts = []
18
+ for c in constraint.conjucts:
19
+ new_c, counter = transform_to_z3(c, counter, dimension_dict)
20
+ conjuncts.append(new_c)
21
+ return z3.And(conjuncts), counter
22
+
23
+ elif isinstance(constraint, Disj):
24
+ disjuncts = []
25
+ for c in constraint.disjuncts:
26
+ new_c, counter = transform_to_z3(c, counter, dimension_dict)
27
+ disjuncts.append(new_c)
28
+ return z3.Or(disjuncts), counter
29
+
30
+ elif isinstance(constraint, T):
31
+ return True, counter
32
+
33
+ elif isinstance(constraint, F):
34
+ return False, counter
35
+
36
+ elif isinstance(constraint, BinConstraintT):
37
+ if constraint.op == op_eq:
38
+ lhs, counter = transform_var(constraint.lhs, counter, dimension_dict)
39
+ rhs, counter = transform_var(constraint.rhs, counter, dimension_dict)
40
+ return (lhs == rhs), counter
41
+
42
+ else:
43
+ raise NotImplementedError('Method not yet implemented')
44
+
45
+ elif isinstance(constraint, BinConstraintD):
46
+ if constraint.op == op_eq:
47
+
48
+ if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs):
49
+ transformed_rhs, counter = transform_to_z3(constraint.rhs, counter, dimension_dict)
50
+ transformed_lhs = z3.Bool(constraint.lhs.c)
51
+ return transformed_lhs == transformed_rhs, counter
52
+
53
+ elif is_dim(constraint.lhs) and is_dim(constraint.rhs):
54
+ # with dimension transformations we consider the encoding
55
+ lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict)
56
+ rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict)
57
+ return lhs == rhs, counter
58
+
59
+ else:
60
+ # then we have an algebraic expression which means that we disregard the
61
+ # first element of the encoding
62
+ lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
63
+ rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
64
+ return lhs == rhs, counter
65
+
66
+ # The assumption here is that the LHS and RHS must be dimensions
67
+ elif constraint.op == op_neq:
68
+ assert is_dim(constraint.lhs)
69
+ assert is_dim(constraint.rhs)
70
+ lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict)
71
+ rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict)
72
+ if constraint.rhs == Dyn or constraint.lhs == Dyn:
73
+ if constraint.rhs == Dyn:
74
+ return lhs.arg(0) == 1, counter
75
+ elif constraint.lhs == Dyn:
76
+ return rhs.arg(0) == 1, counter
77
+
78
+ # if one of the instances is a number
79
+ elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int):
80
+ if isinstance(constraint.lhs, int):
81
+ return z3.Or([rhs.arg(0) == 0, z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter
82
+
83
+ elif isinstance(constraint.rhs, int):
84
+ return z3.Or([lhs.arg(0) == 0, z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter
85
+
86
+ else:
87
+ return z3.Or([z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]),
88
+ z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]),
89
+ z3.And([lhs.arg(0) != 0, rhs.arg(0) != 0, lhs.arg(1) != rhs.arg(1)])]), counter
90
+
91
+
92
+ elif constraint.op == op_leq:
93
+ # if the dimensions are not dyn, this will come into effect
94
+ # there would have been another constraint specifying if a given dimension
95
+ # is dyn or not
96
+ assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
97
+ lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
98
+ rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
99
+ return lhs <= rhs, counter
100
+
101
+ elif constraint.op == op_gt:
102
+ assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
103
+ lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
104
+ rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
105
+ return lhs > rhs, counter
106
+
107
+ elif constraint.op == op_lt:
108
+ assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
109
+ lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
110
+ rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
111
+ return lhs < rhs, counter
112
+
113
+ else:
114
+ raise NotImplementedError('operation not yet implemented')
115
+
116
+ else:
117
+ raise NotImplementedError('Operation not yet implemented')
118
+
119
+
120
+ def transform_var(tensor, counter, dimension_dict):
121
+ """
122
+ Transforms tensor variables to a format understood by z3
123
+ Args:
124
+ tensor: Tensor variable or a tensor type potentially with variable dimensions
125
+ Returns: Transformed variable to a z3 format
126
+
127
+ """
128
+ if isinstance(tensor, TensorType):
129
+ res = []
130
+ for t in tensor.__args__:
131
+ transformed, counter = transform_dimension(t, counter, dimension_dict)
132
+ res.append(transformed)
133
+
134
+ assert len(res) <= 4
135
+ if len(tensor.__args__) == 1:
136
+ return tensor_type.tensor1(res[0]), counter
137
+ elif len(tensor.__args__) == 2:
138
+ return tensor_type.tensor2(res[0], res[1]), counter
139
+ elif len(tensor.__args__) == 3:
140
+ return tensor_type.tensor3(res[0], res[1], res[2]), counter
141
+ elif len(tensor.__args__) == 4:
142
+ return tensor_type.tensor4(res[0], res[1], res[2], res[3]), counter
143
+
144
+ elif tensor == Dyn:
145
+ return z3_dyn, counter
146
+
147
+ elif isinstance(tensor, TVar):
148
+ return z3.Const(tensor.tvar, tensor_type), counter
149
+
150
+ def transform_dimension(dimension, counter, dimension_dict):
151
+ """
152
+ Takes a dimension variable or a number and transforms it to a tuple
153
+ according to our scheme
154
+ Args:
155
+ dimension: The dimension to be transformed
156
+ counter: variable tracking
157
+
158
+ Returns: tuple and the current counter
159
+
160
+ """
161
+ if dimension == Dyn:
162
+ counter += 1
163
+ return D(0, z3.Int(counter)), counter
164
+ elif isinstance(dimension, int):
165
+ return D(1, dimension), counter
166
+ elif isinstance(dimension, DVar):
167
+ if dimension.c in dimension_dict:
168
+ return D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), counter
169
+ else:
170
+ counter += 1
171
+ dimension_dict[dimension.c] = counter
172
+ return D(z3.Int(counter), z3.Int(dimension.c)), counter
173
+
174
+
175
+ def transform_algebraic_expression(expr, counter, dimension_dict):
176
+ """
177
+ Transforms an algebraic expression to z3 format
178
+ Args:
179
+ expr: An expression is either a dimension variable or an algebraic-expression
180
+
181
+
182
+ Returns: the transformed expression
183
+
184
+ """
185
+ assert is_algebraic_expression(expr) or is_dim(expr)
186
+
187
+ if is_dim(expr):
188
+ transformed, counter = transform_dimension(expr, counter, dimension_dict)
189
+ return transformed.arg(1), counter
190
+
191
+ elif isinstance(expr, Prod):
192
+
193
+ dims = []
194
+ for dim in expr.products:
195
+ assert is_dim(dim)
196
+ d, counter = transform_dimension(dim, counter, dimension_dict)
197
+ dims.append(d.arg(1))
198
+ return z3.Product(dims), counter
199
+
200
+ elif is_algebraic_expression(expr):
201
+
202
+ lhs, counter = transform_algebraic_expression(expr.lhs, counter, dimension_dict)
203
+ rhs, counter = transform_algebraic_expression(expr.rhs, counter, dimension_dict)
204
+
205
+ if expr.op == op_sub:
206
+ c = lhs - rhs
207
+
208
+ elif expr.op == op_add:
209
+ c = lhs + rhs
210
+
211
+ elif expr.op == op_div:
212
+ c = lhs / rhs
213
+
214
+ elif expr.op == op_mul:
215
+ c = lhs * rhs
216
+
217
+ elif expr.op == op_mod:
218
+ c = lhs % rhs
219
+
220
+ else:
221
+ raise NotImplementedError('operation not yet implemented')
222
+
223
+ return c, counter
224
+
225
+ else:
226
+ raise RuntimeError
227
+
228
+
229
+ def transform_all_constraints(traced, counter=0):
230
+ """
231
+ Given a trace, generates constraints and transforms them to z3 format
232
+
233
+ """
234
+ dimension_dict = {} # type: ignore[var-annotated]
235
+
236
+ generator = ConstraintGenerator(traced)
237
+ new_constraints, counter = generator.generate_constraints(counter)
238
+
239
+ # print(new_constraints.conjucts[0])
240
+ # print(*new_constraints.conjucts, sep='\n')
241
+
242
+ # transform precision, matching, consistency till obtaining a fixed point
243
+ new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
244
+ # print(new_constraints)
245
+ # print(new_constraints.conjucts)
246
+ # new_constraints.conjucts = new_constraints.conjucts[:-1]
247
+ # print(*new_constraints.conjucts, sep='\n')
248
+
249
+ transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
250
+ # print(transformed)
251
+ return transformed
252
+
253
+ def iterate_till_fixed_point(constraints, counter):
254
+ """
255
+ Transform constraints till reaching a fixed point
256
+ """
257
+ old_c = None
258
+ while old_c != constraints:
259
+ old_c = constraints
260
+ constraints, counter = transform_constraint(constraints, counter)
261
+ return constraints, counter
262
+
263
+ def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0):
264
+ """
265
+ Takes a node and a graph and generates two sets of constraints.
266
+ One set constraints the node's constraints and another set
267
+ constraints the negation of the node's constraints
268
+ Args:
269
+ tracer_root: the root for getting the module instances
270
+ graph: the graph so far in the tracing process
271
+ node: node that represents a conditional
272
+ counter: variable tracking
273
+
274
+ Returns: Two sets of constraints. One with a conjunction with the
275
+ the conditional constraint and the other with a conjunction with
276
+ its negation.
277
+
278
+ """
279
+ dimension_dict = {} # type: ignore[var-annotated]
280
+
281
+ generator = ConstraintGenerator(tracer_root, graph)
282
+ new_constraints, counter = generator.generate_constraints(counter)
283
+
284
+ condition_constraint = new_constraints.conjucts[-1]
285
+
286
+ # we know the constraint is a conjunction where the last constraint is about the conditional
287
+ # so remove the last constraint
288
+ new_constraints.conjucts = new_constraints.conjucts[:-1]
289
+
290
+ # transform precision, matching, consistency till obtaining a fixed point
291
+ new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
292
+
293
+
294
+ # since the function returns a list of one element, we get the first element
295
+ # we are only interested in the RHS in this case because the LHS just stores
296
+ # the result
297
+
298
+ # we make sure the constraint is of the form:
299
+ # c = b where b is a boolean expression
300
+ # and we consider b (constraint.rhs) for transformation
301
+ assert isinstance(condition_constraint.lhs, BVar)
302
+ assert is_bool_expr(condition_constraint.rhs)
303
+ condition_constraint_rhs = condition_constraint.rhs
304
+
305
+ # transform the condition constraint
306
+ condition_constraint_rhs, counter = iterate_till_fixed_point(condition_constraint_rhs, counter)
307
+
308
+ transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
309
+
310
+ transformed_condition_constraint, counter = transform_to_z3(condition_constraint_rhs, counter, dimension_dict)
311
+
312
+ negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint)
313
+
314
+ return z3.And([transformed, transformed_condition_constraint]), \
315
+ z3.And([transformed, negation_transformed_condition_constraint])
316
+
317
+
318
+ def evaluate_conditional_with_constraints(tracer_root, graph, node, counter=0, user_constraints=None):
319
+ """
320
+ Given an IR and a node representing a conditional, evaluate the conditional
321
+ and its negation
322
+ Args:
323
+ tracer_root: Tracer root for module instances
324
+ node: The node to be evaluated
325
+
326
+ Returns: the results of evaluating the condition and the negation with
327
+ the rest of the constraints
328
+
329
+ """
330
+
331
+ transformed_positive, transformed_negative = \
332
+ transform_all_constraints_trace_time(tracer_root, graph, node, counter)
333
+
334
+ s = z3.Solver()
335
+ s.add(transformed_positive)
336
+ if user_constraints is not None:
337
+ s.add(user_constraints)
338
+ condition = s.check()
339
+
340
+ s = z3.Solver()
341
+ s.add(transformed_negative)
342
+ if user_constraints is not None:
343
+ s.add(user_constraints)
344
+ negation = s.check()
345
+ return condition, negation
346
+
347
+ except ImportError:
348
+ HAS_Z3 = False
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-311.pyc ADDED
Binary file (4.21 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-311.pyc ADDED
Binary file (14.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-311.pyc ADDED
Binary file (5.25 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (492 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ __all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"]
4
+
5
+ def raises(err, lamda):
6
+ try:
7
+ lamda()
8
+ return False
9
+ except err:
10
+ return True
11
+
12
+
13
+ def expand_tuples(L):
14
+ """
15
+ >>> expand_tuples([1, (2, 3)])
16
+ [(1, 2), (1, 3)]
17
+ >>> expand_tuples([1, 2])
18
+ [(1, 2)]
19
+ """
20
+ if not L:
21
+ return [()]
22
+ elif not isinstance(L[0], tuple):
23
+ rest = expand_tuples(L[1:])
24
+ return [(L[0],) + t for t in rest]
25
+ else:
26
+ rest = expand_tuples(L[1:])
27
+ return [(item,) + t for t in rest for item in L[0]]
28
+
29
+
30
+ # Taken from theano/theano/gof/sched.py
31
+ # Avoids licensing issues because this was written by Matthew Rocklin
32
+ def _toposort(edges):
33
+ """ Topological sort algorithm by Kahn [1] - O(nodes + vertices)
34
+ inputs:
35
+ edges - a dict of the form {a: {b, c}} where b and c depend on a
36
+ outputs:
37
+ L - an ordered list of nodes that satisfy the dependencies of edges
38
+ >>> _toposort({1: (2, 3), 2: (3, )})
39
+ [1, 2, 3]
40
+ >>> # Closely follows the wikipedia page [2]
41
+ >>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
42
+ >>> # Communications of the ACM
43
+ >>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
44
+ """
45
+ incoming_edges = reverse_dict(edges)
46
+ incoming_edges = OrderedDict((k, set(val))
47
+ for k, val in incoming_edges.items())
48
+ S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
49
+ L = []
50
+
51
+ while S:
52
+ n, _ = S.popitem()
53
+ L.append(n)
54
+ for m in edges.get(n, ()):
55
+ assert n in incoming_edges[m]
56
+ incoming_edges[m].remove(n)
57
+ if not incoming_edges[m]:
58
+ S[m] = None
59
+ if any(incoming_edges.get(v, None) for v in edges):
60
+ raise ValueError("Input has cycles")
61
+ return L
62
+
63
+
64
+ def reverse_dict(d):
65
+ """Reverses direction of dependence dict
66
+ >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
67
+ >>> reverse_dict(d) # doctest: +SKIP
68
+ {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
69
+ :note: dict order are not deterministic. As we iterate on the
70
+ input dict, it make the output of this function depend on the
71
+ dict order. So this function output order should be considered
72
+ as undeterministic.
73
+ """
74
+ result = OrderedDict() # type: ignore[var-annotated]
75
+ for key in d:
76
+ for val in d[key]:
77
+ result[val] = result.get(val, tuple()) + (key, )
78
+ return result
79
+
80
+
81
+ # Taken from toolz
82
+ # Avoids licensing issues because this version was authored by Matthew Rocklin
83
+ def groupby(func, seq):
84
+ """ Group a collection by a key function
85
+ >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
86
+ >>> groupby(len, names) # doctest: +SKIP
87
+ {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
88
+ >>> iseven = lambda x: x % 2 == 0
89
+ >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
90
+ {False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
91
+ See Also:
92
+ ``countby``
93
+ """
94
+
95
+ d = OrderedDict() # type: ignore[var-annotated]
96
+ for item in seq:
97
+ key = func(item)
98
+ if key not in d:
99
+ d[key] = list()
100
+ d[key].append(item)
101
+ return d
102
+
103
+
104
+ def typename(type):
105
+ """Get the name of `type`.
106
+ Parameters
107
+ ----------
108
+ type : Union[Type, Tuple[Type]]
109
+ Returns
110
+ -------
111
+ str
112
+ The name of `type` or a tuple of the names of the types in `type`.
113
+ Examples
114
+ --------
115
+ >>> typename(int)
116
+ 'int'
117
+ >>> typename((int, float))
118
+ '(int, float)'
119
+ """
120
+ try:
121
+ return type.__name__
122
+ except AttributeError:
123
+ if len(type) == 1:
124
+ return typename(*type)
125
+ return f"({', '.join(map(typename, type))})"
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-311.pyc ADDED
Binary file (10.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-311.pyc ADDED
Binary file (30.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/annotate_getitem_nodes.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+
3
+ import torch
4
+
5
+
6
+ def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
7
+ """
8
+ Annotate the type of getitem nodes, inferred from the type of sequence node.
9
+ If sequence node is not annotated with a type, do nothing.
10
+ Currently support getitem nodes from Tuple, List, and NamedTuple sequence node.
11
+
12
+ This is helpful since annotations on local names within function are lost during FX transforms.
13
+ Adding back known type annotation for getitem nodes to improve jit scriptability.
14
+
15
+ Args:
16
+ graph (Graph): The graph to be annotated
17
+ """
18
+ for node in graph.nodes:
19
+ if node.target == operator.getitem:
20
+ sequence_node, index_node = node.args
21
+ if not sequence_node.type:
22
+ continue
23
+ # container types
24
+ if hasattr(sequence_node.type, "_name"):
25
+ parameterized_types = sequence_node.type.__args__
26
+ if sequence_node.type._name == "Tuple":
27
+ if len(parameterized_types) == 2 and isinstance(
28
+ parameterized_types[1], type(...)
29
+ ):
30
+ node.type = parameterized_types[0]
31
+ else:
32
+ assert len(parameterized_types) > index_node
33
+ node_type = parameterized_types[index_node]
34
+ node.type = node_type
35
+ elif sequence_node.type._name == "List":
36
+ assert len(parameterized_types) == 1
37
+ node.type = parameterized_types[0]
38
+ # NamedTuple type
39
+ elif hasattr(sequence_node.type, "__annotations__"):
40
+ if sequence_node.type == torch.Tensor:
41
+ continue
42
+ sequence_node_field_types = sequence_node.type.__annotations__
43
+ field_name = sequence_node.type._fields[index_node]
44
+ node.type = sequence_node_field_types[field_name]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (224 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/cse_pass.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Tuple, Any
2
+
3
+ import torch
4
+ from torch.fx.passes.infra.pass_base import PassBase, PassResult
5
+ from torch.utils._pytree import tree_flatten
6
+
7
+ from torch.fx import GraphModule, Graph
8
+ from torch.fx import Node
9
+
10
+ aten = torch.ops.aten
11
+
12
+
13
+ # stateful ops are banned from CSE
14
+ rand_ops = {aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm} # noqa: E501,B950
15
+
16
+ inplace_ops = {aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_} # noqa: E501
17
+
18
+
19
+ @torch.fx._compatibility.compatibility(is_backward_compatible=False)
20
+ def get_CSE_banned_ops():
21
+ return rand_ops.union(inplace_ops)
22
+
23
+
24
+ @torch.fx._compatibility.compatibility(is_backward_compatible=False)
25
+ class CSEPass(PassBase):
26
+
27
+ def __init__(self, banned_ops=None):
28
+ """
29
+ This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node.
30
+
31
+ For functional dialects, user would only need to specify the random ops in ban list.
32
+
33
+ Warning: CSE Pass cannot be safely applied on a FX graph in non-functional dialects.
34
+ If your dialect contains stateful operators, please customized the banned_ops.
35
+
36
+ """
37
+ if banned_ops is None:
38
+ banned_ops = set()
39
+ self.banned_ops = banned_ops
40
+ super().__init__()
41
+
42
+ def call(self, graph_module: GraphModule) -> PassResult:
43
+ """
44
+ Return a new copy of torch.fx.GraphModule with CSE applied to the input graph
45
+
46
+ Example usage:
47
+
48
+ from torch.fx.experimental.proxy_tensor import make_fx
49
+ def f(a):
50
+ b = a * a
51
+ c = a * a
52
+ return b+c
53
+
54
+ p = CSEPass()
55
+ traced_graph = make_fx(f)(torch.tensor(1))
56
+ print(traced_graph)
57
+ result = p(traced_graph)
58
+ print(result.graph_module)
59
+ """
60
+ def get_aten_target(node):
61
+ if hasattr(node.target, 'overloadpacket'):
62
+ return node.target.overloadpacket
63
+ return node.target
64
+
65
+ modified = False
66
+ new_graph = Graph()
67
+ env: Dict[Node, Node] = {} # map from node in the old graph to node in the new graph
68
+ hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {} # map from hash to a node in the new graph
69
+ token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {} # map from hash to token
70
+ for n in graph_module.graph.nodes:
71
+ # The placeholder, output, and get_attr nodes are copied to the new graph without change
72
+ # do not CSE away random operations
73
+ if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in self.banned_ops:
74
+ new_node = new_graph.node_copy(n, lambda x: env[x])
75
+ env[n] = new_node
76
+ else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
77
+ # substitute args and kwargs members to their mapping in env if exists
78
+ # specs can be used to reconstruct nested list/dictionaries
79
+ def substitute(arg_list):
80
+ arg_list, spec = tree_flatten(arg_list)
81
+ for i in range(len(arg_list)):
82
+ v = arg_list[i]
83
+ if isinstance(v, Node) and v in env:
84
+ arg_list[i] = env[v]
85
+ return tuple(arg_list), spec
86
+ args, args_spec = substitute(n.args)
87
+ kwargs, kwargs_spec = substitute(n.kwargs)
88
+
89
+ # each token corresponds to a unique node
90
+ # nodes with the same token can be substituted
91
+ token = {"target": n.target, "args": args, "args_spec": args_spec,
92
+ "kwargs": kwargs, "kwargs_spec": kwargs_spec}
93
+
94
+ # hash substituted args to a number, do not hash specs because specs are not hashable
95
+ hash_arg = hash((args, kwargs))
96
+ hash_val = (n.target, hash_arg)
97
+
98
+ # check if a node has a substitute and can be eliminated
99
+ hash_val_in_hash_env = hash_val in hash_env
100
+ if hash_val_in_hash_env and token_map[hash_val] == token:
101
+ modified = True # substitution happens and the graph is modified
102
+ env[n] = hash_env[hash_val]
103
+ continue
104
+
105
+ new_node = new_graph.node_copy(n, lambda x: env[x])
106
+ env[n] = new_node
107
+ if not hash_val_in_hash_env:
108
+ hash_env[hash_val] = new_node
109
+ token_map[hash_val] = token
110
+
111
+ csed_gm = GraphModule(graph_module, new_graph)
112
+ return PassResult(csed_gm, modified)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/fake_tensor_prop.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch.fx
4
+ from torch.fx import Node
5
+ from torch.fx._compatibility import compatibility
6
+ from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
7
+ from torch.fx.experimental.proxy_tensor import py_sym_types, snapshot_fake
8
+ from torch.fx.node import map_aggregate
9
+
10
+ __all__ = ['FakeTensorProp']
11
+
12
+ @compatibility(is_backward_compatible=False)
13
+ class FakeTensorProp(torch.fx.Interpreter):
14
+ """
15
+ Execute an FX graph Node-by-Node and record a fake tensor representing
16
+ the metadata for the node. Unlike ShapeProp, (1) this propagation
17
+ is cheap--it does the propagation with meta tensors which do not actually
18
+ store data, and (2) the fake tensors have much more fine grained information,
19
+ e.g., they have accurate alias information that can be consulted by looking
20
+ at the storages.
21
+
22
+ Args:
23
+ module (GraphModule): The module to be executed
24
+ mode (Optional[FakeTensorMode]): The dispatch mode used to execute computation indicated by each FX Node.
25
+ """
26
+ def __init__(self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None):
27
+ super().__init__(module)
28
+ if mode is None:
29
+ mode = FakeTensorMode()
30
+ self._mode = mode
31
+
32
+ def run_node(self, n: Node):
33
+ import sympy
34
+ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
35
+
36
+ result = super().run_node(n)
37
+ sym = None
38
+ if (
39
+ 'val' in n.meta and
40
+ isinstance(v := n.meta['val'], torch.SymInt) and
41
+ isinstance(v.node.expr, sympy.Symbol) and free_unbacked_symbols(v)
42
+ ):
43
+ sym = v
44
+
45
+ def extract_val(obj):
46
+ if isinstance(obj, FakeTensor):
47
+ return snapshot_fake(obj)
48
+ elif isinstance(obj, torch.Tensor):
49
+ # TODO: How is it possible that we get a non fake tensor? We
50
+ # should be running under the mode...
51
+ return snapshot_fake(self._mode.from_tensor(obj, static_shapes=True))
52
+ elif isinstance(obj, py_sym_types):
53
+ return obj
54
+ else:
55
+ return None
56
+
57
+ meta = map_aggregate(result, extract_val)
58
+ if meta is not None:
59
+ n.meta['val'] = meta
60
+ if sym is not None:
61
+ torch._check(meta == v)
62
+ return result
63
+
64
+ def propagate(self, *args):
65
+ fake_args = [
66
+ self._mode.from_tensor(a) if isinstance(a, torch.Tensor) else a
67
+ for a in args
68
+ ]
69
+ return self.propagate_dont_convert_inputs(*fake_args)
70
+
71
+ def propagate_dont_convert_inputs(self, *args):
72
+ with self._mode:
73
+ return super().run(*args)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_drawer.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import hashlib
3
+ import torch
4
+ import torch.fx
5
+ from typing import Any, Dict, Optional, TYPE_CHECKING
6
+ from torch.fx.node import _get_qualified_name, _format_arg
7
+ from torch.fx.graph import _parse_stack_trace
8
+ from torch.fx.passes.shape_prop import TensorMetadata
9
+ from torch.fx._compatibility import compatibility
10
+ from itertools import chain
11
+
12
+ __all__ = ['FxGraphDrawer']
13
+ try:
14
+ import pydot
15
+ HAS_PYDOT = True
16
+ except ImportError:
17
+ HAS_PYDOT = False
18
+
19
+ _COLOR_MAP = {
20
+ "placeholder": '"AliceBlue"',
21
+ "call_module": "LemonChiffon1",
22
+ "get_param": "Yellow2",
23
+ "get_attr": "LightGrey",
24
+ "output": "PowderBlue",
25
+ }
26
+
27
+ _HASH_COLOR_MAP = [
28
+ "CadetBlue1",
29
+ "Coral",
30
+ "DarkOliveGreen1",
31
+ "DarkSeaGreen1",
32
+ "GhostWhite",
33
+ "Khaki1",
34
+ "LavenderBlush1",
35
+ "LightSkyBlue",
36
+ "MistyRose1",
37
+ "MistyRose2",
38
+ "PaleTurquoise2",
39
+ "PeachPuff1",
40
+ "Salmon",
41
+ "Thistle1",
42
+ "Thistle3",
43
+ "Wheat1",
44
+ ]
45
+
46
+ _WEIGHT_TEMPLATE = {
47
+ "fillcolor": "Salmon",
48
+ "style": '"filled,rounded"',
49
+ "fontcolor": "#000000",
50
+ }
51
+
52
+ if HAS_PYDOT:
53
+ @compatibility(is_backward_compatible=False)
54
+ class FxGraphDrawer:
55
+ """
56
+ Visualize a torch.fx.Graph with graphviz
57
+ Basic usage:
58
+ g = FxGraphDrawer(symbolic_traced, "resnet18")
59
+ g.get_dot_graph().write_svg("a.svg")
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ graph_module: torch.fx.GraphModule,
65
+ name: str,
66
+ ignore_getattr: bool = False,
67
+ ignore_parameters_and_buffers: bool = False,
68
+ skip_node_names_in_args: bool = True,
69
+ parse_stack_trace: bool = False,
70
+ dot_graph_shape: Optional[str] = None,
71
+ ):
72
+ self._name = name
73
+ self.dot_graph_shape = (
74
+ dot_graph_shape if dot_graph_shape is not None else "record"
75
+ )
76
+ _WEIGHT_TEMPLATE["shape"] = self.dot_graph_shape
77
+
78
+ self._dot_graphs = {
79
+ name: self._to_dot(
80
+ graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args, parse_stack_trace
81
+ )
82
+ }
83
+
84
+ for node in graph_module.graph.nodes:
85
+ if node.op != "call_module":
86
+ continue
87
+
88
+ leaf_node = self._get_leaf_node(graph_module, node)
89
+
90
+ if not isinstance(leaf_node, torch.fx.GraphModule):
91
+ continue
92
+
93
+
94
+ self._dot_graphs[f"{name}_{node.target}"] = self._to_dot(
95
+ leaf_node,
96
+ f"{name}_{node.target}",
97
+ ignore_getattr,
98
+ ignore_parameters_and_buffers,
99
+ skip_node_names_in_args,
100
+ parse_stack_trace,
101
+ )
102
+
103
+ def get_dot_graph(self, submod_name=None) -> pydot.Dot:
104
+ """
105
+ Visualize a torch.fx.Graph with graphviz
106
+ Example:
107
+ >>> # xdoctest: +REQUIRES(module:pydot)
108
+ >>> # define module
109
+ >>> class MyModule(torch.nn.Module):
110
+ >>> def __init__(self):
111
+ >>> super().__init__()
112
+ >>> self.linear = torch.nn.Linear(4, 5)
113
+ >>> def forward(self, x):
114
+ >>> return self.linear(x).clamp(min=0.0, max=1.0)
115
+ >>> module = MyModule()
116
+ >>> # trace the module
117
+ >>> symbolic_traced = torch.fx.symbolic_trace(module)
118
+ >>> # setup output file
119
+ >>> import ubelt as ub
120
+ >>> dpath = ub.Path.appdir('torch/tests/FxGraphDrawer').ensuredir()
121
+ >>> fpath = dpath / 'linear.svg'
122
+ >>> # draw the graph
123
+ >>> g = FxGraphDrawer(symbolic_traced, "linear")
124
+ >>> g.get_dot_graph().write_svg(fpath)
125
+ """
126
+ if submod_name is None:
127
+ return self.get_main_dot_graph()
128
+ else:
129
+ return self.get_submod_dot_graph(submod_name)
130
+
131
+ def get_main_dot_graph(self) -> pydot.Dot:
132
+ return self._dot_graphs[self._name]
133
+
134
+ def get_submod_dot_graph(self, submod_name) -> pydot.Dot:
135
+ return self._dot_graphs[f"{self._name}_{submod_name}"]
136
+
137
+ def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]:
138
+ return self._dot_graphs
139
+
140
+ def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]:
141
+
142
+ template = {
143
+ "shape": self.dot_graph_shape,
144
+ "fillcolor": "#CAFFE3",
145
+ "style": '"filled,rounded"',
146
+ "fontcolor": "#000000",
147
+ }
148
+ if node.op in _COLOR_MAP:
149
+ template["fillcolor"] = _COLOR_MAP[node.op]
150
+ else:
151
+ # Use a random color for each node; based on its name so it's stable.
152
+ target_name = node._pretty_print_target(node.target)
153
+ target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16)
154
+ template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)]
155
+ return template
156
+
157
+ def _get_leaf_node(
158
+ self, module: torch.nn.Module, node: torch.fx.Node
159
+ ) -> torch.nn.Module:
160
+ py_obj = module
161
+ assert isinstance(node.target, str)
162
+ atoms = node.target.split(".")
163
+ for atom in atoms:
164
+ if not hasattr(py_obj, atom):
165
+ raise RuntimeError(
166
+ str(py_obj) + " does not have attribute " + atom + "!"
167
+ )
168
+ py_obj = getattr(py_obj, atom)
169
+ return py_obj
170
+
171
+ def _typename(self, target: Any) -> str:
172
+ if isinstance(target, torch.nn.Module):
173
+ ret = torch.typename(target)
174
+ elif isinstance(target, str):
175
+ ret = target
176
+ else:
177
+ ret = _get_qualified_name(target)
178
+
179
+ # Escape "{" and "}" to prevent dot files like:
180
+ # https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc
181
+ # which triggers `Error: bad label format (...)` from dot
182
+ return ret.replace("{", r"\{").replace("}", r"\}")
183
+
184
+ # shorten path to avoid drawing long boxes
185
+ # for full path = '/home/weif/pytorch/test.py'
186
+ # return short path = 'pytorch/test.py'
187
+ def _shorten_file_name(
188
+ self,
189
+ full_file_name: str,
190
+ truncate_to_last_n: int = 2,
191
+ ):
192
+ splits = full_file_name.split('/')
193
+ if len(splits) >= truncate_to_last_n:
194
+ return '/'.join(splits[-truncate_to_last_n:])
195
+ return full_file_name
196
+
197
+
198
+ def _get_node_label(
199
+ self,
200
+ module: torch.fx.GraphModule,
201
+ node: torch.fx.Node,
202
+ skip_node_names_in_args: bool,
203
+ parse_stack_trace: bool,
204
+ ) -> str:
205
+ def _get_str_for_args_kwargs(arg):
206
+ if isinstance(arg, tuple):
207
+ prefix, suffix = r"|args=(\l", r",\n)\l"
208
+ arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg]
209
+ elif isinstance(arg, dict):
210
+ prefix, suffix = r"|kwargs={\l", r",\n}\l"
211
+ arg_strs_list = [
212
+ f"{k}: {_format_arg(v, max_list_len=8)}"
213
+ for k, v in arg.items()
214
+ ]
215
+ else: # Fall back to nothing in unexpected case.
216
+ return ""
217
+
218
+ # Strip out node names if requested.
219
+ if skip_node_names_in_args:
220
+ arg_strs_list = [a for a in arg_strs_list if "%" not in a]
221
+ if len(arg_strs_list) == 0:
222
+ return ""
223
+ arg_strs = prefix + r",\n".join(arg_strs_list) + suffix
224
+ if len(arg_strs_list) == 1:
225
+ arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "")
226
+ return arg_strs.replace("{", r"\{").replace("}", r"\}")
227
+
228
+
229
+ label = "{" + f"name=%{node.name}|op_code={node.op}\n"
230
+
231
+ if node.op == "call_module":
232
+ leaf_module = self._get_leaf_node(module, node)
233
+ label += r"\n" + self._typename(leaf_module) + r"\n|"
234
+ extra = ""
235
+ if hasattr(leaf_module, "__constants__"):
236
+ extra = r"\n".join(
237
+ [f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__] # type: ignore[union-attr]
238
+ )
239
+ label += extra + r"\n"
240
+ else:
241
+ label += f"|target={self._typename(node.target)}" + r"\n"
242
+ if len(node.args) > 0:
243
+ label += _get_str_for_args_kwargs(node.args)
244
+ if len(node.kwargs) > 0:
245
+ label += _get_str_for_args_kwargs(node.kwargs)
246
+ label += f"|num_users={len(node.users)}" + r"\n"
247
+
248
+ tensor_meta = node.meta.get('tensor_meta')
249
+ label += self._tensor_meta_to_label(tensor_meta)
250
+
251
+ # for original fx graph
252
+ # print buf=buf0, n_origin=6
253
+ buf_meta = node.meta.get('buf_meta', None)
254
+ if buf_meta is not None:
255
+ label += f"|buf={buf_meta.name}" + r"\n"
256
+ label += f"|n_origin={buf_meta.n_origin}" + r"\n"
257
+
258
+ # for original fx graph
259
+ # print file:lineno code
260
+ if parse_stack_trace and node.stack_trace is not None:
261
+ parsed_stack_trace = _parse_stack_trace(node.stack_trace)
262
+ fname = self._shorten_file_name(parsed_stack_trace.file)
263
+ label += f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + r"\n"
264
+
265
+
266
+ return label + "}"
267
+
268
+ def _tensor_meta_to_label(self, tm) -> str:
269
+ if tm is None:
270
+ return ""
271
+ elif isinstance(tm, TensorMetadata):
272
+ return self._stringify_tensor_meta(tm)
273
+ elif isinstance(tm, list):
274
+ result = ""
275
+ for item in tm:
276
+ result += self._tensor_meta_to_label(item)
277
+ return result
278
+ elif isinstance(tm, dict):
279
+ result = ""
280
+ for v in tm.values():
281
+ result += self._tensor_meta_to_label(v)
282
+ return result
283
+ elif isinstance(tm, tuple):
284
+ result = ""
285
+ for item in tm:
286
+ result += self._tensor_meta_to_label(item)
287
+ return result
288
+ else:
289
+ raise RuntimeError(f"Unsupported tensor meta type {type(tm)}")
290
+
291
+ def _stringify_tensor_meta(self, tm: TensorMetadata) -> str:
292
+ result = ""
293
+ if not hasattr(tm, "dtype"):
294
+ print("tm", tm)
295
+ result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n"
296
+ result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n"
297
+ result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n"
298
+ result += "|" + "stride" + "=" + str(tm.stride) + r"\n"
299
+ if tm.is_quantized:
300
+ assert tm.qparams is not None
301
+ assert "qscheme" in tm.qparams
302
+ qscheme = tm.qparams["qscheme"]
303
+ if qscheme in {
304
+ torch.per_tensor_affine,
305
+ torch.per_tensor_symmetric,
306
+ }:
307
+ result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
308
+ result += "|" + "q_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n"
309
+ elif qscheme in {
310
+ torch.per_channel_affine,
311
+ torch.per_channel_symmetric,
312
+ torch.per_channel_affine_float_qparams,
313
+ }:
314
+ result += "|" + "q_per_channel_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
315
+ result += "|" + "q_per_channel_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n"
316
+ result += "|" + "q_per_channel_axis" + "=" + str(tm.qparams["axis"]) + r"\n"
317
+ else:
318
+ raise RuntimeError(f"Unsupported qscheme: {qscheme}")
319
+ result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n"
320
+ return result
321
+
322
+ def _get_tensor_label(self, t: torch.Tensor) -> str:
323
+ return str(t.dtype) + str(list(t.shape)) + r"\n"
324
+
325
+ # when parse_stack_trace=True
326
+ # print file:lineno code
327
+ def _to_dot(
328
+ self,
329
+ graph_module: torch.fx.GraphModule,
330
+ name: str,
331
+ ignore_getattr: bool,
332
+ ignore_parameters_and_buffers: bool,
333
+ skip_node_names_in_args: bool,
334
+ parse_stack_trace: bool,
335
+ ) -> pydot.Dot:
336
+ """
337
+ Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph.
338
+ If ignore_parameters_and_buffers is True, the parameters and buffers
339
+ created with the module will not be added as nodes and edges.
340
+ """
341
+
342
+ # "TB" means top-to-bottom rank direction in layout
343
+ dot_graph = pydot.Dot(name, rankdir="TB")
344
+
345
+
346
+ buf_name_to_subgraph = {}
347
+
348
+ for node in graph_module.graph.nodes:
349
+ if ignore_getattr and node.op == "get_attr":
350
+ continue
351
+
352
+ style = self._get_node_style(node)
353
+ dot_node = pydot.Node(
354
+ node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args, parse_stack_trace), **style
355
+ )
356
+
357
+ current_graph = dot_graph
358
+
359
+ buf_meta = node.meta.get('buf_meta', None)
360
+ if buf_meta is not None and buf_meta.n_origin > 1:
361
+ buf_name = buf_meta.name
362
+ if buf_name not in buf_name_to_subgraph:
363
+ buf_name_to_subgraph[buf_name] = pydot.Cluster(buf_name, label=buf_name)
364
+ current_graph = buf_name_to_subgraph.get(buf_name)
365
+
366
+ current_graph.add_node(dot_node)
367
+
368
+ def get_module_params_or_buffers():
369
+ for pname, ptensor in chain(
370
+ leaf_module.named_parameters(), leaf_module.named_buffers()
371
+ ):
372
+ pname1 = node.name + "." + pname
373
+ label1 = (
374
+ pname1 + "|op_code=get_" + "parameter"
375
+ if isinstance(ptensor, torch.nn.Parameter)
376
+ else "buffer" + r"\l"
377
+ )
378
+ dot_w_node = pydot.Node(
379
+ pname1,
380
+ label="{" + label1 + self._get_tensor_label(ptensor) + "}",
381
+ **_WEIGHT_TEMPLATE,
382
+ )
383
+ dot_graph.add_node(dot_w_node)
384
+ dot_graph.add_edge(pydot.Edge(pname1, node.name))
385
+
386
+ if node.op == "call_module":
387
+ leaf_module = self._get_leaf_node(graph_module, node)
388
+
389
+ if not ignore_parameters_and_buffers and not isinstance(leaf_module, torch.fx.GraphModule):
390
+ get_module_params_or_buffers()
391
+
392
+ for subgraph in buf_name_to_subgraph.values():
393
+ subgraph.set('color', 'royalblue')
394
+ subgraph.set('penwidth', '2')
395
+ dot_graph.add_subgraph(subgraph)
396
+
397
+ for node in graph_module.graph.nodes:
398
+ if ignore_getattr and node.op == "get_attr":
399
+ continue
400
+
401
+ for user in node.users:
402
+ dot_graph.add_edge(pydot.Edge(node.name, user.name))
403
+
404
+ return dot_graph
405
+
406
+ else:
407
+ if not TYPE_CHECKING:
408
+ @compatibility(is_backward_compatible=False)
409
+ class FxGraphDrawer:
410
+ def __init__(
411
+ self,
412
+ graph_module: torch.fx.GraphModule,
413
+ name: str,
414
+ ignore_getattr: bool = False,
415
+ ignore_parameters_and_buffers: bool = False,
416
+ skip_node_names_in_args: bool = True,
417
+ parse_stack_trace: bool = False,
418
+ dot_graph_shape: Optional[str] = None,
419
+ ):
420
+ raise RuntimeError('FXGraphDrawer requires the pydot package to be installed. Please install '
421
+ 'pydot through your favorite Python package manager.')
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (273 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/partitioner.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
2
+ import collections
3
+ import itertools
4
+ import logging
5
+
6
+ from copy import copy
7
+ from typing import Dict, Iterable, List, Optional, Sequence, Set
8
+
9
+ from torch.fx.graph_module import GraphModule
10
+ from torch.fx.node import Node, _get_qualified_name
11
+ from torch.fx.passes.operator_support import OperatorSupportBase
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+ logger.setLevel(logging.WARNING)
16
+
17
+ class Partition:
18
+ def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None):
19
+ self.id = id
20
+ self.nodes: Set[Node] = set(nodes) if nodes is not None else set()
21
+
22
+ def __repr__(self) -> str:
23
+ return str(self.nodes)
24
+
25
+ def add_node(self, node: Node):
26
+ self.nodes.add(node)
27
+
28
+ def remove_node(self, node: Node):
29
+ self.nodes.remove(node)
30
+
31
+ def size(self):
32
+ return len(self.nodes)
33
+
34
+ class _DependencyViewer:
35
+ def __init__(self, graph_module: GraphModule):
36
+ self.upstreams = collections.defaultdict(set)
37
+ self.downstreams = collections.defaultdict(set)
38
+
39
+ for node in graph_module.graph.nodes:
40
+ for input_node in node.all_input_nodes:
41
+ # add input_node and input_node's upstream dependency
42
+ self.upstreams[node].add(input_node)
43
+ self.upstreams[node].update(self.upstreams[input_node])
44
+
45
+ for node in reversed(graph_module.graph.nodes):
46
+ for output_node in node.users:
47
+ # add output_node and output_node's downstream dependency
48
+ self.downstreams[node].add(output_node)
49
+ self.downstreams[node].update(self.downstreams[output_node])
50
+
51
+ def downstreams_of(self, node: Node) -> Set[Node]:
52
+ return self.downstreams[node]
53
+
54
+ def upstreams_of(self, node: Node) -> Set[Node]:
55
+ return self.upstreams[node]
56
+
57
+ class CapabilityBasedPartitioner:
58
+
59
+ def __init__(self,
60
+ graph_module: GraphModule,
61
+ operator_support: OperatorSupportBase,
62
+ allows_single_node_partition: bool = False,
63
+ non_compute_ops: Optional[Sequence[str]] = None,
64
+ allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
65
+ ) -> None:
66
+ self.graph_module = graph_module
67
+ self.operator_support = operator_support
68
+ self.allows_single_node_partition = allows_single_node_partition
69
+ self.non_compute_ops = non_compute_ops if non_compute_ops is not None else []
70
+ self.allowed_single_node_partition_ops = (
71
+ allowed_single_node_partition_ops
72
+ if allowed_single_node_partition_ops is not None
73
+ else []
74
+ )
75
+ self.dependency_viewer = _DependencyViewer(graph_module)
76
+
77
+ def __is_node_supported(self, node: Node) -> bool:
78
+ return (
79
+ self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node)
80
+ )
81
+
82
+ def propose_partitions(self) -> List[Partition]:
83
+ # partition_map is a mapping from partition id to a set of partition id's.
84
+ # The value set contains all the partition ids that can be reached by doing a
85
+ # DFS starting from the partition id in the key.
86
+ partition_map : Dict[int, Set] = collections.defaultdict(set)
87
+
88
+ # assumptions: nodes in candidate list is sorted in topological order
89
+ assignment: Dict[Node, int] = {} # mapping from node to partition_id
90
+ partitions_by_id: Dict[int, Partition] = {} # mapping from partition_id to partition
91
+ new_partition_id = itertools.count()
92
+
93
+ # try to merge partition other_id into partition self_id
94
+ # merge only happens if the end graph doesn't contain cyclic dependency
95
+ # returns `True` when merge happens, `False` otherwise.
96
+ def maybe_merge_partition(self_id: int, other_id: int):
97
+ # merged_nodes is the union of nodes in two partition to-be-merged
98
+ merged_nodes = copy(partitions_by_id[self_id].nodes)
99
+ merged_nodes.update(partitions_by_id[other_id].nodes)
100
+
101
+ def dfs_iter_find_cycle(all_user_nodes: List[Node]):
102
+ for user_node in all_user_nodes:
103
+ visited_partition_ids = set()
104
+
105
+ for path_node in self.dependency_viewer.downstreams_of(user_node):
106
+ # If any of the nodes in the dfs path of this node are in the merged_nodes
107
+ # list then there is a cycle in the graph.
108
+ if path_node in merged_nodes:
109
+ return True
110
+
111
+ # If any of the nodes in the dfs path of this node are in the assignment
112
+ # map then we have to make sure that the partitions that these nodes belong
113
+ # to do not form a cycle with the current partitions being merged. This means
114
+ # iterating through all the nodes in all the parititons that are traversed in
115
+ # the dfs path and checking if they are in the merged_nodes list.
116
+ if path_node in assignment:
117
+ partition_id = assignment[path_node]
118
+ # If the partition id has already been visited then we know that it doesn't
119
+ # form a cycle with the current partitions being merged.
120
+ if partition_id in visited_partition_ids:
121
+ continue
122
+ p_map = partition_map[partition_id]
123
+ if self_id in p_map or other_id in p_map:
124
+ return True
125
+
126
+ visited_partition_ids.add(partition_id)
127
+
128
+ return False
129
+
130
+ # check if merge would create cyclic dependency.
131
+ all_user_nodes = []
132
+ for node in merged_nodes:
133
+ for user_node in node.users:
134
+ if user_node not in merged_nodes:
135
+ all_user_nodes.append(user_node)
136
+
137
+ if dfs_iter_find_cycle(all_user_nodes):
138
+ # return false indicating cyclic dependency found and
139
+ # merge is aborted
140
+ return False
141
+
142
+ # no cyclic dependency found, move forward with the merge
143
+ # updating partition nodes
144
+ partitions_by_id[self_id].nodes = merged_nodes
145
+ # updating assignment map
146
+ for node in partitions_by_id[other_id].nodes:
147
+ assignment[node] = self_id
148
+ # delete other partition
149
+ del partitions_by_id[other_id]
150
+
151
+ partition_map[self_id] = partition_map[self_id].union(partition_map[other_id])
152
+ del partition_map[other_id]
153
+
154
+ return True
155
+
156
+ def merge_single_node(node: Node, id: Optional[int]):
157
+ def _update_partition_map(node: Node, id: int):
158
+ # Iterate through all the downstream nodes of this node and update the partition map
159
+ # to indicate that there is a path from the partition id of this node to the target
160
+ # partition id.
161
+ downstream_nodes = self.dependency_viewer.downstreams_of(node)
162
+ for curr_node in downstream_nodes:
163
+ target_id = assignment.get(curr_node, None)
164
+ if target_id is not None:
165
+ partition_map[id].add(target_id)
166
+
167
+ # Iterate through all the upstream nodes of this node and update the partition map
168
+ # to indicate that there is a path from the partition id of the upstream node to the
169
+ # current node's partition id.
170
+ upstream_nodes = self.dependency_viewer.upstreams_of(node)
171
+ for curr_node in upstream_nodes:
172
+ source_id = assignment.get(curr_node, None)
173
+ if source_id is not None:
174
+ partition_map[source_id].add(id)
175
+
176
+ if node in assignment:
177
+ partitions_by_id[assignment[node]].remove_node(node)
178
+
179
+ if id is None:
180
+ assignment.pop(node)
181
+ elif id not in partitions_by_id:
182
+ assignment[node] = id
183
+ partitions_by_id[id] = Partition(id=id, nodes=[node])
184
+ _update_partition_map(node, id)
185
+ else:
186
+ assignment[node] = id
187
+ partitions_by_id[id].add_node(node)
188
+ _update_partition_map(node, id)
189
+
190
+ logger.debug("Proposing partitions...")
191
+
192
+ for node in reversed(self.graph_module.graph.nodes):
193
+ # use Dict as an ordered set to ensure deterministic partitioning result, don't care value
194
+ merge_candidates: Dict[int, None] = {}
195
+
196
+ # Note a limited horizontal fusion is enabled:
197
+ # when `node` is not supported, the code below attempts to fuse consumer of `node`.
198
+ #
199
+ # I don't see a need to add a knob to disable horizontal fusion yet, we can short-cut
200
+ # the fusion by adding an `else` block here to skip horizontal fusion.
201
+ if self.__is_node_supported(node) and node not in assignment:
202
+ partition_id = next(new_partition_id)
203
+ merge_single_node(node, partition_id)
204
+ merge_candidates[partition_id] = None
205
+
206
+ # merge all possible partitions
207
+ for node in assignment:
208
+ merge_candidates[assignment[node]] = None
209
+
210
+ merge_candidates_list = list(merge_candidates.keys())
211
+ if len(merge_candidates_list) > 1:
212
+ self_id = merge_candidates_list[0]
213
+ for other_id in merge_candidates_list[1:]:
214
+ # note: merge partition `other_id` into partition `self_id` if
215
+ # it doesn't create cyclic dependency in the graph, otherwise,
216
+ # this is a no-op
217
+ maybe_merge_partition(self_id, other_id)
218
+
219
+ # post processing to re-assign "getitem" nodes into upstream partition
220
+ logger.debug("Reassigning getitem nodes to its producer node's partition...")
221
+ nodes_reassignment: Dict[Node, int] = {}
222
+ for node in self.graph_module.graph.nodes:
223
+ is_tuple_output = True
224
+ for user in node.users:
225
+ if user.op != "call_function" or \
226
+ _get_qualified_name(user.target) != "_operator.getitem": # type: ignore[arg-type]
227
+ is_tuple_output = False
228
+ break
229
+
230
+ # node has tuple outputs, re-assign all following getitem node into node's partition
231
+ if is_tuple_output:
232
+ id = assignment.get(node, None) # type: ignore[arg-type]
233
+ for user in node.users:
234
+ if assignment.get(user, None) != id: # type: ignore[arg-type]
235
+ nodes_reassignment[user] = id # type: ignore[assignment]
236
+ for node, id in nodes_reassignment.items():
237
+ merge_single_node(node, id)
238
+
239
+ # filter out single node partitions
240
+ if not self.allows_single_node_partition:
241
+ logger.debug("Filtering out single node partitions...")
242
+ default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
243
+ non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops))
244
+ partitions_to_remove: List[int] = []
245
+ for id, partition in partitions_by_id.items():
246
+ compute_node_count = 0
247
+ for node in partition.nodes:
248
+ if node.op == "call_function":
249
+ assert callable(node.target)
250
+ if _get_qualified_name(node.target) not in non_compute_ops:
251
+ compute_node_count += 1
252
+ if _get_qualified_name(node.target) in self.allowed_single_node_partition_ops:
253
+ compute_node_count += 1
254
+ if compute_node_count <= 1:
255
+ partitions_to_remove.append(id)
256
+ for id in partitions_to_remove:
257
+ del partitions_by_id[id]
258
+
259
+ logger.debug("Partitions proposed:")
260
+ for id, partition in partitions_by_id.items():
261
+ logger.debug("partition #%s: %s", id, [node.name for node in partition.nodes])
262
+
263
+ return list(partitions_by_id.values())
264
+
265
+ def fuse_partitions(self, partitions: List[Partition]) -> GraphModule:
266
+ logger.debug("Fusing partitions...")
267
+ # fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ]
268
+ return fuse_by_partitions(self.graph_module, [list(partition.nodes) for partition in partitions])
269
+
270
+ # remove non-compute-ops that sits at the boundary of a partition.
271
+ def remove_bookend_non_compute_ops(self, partitions: List[Partition]):
272
+ non_compute_ops = set(self.non_compute_ops)
273
+
274
+ def is_non_compute_node(node: Node):
275
+ return node.op == "call_function" and \
276
+ _get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type]
277
+
278
+ # cache transparent nodes
279
+ transparent_input_nodes: Dict[Node, bool] = {}
280
+ transparent_output_nodes: Dict[Node, bool] = {}
281
+
282
+ def is_transparent_input_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
283
+ if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
284
+ return True
285
+ if node in transparent_input_nodes:
286
+ return transparent_input_nodes[node]
287
+ if is_non_compute_node(node):
288
+ for input_n in node.all_input_nodes:
289
+ if not is_transparent_input_node(input_n, partition, removed_nodes):
290
+ transparent_input_nodes[node] = False
291
+ return False
292
+ transparent_input_nodes[node] = True
293
+ return True
294
+ transparent_input_nodes[node] = False
295
+ return False
296
+
297
+ def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
298
+ if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
299
+ return True
300
+ if node in transparent_output_nodes:
301
+ return transparent_output_nodes[node]
302
+ if is_non_compute_node(node):
303
+ for output_n in node.users:
304
+ if not is_transparent_output_node(output_n, partition, removed_nodes):
305
+ transparent_output_nodes[node] = False
306
+ return False
307
+ transparent_output_nodes[node] = True
308
+ return True
309
+ transparent_output_nodes[node] = False
310
+ return False
311
+
312
+ for partition in partitions:
313
+ # Note it's ok to use `set` here, since we are only query if a node
314
+ # has been removed. We are NEVER going to iterate on nodes inside
315
+ # the set.
316
+ remove_node: Set[Node] = set()
317
+ for node in partition.nodes:
318
+ if is_non_compute_node(node) and \
319
+ (is_transparent_input_node(node, partition.nodes, remove_node) or
320
+ is_transparent_output_node(node, partition.nodes, remove_node)):
321
+ remove_node.add(node)
322
+
323
+ if len(remove_node) != 0:
324
+ partition.nodes = partition.nodes - remove_node
325
+
326
+ def partition_and_fuse(self) -> GraphModule:
327
+ partitions = self.propose_partitions()
328
+ fused_gm = self.fuse_partitions(partitions)
329
+ return fused_gm
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/operator_support.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import typing as t
3
+
4
+ import torch
5
+ import torch.fx
6
+ from torch.fx._compatibility import compatibility
7
+ from .shape_prop import TensorMetadata
8
+ from .tools_common import get_node_target, CALLABLE_NODE_OPS
9
+
10
+
11
+ __all__ = ['OperatorSupportBase', 'OperatorSupport', 'create_op_support', 'chain', 'OpSupports', 'any_chain']
12
+
13
+ # fx.Node.target typename, as returned by `get_node_target()`
14
+ TargetTypeName = str
15
+
16
+ # Arguments' dtypes for a given node, see `OperatorSupport`
17
+ SupportedArgumentDTypes = t.Optional[
18
+ t.Tuple[
19
+ t.Sequence[t.Sequence[torch.dtype]],
20
+ t.Dict[str, t.Sequence[torch.dtype]],
21
+ ]
22
+ ]
23
+
24
+ SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes]
25
+
26
+
27
+ @compatibility(is_backward_compatible=False)
28
+ class OperatorSupportBase(abc.ABC):
29
+ """Interface for determining if a fx.Node is supported by a backend"""
30
+ @abc.abstractmethod
31
+ def is_node_supported(
32
+ self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
33
+ ) -> bool:
34
+ raise NotImplementedError()
35
+
36
+
37
+ @compatibility(is_backward_compatible=False)
38
+ class OperatorSupport(OperatorSupportBase):
39
+ """
40
+ `_support_dict` maps node.target typename to supported inputs dtypes.
41
+
42
+ node.target typename is retrieved using helper function `get_node_target()`
43
+
44
+ If supported inputs dtypes is None, it means any dtype is supported, else
45
+ we should see a tuple like (([dtypes], ...), {"name":[dtypes], ...}).
46
+
47
+ The first tuple ([dtypes], ...) indicates what dtypes are supported for
48
+ inputs in node.args and the second dict {"name": [dtypes], ...} indicates
49
+ what dtypes are supported for inputs in node.kwargs.
50
+
51
+ For inputs in args, if we don't want to check it, we can put None there,
52
+ e.g. (None, [torch.float]) indicates that we don't care about the type of
53
+ the first input in args. And for inputs in kwargs, if not listed, will not
54
+ be checked.
55
+ """
56
+
57
+ _support_dict: SupportDict
58
+
59
+ def __init__(
60
+ self,
61
+ support_dict: t.Optional[SupportDict] = None
62
+ ):
63
+ self._support_dict = support_dict or {}
64
+
65
+ def is_node_supported(
66
+ self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
67
+ ) -> bool:
68
+ """
69
+ Args:
70
+ `submodules`: mapping from module name to the module. This can be
71
+ retrieved by calling model.named_modules().
72
+
73
+ `node`: a Fx node that we want to determine whether it's supported.
74
+
75
+ Returns:
76
+ `is_supported`: whether the arg `node` is supported.
77
+ """
78
+ if node.op not in CALLABLE_NODE_OPS:
79
+ return True
80
+
81
+ target = get_node_target(submodules, node)
82
+
83
+ # Target not found in _support_dict meaning that we don't support this op at all
84
+ if target not in self._support_dict:
85
+ return False
86
+
87
+ # The rule for target is None meaning that we accept any dtype
88
+ if self._support_dict[target] is None:
89
+ return True
90
+
91
+ args_dtypes, kwargs_dtypes = self._support_dict[target] # type: ignore[misc]
92
+
93
+ # Check args dtypes
94
+ for i, dtypes in enumerate(args_dtypes):
95
+ if len(node.args) <= i:
96
+ break
97
+
98
+ # None indicates we don't care about the dtype of args[i]
99
+ if dtypes is None:
100
+ continue
101
+
102
+ # If arg is not a node then we don't check it
103
+ if not isinstance(node.args[i], torch.fx.Node):
104
+ continue
105
+
106
+ arg_dtype = _get_arg_dtype(node.args[i]) # type: ignore[arg-type]
107
+ if arg_dtype not in dtypes:
108
+ return False
109
+
110
+ # Check kwargs dtypes
111
+ for k, dtypes in kwargs_dtypes.items():
112
+ if k not in node.kwargs:
113
+ continue
114
+
115
+ # If arg is not a node then we don't check it
116
+ if not isinstance(node.kwargs[k], torch.fx.Node):
117
+ continue
118
+
119
+ kwarg_dtype = _get_arg_dtype(node.kwargs[k]) # type: ignore[arg-type]
120
+ if kwarg_dtype not in dtypes:
121
+ return False
122
+
123
+ return True
124
+
125
+
126
+ # ======================================================================
127
+ # Functional interfaces and utils for defining basic operator support logic
128
+ # and composing them into more complex ones
129
+ # ======================================================================
130
+
131
+ IsNodeSupported = t.Callable[[t.Mapping[str, torch.nn.Module], torch.fx.Node], bool]
132
+
133
+
134
+ @compatibility(is_backward_compatible=False)
135
+ def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase:
136
+ """Wraps a `IsNodeSupported` function into an `OperatorSupportBase` instance
137
+
138
+ `IsNodeSupported` has the same call signature as
139
+ `OperatorSupportBase.is_node_supported`
140
+ """
141
+ class FunctionalOperatorSupport(OperatorSupportBase):
142
+ def is_node_supported(
143
+ self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
144
+ ) -> bool:
145
+ return is_node_supported(submodules, node)
146
+ return FunctionalOperatorSupport()
147
+
148
+
149
+ @compatibility(is_backward_compatible=False)
150
+ def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
151
+ """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase`
152
+ instance by evaluating each input `OperatorSupportBase` instance, and returns False if
153
+ any of it reports False.
154
+ """
155
+ def _chain(submods, node) -> bool:
156
+ return all(
157
+ x.is_node_supported(submods, node)
158
+ for x in op_support
159
+ )
160
+ return create_op_support(_chain)
161
+
162
+
163
+ @compatibility(is_backward_compatible=False)
164
+ def any_chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
165
+ """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase`
166
+ instance by evaluating each input `OperatorSupportBase` instance, and returns True if
167
+ any of it reports True.
168
+ """
169
+ def _any_chain(submods, node) -> bool:
170
+ return any(
171
+ x.is_node_supported(submods, node)
172
+ for x in op_support
173
+ )
174
+ return create_op_support(_any_chain)
175
+
176
+
177
+ @compatibility(is_backward_compatible=False)
178
+ class OpSupports:
179
+ """A set of atomic `OperatorSupportBase` instances that can be combined together
180
+ to form more complex operator support logic.
181
+ """
182
+ @classmethod
183
+ def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase:
184
+ """Report a node as non-supported, if any of its arguments is of dtype"""
185
+
186
+ def _decline_if_input_dtype(
187
+ submodules: t.Mapping[str, torch.nn.Module],
188
+ node: torch.fx.Node,
189
+ ) -> bool:
190
+ for arg in node.all_input_nodes:
191
+ arg_dtype = _get_arg_dtype(arg)
192
+ if arg_dtype == dtype:
193
+ return False
194
+ return True
195
+ return create_op_support(_decline_if_input_dtype)
196
+
197
+ @classmethod
198
+ def decline_if_node_in_names(cls, disallow_set: t.Set[str]) -> OperatorSupportBase:
199
+ """
200
+ If a node has a name that is in the disallow set, reported it as non-supported.
201
+ """
202
+ def _decline_if_node_in_names(
203
+ submodules: t.Mapping[str, torch.nn.Module],
204
+ node: torch.fx.Node,
205
+ ) -> bool:
206
+ if node.name in disallow_set:
207
+ return False
208
+ else:
209
+ return True
210
+ return create_op_support(_decline_if_node_in_names)
211
+
212
+
213
+ def _get_arg_dtype(arg: torch.fx.Node) -> t.Any:
214
+ assert isinstance(arg, torch.fx.Node)
215
+ tensor_meta = arg.meta.get("tensor_meta") # type: ignore[union-attr]
216
+ dtype = tensor_meta.dtype if isinstance(tensor_meta, TensorMetadata) else arg.meta["type"]
217
+ return dtype
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/param_fetch.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.fx.graph_module import GraphModule
2
+ from typing import Any, Callable, Dict, List, Tuple, Type
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from torch.fx._compatibility import compatibility
7
+
8
+ __all__ = ['default_matching', 'extract_attrs_for_lowering', 'lift_lowering_attrs_to_nodes']
9
+
10
+ # Matching method matches the attribute name of current version to the attribute name of `target_version`
11
+ @compatibility(is_backward_compatible=False)
12
+ def default_matching(name: str, target_version: int) -> str:
13
+ """Default matching method
14
+ """
15
+ return name
16
+
17
+ # This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering.
18
+ # The first integer in the tuple is the version number of the nn.Module class when we create the parameter list.
19
+ # If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module.
20
+ module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = {
21
+ torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching),
22
+ torch.nn.modules.conv.Conv2d: (
23
+ 1, ["weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode"], default_matching
24
+ ),
25
+ torch.nn.modules.batchnorm.BatchNorm2d: (2, ["weight", "bias", "running_mean", "running_var", "eps"], default_matching),
26
+ torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching),
27
+ torch.nn.modules.pooling.MaxPool2d: (
28
+ 1, ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], default_matching
29
+ ),
30
+ torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching),
31
+ }
32
+
33
+ @compatibility(is_backward_compatible=False)
34
+ def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]:
35
+ """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book`
36
+ after checking module's version is compatible with the `module_fetch_book`.
37
+ """
38
+ attrs_for_lowering: Dict[str, Any] = {}
39
+ attrs_for_lowering["name"] = torch.typename(mod)
40
+
41
+ if type(mod) in module_fetch_book:
42
+ version, param_to_fetch, matching_method = module_fetch_book[type(mod)]
43
+ if version < mod._version:
44
+ raise RuntimeError(f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, "
45
+ "please upgrade the module_fetch_book, open an issue and @842974287 "
46
+ "or report a bug to AIACC team directly.")
47
+ for attr in param_to_fetch:
48
+ attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version))
49
+ else:
50
+ raise RuntimeError(f"{torch.typename(mod)} is not in the module_fetch_book yet, "
51
+ "please add it to the module_fetch_book, open an issue and @842974287 "
52
+ "or report a bug to AIACC team directly.")
53
+ return attrs_for_lowering
54
+
55
+ @compatibility(is_backward_compatible=False)
56
+ def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None:
57
+ """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module.
58
+ """
59
+ submodules = dict(fx_module.named_modules())
60
+
61
+ for node in fx_module.graph.nodes:
62
+ if node.op == "call_module":
63
+ if isinstance(submodules[node.target], GraphModule):
64
+ lift_lowering_attrs_to_nodes(submodules[node.target])
65
+ else:
66
+ node.attrs_for_lowering = extract_attrs_for_lowering(submodules[node.target])
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/shape_prop.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ import torch
4
+ import torch.fx
5
+ import traceback
6
+
7
+ from torch._dispatch.python import enable_python_dispatcher
8
+ from torch.fx.node import Node, map_aggregate
9
+ from typing import Any, Tuple, NamedTuple, Optional, Dict
10
+ from torch.fx._compatibility import compatibility
11
+ from torch._guards import detect_fake_mode
12
+
13
+ __all__ = ['TensorMetadata', 'ShapeProp']
14
+
15
+ @compatibility(is_backward_compatible=True)
16
+ class TensorMetadata(NamedTuple):
17
+ # TensorMetadata is a structure containing pertinent information
18
+ # about a tensor within a PyTorch program.
19
+
20
+ # General Tensor metadata
21
+ shape : torch.Size
22
+ dtype : torch.dtype
23
+ requires_grad : bool
24
+ stride : Tuple[int, ...]
25
+ memory_format : Optional[torch.memory_format]
26
+
27
+ # Quantization metadata
28
+ is_quantized : bool
29
+ qparams: Dict[str, Any]
30
+
31
+ def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> TensorMetadata:
32
+ """
33
+ Extract a TensorMetadata NamedTuple describing `result`.
34
+ """
35
+ shape = result.shape
36
+ dtype = result.dtype
37
+ requires_grad = result.requires_grad
38
+ stride = result.stride()
39
+
40
+ memory_format = None
41
+
42
+ if include_contiguity:
43
+ memory_formats = {
44
+ torch.contiguous_format,
45
+ torch.channels_last,
46
+ torch.channels_last_3d,
47
+ }
48
+ for query_format in memory_formats:
49
+ if result.is_contiguous(memory_format=query_format):
50
+ memory_format = query_format
51
+ break
52
+
53
+ is_quantized = result.is_quantized
54
+ qparams: Dict[str, Any] = {}
55
+ if is_quantized:
56
+ qscheme = result.qscheme()
57
+ qparams["qscheme"] = qscheme
58
+ if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
59
+ qparams["scale"] = result.q_scale() # type: ignore[assignment]
60
+ qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment]
61
+ elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}:
62
+ # In this branch, scale and zero_point are expected to be tensors,
63
+ # we store the values as immutable_list in TensorMetadata for
64
+ # easier serialization downstream
65
+ qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment]
66
+ qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment]
67
+ qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment]
68
+
69
+ return TensorMetadata(
70
+ shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams)
71
+
72
+ @compatibility(is_backward_compatible=True)
73
+ class ShapeProp(torch.fx.Interpreter):
74
+ """
75
+ Execute an FX graph Node-by-Node and
76
+ record the shape and type of the result
77
+ into the corresponding node.
78
+
79
+ Example:
80
+ In this example, we record the shape
81
+ and data type of a module given
82
+ an example input ``torch.randn(50, D_in)``.
83
+ We print the name, shape and dtype of each node.
84
+
85
+ class TwoLayerNet(torch.nn.Module):
86
+ def __init__(self, D_in, H, D_out):
87
+ super().__init__()
88
+ self.linear1 = torch.nn.Linear(D_in, H)
89
+ self.linear2 = torch.nn.Linear(H, D_out)
90
+ def forward(self, x):
91
+ h_relu = self.linear1(x).clamp(min=0)
92
+ y_pred = self.linear2(h_relu)
93
+ return y_pred
94
+ N, D_in, H, D_out = 64, 1000, 100, 10
95
+ x = torch.randn(N, D_in)
96
+ y = torch.randn(N, D_out)
97
+ model = TwoLayerNet(D_in, H, D_out)
98
+ gm = torch.fx.symbolic_trace(model)
99
+ sample_input = torch.randn(50, D_in)
100
+ ShapeProp(gm).propagate(sample_input)
101
+
102
+ for node in gm.graph.nodes:
103
+ print(node.name, node.meta['tensor_meta'].dtype,
104
+ node.meta['tensor_meta'].shape)
105
+
106
+ The output of this code is:
107
+
108
+ x torch.float32 torch.Size([50, 1000])
109
+ linear1 torch.float32 torch.Size([50, 100])
110
+ clamp_1 torch.float32 torch.Size([50, 100])
111
+ linear2 torch.float32 torch.Size([50, 10])
112
+ output torch.float32 torch.Size([50, 10])
113
+
114
+ Args:
115
+ module (GraphModule): The module to be executed
116
+ fake_mode (FakeTensorMode): A fake mode for copying the gm
117
+
118
+ """
119
+ def __init__(self, gm, fake_mode=None):
120
+ super().__init__(gm)
121
+ if fake_mode is None:
122
+ fake_mode = detect_fake_mode()
123
+ if fake_mode is not None:
124
+ from torch._dynamo.utils import deepcopy_to_fake_tensor
125
+ # Note:
126
+ # We need fake execution cause the inputs are fake, however, we cannot fakify the module
127
+ # - because we need to write to the tensor_meta of the real module. So we fakify to
128
+ # produce a result (L131 below), to extract tensor meta, and then keep going.
129
+ #
130
+ # If we were to fakify, we would write to the wrong node, and then downstream fusion
131
+ # would be missing the tensor_meta.
132
+ #
133
+ # See torch/_inductor/overrides.py for where this is called upstream of fusion.
134
+ self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode)
135
+ self.fake_mode = fake_mode
136
+ else:
137
+ self.fake_module = None
138
+ self.fake_mode = None
139
+
140
+ self.real_module = self.module
141
+
142
+ def run_node(self, n : Node) -> Any:
143
+ try:
144
+ if self.fake_module is not None:
145
+ # Hacky swap. Alternatively, we could do this with overriding
146
+ # call_module and get_attr.
147
+ self.module = self.fake_module
148
+ try:
149
+ if self.fake_mode is not None:
150
+ with self.fake_mode, enable_python_dispatcher():
151
+ result = super().run_node(n)
152
+ else:
153
+ result = super().run_node(n)
154
+ finally:
155
+ self.module = self.real_module
156
+ except Exception as e:
157
+ traceback.print_exc()
158
+ raise RuntimeError(
159
+ f"ShapeProp error for: node={n.format_node()} with "
160
+ f"meta={n.meta}"
161
+ ) from e
162
+
163
+ found_tensor = False
164
+
165
+ def extract_tensor_meta(obj):
166
+ if isinstance(obj, torch.Tensor):
167
+ nonlocal found_tensor
168
+ found_tensor = True
169
+ return _extract_tensor_metadata(obj)
170
+ else:
171
+ return obj
172
+
173
+ meta = map_aggregate(result, extract_tensor_meta)
174
+ if found_tensor:
175
+ n.meta['tensor_meta'] = meta
176
+
177
+ n.meta['type'] = type(result)
178
+ return result
179
+
180
+ def propagate(self, *args):
181
+ """
182
+ Run `module` via interpretation and return the result and
183
+ record the shape and type of each node.
184
+
185
+ Args:
186
+ *args (Tensor): the sample input.
187
+
188
+ Returns:
189
+ Any: The value returned from executing the Module
190
+ """
191
+ if self.fake_mode is not None:
192
+ fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args]
193
+ else:
194
+ fake_args = args
195
+ return super().run(*fake_args)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/split_utils.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from dataclasses import dataclass, field
3
+ from typing import Dict, List, Optional, Tuple, Type, Union
4
+
5
+ import torch.fx
6
+ from torch.fx._compatibility import compatibility
7
+ from torch.fx.graph import map_arg
8
+ from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module
9
+
10
+ from .tools_common import NodeList
11
+
12
+ __all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"]
13
+
14
+
15
+ @compatibility(is_backward_compatible=False)
16
+ def getattr_recursive(obj, name):
17
+ for layer in name.split("."):
18
+ if hasattr(obj, layer):
19
+ obj = getattr(obj, layer)
20
+ else:
21
+ return None
22
+ return obj
23
+
24
+
25
+ @compatibility(is_backward_compatible=False)
26
+ def setattr_recursive(obj, attr, value):
27
+ if "." not in attr:
28
+ setattr(obj, attr, value)
29
+ else:
30
+ layer = attr.split(".")
31
+ setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value)
32
+
33
+
34
+ @compatibility(is_backward_compatible=False)
35
+ @dataclass
36
+ class Component:
37
+ """
38
+ A component serves as a container for a subgraph we want to create afterwards.
39
+ """
40
+
41
+ graph: torch.fx.Graph
42
+ order: int
43
+ name: str
44
+
45
+ # Stores the placeholder nodes in `graph`.
46
+ input_placeholders: List = field(default_factory=list)
47
+
48
+ # Store the nodes in original graph that are placeholder in `graph`.
49
+ orig_inputs: List = field(default_factory=list)
50
+
51
+ # Store the nodes in original graph that are outputs in `graph`.
52
+ orig_outputs: List = field(default_factory=list)
53
+
54
+ # Mapping from get_attr node in original graph to get_attr node in `graph`.
55
+ getattr_maps: Dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict)
56
+ constructor_args: List[str] = field(default_factory=list)
57
+ gm: Optional[torch.fx.GraphModule] = None
58
+
59
+
60
+ @compatibility(is_backward_compatible=False)
61
+ def split_by_tags(
62
+ gm: torch.fx.GraphModule,
63
+ tags: List[str],
64
+ return_fqn_mapping: bool = False,
65
+ return_tuple: bool = False,
66
+ GraphModuleCls: Type[torch.fx.GraphModule] = torch.fx.GraphModule,
67
+ ) -> Union[torch.fx.GraphModule, Tuple[torch.fx.GraphModule, Dict[str, str]]]:
68
+ """
69
+ Splits a GraphModule using tags on its graph nodes. We honor the order of
70
+ tags. For example, we have tags = ["a", "b", "c"], the function will create
71
+ the initial submodules in the order of "a", "b", "c".
72
+
73
+ To set a tag:
74
+ gm.graph.nodes[idx].tag = "mytag"
75
+
76
+ This will result in all nodes with the same tag being extracted and placed in their
77
+ own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder
78
+ and output nodes are created when needed while get_attr nodes get copied to submodules
79
+ where they are used.
80
+
81
+ Given the following module def:
82
+
83
+ class SimpleModule(torch.nn.Module):
84
+ def __init__(self):
85
+ super().__init__()
86
+ self.linear1 = torch.nn.Linear(...)
87
+ self.linear2 = torch.nn.Linear(...)
88
+ self.linear3 = torch.nn.Linear(...)
89
+
90
+ def forward(self, in1, in2):
91
+ r1 = self.linear1(in1)
92
+ r2 = self.linear2(in2)
93
+ r3 = torch.cat([r1, r2])
94
+ return self.linear3(r3)
95
+
96
+ Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split:
97
+
98
+ ro:
99
+ def forward(self, in1):
100
+ self = self.root
101
+ linear1 = self.linear1(in1)
102
+ return linear1
103
+
104
+ main:
105
+ def forward(self, in2, linear1):
106
+ self = self.root
107
+ linear2 = self.linear2(in2)
108
+ cat_1 = torch.cat([linear1, linear2])
109
+ linear3 = self.linear3(cat_1)
110
+ return linear3
111
+
112
+ main:
113
+ def forward(self, in1, in2):
114
+ self = self.root
115
+ ro_0 = self.ro_0(in1)
116
+ main_1 = self.main_1(in2, ro_0)
117
+ return main_1
118
+
119
+ Returns:
120
+ split_gm: torch fx graph after split
121
+ orig_to_split_fqn_mapping: a map between the original fqn and the fqn
122
+ after split for call_module and get_attr.
123
+ """
124
+
125
+ def flatten(x: torch.fx.node.Argument) -> NodeList:
126
+ """
127
+ Stores nodes in x to a list and returns the list.
128
+ """
129
+ r: NodeList = []
130
+ map_arg(x, r.append)
131
+ return r
132
+
133
+ # Mapping from node in original module to node in created submodule.
134
+ node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
135
+
136
+ # Mapping from node in original module or created submodules to
137
+ # corresponding component.
138
+ node_to_component: Dict[torch.fx.Node, Component] = {}
139
+
140
+ # Mapping from tag to the corresponding component.
141
+ tag_to_component: Dict[str, Component] = {}
142
+
143
+ # Stores all components.
144
+ all_components: List[Component] = []
145
+
146
+ # Stores nodes that will be used in main graph.
147
+ used_in_main: Dict[torch.fx.Node, None] = {}
148
+
149
+ # Main graph after split.
150
+ main_g = torch.fx.Graph()
151
+
152
+ # Mapping from node in original module to node in main graph after split.
153
+ main_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
154
+
155
+ # Output node of original module.
156
+ output_node: Optional[torch.fx.Node] = None
157
+
158
+ # Create a component for each tag, we don't expect to create other components afterwards.
159
+ for tag in tags:
160
+ comp = Component(torch.fx.Graph(), len(all_components), f"{tag}")
161
+ all_components.append(comp)
162
+ tag_to_component[tag] = comp
163
+
164
+ # Traverse the nodes in original graph and take care of them.
165
+ for node in gm.graph.nodes:
166
+ if node.op == "output":
167
+ if output_node is not None:
168
+ raise RuntimeError("Multiple output nodes in graph!")
169
+ output_node = node
170
+ continue
171
+
172
+ # Placeholders in the original graph get copied to main graph.
173
+ if node.op == "placeholder":
174
+ main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type)
175
+ main_remapping[node].meta = copy.copy(node.meta)
176
+ continue
177
+
178
+ # Get_attr nodes are ignored because we are not tagging them.
179
+ # Instead, we copy them directly to the submodules use them afterwards.
180
+ if node.op == "get_attr":
181
+ continue
182
+
183
+ # Now we process callable nodes which are nodes with op of call_module,
184
+ # call_function or call_method. Every callable nodes should be tagged.
185
+ assert hasattr(node, "tag")
186
+
187
+ upstream_components = [
188
+ node_to_component[x]
189
+ for x in flatten(node.args) + flatten(node.kwargs)
190
+ if x.op not in {"placeholder", "get_attr"}
191
+ ]
192
+
193
+ comp = tag_to_component[node.tag]
194
+ node_to_component[node] = comp
195
+
196
+ # Max order of upperstream components.
197
+ mx = max((c.order for c in upstream_components), default=0)
198
+
199
+ # Expect the component for `node` has higher order then its upstream components.
200
+ assert comp.order >= mx
201
+
202
+ # Map a input of `node` to nodes in the component's graph.
203
+ def remap_func(x):
204
+ # If input is a get_attr node, copy it to current component's graph.
205
+ # Returns the get_attr node in current component's graph.
206
+ if x.op == "get_attr":
207
+ if x not in comp.getattr_maps:
208
+ comp.getattr_maps[x] = comp.graph.get_attr(
209
+ x.target, type_expr=x.type
210
+ )
211
+ return comp.getattr_maps[x]
212
+
213
+ # If input is not a placeholder, it should have been put into a component
214
+ # already. If it's the current component then we return the corresponding
215
+ # node in the component.
216
+ if x.op != "placeholder" and node_to_component[x] == comp:
217
+ return node_remapping[x]
218
+
219
+ # If input is a placeholder or it's in other components, we want to make it
220
+ # as a placeholder in current component's graph.
221
+ if x not in comp.orig_inputs:
222
+ comp.orig_inputs.append(x)
223
+ placeholder = comp.graph.placeholder(x.name, type_expr=x.type)
224
+ placeholder.meta = copy.copy(x.meta)
225
+ comp.input_placeholders.append(placeholder)
226
+ used_in_main[x] = None
227
+
228
+ return comp.input_placeholders[comp.orig_inputs.index(x)]
229
+
230
+ n = comp.graph.node_copy(node, remap_func)
231
+ n.tag = node.tag # type: ignore[attr-defined]
232
+ node_remapping[node] = n
233
+ node_to_component[n] = comp
234
+
235
+ if output_node is None:
236
+ raise RuntimeError("Graph had no output node!")
237
+
238
+ for x in flatten(output_node.args[0]):
239
+ if x.op == "get_attr":
240
+ # We don't need components mapping for nodes of type "get_attr"
241
+ # that are consumed by the output. Only need to make sure we create
242
+ # corresponding counterparts in the resulting graph.
243
+ main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type)
244
+ else:
245
+ # All component results consumed by the output node should be
246
+ # marked as "used in main".
247
+ used_in_main[x] = None
248
+
249
+ # If a node is used in main graph then we mark it as an output in the component
250
+ # it belongs to.
251
+ for n in used_in_main:
252
+ if n.op != "placeholder":
253
+ node_to_component[n].orig_outputs.append(n)
254
+
255
+ # Now we create a graphmodule for each component.
256
+ orig_to_split_fqn_mapping: Dict[str, str] = {}
257
+ for comp in all_components:
258
+ outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))
259
+
260
+ if return_tuple:
261
+ comp.graph.output(outs)
262
+ else:
263
+ # Take care of the args of FX output node. If there's a single
264
+ # output then the output node args is like (output_single), else
265
+ # if there're multiple outputs then the output node args is like
266
+ # ((output_0, output_1, ...)).
267
+ comp.graph.output(outs[0] if len(outs) == 1 else outs)
268
+
269
+ comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module(
270
+ gm, subgraph=comp.graph, comp_name=comp.name
271
+ )
272
+ orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping)
273
+
274
+ # Create a call_module node in main graph.
275
+ main_node = main_g.call_module(
276
+ comp.name,
277
+ args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)),
278
+ kwargs=None,
279
+ )
280
+
281
+ if len(outs) == 1 and not return_tuple:
282
+ main_remapping[comp.orig_outputs[0]] = main_node
283
+ else:
284
+ for i, o in enumerate(comp.orig_outputs):
285
+ # Use Proxy to record getitem access.
286
+ main_remapping[o] = torch.fx.Proxy(main_node)[i].node # type: ignore[index]
287
+
288
+ main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__))
289
+ main_root = HolderModule({comp.name: comp.gm for comp in all_components})
290
+ main_g._codegen = gm.graph._codegen
291
+
292
+ # If the output nodes consumes get_attr directly in the original graph,
293
+ # then we need to make sure get_attr is copied to the new graph.
294
+ for x in flatten(output_node.args[0]):
295
+ if x.op == "get_attr":
296
+ setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type]
297
+
298
+ result_gm = GraphModuleCls(main_root, main_g)
299
+ if return_fqn_mapping:
300
+ return result_gm, orig_to_split_fqn_mapping
301
+
302
+ return result_gm
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (350 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-311.pyc ADDED
Binary file (6.99 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/common.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Tuple
2
+
3
+ from torch.fx._compatibility import compatibility
4
+ from torch.fx.graph import Graph
5
+
6
+ from torch.fx.graph_module import GraphModule
7
+ from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
8
+ from torch.nn import Module
9
+
10
+
11
+ __all__ = ["HolderModule", "lift_subgraph_as_module", "compare_graphs"]
12
+
13
+
14
+ @compatibility(is_backward_compatible=False)
15
+ class HolderModule(Module):
16
+ """
17
+ HolderModule is used to copy all the attributes from original module to submodules
18
+ that uses the attributes
19
+ """
20
+
21
+ def __init__(self, d):
22
+ super().__init__()
23
+ for k, v in d.items():
24
+ self.add_module(k, v)
25
+
26
+
27
+ @compatibility(is_backward_compatible=False)
28
+ def lift_subgraph_as_module(
29
+ gm: GraphModule,
30
+ subgraph: Graph,
31
+ comp_name: str = "",
32
+ class_name: str = "GraphModule",
33
+ ) -> Tuple[GraphModule, Dict[str, str]]:
34
+ """
35
+ Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module.
36
+
37
+ Args:
38
+ gm (GraphModule): parent graph module
39
+
40
+ subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph
41
+
42
+ comp_name (str): name for the new component
43
+
44
+ class_name (str): name for the submodule
45
+
46
+ """
47
+
48
+ # Loop through all module calls (call_module) and param fetches (get_attr)
49
+ # in this component, creating HolderModules as necessary to match the path.
50
+ # e.g. if in the original module there's a get_attr node fetches "conv.weight".
51
+ # We create a HolderModule as root -> add a HolderModule named "conv" ->
52
+ # make "weight" a attribute of "conv" HolderModule and point to conv.weight in
53
+ # the original module.
54
+ submodule = HolderModule({})
55
+ orig_to_split_fqn_mapping: Dict[str, str] = {}
56
+ for n in subgraph.nodes:
57
+ if n.op not in ("call_module", "get_attr"):
58
+ continue
59
+
60
+ target = n.target
61
+ assert isinstance(target, str)
62
+ target_name_parts = target.split(".")
63
+ curr = submodule
64
+ orig_gm = gm
65
+
66
+ for name in target_name_parts[:-1]:
67
+ if not hasattr(curr, name):
68
+ curr.add_module(name, HolderModule({}))
69
+
70
+ curr = getattr(curr, name)
71
+ orig_gm = getattr(orig_gm, name)
72
+
73
+ leaf_node_name = target_name_parts[-1]
74
+ leaf_node = getattr(orig_gm, leaf_node_name)
75
+
76
+ orig_to_split_fqn_mapping[target] = f"{comp_name}.{target}"
77
+ # Relies on custom __setattr__ magic.
78
+ setattr(curr, leaf_node_name, leaf_node)
79
+
80
+ return GraphModule(submodule, subgraph, class_name), orig_to_split_fqn_mapping
81
+
82
+
83
+ @compatibility(is_backward_compatible=False)
84
+ def compare_graphs(left: Graph, right: Graph) -> bool:
85
+ """
86
+ Return True if two graphs are identical, i.e they
87
+ - have the same number of outputs in the same order
88
+ - have the same number of inputs in the same order
89
+ - have the same set of nodes, and identical connectivity
90
+ """
91
+
92
+ matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True)
93
+ matches = matcher.match(right)
94
+
95
+ return len(matches) > 0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/matcher_utils.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from collections import defaultdict
3
+ import copy
4
+ import torch
5
+ from torch.fx import (
6
+ Node,
7
+ Graph,
8
+ )
9
+ from torch.fx._compatibility import compatibility
10
+ from typing import Dict, List, Set, Any, Union, Tuple
11
+ import logging
12
+ import os
13
+
14
+ __all__ = ['SubgraphMatcher', 'InternalMatch']
15
+
16
+ # Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
17
+ def _init_logger():
18
+ logger = logging.getLogger(__name__)
19
+
20
+ level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper()
21
+ logger.setLevel(level)
22
+ console = logging.StreamHandler()
23
+ formatter = logging.Formatter("%(filename)s > %(message)s")
24
+ console.setFormatter(formatter)
25
+ console.setLevel(level)
26
+ # add the handlers to the logger
27
+ logger.addHandler(console)
28
+ logger.propagate = False
29
+ return logger
30
+
31
+ logger = _init_logger()
32
+
33
+ @compatibility(is_backward_compatible=False)
34
+ @dataclass
35
+ class InternalMatch:
36
+ # Nodes from which the match was found
37
+ anchors: List[Node]
38
+ # Maps nodes in the pattern subgraph to nodes in the larger graph
39
+ nodes_map: Dict[Node, Node] = field(default_factory=dict)
40
+
41
+ # nodes in target graph that are matched placeholder in pattern
42
+ placeholder_nodes: List[Node] = field(default_factory=list)
43
+
44
+ # nodes in matched subgraph returned by output
45
+ returning_nodes: List[Node] = field(default_factory=list)
46
+
47
+ # map from a string name to a node in the target graph
48
+ # only available if the matcher is `SubgraphMatcherWithNameNodesMap`
49
+ name_node_map: Dict[str, Node] = field(default_factory=dict)
50
+
51
+ def __copy__(self):
52
+ return InternalMatch(anchors=self.anchors, nodes_map=self.nodes_map.copy(),
53
+ placeholder_nodes=self.placeholder_nodes.copy(),
54
+ returning_nodes=self.returning_nodes.copy())
55
+
56
+ @compatibility(is_backward_compatible=False)
57
+ class SubgraphMatcher:
58
+ def __init__(self, pattern: Graph,
59
+ match_output: bool = False,
60
+ match_placeholder: bool = False,
61
+ remove_overlapping_matches: bool = True,
62
+ ignore_literals: bool = False) -> None:
63
+ """
64
+ Args:
65
+ pattern: the targeted matching pattern, represented in fx.Graph.
66
+ match_output: If True, output node in the pattern graph will be treated as a part of the targeted pattern.
67
+ If False, output node is ignored during match.
68
+ match_placeholder: If True, placeholder node in the pattern graph will be treated as a part of
69
+ the targeted pattern. If False, placeholder nodes will be used a wildcard.
70
+ remove_overlapping_matches: If True, in the case of overlapping matches, only the first match
71
+ will be returned.
72
+ ignore_literals: If True, will not check if literals are equal and
73
+ will instead treat them as wildcards.
74
+ """
75
+
76
+ self.pattern = pattern
77
+ self.match_output = match_output
78
+ self.match_placeholder = match_placeholder
79
+ self.remove_overlapping_matches = remove_overlapping_matches
80
+ self.ignore_literals = ignore_literals
81
+
82
+ if len(pattern.nodes) == 0:
83
+ raise ValueError("SubgraphMatcher cannot be initialized with an empty pattern")
84
+
85
+ for node in pattern.nodes:
86
+ if node.op != "output":
87
+ assert len(node.users) > 0, \
88
+ "SubgraphMatcher cannot be initialized with an pattern with dead code"
89
+
90
+ # TODO: assert pattern is a connected graph
91
+
92
+ self.pattern_placeholder_nodes = [n for n in pattern.nodes if n.op == "placeholder"]
93
+ output_node = next(iter(reversed(pattern.nodes)))
94
+ # nodes returned by outputs
95
+ self.pattern_returning_nodes: List[Node] = output_node.all_input_nodes
96
+
97
+ self.pattern_anchors: List[Node] = []
98
+ if match_output:
99
+ self.pattern_anchors = [output_node]
100
+ else:
101
+ # If a node has output_node as the ONLY user, then this node is a graph sink,
102
+ # and should be matched against as an anchor
103
+ self.pattern_anchors = [n for n in output_node.all_input_nodes if len(n.users) == 1]
104
+
105
+ def _match_attributes(self, pn: Node, gn: Node) -> bool:
106
+ # Attributes matching is complicated. Right now we only support matching constant tensor
107
+ assert isinstance(pn.target, str), f"pn.target {pn.target} must be a string."
108
+ assert isinstance(gn.target, str), f"gn.target {gn.target} must be a string."
109
+
110
+ # TODO(tmanlaibaatar) should probably make this actual API
111
+ def _getattr(model: torch.fx.GraphModule, attr_name: str):
112
+ *prefix, field = attr_name.split(".")
113
+ t = model
114
+ for item in prefix:
115
+ t = getattr(t, item, None) # type: ignore[assignment]
116
+ assert t is not None
117
+
118
+ return getattr(t, field)
119
+
120
+ pn_value = _getattr(pn.graph.owning_module, pn.target)
121
+ gn_value = _getattr(gn.graph.owning_module, gn.target)
122
+
123
+ if type(pn_value) != type(gn_value):
124
+ return False
125
+
126
+ # Don't require exact match on tensor values.
127
+ if isinstance(pn_value, torch.Tensor):
128
+ return isinstance(gn_value, torch.Tensor)
129
+ else:
130
+ raise RuntimeError(f"Unsupported type {pn_value} when matching attributes")
131
+ return False
132
+
133
+ def _nodes_are_equal(self, pn: Node, gn: Node) -> bool:
134
+ # if exact match for placeholder is not required, then use placeholder as a wildcard
135
+ if not self.match_placeholder and pn.op == "placeholder":
136
+ return True
137
+
138
+ if pn.op == gn.op:
139
+ if pn.op == "placeholder" or pn.op == "output":
140
+ return True
141
+ elif pn.op == "get_attr":
142
+ return self._match_attributes(pn, gn)
143
+ return pn.target == gn.target
144
+ return False
145
+
146
+ def _is_contained(self, nodes_map: Dict[Node, Node]) -> bool:
147
+ # `lookup` represents all the nodes in `original_graph`
148
+ # that are part of `pattern`
149
+
150
+ # Placeholders can be used by other nodes in the graphs
151
+ lookup: Dict[Node, Node] = {gn : pn for pn, gn in nodes_map.items() if pn.op != "placeholder"}
152
+
153
+ for gn, pn in lookup.items():
154
+ # nodes returned by output are allowed to be used in other areas of the graph
155
+ if pn in self.pattern_returning_nodes:
156
+ continue
157
+
158
+ for user in gn.users:
159
+ # If this node has users that were not in `lookup`, then it must leak out of the
160
+ # pattern subgraph
161
+ if user not in lookup:
162
+ return False
163
+ return True
164
+
165
+ def _remove_overlapping_matches(self, matches: List[InternalMatch]) -> List[InternalMatch]:
166
+ non_overlapping_matches: List[InternalMatch] = list()
167
+ nodes_matched: Set[Node] = set()
168
+
169
+ for match in matches:
170
+ found_overlap = False
171
+ for pn, gn in match.nodes_map.items():
172
+ if pn.op not in {"placeholder", "output"} and gn in nodes_matched:
173
+ found_overlap = True
174
+ break
175
+
176
+ if not found_overlap:
177
+ non_overlapping_matches.append(match)
178
+ for pn, gn in match.nodes_map.items():
179
+ if pn.op not in {"placeholder", "output"}:
180
+ nodes_matched.add(gn)
181
+ return non_overlapping_matches
182
+
183
+ def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool:
184
+ assert not (isinstance(pn, Node) and isinstance(gn, Node)), "pn and gn cannot both be Node"
185
+
186
+ if isinstance(pn, Node) and not isinstance(gn, Node):
187
+ if pn.op == "placeholder":
188
+ # Check if we've already matched these nodes in the current
189
+ # traversal
190
+ if pn in match.nodes_map:
191
+ return match.nodes_map[pn] == gn
192
+
193
+ match.nodes_map[pn] = gn
194
+ return True
195
+ else:
196
+ return False
197
+ elif not isinstance(pn, Node) and isinstance(gn, Node):
198
+ return False
199
+ else:
200
+ return type(gn) == type(pn) and gn == pn
201
+
202
+ def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool:
203
+ logger.info(" matching %s to %s", pn, gn)
204
+
205
+ assert isinstance(pn, Node) and isinstance(gn, Node), str(f"pn and gn must be Node, pn: {pn}, gn: {gn}")
206
+
207
+ # Check if we've already matched these nodes in the current
208
+ # traversal
209
+ if pn in match.nodes_map:
210
+ return match.nodes_map[pn] == gn
211
+
212
+ # TODO: use a more efficient way to check if gn is matched before: two-way dict
213
+ if gn in match.nodes_map.values():
214
+ return False
215
+
216
+ if not self._nodes_are_equal(pn, gn):
217
+ return False
218
+
219
+ # Optimistically mark `pn` as a match for `gn`, and save a local copy of match
220
+ saved_match = copy.copy(match)
221
+ match.nodes_map[pn] = gn
222
+
223
+ # Placeholder is a wildcard and can be matched with any python object
224
+ # (including list/tuple)
225
+ if pn.op == "placeholder":
226
+ return True
227
+
228
+ # Recursively traverse upwards to check if `pn` is a true
229
+ # match for `gn`
230
+ match_found = True
231
+
232
+ def _match_args(args1: Union[List, Tuple], args2: Union[List, Tuple]) -> bool:
233
+ if len(args1) != len(args2):
234
+ return False
235
+
236
+ for a1, a2 in zip(args1, args2):
237
+ if isinstance(a1, Node) and isinstance(a2, Node):
238
+ matched = self._match_nodes(a1, a2, match)
239
+ elif isinstance(a1, (list, tuple)) and isinstance(a2, (list, tuple)):
240
+ matched = _match_args(a1, a2)
241
+ else:
242
+ matched = self._match_literals(a1, a2, match) or self.ignore_literals
243
+
244
+ if not matched:
245
+ return False
246
+
247
+ return True
248
+
249
+ # Flatten all args/kwargs into 1 list of args
250
+ pn_args, gn_args = None, None
251
+ if (
252
+ (len(pn.args) != len(gn.args) or list(pn.kwargs.keys()) != list(gn.kwargs.keys())) and
253
+ pn.op == "call_function" and
254
+ isinstance(pn.target, torch._ops.OpOverload)
255
+ ):
256
+ args_schema = pn.target._schema.arguments
257
+
258
+ def get_all_arguments(orig_args, orig_kwargs):
259
+ all_args = []
260
+ for i, schema in enumerate(args_schema):
261
+ if schema.name in orig_kwargs:
262
+ all_args.append(orig_kwargs[schema.name])
263
+ elif not schema.kwarg_only and i < len(orig_args):
264
+ all_args.append(orig_args[i])
265
+ else:
266
+ all_args.append(schema.default_value)
267
+ return all_args
268
+
269
+ pn_args = get_all_arguments(pn.args, pn.kwargs)
270
+ gn_args = get_all_arguments(gn.args, gn.kwargs)
271
+
272
+ elif len(pn.args) == len(gn.args) and list(pn.kwargs.keys()) == list(gn.kwargs.keys()):
273
+ pn_args = list(pn.args)
274
+ gn_args = list(gn.args)
275
+ pn_args.extend(list(pn.kwargs.values()))
276
+ gn_args.extend(list(gn.kwargs.values()))
277
+ else:
278
+ match_found = False
279
+
280
+ match_found = (
281
+ match_found and
282
+ pn_args is not None and
283
+ gn_args is not None and
284
+ _match_args(pn_args, gn_args)
285
+ )
286
+
287
+ if not match_found:
288
+ # revert to saved_match before matching with current node
289
+ match = copy.copy(saved_match)
290
+ return False
291
+
292
+ return True
293
+
294
+ def match(self, graph: Graph) -> List[InternalMatch]:
295
+ """
296
+ Returns:
297
+ The matched subgraphs.
298
+ Thre returned subgraph would be fully self-contained, meaning the nodes (except placeholder
299
+ and nodes returned by output) can only be consumed by nodes within the matched subgraph.
300
+
301
+ Subgraph pattern matcher is implemented with the backtracking style in the following steps:
302
+
303
+ 1. We first identify all the anchor nodes in the pattern graph. The anchor nodes
304
+ are the "sinks" (nodes with no user other than the output node) of the pattern graph.
305
+ One pattern graph could have multiple anchors if it has multiple return values.
306
+
307
+ 2. In the target graph, we identify the potential candidate nodes that can be matched
308
+ with each anchor. These anchor-candidate pairs are the starting points for
309
+ pairwise per-node matching.
310
+
311
+ 3. For each anchor-candidate pair, we simultaneously traverse backwards (DFS) in both
312
+ pattern and target graphs. For every pattern nodes along traversal path, we compare it
313
+ against the target nodes. In case any comparison failed, the match for this anchor-candidate
314
+ pair fails. A match is found when DFS completes traversing the graph. See `self._match_nodes`
315
+ for more details.
316
+
317
+ 4. In the case of multiple anchors, every anchor will need to find a match using step 3.
318
+ In addition, the matches found between anchors need to have a common intersection node
319
+ in order for the match to be valid. This is implemented with backtracking. See `backtracking`
320
+ for more details.
321
+
322
+ Notice: graph traversal must be done in the reverser order because a tensor can have multiple
323
+ consumers, but can only have a single producer. Only with reverser order, we can we jointly
324
+ traverse the pattern and target graph in a deterministic path.
325
+
326
+ Warning: In theory, this backtracking algorithm have an **exponential** time complexity. However,
327
+ in practice, it's unlikely to blow up.
328
+
329
+ """
330
+ from torch.fx.passes.utils.fuser_utils import validate_partition
331
+
332
+ # find candidate nodes to match with pattern anchors
333
+ match_candidates: Dict[Node, List[Node]] = defaultdict(list)
334
+ for pattern_anchor in self.pattern_anchors:
335
+ for node in graph.nodes:
336
+ if self._nodes_are_equal(pattern_anchor, node):
337
+ match_candidates[pattern_anchor].append(node)
338
+ match_candidates_list = list(match_candidates.items())
339
+
340
+ logger.info("Initial match_candidates_list: %s\n", match_candidates_list)
341
+
342
+ matches: List[InternalMatch] = []
343
+
344
+ def backtracking(anchor_index, match):
345
+ if anchor_index == len(match_candidates_list):
346
+ match.placeholder_nodes = [match.nodes_map[pn] for pn in self.pattern_placeholder_nodes]
347
+ match.returning_nodes = [match.nodes_map[pn] for pn in self.pattern_returning_nodes]
348
+ matches.append(match)
349
+
350
+ logger.info("Found a match: %s\n", match)
351
+ return
352
+
353
+ pattern_anchor, candidate_nodes = match_candidates_list[anchor_index]
354
+ saved_match = copy.copy(match)
355
+
356
+ for node in candidate_nodes:
357
+ logger.info("Trying to match anchor %s to %s", pattern_anchor, node)
358
+
359
+ match_found = self._match_nodes(pattern_anchor, node, match)
360
+ if match_found:
361
+ # match next anchor
362
+ backtracking(anchor_index + 1, match)
363
+ else:
364
+ logger.info("Failed to match anchor %s to %s\n", pattern_anchor, node)
365
+
366
+ # revert to saved_match before matching with current anchor
367
+ match = copy.copy(saved_match)
368
+
369
+ match = InternalMatch(anchors=self.pattern_anchors)
370
+ if match_candidates_list:
371
+ backtracking(0, match)
372
+
373
+ # filter out the matches where the subgraph is not fully_contained
374
+ before = len(matches)
375
+ matches = [match for match in matches if self._is_contained(match.nodes_map)]
376
+ after = len(matches)
377
+ if before != after:
378
+ logger.info("Filtered out %s matches because they are not fully contained", before - after)
379
+
380
+ # filter out the matches that form a cycle if the subgraph is fused
381
+ valid_matches = []
382
+ for match in matches:
383
+ matched_compute_nodes = \
384
+ [gn for pn, gn in match.nodes_map.items() if pn.op not in {"placeholder", "output"}]
385
+ if validate_partition(matched_compute_nodes):
386
+ valid_matches.append(match)
387
+ if len(valid_matches) != len(matches):
388
+ logger.info("Filtered out %s matches because \
389
+ matched subgraph would form a cycle if fused", len(matches) - len(valid_matches))
390
+
391
+ if self.remove_overlapping_matches:
392
+ before = len(valid_matches)
393
+ matches = self._remove_overlapping_matches(valid_matches)
394
+ after = len(matches)
395
+ if before != after:
396
+ logger.info("Filtered out %s matches because matched subgraphs are overlapping", before - after)
397
+
398
+ logger.info("Matches returned: %s", matches)
399
+
400
+ return matches
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/pool.cpython-311.pyc ADDED
Binary file (2.89 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_expanded_weights.cpython-311.pyc ADDED
Binary file (3.52 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__init__.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .quantize import * # noqa: F403
2
+ from .observer import * # noqa: F403
3
+ from .qconfig import * # noqa: F403
4
+ from .fake_quantize import * # noqa: F403
5
+ from .fuse_modules import fuse_modules
6
+ from .stubs import * # noqa: F403
7
+ from .quant_type import * # noqa: F403
8
+ from .quantize_jit import * # noqa: F403
9
+
10
+ # from .quantize_fx import *
11
+ from .quantization_mappings import * # noqa: F403
12
+ from .fuser_method_mappings import * # noqa: F403
13
+
14
+
15
+ def default_eval_fn(model, calib_data):
16
+ r"""
17
+ Default evaluation function takes a torch.utils.data.Dataset or a list of
18
+ input Tensors and run the model on the dataset
19
+ """
20
+ for data, target in calib_data:
21
+ model(data)
22
+
23
+
24
+ __all__ = [
25
+ "QuantWrapper",
26
+ "QuantStub",
27
+ "DeQuantStub",
28
+ # Top level API for eager mode quantization
29
+ "quantize",
30
+ "quantize_dynamic",
31
+ "quantize_qat",
32
+ "prepare",
33
+ "convert",
34
+ "prepare_qat",
35
+ # Top level API for graph mode quantization on TorchScript
36
+ "quantize_jit",
37
+ "quantize_dynamic_jit",
38
+ "_prepare_ondevice_dynamic_jit",
39
+ "_convert_ondevice_dynamic_jit",
40
+ "_quantize_ondevice_dynamic_jit",
41
+ # Top level API for graph mode quantization on GraphModule(torch.fx)
42
+ # 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx
43
+ # 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',
44
+ "QuantType", # quantization type
45
+ # custom module APIs
46
+ "get_default_static_quant_module_mappings",
47
+ "get_static_quant_module_class",
48
+ "get_default_dynamic_quant_module_mappings",
49
+ "get_default_qat_module_mappings",
50
+ "get_default_qconfig_propagation_list",
51
+ "get_default_compare_output_module_list",
52
+ "get_quantized_operator",
53
+ "get_fuser_method",
54
+ # Sub functions for `prepare` and `swap_module`
55
+ "propagate_qconfig_",
56
+ "add_quant_dequant",
57
+ "swap_module",
58
+ "default_eval_fn",
59
+ # Observers
60
+ "ObserverBase",
61
+ "WeightObserver",
62
+ "HistogramObserver",
63
+ "observer",
64
+ "default_observer",
65
+ "default_weight_observer",
66
+ "default_placeholder_observer",
67
+ "default_per_channel_weight_observer",
68
+ # FakeQuantize (for qat)
69
+ "default_fake_quant",
70
+ "default_weight_fake_quant",
71
+ "default_fixed_qparams_range_neg1to1_fake_quant",
72
+ "default_fixed_qparams_range_0to1_fake_quant",
73
+ "default_per_channel_weight_fake_quant",
74
+ "default_histogram_fake_quant",
75
+ # QConfig
76
+ "QConfig",
77
+ "default_qconfig",
78
+ "default_dynamic_qconfig",
79
+ "float16_dynamic_qconfig",
80
+ "float_qparams_weight_only_qconfig",
81
+ # QAT utilities
82
+ "default_qat_qconfig",
83
+ "prepare_qat",
84
+ "quantize_qat",
85
+ # module transformations
86
+ "fuse_modules",
87
+ ]