koichi12 commited on
Commit
055b29c
·
verified ·
1 Parent(s): 76cb23d

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/autograd.py +274 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/__init__.py +0 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/__pycache__/_conversions.cpython-311.pyc +0 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/functional/__init__.py +1230 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/_numeric_suite_fx.cpython-311.pyc +0 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/_numeric_suite.py +526 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/graph_matcher.cpython-311.pyc +0 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/mappings.cpython-311.pyc +0 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/n_shadows_utils.cpython-311.pyc +0 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/weight_utils.cpython-311.pyc +0 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/weight_utils.py +275 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-311.pyc +0 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-311.pyc +0 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/utils.cpython-311.pyc +0 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/graph_signature.cpython-311.pyc +0 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_remove_auto_functionalized_pass.py +93 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py +0 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py +14 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py +29 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__init__.py +4 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-311.pyc +0 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py +3 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-311.pyc +0 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-311.pyc +0 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__init__.py +11 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__init__.py +0 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-311.pyc +0 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_manipulation.py +110 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/net_min_base.py +731 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/pass_manager.py +257 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/split_module.py +514 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/splitter_base.py +871 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tools_common.py +273 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/cpp.py +88 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/functional.py +0 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-311.pyc +0 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/linear_relu.py +15 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/__pycache__/__init__.cpython-311.pyc +0 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/__init__.py +1 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py +5 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-311.pyc +0 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/bn_relu.py +7 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/linear_relu.py +5 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/activation.cpython-311.pyc +0 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-311.pyc +0 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-311.pyc +0 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/distance.cpython-311.pyc +0 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-311.pyc +0 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/linear.cpython-311.pyc +0 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/normalization.cpython-311.pyc +0 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/autograd.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils._pytree as pytree
3
+ from collections import namedtuple
4
+ import functools
5
+
6
+
7
+ # NOTE [CustomOp autograd kernel indirection]
8
+ # We register `inner` as the autograd kernel for this custom_op.
9
+ # `inner` either calls the autograd formula registered by the user,
10
+ # or goes into an `autograd_not_implemented` kernel.
11
+ #
12
+ # The reason why this indirection exists is
13
+ # so that we can swap out the autograd kernel (the PyTorch dispatcher
14
+ # doesn't actually allow us to do this). By default, we want
15
+ # the `autograd_not_implemented` behavior, but then the user may come
16
+ # and register something that is actually a backward formula
17
+ def autograd_kernel_indirection(custom_op):
18
+ autograd_fallback = autograd_not_implemented(custom_op)
19
+
20
+ def inner(*args, **kwargs):
21
+ if custom_op._has_impl('autograd'):
22
+ kernel = custom_op._get_impl('autograd').func
23
+ return kernel(*args, **kwargs)
24
+ # As explained in NOTE ["backward", "save_for_backward", and "autograd"],
25
+ # after the user gives us "backward" and "save_for_backward", we generate
26
+ # the "autograd" impl. If the user only provided one, then we tell
27
+ # the user they've done something wrong.
28
+ if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'):
29
+ missing = (
30
+ 'save_for_backward' if custom_op._has_impl('backward')
31
+ else 'backward'
32
+ )
33
+ found = 'save_for_backward' if missing == 'backward' else 'backward'
34
+ loc = custom_op._get_impl(found).location
35
+ raise RuntimeError(
36
+ f"We found a '{found}' registration for {custom_op} at "
37
+ f"{loc} but were unable to find a '{missing}' registration. "
38
+ f"To use the CustomOp API to register a backward formula, "
39
+ f"please provide us both a backward function and a "
40
+ f"'save for backward' function via `impl_backward` and "
41
+ f"`impl_save_for_backward` respectively.")
42
+ return autograd_fallback(*args, **kwargs)
43
+ return inner
44
+
45
+
46
+ # TODO(#101191): Use the actual C++ autograd not implemented fallback,
47
+ # or change the default autograd fallback to the autograd not implemented fallback.
48
+ def autograd_not_implemented(custom_op):
49
+ def kernel(*args, **kwargs):
50
+ if torch.is_grad_enabled() and pytree.tree_any(
51
+ lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
52
+ ):
53
+ raise RuntimeError("Autograd has not been implemented for operator")
54
+ with torch._C._AutoDispatchBelowAutograd():
55
+ return custom_op(*args, **kwargs)
56
+ return kernel
57
+
58
+
59
+ def mark_non_differentiable(ctx, output, output_differentiability):
60
+ # Output types are restricted to be:
61
+ # - Tensor
62
+ # - Tensor[]
63
+ # - int, bool, Scalar, float
64
+ # See _check_can_register_backward
65
+ if output_differentiability is not None:
66
+ if not isinstance(output, tuple):
67
+ tuple_output = (output,)
68
+ else:
69
+ tuple_output = output # type: ignore[assignment]
70
+ assert len(output_differentiability) == len(tuple_output)
71
+ non_differentiable_tensors = []
72
+ for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)):
73
+ if isinstance(out, torch.Tensor):
74
+ if not differentiable:
75
+ non_differentiable_tensors.append(out)
76
+ continue
77
+ if isinstance(out, list):
78
+ if not differentiable:
79
+ non_differentiable_tensors.extend(out)
80
+ continue
81
+ if differentiable:
82
+ raise RuntimeError(
83
+ f"With output_differentiability={output_differentiability}. "
84
+ f"At idx {idx}, we received an object of type {type(out)} that "
85
+ f"is not a Tensor, so it cannot have be marked as differentiable in "
86
+ f"output_differentiability.")
87
+ if non_differentiable_tensors:
88
+ ctx.mark_non_differentiable(*non_differentiable_tensors)
89
+
90
+
91
+ def construct_autograd_kernel(
92
+ schema,
93
+ output_differentiability,
94
+ custom_op,
95
+ op_overload,
96
+ save_for_backward_fn,
97
+ backward_fn):
98
+
99
+ def apply(*args):
100
+ flat_args, spec = pytree.tree_flatten(args)
101
+ out_spec = None
102
+
103
+ def forward(ctx, *flat_args):
104
+ ctx.set_materialize_grads(True)
105
+ args = pytree.tree_unflatten(list(flat_args), spec)
106
+ with torch._C._AutoDispatchBelowAutograd():
107
+ output = op_overload(*args)
108
+
109
+ # We use the info about args to give better error messages in backward
110
+ args_info = namedtuple_args(
111
+ schema, pytree.tree_map(type, args))
112
+
113
+ save_for_backward_fn_inputs = namedtuple_args(schema, args)
114
+ to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
115
+
116
+ save_pytree_for_backward(ctx, (to_save, args_info))
117
+ mark_non_differentiable(ctx, output, output_differentiability)
118
+
119
+ nonlocal out_spec
120
+ flat_output, out_spec = pytree.tree_flatten(output)
121
+ return tuple(flat_output)
122
+
123
+ def backward(ctx, *flat_grad_output):
124
+ assert out_spec is not None
125
+ grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
126
+ saved, args_info = unpack_saved(ctx)
127
+ # There is nothing on the ctx object for now, it is just there so
128
+ # that we can add additional things in the future.
129
+ inner_ctx = object()
130
+ if not isinstance(grads, tuple):
131
+ grads = (grads,)
132
+ grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
133
+
134
+ # Massage the grad_inputs_dict to a form acceptable by
135
+ # autograd.Function.
136
+ validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
137
+ return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
138
+
139
+ generated_cls = gen_autograd_function(
140
+ custom_op._opname + '_customop', forward, backward)
141
+
142
+ flat_output = generated_cls.apply(*flat_args)
143
+ assert out_spec is not None
144
+ return pytree.tree_unflatten(list(flat_output), out_spec)
145
+ return apply
146
+
147
+
148
+ def gen_autograd_function(name, forward, backward):
149
+ generated_cls = type(
150
+ name,
151
+ (torch.autograd.Function,),
152
+ {
153
+ 'forward': staticmethod(forward),
154
+ 'backward': staticmethod(backward),
155
+ }
156
+ )
157
+ return generated_cls
158
+
159
+
160
+ @functools.lru_cache
161
+ def namedtuple_args_cls(schema):
162
+ attribs = [arg.name for arg in schema.arguments.flat_all]
163
+ name = str(schema.name) + "_args"
164
+ # mypy doesn't support dynamic namedtuple name
165
+ tuple_cls = namedtuple(name, attribs) # type: ignore[misc]
166
+ return tuple_cls
167
+
168
+
169
+ def namedtuple_args(schema, args):
170
+ assert isinstance(args, tuple)
171
+ tuple_cls = namedtuple_args_cls(schema)
172
+ return tuple_cls(*args)
173
+
174
+
175
+ def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
176
+ def error(what):
177
+ backward = forward_op._get_impl('backward')
178
+ raise RuntimeError(
179
+ f"In the backward function defined for {forward_op} at "
180
+ f"{backward.location} using the CustomOp API, {what}")
181
+
182
+ if not isinstance(grad_inputs_dict, dict):
183
+ error(f"expected the output of the backward function to be a dict but "
184
+ f"got {type(grad_inputs_dict)}")
185
+
186
+ expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all
187
+ if arg.type.is_tensor_like()}
188
+ actual_keys = grad_inputs_dict.keys()
189
+ if expected_keys != actual_keys:
190
+ error(f"expected the returned grad_input dict to have keys "
191
+ f"{expected_keys} but got {actual_keys}. The backward "
192
+ f"function must return a gradient (can be None) for each arg "
193
+ f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
194
+ f"Args declared to be non-Tensor-like types should not appear "
195
+ f"in the grad_input dict")
196
+
197
+ for name, grad in grad_inputs_dict.items():
198
+ arg_info = getattr(args_info, name)
199
+
200
+ if isinstance(arg_info, list):
201
+ if not isinstance(grad, (tuple, list)):
202
+ error(f"for input '{name}' expected the grad_input dict to "
203
+ f"hold a list of gradients but got object of type "
204
+ f"{type(grad)}.")
205
+ if not len(grad) == len(arg_info):
206
+ error(f"for input '{name}' expected the grad_input dict to "
207
+ f"hold a list of {len(arg_info)} gradients but got "
208
+ f"{len(grad)}")
209
+ for idx, (g, info) in enumerate(zip(grad, arg_info)):
210
+ if g is None:
211
+ continue
212
+ if not isinstance(g, torch.Tensor):
213
+ error(f"for input '{name}' expected the grad_input dict to "
214
+ f"hold a list of None or Tensor gradients but got "
215
+ f"object of {type(g)} at index {idx}")
216
+ if not issubclass(info, torch.Tensor):
217
+ error(f"for input '{name}', got a Tensor as the gradient "
218
+ f"for the {idx}-th value but expected None because "
219
+ f"the {idx}-th value was not a Tensor (it was "
220
+ f"type {arg_info}")
221
+ continue
222
+
223
+ if grad is None:
224
+ continue
225
+ if not isinstance(grad, torch.Tensor):
226
+ error(f"got object of type {type(grad)} as the gradient for input "
227
+ f"'{name}', "
228
+ f"but expected the gradient to be either None or a Tensor")
229
+ if not issubclass(arg_info, torch.Tensor):
230
+ error(f"got a Tensor as the gradient for input '{name}' but "
231
+ f"expected None as the gradient because input '{name}' "
232
+ f"was not a Tensor (it was type {arg_info}).")
233
+
234
+
235
+ def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
236
+ result = []
237
+ for name, arg_info in args_info._asdict().items():
238
+ if name not in grad_inputs_dict:
239
+ result.append(pytree.tree_map(lambda x: None, arg_info))
240
+ continue
241
+ result.append(grad_inputs_dict[name])
242
+ return tuple(pytree.tree_leaves(result))
243
+
244
+ # Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
245
+ # autograd.Function prefers that users use ctx.save_for_backward to
246
+ # save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
247
+ # ctx object.
248
+ def save_pytree_for_backward(ctx, stuff):
249
+ flat_stuff, spec = pytree.tree_flatten(stuff)
250
+ num_elts = len(flat_stuff)
251
+ tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
252
+ if isinstance(thing, torch.Tensor)]
253
+ non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
254
+ if not isinstance(thing, torch.Tensor)]
255
+ tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
256
+ non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
257
+
258
+ ctx.spec = spec
259
+ ctx.num_elts = num_elts
260
+ ctx.save_for_backward(*tensors)
261
+ ctx.tensor_idxs = tensor_idxs
262
+ ctx.saved_non_tensors = non_tensors
263
+ ctx.non_tensor_idxs = non_tensor_idxs
264
+
265
+
266
+ # Inverse operation to save_pytree_for_backward
267
+ def unpack_saved(ctx):
268
+ flat_stuff = [None] * ctx.num_elts
269
+ for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
270
+ flat_stuff[idx] = tensor
271
+ for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
272
+ flat_stuff[idx] = non_tensor
273
+ stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
274
+ return stuff
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/__init__.py ADDED
The diff for this file is too large to render. See raw diff
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/__pycache__/_conversions.cpython-311.pyc ADDED
Binary file (4.64 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/functional/__init__.py ADDED
@@ -0,0 +1,1230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import wraps
3
+ from typing import Callable, Optional, Union
4
+
5
+ import torch
6
+ import torch._prims as prims
7
+ import torch._prims_common as utils
8
+ import torch._refs as refs
9
+ from torch._decomp import register_decomposition
10
+ from torch._prims_common import (
11
+ ELEMENTWISE_TYPE_PROMOTION_KIND,
12
+ NumberType,
13
+ ShapeType,
14
+ TensorLike,
15
+ TensorLikeType,
16
+ )
17
+ from torch._prims_common.wrappers import (
18
+ elementwise_type_promotion_wrapper,
19
+ elementwise_unary_scalar_wrapper,
20
+ out_wrapper,
21
+ )
22
+ from torch._refs import _make_inplace
23
+
24
+ __all__ = [
25
+ "alpha_dropout",
26
+ "celu",
27
+ "celu_",
28
+ "dropout",
29
+ "elu",
30
+ "elu_",
31
+ "gelu",
32
+ "glu",
33
+ "group_norm",
34
+ "hardshrink",
35
+ "hardtanh",
36
+ "hinge_embedding_loss",
37
+ "huber_loss",
38
+ "l1_loss",
39
+ "layer_norm",
40
+ "leaky_relu",
41
+ "log_softmax",
42
+ "margin_ranking_loss",
43
+ "mish",
44
+ "mish_",
45
+ "mse_loss",
46
+ "nll_loss",
47
+ "pairwise_distance",
48
+ "pdist",
49
+ "poisson_nll_loss",
50
+ "prelu",
51
+ "relu",
52
+ "relu6",
53
+ "selu",
54
+ "selu_",
55
+ "smooth_l1_loss",
56
+ "softmax",
57
+ "softmin",
58
+ "softplus",
59
+ "softshrink",
60
+ "tanhshrink",
61
+ "threshold",
62
+ "threshold_",
63
+ "triplet_margin_loss",
64
+ ]
65
+
66
+ Tensor = torch.Tensor
67
+ aten = torch._ops.ops.aten
68
+ DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
69
+
70
+
71
+ def _dropout_helper(
72
+ self: TensorLikeType,
73
+ val: float,
74
+ ) -> TensorLikeType:
75
+ """
76
+ Helper function for all dropout-type operators. During training,
77
+ some of the elements of the input tensor are randomly masked.
78
+
79
+ Returns the masked tensor of the boolean values.
80
+
81
+ """
82
+
83
+ return (
84
+ refs._uniform_helper(
85
+ self.shape, low=0.0, high=1.0, dtype=torch.float32, device=self.device
86
+ )
87
+ < val
88
+ )
89
+
90
+
91
+ @register_decomposition(aten.alpha_dropout)
92
+ def alpha_dropout(
93
+ self: TensorLikeType, p: float = 0.5, training: bool = False, inplace: bool = False
94
+ ) -> TensorLikeType:
95
+ if inplace:
96
+ raise NotImplementedError
97
+
98
+ if not training:
99
+ return self
100
+
101
+ torch._check(
102
+ p <= 1 and p >= 0,
103
+ lambda: f"dropout probability has to be between 0 and 1, but got, {p}",
104
+ )
105
+
106
+ if p == 1:
107
+ return torch.zeros_like(self)
108
+
109
+ if p == 0:
110
+ return self
111
+
112
+ dropout_mask = _dropout_helper(self, 1 - p)
113
+
114
+ # From paper: Self-Normalizing Neural Networks (https://arxiv.org/pdf/1706.02515.pdf)
115
+ # alpha = - SELU.alpha * SELU.scale, here
116
+ # SELU.alpha = 1.6732632423543772848170429916717 and
117
+ # SELU.scale = 1.0507009873554804934193349852946
118
+ alpha = -1.7580993408473766
119
+
120
+ a = 1.0 / math.sqrt((alpha * alpha * p + 1) * (1 - p))
121
+ b = torch.logical_not(dropout_mask)
122
+ b = b * (alpha * a) + alpha * a * p
123
+ dropout_mask = a * dropout_mask
124
+
125
+ return self * dropout_mask + b
126
+
127
+
128
+ def _inplace_wrapper(fn):
129
+ """
130
+ Given a nn.functional non-linearity, implements its `inplace: bool` argument
131
+ """
132
+
133
+ # nb. We use the name of the first argument used in the unary references
134
+ @wraps(fn)
135
+ def _fn(a, *args, inplace=False, **kwargs):
136
+ if inplace:
137
+ torch._check(
138
+ "out" not in kwargs,
139
+ lambda: "Cannot set inplace=True and pass out= at the same time",
140
+ )
141
+ return fn(a, *args, inplace=False, out=a, **kwargs)
142
+ else:
143
+ return fn(a, *args, inplace=False, **kwargs)
144
+
145
+ return _fn
146
+
147
+
148
+ # celu is implemented specially because it has an alpha argument
149
+ # celu is very similar to elu
150
+ @register_decomposition(aten.celu)
151
+ @_inplace_wrapper
152
+ @out_wrapper()
153
+ @elementwise_type_promotion_wrapper(
154
+ type_promoting_args=("a",),
155
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
156
+ )
157
+ def celu(
158
+ a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False
159
+ ) -> TensorLikeType:
160
+ """
161
+ Reference implementation of torch.nn.functional.celu
162
+ """
163
+
164
+ if inplace:
165
+ raise NotImplementedError
166
+
167
+ rhs: TensorLikeType
168
+ if alpha is not None:
169
+ python_type = utils.dtype_to_type(a.dtype)
170
+ if not utils.is_weakly_lesser_type(type(alpha), python_type):
171
+ msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!"
172
+ raise ValueError(msg)
173
+ rhs = alpha * torch.expm1(torch.true_divide(a, alpha)) # type: ignore[arg-type]
174
+ else:
175
+ rhs = torch.expm1(a)
176
+
177
+ return torch.where(a > 0, a, rhs)
178
+
179
+
180
+ @_inplace_wrapper
181
+ @out_wrapper()
182
+ def dropout(
183
+ a: TensorLikeType, p: float = 0.5, training: bool = True, inplace: bool = False
184
+ ) -> TensorLikeType:
185
+ if inplace:
186
+ raise NotImplementedError
187
+
188
+ if not training:
189
+ return a
190
+
191
+ torch._check(
192
+ p <= 1 and p >= 0,
193
+ lambda: f"dropout probability has to be between 0 and 1, but got, {p}",
194
+ )
195
+
196
+ if p == 1:
197
+ return torch.zeros_like(a)
198
+
199
+ if p == 0:
200
+ return a
201
+
202
+ scale = 1 / (1 - p)
203
+ dropout_mask = _dropout_helper(a, 1 - p)
204
+
205
+ return a * dropout_mask * scale
206
+
207
+
208
+ @register_decomposition(aten.elu)
209
+ @_inplace_wrapper
210
+ @out_wrapper()
211
+ @elementwise_type_promotion_wrapper(
212
+ type_promoting_args=("a",),
213
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
214
+ )
215
+ def elu(
216
+ a: TensorLikeType,
217
+ alpha: NumberType = 1.0,
218
+ scale: NumberType = 1.0,
219
+ input_scale: NumberType = 1.0,
220
+ inplace: bool = False,
221
+ ) -> TensorLikeType:
222
+ """
223
+ Reference implementation of torch.nn.functional.elu
224
+ """
225
+ if inplace:
226
+ raise NotImplementedError
227
+
228
+ # nb. This should be factored out into a can_cast aux function
229
+ python_type = utils.dtype_to_type(a.dtype)
230
+ torch._check(
231
+ utils.is_weakly_lesser_type(type(input_scale), python_type),
232
+ lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!",
233
+ )
234
+ torch._check(
235
+ utils.is_weakly_lesser_type(type(scale), python_type),
236
+ lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!",
237
+ )
238
+ torch._check(
239
+ utils.is_weakly_lesser_type(type(alpha), python_type),
240
+ lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
241
+ )
242
+
243
+ return torch.where(a > 0, scale * a, (alpha * scale) * torch.expm1(a * input_scale))
244
+
245
+
246
+ @register_decomposition(aten.relu)
247
+ @_inplace_wrapper
248
+ @out_wrapper()
249
+ @elementwise_type_promotion_wrapper(
250
+ type_promoting_args=("a",),
251
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
252
+ )
253
+ def relu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType:
254
+ """
255
+ Reference implementation of torch.nn.functional.relu
256
+ """
257
+
258
+ if inplace:
259
+ raise NotImplementedError
260
+
261
+ return torch.where(torch.le(a, 0), 0, a)
262
+
263
+
264
+ def group_norm(
265
+ input: Tensor,
266
+ num_groups: int,
267
+ weight: Optional[Tensor] = None,
268
+ bias: Optional[Tensor] = None,
269
+ eps: float = 1e-5,
270
+ ) -> Tensor:
271
+ """
272
+ Reference implementation of :func:`torch.nn.functional.group_norm`.
273
+ """
274
+ torch._check(
275
+ input.ndim >= 2,
276
+ lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
277
+ )
278
+
279
+ batch_size = input.shape[0]
280
+ num_channels = input.shape[1]
281
+ torch._check(
282
+ num_channels % num_groups == 0,
283
+ lambda: "Expected number of channels in input to be divisible by num_groups, "
284
+ + f"but got input of shape {input.shape} and num_groups = {num_groups}",
285
+ )
286
+
287
+ # input shape is (N, C, *), so we flatten all inner dimensions except (N, C)
288
+ flattened_inner_size = 1
289
+ for dim_length in input.shape[2:]:
290
+ flattened_inner_size *= dim_length
291
+
292
+ return torch.native_group_norm(
293
+ input,
294
+ weight,
295
+ bias,
296
+ batch_size,
297
+ num_channels,
298
+ flattened_inner_size,
299
+ num_groups,
300
+ eps,
301
+ )[0]
302
+
303
+
304
+ def layer_norm(
305
+ input: Tensor,
306
+ normalized_shape: ShapeType,
307
+ weight: Optional[Tensor] = None,
308
+ bias: Optional[Tensor] = None,
309
+ eps: float = 1e-5,
310
+ ) -> Tensor:
311
+ """
312
+ Reference implementation of :func:`torch.nn.functional.layer_norm`.
313
+ """
314
+ return torch.native_layer_norm(input, normalized_shape, weight, bias, eps)[0]
315
+
316
+
317
+ @register_decomposition(aten.leaky_relu)
318
+ @_inplace_wrapper
319
+ @out_wrapper()
320
+ @elementwise_type_promotion_wrapper(
321
+ type_promoting_args=("a",),
322
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
323
+ )
324
+ def leaky_relu(
325
+ a: TensorLikeType, negative_slope: float = 0.01, inplace: bool = False
326
+ ) -> TensorLikeType:
327
+ """
328
+ Reference implementation of torch.nn.functional.leaky_relu
329
+ """
330
+
331
+ if inplace:
332
+ raise NotImplementedError
333
+
334
+ python_type = utils.dtype_to_type(a.dtype)
335
+ if not utils.is_weakly_lesser_type(type(negative_slope), python_type):
336
+ msg = f"negative_slope argument of type {type(negative_slope)} cannot be safely cast to type {python_type}!"
337
+ raise ValueError(msg)
338
+ return torch.where(torch.gt(a, 0), a, torch.mul(a, negative_slope))
339
+
340
+
341
+ @register_decomposition(aten.mish)
342
+ @_inplace_wrapper
343
+ @out_wrapper()
344
+ @elementwise_type_promotion_wrapper(
345
+ type_promoting_args=("a",),
346
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
347
+ )
348
+ def mish(a: TensorLikeType, inplace: bool = False) -> TensorLikeType:
349
+ """
350
+ Reference implementation of torch.nn.functional.mish
351
+ """
352
+
353
+ if inplace:
354
+ raise NotImplementedError
355
+ return a * torch.tanh(torch.nn.functional.softplus(a))
356
+
357
+
358
+ @register_decomposition(aten.selu)
359
+ @_inplace_wrapper
360
+ @out_wrapper()
361
+ @elementwise_type_promotion_wrapper(
362
+ type_promoting_args=("a",),
363
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
364
+ )
365
+ def selu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType:
366
+ """
367
+ Reference implementation of torch.nn.functional.selu
368
+ """
369
+ if inplace:
370
+ raise NotImplementedError
371
+
372
+ alpha = 1.6732632423543772848170429916717
373
+ scale = 1.0507009873554804934193349852946
374
+
375
+ rhs = alpha * torch.expm1(a)
376
+
377
+ return scale * torch.where(a > 0, a, rhs)
378
+
379
+
380
+ # Forwarding alias: the functional variant doesn't support the out kwarg
381
+ # CompositeImplicitAutograd - don't register decomp
382
+ def softmax(
383
+ a: TensorLikeType,
384
+ dim: Optional[int] = None,
385
+ _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True)
386
+ dtype: Optional[torch.dtype] = None,
387
+ ) -> TensorLikeType:
388
+ # The error is for compat with regular PyTorch, which has this behavior
389
+ # deprecated. For PrimTorch, it's fine to drop support for deprecated
390
+ # behavior because it requires explicit opt in. This error is to inform
391
+ # users how to update their calls.
392
+ torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
393
+ return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
394
+
395
+
396
+ # CompositeImplicitAutograd - don't register decomp
397
+ def softmin(
398
+ a: TensorLikeType,
399
+ dim: Optional[int] = None,
400
+ _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True)
401
+ dtype: Optional[torch.dtype] = None,
402
+ ) -> TensorLikeType:
403
+ # The error is for compat with regular PyTorch, which has this behavior
404
+ # deprecated. For PrimTorch, it's fine to drop support for deprecated
405
+ # behavior because it requires explicit opt in. This error is to inform
406
+ # users how to update their calls.
407
+ torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
408
+ return torch.softmax(a=-a, dim=dim, dtype=dtype) # type: ignore[call-overload]
409
+
410
+
411
+ # softplus is implemented specially because it has beta and threshold arguments
412
+ @register_decomposition(aten.softplus)
413
+ @_inplace_wrapper
414
+ @out_wrapper()
415
+ @elementwise_type_promotion_wrapper(
416
+ type_promoting_args=("a",),
417
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
418
+ )
419
+ def softplus(
420
+ a: TensorLikeType,
421
+ beta: Optional[NumberType] = None,
422
+ threshold: NumberType = 20,
423
+ inplace: bool = False,
424
+ ) -> TensorLikeType:
425
+ """
426
+ Reference implementation of torch.nn.functional.softplus
427
+ """
428
+
429
+ if inplace:
430
+ raise NotImplementedError
431
+
432
+ rhs: TensorLikeType
433
+ if beta is not None:
434
+ python_type = utils.dtype_to_type(a.dtype)
435
+ if not utils.is_weakly_lesser_type(type(beta), python_type):
436
+ msg = f"beta argument of type {type(beta)} cannot be safely cast to type {python_type}!"
437
+ raise ValueError(msg)
438
+ scaled_input = a * beta
439
+ rhs = torch.true_divide(torch.log1p(torch.exp(scaled_input)), beta) # type: ignore[arg-type]
440
+
441
+ else:
442
+ scaled_input = a
443
+ rhs = torch.log1p(torch.exp(scaled_input))
444
+
445
+ return torch.where(scaled_input > threshold, a, rhs)
446
+
447
+
448
+ @aten.hardshrink.default.py_impl(DispatchKey.Autograd)
449
+ @register_decomposition(aten.hardshrink)
450
+ @out_wrapper()
451
+ def hardshrink(a: TensorLikeType, lambd: float = 0.5):
452
+ # Formula for reference,
453
+ # hardshrink(x) = x if x > lambd
454
+ # = x if x < -lambd
455
+ # = 0 otherwise
456
+ return torch.where(torch.abs(a) <= lambd, 0, a)
457
+
458
+
459
+ @aten.softshrink.default.py_impl(DispatchKey.Autograd)
460
+ @register_decomposition(aten.softshrink)
461
+ @out_wrapper()
462
+ def softshrink(a: TensorLikeType, lambd: float = 0.5):
463
+ # Formula for reference,
464
+ # softshrink(x) = x - lambd if x > lambd
465
+ # = x + lambd if x < -lambd
466
+ # = 0 otherwise
467
+ torch._check(
468
+ lambd >= 0,
469
+ lambda: f"lambda must be greater or equal to 0, but found to be {lambd}",
470
+ )
471
+ # We implement this in one torch.where to generate better code in the backward
472
+ # see https://github.com/pytorch/pytorch/pull/107052#discussion_r1293748211
473
+ return torch.where(torch.abs(a) > lambd, a - torch.sign(a) * lambd, 0)
474
+
475
+
476
+ # Losses
477
+ def _reduction_int_to_str(reduction: int) -> str:
478
+ from torch._decomp.decompositions import Reduction
479
+
480
+ if reduction == Reduction.NONE.value:
481
+ return "none"
482
+ elif reduction == Reduction.MEAN.value:
483
+ return "mean"
484
+ elif reduction == Reduction.SUM.value:
485
+ return "sum"
486
+ else:
487
+ raise ValueError(f"{reduction} is not a valid value for reduction")
488
+
489
+
490
+ def _apply_loss_reduction(loss: TensorLikeType, reduction: str) -> TensorLikeType:
491
+ if reduction == "sum":
492
+ return torch.sum(loss)
493
+ elif reduction == "mean":
494
+ return torch.mean(loss)
495
+ else: # reduction == "none"
496
+ return loss
497
+
498
+
499
+ def _check_reduction_value(reduction: str):
500
+ if reduction not in ("mean", "sum", "none"):
501
+ raise ValueError(f"{reduction} is not a valid value for reduction")
502
+
503
+
504
+ # This helper function maps depreciated arguments, "size_average" and "reduce"
505
+ # to their corresponding "reduction" string argument
506
+ def _get_string_reduction_arg(
507
+ *, size_average: Optional[bool], reduce: Optional[bool]
508
+ ) -> str:
509
+ if size_average is None:
510
+ size_average = True
511
+ if reduce is None:
512
+ reduce = True
513
+ if size_average and reduce:
514
+ ret = "mean"
515
+ elif reduce:
516
+ ret = "sum"
517
+ else:
518
+ ret = "none"
519
+ return ret
520
+
521
+
522
+ # CompositeImplicitAutograd - don't register decomp
523
+ @elementwise_type_promotion_wrapper(
524
+ type_promoting_args=("input", "target"),
525
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
526
+ )
527
+ def l1_loss(
528
+ input: TensorLikeType,
529
+ target: TensorLikeType,
530
+ size_average: Optional[bool] = None,
531
+ reduce: Optional[bool] = None,
532
+ reduction: str = "mean",
533
+ ) -> TensorLikeType:
534
+ """
535
+ Reference implementation of torch.nn.functional.l1_loss
536
+ """
537
+ if size_average is not None or reduce is not None:
538
+ # TODO: Raise exception instead of converting value. This is only for
539
+ # primTorch since it can drop support for deprecated arguments.
540
+ # msg = "size_average and reduce args are deprecated, please use reduction argument."
541
+ reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
542
+ _check_reduction_value(reduction)
543
+ loss = torch.abs(input - target)
544
+ return _apply_loss_reduction(loss, reduction)
545
+
546
+
547
+ @elementwise_type_promotion_wrapper(
548
+ type_promoting_args=("input", "target"),
549
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
550
+ )
551
+ def smooth_l1_loss(
552
+ input: TensorLikeType,
553
+ target: TensorLikeType,
554
+ size_average: Optional[bool] = None,
555
+ reduce: Optional[bool] = None,
556
+ reduction: str = "mean",
557
+ beta: float = 1.0,
558
+ ) -> TensorLikeType:
559
+ """
560
+ Reference implementation of torch.nn.functional.smooth_l1_loss
561
+ """
562
+ if size_average is not None or reduce is not None:
563
+ # TODO: Raise exception instead of converting value. This is only for
564
+ # primTorch since it can drop support for deprecated arguments.
565
+ # msg = "size_average and reduce args are deprecated, please use reduction argument."
566
+ reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
567
+ _check_reduction_value(reduction)
568
+
569
+ if beta == 0.0:
570
+ return torch.nn.functional.l1_loss(
571
+ input, target, size_average=size_average, reduce=reduce, reduction=reduction
572
+ )
573
+ else:
574
+ loss = torch.abs(input - target)
575
+ loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta)
576
+ return _apply_loss_reduction(loss, reduction)
577
+
578
+
579
+ # Forwarding alias: the functional variant doesn't support the out kwarg
580
+ # CompositeImplicitAutograd - don't register decomp
581
+ def log_softmax(
582
+ a: TensorLikeType,
583
+ dim: Optional[int] = None,
584
+ _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True)
585
+ dtype: Optional[torch.dtype] = None,
586
+ ) -> TensorLikeType:
587
+ # The error is for compat with regular PyTorch, which has this behavior
588
+ # deprecated. For PrimTorch, it's fine to drop support for deprecated
589
+ # behavior because it requires explicit opt in. This error is to inform
590
+ # users how to update their calls.
591
+ torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
592
+ return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
593
+
594
+
595
+ @register_decomposition(aten.margin_ranking_loss)
596
+ def margin_ranking_loss(
597
+ input1: TensorLikeType,
598
+ input2: TensorLikeType,
599
+ target: TensorLikeType,
600
+ margin: float = 0.0,
601
+ reduction: str = "mean",
602
+ ) -> TensorLikeType:
603
+ # loss_without_reduction = max(0, −target * (input1 − input2) + margin)
604
+ if input1.ndim != input2.ndim or input1.ndim != target.ndim:
605
+ raise RuntimeError(
606
+ "margin_ranking_loss : All input tensors should have same dimension but got sizes: "
607
+ f"input1: {input1.shape}, input2: {input2.shape}, target: {target.shape} "
608
+ )
609
+ _check_reduction_value(reduction)
610
+ loss = torch.clamp_min(-target * (input1 - input2) + margin, 0)
611
+ return _apply_loss_reduction(loss, reduction)
612
+
613
+
614
+ @elementwise_type_promotion_wrapper(
615
+ type_promoting_args=("input", "target"),
616
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
617
+ )
618
+ def mse_loss(
619
+ input: TensorLikeType,
620
+ target: TensorLikeType,
621
+ size_average: Optional[bool] = None,
622
+ reduce: Optional[bool] = None,
623
+ reduction: str = "mean",
624
+ ) -> TensorLikeType:
625
+ if size_average is not None or reduce is not None:
626
+ # TODO: Raise exception instead of converting value. This is only for
627
+ # primTorch since it can drop support for deprecated arguments.
628
+ # msg = "size_average and reduce args are deprecated, please use reduction argument."
629
+ reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
630
+ _check_reduction_value(reduction)
631
+ loss = torch.pow(input - target, 2)
632
+ return _apply_loss_reduction(loss, reduction)
633
+
634
+
635
+ @register_decomposition(aten.hinge_embedding_loss)
636
+ def hinge_embedding_loss(
637
+ input: TensorLikeType,
638
+ target: TensorLikeType,
639
+ margin: float = 1.0,
640
+ reduction: str = "mean",
641
+ ) -> TensorLikeType:
642
+ # loss_without_reduction = input if y == 1
643
+ # = max(0, margin - input) if y == -1
644
+ _check_reduction_value(reduction)
645
+ margin_clamp = torch.clamp_min(margin - input, 0)
646
+ output_margin = torch.where(target != 1, margin_clamp, 0)
647
+ output_self = torch.where(target != -1, input, 0)
648
+ loss = output_margin + output_self
649
+ return _apply_loss_reduction(loss, reduction)
650
+
651
+
652
+ def _nll_loss_nd(
653
+ input: TensorLikeType,
654
+ target: TensorLikeType,
655
+ weight: Optional[TensorLikeType],
656
+ reduction: str,
657
+ ignore_index: int,
658
+ ) -> TensorLikeType:
659
+ torch._check(
660
+ input.ndim > 0 and input.ndim <= 3,
661
+ lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.",
662
+ )
663
+
664
+ torch._check(
665
+ (input.ndim == 1) or (input.shape[0] == target.shape[0]),
666
+ lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.",
667
+ )
668
+
669
+ _check_reduction_value(reduction)
670
+
671
+ flat_target = torch.flatten(target)
672
+ ignore_classes_mask = torch.eq(flat_target, ignore_index)
673
+
674
+ # TODO: Enable data-dependent checks with debug mode
675
+ # TODO: This check does not work with FakeTensor inputs; See Issue #85834
676
+ # Explicit cast for class_check to bool; See Issue #78071
677
+ """
678
+ from torch._subclasses.fake_tensor import FakeTensor
679
+ num_classes = input.shape[1] if input.ndim > 1 else input.shape[0]
680
+ valid_classes_mask = torch.logical_and(
681
+ (flat_target >= 0), (flat_target < num_classes)
682
+ )
683
+ class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask))
684
+ torch._check(
685
+ isinstance(target, FakeTensor) or bool(class_check.item()),
686
+ lambda: "A target class is out-of-bounds and not the ignore index.",
687
+ )
688
+ """
689
+
690
+ ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device)
691
+ class_weight = (
692
+ torch.scalar_tensor(1, dtype=input.dtype, device=input.device)
693
+ if weight is None
694
+ else weight[flat_target]
695
+ )
696
+ current_weight = torch.where(
697
+ ignore_classes_mask,
698
+ ignore_class_weight,
699
+ class_weight,
700
+ )
701
+
702
+ if input.ndim == 1:
703
+ # implicit batch size = 1
704
+ # input (1 batch size, C classes)
705
+ loss = -input[target] * current_weight
706
+ elif input.ndim == 2:
707
+ # input (N batch size, C classes)
708
+ batch_size = input.shape[0]
709
+ loss = -input[torch.arange(batch_size), target] * current_weight
710
+ else:
711
+ # 3D case (N batch size, C classe, K dimensions)
712
+ # input (N batch size, C classes, K)
713
+ batch_size = input.shape[0]
714
+ extent = input.shape[2]
715
+ numel = batch_size * extent
716
+ indices = torch.arange(numel)
717
+ bdx = indices // extent
718
+ kdx = indices % extent
719
+ loss = -input[bdx, flat_target, kdx] * current_weight
720
+ loss = torch.reshape(loss, target.shape)
721
+
722
+ if reduction == "none":
723
+ return loss
724
+ elif reduction == "sum":
725
+ return torch.sum(loss)
726
+ else:
727
+ # calculate weighted mean of the loss function
728
+ return torch.sum(loss) / torch.sum(current_weight)
729
+
730
+
731
+ @register_decomposition(aten.nll_loss)
732
+ @out_wrapper()
733
+ @elementwise_type_promotion_wrapper(
734
+ type_promoting_args=("input",),
735
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
736
+ )
737
+ def nll_loss(
738
+ input: TensorLikeType,
739
+ target: TensorLikeType,
740
+ weight: Optional[TensorLikeType] = None,
741
+ size_average: Optional[bool] = None,
742
+ ignore_index: int = -100,
743
+ reduce: Optional[bool] = None,
744
+ reduction: str = "mean",
745
+ ) -> TensorLikeType:
746
+ """
747
+ Reference implementation of torch.nn.functional.nll_loss
748
+ """
749
+ torch._check(
750
+ input.ndim > 0,
751
+ lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})",
752
+ )
753
+
754
+ # TODO: raise exception instead of converting value
755
+ # msg = "size_average and reduce args are deprecated, please use reduction argument."
756
+ # Convert these options for consistency with the eager mode
757
+ if size_average is not None or reduce is not None:
758
+ reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
759
+
760
+ # The expected behavior when the target and input have zero elements:
761
+ # reduction = 'none' --- tensor([])
762
+ # reduction = 'sum' --- tensor(0.)
763
+ # reduction = 'mean' --- tensor(nan)
764
+ # Mean reduction on empty tensors produces NaN. See the discussion in
765
+ # https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
766
+ if input.numel() == 0 and target.numel() == 0:
767
+ if reduction == "none":
768
+ return torch.zeros_like(target)
769
+ elif reduction == "sum":
770
+ return torch.empty_like(target)
771
+ else:
772
+ return torch.full_like(target, float("nan"))
773
+
774
+ # The _nll_loss_nd helper function handles the most common cases.
775
+ # ndim == 1 (Single Example)
776
+ # => Batch Size: 1, Input: (C), Target: ()
777
+ # ndim == 2 (k = 1)
778
+ # => Batch Size: N, Input: (N, C), Target: (N)
779
+ # ndim == 3 (k > 1)
780
+ # => Batch Size: N, Input: (N, C, K), Target: (N, K)
781
+ if input.ndim <= 3:
782
+ return _nll_loss_nd(input, target, weight, reduction, ignore_index)
783
+
784
+ # For ndim > 3, we reshape the input and target to 3-D case.
785
+ # Input (N batch-size, C classes, k-dimensions)
786
+ # Target (N batch-size, k-dimensions)
787
+ torch._check(
788
+ input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:],
789
+ lambda: (
790
+ "Expected input and target to both have ndim > 0 and "
791
+ "target.shape[1:] == input.shape[2:], but got "
792
+ f"target.shape {target.shape} and input.shape {input.shape}"
793
+ ),
794
+ )
795
+
796
+ batch_size = input.shape[0]
797
+ num_classes = input.shape[1]
798
+ out_size = [batch_size] + list(target.shape[1:])
799
+
800
+ input = torch.reshape(input, [batch_size, num_classes, -1])
801
+ target = torch.reshape(target, [batch_size, -1])
802
+ if reduction != "none":
803
+ return _nll_loss_nd(input, target, weight, reduction, ignore_index)
804
+ else:
805
+ result = _nll_loss_nd(input, target, weight, reduction, ignore_index)
806
+ # reshape flattened inner-dim to original k-dimensions
807
+ return torch.reshape(result, out_size)
808
+
809
+
810
+ # TODO: This ref supports int reduction and out kwarg to be compatible with ATen:
811
+ # https://github.com/pytorch/pytorch/issues/83931
812
+ # TODO: Could be rewritten to support complex:
813
+ # https://github.com/pytorch/pytorch/pull/85041
814
+ @register_decomposition(aten.huber_loss)
815
+ @out_wrapper()
816
+ @elementwise_type_promotion_wrapper(
817
+ type_promoting_args=("input", "target"),
818
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
819
+ )
820
+ def huber_loss(
821
+ input: TensorLikeType,
822
+ target: TensorLikeType,
823
+ reduction: Union[str, int] = "mean",
824
+ delta: float = 1.0,
825
+ ) -> TensorLikeType:
826
+ """
827
+ Reference implementation of torch.nn.functional.huber_loss
828
+ """
829
+ if type(reduction) is int:
830
+ reduction = _reduction_int_to_str(reduction)
831
+ _check_reduction_value(reduction) # type: ignore[arg-type]
832
+ torch._check(
833
+ delta > 0,
834
+ lambda: "huber_loss does not support non-positive values for delta.",
835
+ )
836
+ z = (input - target).abs()
837
+ loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta))
838
+ return _apply_loss_reduction(loss, reduction) # type: ignore[arg-type]
839
+
840
+
841
+ # tanhshrink does not use _make_elementwise_unary_reference because it does not support out
842
+ @elementwise_unary_scalar_wrapper
843
+ @elementwise_type_promotion_wrapper(
844
+ type_promoting_args=("a",),
845
+ type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
846
+ )
847
+ def tanhshrink(a: TensorLikeType) -> TensorLikeType:
848
+ """
849
+ Reference implementation of torch.nn.functional.tanhshrink
850
+ """
851
+ if not isinstance(a, TensorLike):
852
+ raise RuntimeError(
853
+ "Expected a tensor input for an elementwise unary operation!"
854
+ )
855
+ return a - torch.tanh(a)
856
+
857
+
858
+ @register_decomposition(aten.threshold)
859
+ @_inplace_wrapper
860
+ @out_wrapper()
861
+ @elementwise_type_promotion_wrapper(
862
+ type_promoting_args=("a",),
863
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
864
+ )
865
+ def threshold(
866
+ a: TensorLikeType,
867
+ threshold: NumberType,
868
+ value: Union[bool, int, float],
869
+ inplace: bool = False,
870
+ ) -> TensorLikeType:
871
+ """
872
+ Reference implementation of torch.nn.functional.threshold
873
+ """
874
+
875
+ if inplace:
876
+ raise NotImplementedError
877
+
878
+ return torch.where(a <= threshold, value, a)
879
+
880
+
881
+ # CompositeImplicitAutograd - don't register decomp
882
+ # No elementwise type promotion - core op doesn't explicitly type promote
883
+ def triplet_margin_loss(
884
+ anchor: TensorLikeType,
885
+ positive: TensorLikeType,
886
+ negative: TensorLikeType,
887
+ margin: float = 1.0,
888
+ p: float = 2,
889
+ eps: float = 1e-6,
890
+ swap: bool = False,
891
+ size_average: Optional[bool] = None,
892
+ reduce: Optional[bool] = None,
893
+ reduction: str = "mean",
894
+ ) -> TensorLikeType:
895
+ if size_average is not None or reduce is not None:
896
+ # TODO: Raise exception instead of converting value. This is only for
897
+ # primTorch since it can drop support for deprecated arguments.
898
+ # msg = "size_average and reduce args are deprecated, please use reduction argument."
899
+ reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
900
+
901
+ # torch.nn.functional.triplet_margin_with_distance_loss has no ref defined
902
+ # since it's a pure Python implementation. Use this helper instead.
903
+ return _triplet_margin_with_distance_loss(
904
+ anchor=anchor,
905
+ positive=positive,
906
+ negative=negative,
907
+ distance_function=lambda x, y: torch.pairwise_distance(x, y, p, eps),
908
+ margin=margin,
909
+ swap=swap,
910
+ reduction=reduction,
911
+ )
912
+
913
+
914
+ # Pure Python impl - don't register decomp and don't add a ref. Defined as a
915
+ # helper here since triplet_margin_loss can be nicely implemented with it.
916
+ def _triplet_margin_with_distance_loss(
917
+ anchor: TensorLikeType,
918
+ positive: TensorLikeType,
919
+ negative: TensorLikeType,
920
+ *,
921
+ distance_function: Optional[
922
+ Callable[[TensorLikeType, TensorLikeType], TensorLikeType]
923
+ ] = None,
924
+ margin: float = 1.0,
925
+ swap: bool = False,
926
+ reduction: str = "mean",
927
+ ) -> TensorLikeType:
928
+ _check_reduction_value(reduction)
929
+
930
+ a_dim = anchor.ndim
931
+ p_dim = positive.ndim
932
+ n_dim = negative.ndim
933
+ torch._check(
934
+ a_dim == p_dim and p_dim == n_dim,
935
+ lambda: (
936
+ f"The anchor, positive, and negative tensors are expected to have "
937
+ f"the same number of dimensions, but got: anchor {a_dim}D, "
938
+ f"positive {p_dim}D, and negative {n_dim}D inputs"
939
+ ),
940
+ )
941
+
942
+ if distance_function is None:
943
+ distance_function = torch.pairwise_distance
944
+
945
+ dist_pos = distance_function(anchor, positive)
946
+ dist_neg = distance_function(anchor, negative)
947
+ # The distance swap is described in the paper "Learning shallow
948
+ # convolutional feature descriptors with triplet losses" by V. Balntas, E.
949
+ # Riba et al. If True, and if the positive example is closer to the
950
+ # negative example than the anchor is, swaps the positive example and the
951
+ # anchor in the loss computation.
952
+ if swap:
953
+ dist_swap = distance_function(positive, negative)
954
+ dist_neg = torch.minimum(dist_neg, dist_swap)
955
+ loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)
956
+ return _apply_loss_reduction(loss, reduction)
957
+
958
+
959
+ @register_decomposition(aten.hardtanh)
960
+ @_inplace_wrapper
961
+ @out_wrapper()
962
+ @elementwise_unary_scalar_wrapper
963
+ @elementwise_type_promotion_wrapper(
964
+ type_promoting_args=("a"),
965
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
966
+ )
967
+ def hardtanh(
968
+ a: TensorLikeType,
969
+ min_val: NumberType = -1,
970
+ max_val: NumberType = 1,
971
+ inplace: bool = False,
972
+ ) -> TensorLikeType:
973
+ """
974
+ Reference implementation of torch.nn.functional.hardtanh
975
+ """
976
+ if inplace:
977
+ raise NotImplementedError
978
+ if utils.is_boolean_dtype(a.dtype):
979
+ raise RuntimeError("Bool inputs not supported for hardtanh")
980
+
981
+ # preserve legacy behavior of boundaries not causing type promotion
982
+ if utils.is_integer_dtype(a.dtype):
983
+ min_val = int(min_val) # type: ignore[arg-type]
984
+ max_val = int(max_val) # type: ignore[arg-type]
985
+ if not (a.dtype != torch.uint8 or (min_val >= 0 and max_val >= 0)):
986
+ raise RuntimeError(
987
+ "Cannot do hardtanh on an unsigned type with negative limits"
988
+ )
989
+ return torch.clamp(a, min_val, max_val) # type: ignore[arg-type]
990
+
991
+
992
+ @register_decomposition(aten.gelu)
993
+ @out_wrapper()
994
+ @elementwise_unary_scalar_wrapper
995
+ @elementwise_type_promotion_wrapper(
996
+ type_promoting_args=("a",),
997
+ type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
998
+ )
999
+ def gelu(a: TensorLikeType, approximate: str = "none") -> TensorLikeType:
1000
+ """
1001
+ Reference implementation of torch.nn.functional.gelu
1002
+ """
1003
+ if not isinstance(a, TensorLike):
1004
+ raise RuntimeError(
1005
+ "Expected a tensor input for an elementwise unary operation!"
1006
+ )
1007
+ M_SQRT2 = 1.41421356237309504880
1008
+ M_SQRT1_2 = 0.70710678118654752440
1009
+ M_2_SQRTPI = 1.12837916709551257390
1010
+ if approximate == "tanh":
1011
+ kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
1012
+ kKappa = 0.044715
1013
+ a_cube = a * a * a
1014
+ inner = kBeta * (a + kKappa * a_cube)
1015
+ return 0.5 * a * (1 + torch.tanh(inner))
1016
+ elif approximate == "none":
1017
+ kAlpha = M_SQRT1_2
1018
+ return a * 0.5 * (1 + torch.erf(a * kAlpha))
1019
+ else:
1020
+ raise RuntimeError("approximate argument must be either none or tanh.")
1021
+
1022
+
1023
+ # CompositeImplicitAutograd - don't register decomp
1024
+ @elementwise_type_promotion_wrapper(
1025
+ type_promoting_args=("input", "target"),
1026
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1027
+ )
1028
+ def poisson_nll_loss(
1029
+ input: TensorLikeType,
1030
+ target: TensorLikeType,
1031
+ log_input: bool = True,
1032
+ full: bool = False,
1033
+ size_average: Optional[bool] = None,
1034
+ eps: float = 1e-8,
1035
+ reduce: Optional[bool] = None,
1036
+ reduction: str = "mean",
1037
+ ) -> TensorLikeType:
1038
+ """
1039
+ Reference implementation of torch.nn.functional.poisson_nll_loss
1040
+ """
1041
+ if size_average is not None or reduce is not None:
1042
+ # TODO: Raise exception instead of converting value. This is only for
1043
+ # primTorch since it can drop support for deprecated arguments.
1044
+ # msg = "size_average and reduce args are deprecated, please use reduction argument."
1045
+ reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
1046
+ _check_reduction_value(reduction)
1047
+ if log_input:
1048
+ loss = torch.exp(input) - target * input
1049
+ else:
1050
+ loss = input - target * torch.log(input + eps)
1051
+
1052
+ if full:
1053
+ stirling_term = (
1054
+ target * torch.log(target) - target + 0.5 * torch.log(2 * torch.pi * target)
1055
+ )
1056
+ # avoid inplace add
1057
+ loss = loss + stirling_term.masked_fill(target <= 1, 0)
1058
+ return _apply_loss_reduction(loss, reduction)
1059
+
1060
+
1061
+ @register_decomposition(aten.prelu)
1062
+ @elementwise_type_promotion_wrapper(
1063
+ type_promoting_args=("a", "weight"),
1064
+ type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1065
+ )
1066
+ def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType:
1067
+ """
1068
+ Reference implementation of torch.nn.functional.prelu
1069
+ """
1070
+ torch._check(
1071
+ isinstance(a, TensorLike),
1072
+ lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}",
1073
+ )
1074
+ torch._check(
1075
+ isinstance(weight, TensorLike),
1076
+ lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}",
1077
+ )
1078
+
1079
+ if weight.numel() != 1:
1080
+ torch._check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.")
1081
+ channel_size = a.shape[1] if a.ndim >= 2 else 1
1082
+ torch._check(
1083
+ weight.numel() == channel_size,
1084
+ lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers ="
1085
+ f" {weight.numel()} and channel size = {channel_size}.",
1086
+ )
1087
+
1088
+ torch._check(
1089
+ weight.ndim == 0 or weight.ndim == 1,
1090
+ lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: "
1091
+ f"ndim = {weight.ndim}",
1092
+ )
1093
+ if a.ndim == 0:
1094
+ weight = weight[0] if weight.ndim == 1 else weight
1095
+ else:
1096
+ weight = prims.broadcast_in_dim(
1097
+ weight, a.shape, tuple() if weight.ndim == 0 else (0 if a.ndim == 1 else 1,)
1098
+ )
1099
+
1100
+ return torch.where(a > 0, a, a * weight)
1101
+
1102
+
1103
+ @register_decomposition(aten.relu6)
1104
+ @_inplace_wrapper
1105
+ @out_wrapper()
1106
+ def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType:
1107
+ """
1108
+ Reference implementation of torch.nn.functional.relu6
1109
+ """
1110
+ if inplace:
1111
+ raise NotImplementedError
1112
+
1113
+ # See https://github.com/pytorch/pytorch/pull/81142#discussion_r918220126
1114
+ # It may be better to use clamp here, but we use hardtanh to replicate
1115
+ # the behavior of the existing implementation
1116
+ return torch.nn.functional.hardtanh(a, 0, 6)
1117
+
1118
+
1119
+ @register_decomposition(aten.glu)
1120
+ @out_wrapper()
1121
+ @elementwise_type_promotion_wrapper(
1122
+ type_promoting_args=("a",),
1123
+ type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1124
+ )
1125
+ def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType:
1126
+ dim = utils.canonicalize_dims(a.ndim, dim)
1127
+ torch._check(
1128
+ a.shape[dim] % 2 == 0,
1129
+ lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}",
1130
+ )
1131
+ b, c = torch.tensor_split(a, 2, dim)
1132
+
1133
+ return b * torch.sigmoid(c)
1134
+
1135
+
1136
+ @register_decomposition(aten.pairwise_distance)
1137
+ @out_wrapper()
1138
+ def pairwise_distance(
1139
+ x1: TensorLikeType,
1140
+ x2: TensorLikeType,
1141
+ p: NumberType = 2.0,
1142
+ eps: NumberType = 1e-6,
1143
+ keepdim=False,
1144
+ ) -> TensorLikeType:
1145
+ return torch.linalg.vector_norm(x1 - x2 + eps, ord=p, dim=-1, keepdim=keepdim)
1146
+
1147
+
1148
+ @register_decomposition(aten.pdist)
1149
+ @out_wrapper()
1150
+ @elementwise_type_promotion_wrapper(
1151
+ type_promoting_args=("a",),
1152
+ type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1153
+ )
1154
+ def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType:
1155
+ torch._check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D")
1156
+ torch._check(p >= 0, lambda: "pdist only supports non-negative p values")
1157
+ # For p == 2 we can use an efficient implementation, but other values of p
1158
+ # require creating a much bigger tensor for an intermediate step
1159
+ if p == 2:
1160
+ aTa = torch.mm(a, a.T)
1161
+ aTa_diag = torch.diag(aTa)
1162
+ t = torch.sqrt(torch.clamp(aTa_diag + aTa_diag.unsqueeze(-1) - 2 * aTa, min=0))
1163
+ else:
1164
+ t = torch.linalg.vector_norm(a.unsqueeze(1) - a, ord=p, dim=2)
1165
+ i = torch.triu_indices(t.shape[0], t.shape[1], offset=1, device=a.device)
1166
+ return t.flatten().index_select(0, i[0] * t.shape[0] + i[1])
1167
+
1168
+
1169
+ @register_decomposition(aten.pixel_shuffle)
1170
+ @out_wrapper()
1171
+ def pixel_shuffle(self: Tensor, upscale_factor: int):
1172
+ torch._check(
1173
+ self.dim() >= 3,
1174
+ lambda: f"pixel_shuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)",
1175
+ )
1176
+ batch = self.shape[:-3]
1177
+ C_out = self.shape[-3] // upscale_factor**2
1178
+ HW_out = (self.shape[-2] * upscale_factor, self.shape[-1] * upscale_factor)
1179
+ n = len(batch)
1180
+ B_dims = range(n)
1181
+ C_dim, r1_dim, r2_dim, H_dim, W_dim = range(n, n + 5)
1182
+ return (
1183
+ self.view(
1184
+ *batch,
1185
+ C_out,
1186
+ upscale_factor,
1187
+ upscale_factor,
1188
+ self.shape[-2],
1189
+ self.shape[-1],
1190
+ )
1191
+ .permute(*B_dims, C_dim, H_dim, r1_dim, W_dim, r2_dim)
1192
+ .reshape(*batch, C_out, *HW_out)
1193
+ .clone(memory_format=utils.suggest_memory_format(self))
1194
+ )
1195
+
1196
+
1197
+ @register_decomposition(aten.pixel_unshuffle)
1198
+ @out_wrapper()
1199
+ def pixel_unshuffle(self: Tensor, downscale_factor: int):
1200
+ torch._check(
1201
+ self.dim() >= 3,
1202
+ lambda: f"pixel_unshuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)",
1203
+ )
1204
+ batch = self.shape[:-3]
1205
+ C_out = self.shape[-3] * downscale_factor**2
1206
+ HW_out = (self.shape[-2] // downscale_factor, self.shape[-1] // downscale_factor)
1207
+ n = len(batch)
1208
+ B_dims = range(n)
1209
+ C_dim, H_dim, r1_dim, W_dim, r2_dim = range(n, n + 5)
1210
+ return (
1211
+ self.view(
1212
+ *batch,
1213
+ self.shape[-3],
1214
+ HW_out[0],
1215
+ downscale_factor,
1216
+ HW_out[1],
1217
+ downscale_factor,
1218
+ )
1219
+ .permute(*B_dims, C_dim, r1_dim, r2_dim, H_dim, W_dim)
1220
+ .reshape(*batch, C_out, *HW_out)
1221
+ .clone(memory_format=utils.suggest_memory_format(self))
1222
+ )
1223
+
1224
+
1225
+ # Needed as aten.{celu_,elu_...} exist (even if they don't have the in-place kwarg)
1226
+ celu_ = _make_inplace(celu)
1227
+ elu_ = _make_inplace(elu)
1228
+ mish_ = _make_inplace(mish)
1229
+ selu_ = _make_inplace(selu)
1230
+ threshold_ = _make_inplace(threshold)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/_numeric_suite_fx.cpython-311.pyc ADDED
Binary file (40.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/_numeric_suite.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.ao.nn.quantized as nnq
4
+ import torch.ao.nn.quantized.dynamic as nnqd
5
+ from torch.ao.quantization import prepare
6
+ from typing import Dict, List, Optional, Any, Union, Callable, Set
7
+
8
+ from torch.ao.quantization.quantization_mappings import (
9
+ get_default_compare_output_module_list,
10
+ )
11
+
12
+ NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
13
+ nnqd.Linear,
14
+ nnq.Linear,
15
+ nnqd.LSTM,
16
+ nn.LSTM,
17
+ }
18
+
19
+
20
+ def _find_match(
21
+ str_list: Union[Dict[str, Any], List[str]], key_str: str,
22
+ postfix: str,
23
+ ) -> Optional[str]:
24
+ split_str = key_str.split(".")
25
+ if split_str[-1] == postfix:
26
+ match_string = "".join(key_str.split(".")[0:-1])
27
+ for s2 in str_list:
28
+ pattern1 = "".join(s2.split(".")[0:-1])
29
+ pattern2 = "".join(s2.split(".")[0:-2])
30
+ if match_string == pattern1:
31
+ return s2
32
+ if match_string == pattern2:
33
+ return s2
34
+
35
+ # For matching "fc.weight" and "fc._packed_params._packed_params"
36
+ if postfix == "_packed_params":
37
+ match_string = "".join(key_str.split(".")[0:-2])
38
+ if len(match_string) == 0:
39
+ return None
40
+ for s2 in str_list:
41
+ pattern1 = "".join(s2.split(".")[0:-1])
42
+ pattern2 = "".join(s2.split(".")[0:-2])
43
+ if match_string == pattern1:
44
+ return s2
45
+ if match_string == pattern2:
46
+ return s2
47
+ return None
48
+ else:
49
+ return None
50
+
51
+
52
+ def compare_weights(
53
+ float_dict: Dict[str, Any], quantized_dict: Dict[str, Any]
54
+ ) -> Dict[str, Dict[str, torch.Tensor]]:
55
+ r"""Compare the weights of the float module with its corresponding quantized
56
+ module. Return a dict with key corresponding to module names and each entry being
57
+ a dictionary with two keys 'float' and 'quantized', containing the float and
58
+ quantized weights. This dict can be used to compare and compute the quantization
59
+ error of the weights of float and quantized models.
60
+
61
+ Example usage::
62
+
63
+ wt_compare_dict = compare_weights(
64
+ float_model.state_dict(), qmodel.state_dict())
65
+ for key in wt_compare_dict:
66
+ print(
67
+ key,
68
+ compute_error(
69
+ wt_compare_dict[key]['float'],
70
+ wt_compare_dict[key]['quantized'].dequantize()
71
+ )
72
+ )
73
+
74
+ Args:
75
+ float_dict: state dict of the float model
76
+ quantized_dict: state dict of the quantized model
77
+
78
+ Return:
79
+ weight_dict: dict with key corresponding to module names and each entry being
80
+ a dictionary with two keys 'float' and 'quantized', containing the float and
81
+ quantized weights
82
+ """
83
+ torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights")
84
+ weight_dict: Dict[str, Dict] = {}
85
+ for key in quantized_dict:
86
+ match_key = _find_match(float_dict, key, "weight")
87
+ if match_key is not None:
88
+ weight_dict[key] = {}
89
+ weight_dict[key]["float"] = float_dict[match_key]
90
+ weight_dict[key]["quantized"] = quantized_dict[key]
91
+ continue
92
+
93
+ # For matching "fc.weight" and "fc._packed_params._packed_params"
94
+ match_key = _find_match(float_dict, key, "_packed_params")
95
+ if match_key is not None:
96
+ weight_dict[key] = {}
97
+ weight_dict[key]["float"] = float_dict[match_key]
98
+ weight_dict[key]["quantized"] = quantized_dict[key][0]
99
+
100
+ # For LSTM
101
+ split_str = key.split(".")
102
+ if split_str[-1] == "param" and split_str[-3] == "_all_weight_values":
103
+ layer = split_str[-2]
104
+ module_name = ".".join(split_str[:-3])
105
+ float_weight_ih_key = module_name + ".weight_ih_l" + layer
106
+ float_weight_hh_key = module_name + ".weight_hh_l" + layer
107
+ if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict:
108
+ weight_dict[key] = {}
109
+ weight_dict[key]["float"] = float_dict[float_weight_ih_key]
110
+ weight_dict[key]["quantized"] = (
111
+ quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0]
112
+ )
113
+ weight_dict[key]["float"] = float_dict[float_weight_hh_key]
114
+ weight_dict[key]["quantized"] = (
115
+ quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0]
116
+ )
117
+
118
+ return weight_dict
119
+
120
+
121
+ def _get_logger_dict_helper(
122
+ mod: nn.Module, target_dict: Dict[str, Any],
123
+ prefix: str = "",
124
+ ) -> None:
125
+ r"""This is the helper function for get_logger_dict
126
+
127
+ Args:
128
+ mod: module we want to save all logger stats
129
+ prefix: prefix for the current module
130
+ target_dict: the dictionary used to save all logger stats
131
+ """
132
+
133
+ def get_prefix(prefix):
134
+ return prefix if prefix == "" else prefix + "."
135
+
136
+ for name, child in mod.named_children():
137
+ if isinstance(child, Logger):
138
+ target_dict[get_prefix(prefix) + "stats"] = child.stats
139
+ break
140
+
141
+ for name, child in mod.named_children():
142
+ module_prefix = get_prefix(prefix) + name if prefix else name
143
+ _get_logger_dict_helper(child, target_dict, module_prefix)
144
+
145
+
146
+ def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:
147
+ r"""Traverse the modules and save all logger stats into target dict.
148
+ This is mainly used for quantization accuracy debug.
149
+
150
+ Type of loggers supported:
151
+ ShadowLogger: used to log the outputs of the quantized module and its matching float shadow module,
152
+ OutputLogger: used to log the outputs of the modules
153
+
154
+ Args:
155
+ mod: module we want to save all logger stats
156
+ prefix: prefix for the current module
157
+
158
+ Return:
159
+ target_dict: the dictionary used to save all logger stats
160
+
161
+ """
162
+ torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict")
163
+
164
+ target_dict: Dict[str, Dict] = {}
165
+ _get_logger_dict_helper(mod, target_dict, prefix)
166
+ return target_dict
167
+
168
+
169
+ class Logger(nn.Module):
170
+ r"""Base class for stats logging
171
+ """
172
+
173
+ def __init__(self):
174
+ super().__init__()
175
+ self.stats = {}
176
+ # We only insert observer if the op is quantized with static quantization,
177
+ # which is identified by activation_observer.dtype == quint8. This is needed
178
+ # when attaching Logger as observer for FX mode
179
+ self.dtype = torch.quint8
180
+
181
+ def forward(self, x):
182
+ """
183
+ """ # blank docblock to make autodoc happy
184
+ pass
185
+
186
+
187
+ class ShadowLogger(Logger):
188
+ r"""Class used in Shadow module to record the outputs of the original and
189
+ shadow modules.
190
+ """
191
+
192
+ def __init__(self):
193
+ super().__init__()
194
+ self.stats["float"] = []
195
+ self.stats["quantized"] = []
196
+
197
+ def forward(self, x, y):
198
+ """
199
+ """ # blank docblock to make autodoc happy
200
+ if len(x) > 1:
201
+ x = x[0]
202
+ if len(y) > 1:
203
+ y = y[0]
204
+ self.stats["quantized"].append(x.detach())
205
+ self.stats["float"].append(y.detach())
206
+
207
+
208
+ class OutputLogger(Logger):
209
+ r"""Class used to log the outputs of the module
210
+ """
211
+
212
+ def __init__(self):
213
+ super().__init__()
214
+ self.stats["tensor_val"] = []
215
+
216
+
217
+ def forward(self, x):
218
+ """
219
+ """ # blank docblock to make autodoc happy
220
+ self.stats["tensor_val"].append(x)
221
+ return x
222
+
223
+
224
+ def _convert_tuple_to_list(t: Any) -> Any:
225
+ return [_convert_tuple_to_list(x) for x in t] if type(t) is tuple else t
226
+
227
+
228
+ def _dequantize_tensor_list(t: Any) -> Any:
229
+ return (
230
+ [_dequantize_tensor_list(x) for x in t]
231
+ if type(t) is list
232
+ else t.dequantize()
233
+ if t.is_quantized
234
+ else t
235
+ )
236
+
237
+
238
+ class Shadow(nn.Module):
239
+ r"""Shadow module attaches the float module to its matching quantized module
240
+ as the shadow. Then it uses Logger module to process the outputs of both
241
+ modules.
242
+
243
+ Args:
244
+ q_module: module quantized from float_module that we want to shadow
245
+ float_module: float module used to shadow q_module
246
+ logger_cls: type of logger used to process the outputs of q_module and
247
+ float_module. ShadowLogger or custom loggers can be used.
248
+ """
249
+
250
+ def __init__(self, q_module, float_module, logger_cls):
251
+ super().__init__()
252
+ self.orig_module = q_module
253
+ self.shadow_module = float_module
254
+ self.dequant = nnq.DeQuantize()
255
+ self.logger = logger_cls()
256
+
257
+ def forward(self, *x) -> torch.Tensor:
258
+ """
259
+ """ # blank docblock to make autodoc happy
260
+ xl = _convert_tuple_to_list(x)
261
+ output = self.orig_module(*xl)
262
+ xl_float = _dequantize_tensor_list(xl)
263
+ shadow_output = self.shadow_module(*xl_float)
264
+ self.logger(output, shadow_output)
265
+ return output
266
+
267
+ def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
268
+ """
269
+ """ # blank docblock to make autodoc happy
270
+ output = self.orig_module.add(x, y)
271
+ x = x.dequantize()
272
+ y = y.dequantize()
273
+ shadow_output = self.shadow_module.add(x, y)
274
+ self.logger(output, shadow_output)
275
+ return output
276
+
277
+ def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
278
+ """
279
+ """ # blank docblock to make autodoc happy
280
+ output = self.orig_module.add_scalar(x, y)
281
+ x = x.dequantize()
282
+ shadow_output = self.shadow_module.add_scalar(x, y)
283
+ self.logger(output, shadow_output)
284
+ return output
285
+
286
+ def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
287
+ """
288
+ """ # blank docblock to make autodoc happy
289
+ output = self.orig_module.mul(x, y)
290
+ x = x.dequantize()
291
+ y = y.dequantize()
292
+ shadow_output = self.shadow_module.mul(x, y)
293
+ self.logger(output, shadow_output)
294
+ return output
295
+
296
+ def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
297
+ """
298
+ """ # blank docblock to make autodoc happy
299
+ output = self.orig_module.mul_scalar(x, y)
300
+ x = x.dequantize()
301
+ shadow_output = self.shadow_module.mul_scalar(x, y)
302
+ self.logger(output, shadow_output)
303
+ return output
304
+
305
+ def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
306
+ """
307
+ """ # blank docblock to make autodoc happy
308
+ output = self.orig_module.cat(x, dim)
309
+ x = [y.dequantize() for y in x]
310
+ shadow_output = self.shadow_module.cat(x, dim)
311
+ self.logger(output, shadow_output)
312
+ return output
313
+
314
+ def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
315
+ """
316
+ """ # blank docblock to make autodoc happy
317
+ output = self.orig_module.add_relu(x, y)
318
+ x = x.dequantize()
319
+ y = y.dequantize()
320
+ shadow_output = self.shadow_module.add_relu(x, y)
321
+ self.logger(output, shadow_output)
322
+ return output
323
+
324
+
325
+ def prepare_model_with_stubs(
326
+ float_module: nn.Module, q_module: nn.Module,
327
+ module_swap_list: Set[type], logger_cls: Callable,
328
+ ) -> None:
329
+ r"""Prepare the model by attaching the float module to its matching quantized
330
+ module as the shadow if the float module type is in module_swap_list.
331
+
332
+ Example usage::
333
+
334
+ prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger)
335
+ q_model(data)
336
+ ob_dict = get_logger_dict(q_model)
337
+
338
+ Args:
339
+ float_module: float module used to generate the q_module
340
+ q_module: module quantized from float_module
341
+ module_swap_list: list of float module types to attach the shadow
342
+ logger_cls: type of logger to be used in shadow module to process the outputs of
343
+ quantized module and its float shadow module
344
+ """
345
+ torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_with_stubs")
346
+
347
+ float_module_children = {}
348
+ for name, mod in float_module.named_children():
349
+ float_module_children[name] = mod
350
+
351
+ reassign = {}
352
+ for name, mod in q_module.named_children():
353
+
354
+ if name not in float_module_children:
355
+ continue
356
+
357
+ float_mod = float_module_children[name]
358
+
359
+ if type(float_mod) not in module_swap_list:
360
+ prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls)
361
+
362
+ # Insert shadow module only if the module is not of the same type as
363
+ # the floating point module
364
+ if type(float_mod) in module_swap_list and not _is_identical_module_type(mod, float_mod):
365
+ reassign[name] = Shadow(mod, float_mod, logger_cls)
366
+
367
+ for key, value in reassign.items():
368
+ q_module._modules[key] = value
369
+
370
+ def _is_identical_module_type(mod1, mod2):
371
+ # Compare if two modules have the same dtype
372
+ mod1_module_types = [type(mod) for mod in mod1.modules()]
373
+ mod2_module_types = [type(mod) for mod in mod2.modules()]
374
+ return mod1_module_types == mod2_module_types
375
+
376
+
377
+
378
+ def compare_model_stub(
379
+ float_model: nn.Module, q_model: nn.Module, module_swap_list: Set[type],
380
+ *data, logger_cls=ShadowLogger
381
+ ) -> Dict[str, Dict]:
382
+ r"""Compare quantized module in a model with its floating point counterpart,
383
+ feeding both of them the same input. Return a dict with key corresponding to
384
+ module names and each entry being a dictionary with two keys 'float' and
385
+ 'quantized', containing the output tensors of quantized and its matching
386
+ float shadow module. This dict can be used to compare and compute the module
387
+ level quantization error.
388
+
389
+ This function first call prepare_model_with_stubs() to swap the quantized
390
+ module that we want to compare with the Shadow module, which takes quantized
391
+ module, corresponding float module and logger as input, and creates a forward
392
+ path inside to make the float module to shadow quantized module sharing the
393
+ same input. The logger can be customizable, default logger is ShadowLogger
394
+ and it will save the outputs of the quantized module and float module that
395
+ can be used to compute the module level quantization error.
396
+
397
+ Example usage::
398
+
399
+ module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock]
400
+ ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data)
401
+ for key in ob_dict:
402
+ print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
403
+
404
+ Args:
405
+ float_model: float model used to generate the q_model
406
+ q_model: model quantized from float_model
407
+ module_swap_list: list of float module types at which shadow modules will
408
+ be attached.
409
+ data: input data used to run the prepared q_model
410
+ logger_cls: type of logger to be used in shadow module to process the outputs of
411
+ quantized module and its float shadow module
412
+ """
413
+ torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub")
414
+ prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls)
415
+ q_model(*data)
416
+ ob_dict = get_logger_dict(q_model)
417
+ return ob_dict
418
+
419
+
420
+ def get_matching_activations(
421
+ float_module: nn.Module, q_module: nn.Module,
422
+ ) -> Dict[str, Dict[str, torch.Tensor]]:
423
+ r"""Find the matching activation between float and quantized modules.
424
+
425
+ Args:
426
+ float_module: float module used to generate the q_module
427
+ q_module: module quantized from float_module
428
+
429
+ Return:
430
+ act_dict: dict with key corresponding to quantized module names and each
431
+ entry being a dictionary with two keys 'float' and 'quantized', containing
432
+ the matching float and quantized activations
433
+ """
434
+ torch._C._log_api_usage_once("quantization_api._numeric_suite.get_matching_activations")
435
+ float_dict = get_logger_dict(float_module)
436
+ quantized_dict = get_logger_dict(q_module)
437
+ act_dict: Dict[str, Dict] = {}
438
+ for key in quantized_dict:
439
+ if len(quantized_dict[key]["tensor_val"]) == 0:
440
+ continue
441
+ match_key = _find_match(sorted(float_dict, reverse=True), key, "stats")
442
+ if match_key is not None:
443
+ act_dict[key] = {}
444
+ act_dict[key]["float"] = float_dict[match_key]["tensor_val"]
445
+ act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"]
446
+ return act_dict
447
+
448
+
449
+ def prepare_model_outputs(
450
+ float_module: nn.Module,
451
+ q_module: nn.Module,
452
+ logger_cls=OutputLogger,
453
+ allow_list=None
454
+ ) -> None:
455
+ r"""Prepare the model by attaching the logger to both float module
456
+ and quantized module if they are in the allow_list.
457
+
458
+ Args:
459
+ float_module: float module used to generate the q_module
460
+ q_module: module quantized from float_module
461
+ logger_cls: type of logger to be attached to float_module and q_module
462
+ allow_list: list of module types to attach logger
463
+ """
464
+ torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_outputs")
465
+ if allow_list is None:
466
+ allow_list = get_default_compare_output_module_list()
467
+
468
+ qconfig_debug = torch.ao.quantization.QConfig(activation=logger_cls, weight=None)
469
+ float_module.qconfig = qconfig_debug # type: ignore[assignment]
470
+ prepare(float_module, inplace=True, allow_list=allow_list, prepare_custom_config_dict={})
471
+ q_module.qconfig = qconfig_debug # type: ignore[assignment]
472
+ prepare(
473
+ q_module,
474
+ inplace=True,
475
+ allow_list=allow_list,
476
+ observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
477
+ prepare_custom_config_dict={}
478
+ )
479
+
480
+
481
+ def compare_model_outputs(
482
+ float_model: nn.Module,
483
+ q_model: nn.Module,
484
+ *data,
485
+ logger_cls=OutputLogger,
486
+ allow_list=None
487
+ ) -> Dict[str, Dict[str, torch.Tensor]]:
488
+ r"""Compare output activations between float and quantized models at
489
+ corresponding locations for the same input. Return a dict with key corresponding
490
+ to quantized module names and each entry being a dictionary with two keys
491
+ 'float' and 'quantized', containing the activations of quantized model and
492
+ float model at matching locations. This dict can be used to compare and
493
+ compute the propagation quantization error.
494
+
495
+ Example usage::
496
+
497
+ act_compare_dict = compare_model_outputs(float_model, qmodel, data)
498
+ for key in act_compare_dict:
499
+ print(
500
+ key,
501
+ compute_error(
502
+ act_compare_dict[key]['float'],
503
+ act_compare_dict[key]['quantized'].dequantize()
504
+ )
505
+ )
506
+
507
+ Args:
508
+ float_model: float model used to generate the q_model
509
+ q_model: model quantized from float_model
510
+ data: input data used to run the prepared float_model and q_model
511
+ logger_cls: type of logger to be attached to float_module and q_module
512
+ allow_list: list of module types to attach logger
513
+
514
+ Return:
515
+ act_compare_dict: dict with key corresponding to quantized module names
516
+ and each entry being a dictionary with two keys 'float' and 'quantized',
517
+ containing the matching float and quantized activations
518
+ """
519
+ torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_outputs")
520
+ if allow_list is None:
521
+ allow_list = get_default_compare_output_module_list()
522
+ prepare_model_outputs(float_model, q_model, logger_cls, allow_list)
523
+ float_model(*data)
524
+ q_model(*data)
525
+ act_compare_dict = get_matching_activations(float_model, q_model)
526
+ return act_compare_dict
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/graph_matcher.cpython-311.pyc ADDED
Binary file (18.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/mappings.cpython-311.pyc ADDED
Binary file (21 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/n_shadows_utils.cpython-311.pyc ADDED
Binary file (42.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/weight_utils.cpython-311.pyc ADDED
Binary file (14.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/weight_utils.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.ao.nn.quantized.dynamic as nnqd
5
+ import torch.ao.nn.quantized as nnq
6
+ import torch.ao.nn.intrinsic.qat as nniqat
7
+ import torch.ao.nn.qat as nnqat
8
+ import torch.ao.nn.intrinsic as nni
9
+ import torch.ao.nn.intrinsic.quantized as nniq
10
+ toq = torch.ops.quantized
11
+ from torch.fx import GraphModule
12
+ from torch.fx.graph import Node
13
+
14
+ from .utils import (
15
+ get_target_type_str,
16
+ getattr_from_fqn,
17
+ return_first_non_observer_node,
18
+ )
19
+
20
+ from .ns_types import (
21
+ NSSingleResultValuesType,
22
+ NSSingleResultType,
23
+ )
24
+
25
+ from typing import List, Optional, Dict, Callable
26
+
27
+ def mod_weight_detach(mod: nn.Module) -> torch.Tensor:
28
+ return mod.weight.detach() # type: ignore[operator]
29
+
30
+ def mod_0_weight_detach(mod: nn.Module) -> torch.Tensor:
31
+ return mod[0].weight.detach() # type: ignore[index]
32
+
33
+ def mod_weight_bias_0(mod: nn.Module) -> torch.Tensor:
34
+ return mod._weight_bias()[0] # type: ignore[operator]
35
+
36
+ def get_lstm_weight(mod: nn.Module) -> List[torch.Tensor]:
37
+ res = []
38
+ for idx, param_name in enumerate(mod._flat_weights_names): # type: ignore[arg-type]
39
+ if 'weight_ih_l' in param_name or 'weight_hh_l' in param_name:
40
+ param_value = mod._flat_weights[idx].detach() # type: ignore[index]
41
+ res.append(param_value)
42
+ return res
43
+
44
+ def get_qlstm_weight(mod: nn.Module) -> List[torch.Tensor]:
45
+ res = []
46
+ for weight_value in mod._all_weight_values: # type: ignore[union-attr]
47
+ res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])
48
+ res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0])
49
+ return res
50
+
51
+ def get_conv_mod_weight(mod: nn.Module) -> torch.Tensor:
52
+ if (
53
+ isinstance(mod, (nn.Conv1d, nn.Conv2d, nn.Conv3d))
54
+ ):
55
+ return mod.weight.detach()
56
+ elif (
57
+ isinstance(mod, (nni.ConvReLU1d, nni.ConvReLU2d, nni.ConvReLU3d))
58
+ ):
59
+ return mod[0].weight.detach()
60
+ else:
61
+ return mod._weight_bias()[0] # type: ignore[operator]
62
+
63
+ def get_linear_mod_weight(mod: nn.Module) -> torch.Tensor:
64
+ if isinstance(mod, nn.Linear):
65
+ return mod.weight.detach()
66
+ elif isinstance(mod, nni.LinearReLU):
67
+ return mod[0].weight.detach()
68
+ else:
69
+ return mod._weight_bias()[0] # type: ignore[operator]
70
+
71
+ def get_lstm_mod_weights(mod: nn.Module) -> List[torch.Tensor]:
72
+ # TODO(future PR): make more generic, handle everything
73
+ if isinstance(mod, nn.LSTM):
74
+ res = []
75
+ for idx, param_name in enumerate(mod._flat_weights_names):
76
+ if 'weight_ih_l' in param_name or 'weight_hh_l' in param_name:
77
+ param_value = mod._flat_weights[idx].detach()
78
+ res.append(param_value)
79
+ return res
80
+ else:
81
+ assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet"
82
+ res = []
83
+ for weight_value in mod._all_weight_values:
84
+ res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])
85
+ res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0])
86
+ return res
87
+
88
+ def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
89
+ # traverse backwards from the weight arg, accounting for any observers
90
+ weight_arg_node = node.args[1]
91
+ assert isinstance(weight_arg_node, Node)
92
+ weight_node = return_first_non_observer_node(weight_arg_node, gm)
93
+ assert isinstance(weight_node, Node)
94
+ assert weight_node.op == 'get_attr'
95
+ weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
96
+ return weight.detach()
97
+
98
+ def get_qconv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
99
+ # qconv state is arg 1
100
+ qconv_state_node = node.args[1]
101
+ assert isinstance(qconv_state_node, Node)
102
+ assert qconv_state_node.op == 'get_attr'
103
+ qconv_state_obj = getattr_from_fqn(gm, qconv_state_node.target) # type: ignore[arg-type]
104
+ return qconv_state_obj.weight()
105
+
106
+ def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
107
+ # traverse backwards from the weight arg, accounting for any observers
108
+ # supported patterns:
109
+ # weight -> obs -> linear
110
+ # weight -> to(torch.float16) -> dequantize -> linear
111
+ linear_second_arg = node.args[1]
112
+ assert isinstance(linear_second_arg, Node)
113
+
114
+ if linear_second_arg.op == 'call_module':
115
+ # weight -> obs -> linear
116
+ weight_arg_node = node.args[1]
117
+ assert isinstance(weight_arg_node, Node)
118
+ weight_node = weight_arg_node.args[0]
119
+ assert isinstance(weight_node, Node)
120
+ assert weight_node.op == 'get_attr'
121
+ weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
122
+ return weight.detach()
123
+ elif linear_second_arg.op == 'call_method':
124
+ # weight -> to(torch.float16) -> dequantize -> linear
125
+ assert linear_second_arg.op == 'call_method'
126
+ dequant_node = node.args[1]
127
+ assert isinstance(dequant_node, Node)
128
+ to_fp16_node = dequant_node.args[0]
129
+ assert isinstance(to_fp16_node, Node)
130
+ # extract the dtype, so we can cast to it before returning
131
+ target_dtype = to_fp16_node.args[1]
132
+ weight_node = to_fp16_node.args[0]
133
+ assert isinstance(weight_node, Node)
134
+ assert weight_node.op == 'get_attr'
135
+ weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
136
+ # return the weight with fp16 cast
137
+ return weight.detach().to(target_dtype)
138
+ else:
139
+ assert linear_second_arg.op == 'get_attr'
140
+ weight = getattr_from_fqn(gm, linear_second_arg.target) # type: ignore[arg-type]
141
+ return weight.detach()
142
+
143
+ def get_qlinear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
144
+ # packed weight is arg 1
145
+ packed_weight_node = node.args[1]
146
+ assert isinstance(packed_weight_node, Node)
147
+ assert packed_weight_node.op == 'get_attr'
148
+ packed_weight = getattr_from_fqn(gm, packed_weight_node.target) # type: ignore[arg-type]
149
+ # TODO(future PR): why does packed_weight.unpack() not work?
150
+ (weight, _bias), _name = packed_weight.__getstate__()
151
+ return weight
152
+
153
+ def get_op_to_type_to_weight_extraction_fn() -> Dict[str, Dict[Callable, Callable]]:
154
+
155
+ op_to_type_to_weight_extraction_fn: Dict[str, Dict[Callable, Callable]] = {
156
+ 'call_module': {
157
+ # Conv1d
158
+ nn.Conv1d: mod_weight_detach,
159
+ nni.ConvReLU1d: mod_0_weight_detach,
160
+ nnq.Conv1d: mod_weight_bias_0,
161
+ nnqat.Conv1d: mod_weight_detach,
162
+ nniqat.ConvBn1d: mod_weight_detach,
163
+ nniqat.ConvBnReLU1d: mod_weight_detach,
164
+ nniqat.ConvReLU1d: mod_weight_detach,
165
+ nniq.ConvReLU1d: mod_weight_bias_0,
166
+ # Conv2d
167
+ nn.Conv2d: mod_weight_detach,
168
+ nni.ConvReLU2d: mod_0_weight_detach,
169
+ nnq.Conv2d: mod_weight_bias_0,
170
+ nnqat.Conv2d: mod_weight_detach,
171
+ nniqat.ConvBn2d: mod_weight_detach,
172
+ nniqat.ConvBnReLU2d: mod_weight_detach,
173
+ nniqat.ConvReLU2d: mod_weight_detach,
174
+ nniq.ConvReLU2d: mod_weight_bias_0,
175
+ # Conv3d
176
+ nn.Conv3d: mod_weight_detach,
177
+ nni.ConvReLU3d: mod_0_weight_detach,
178
+ nnq.Conv3d: mod_weight_bias_0,
179
+ nnqat.Conv3d: mod_weight_detach,
180
+ nniqat.ConvBn3d: mod_weight_detach,
181
+ nniqat.ConvBnReLU3d: mod_weight_detach,
182
+ nniqat.ConvReLU3d: mod_weight_detach,
183
+ nniq.ConvReLU3d: mod_weight_bias_0,
184
+ # Linear
185
+ nn.Linear: mod_weight_detach,
186
+ nnq.Linear: mod_weight_bias_0,
187
+ nni.LinearReLU: mod_0_weight_detach,
188
+ nniq.LinearReLU: mod_weight_bias_0,
189
+ nnqat.Linear: mod_weight_detach,
190
+ nnqd.Linear: mod_weight_bias_0,
191
+ nniqat.LinearReLU: mod_weight_detach,
192
+ nniqat.LinearBn1d: mod_weight_detach,
193
+ nn.modules.linear.NonDynamicallyQuantizableLinear: mod_weight_detach,
194
+ # LSTM
195
+ nn.LSTM: get_lstm_weight,
196
+ nnqd.LSTM: get_qlstm_weight,
197
+ },
198
+ 'call_function': {
199
+ # Conv
200
+ F.conv1d: get_conv_fun_weight,
201
+ F.conv2d: get_conv_fun_weight,
202
+ F.conv3d: get_conv_fun_weight,
203
+ toq.conv1d: get_qconv_fun_weight,
204
+ toq.conv2d: get_qconv_fun_weight,
205
+ toq.conv3d: get_qconv_fun_weight,
206
+ toq.conv1d_relu: get_qconv_fun_weight,
207
+ toq.conv2d_relu: get_qconv_fun_weight,
208
+ toq.conv3d_relu: get_qconv_fun_weight,
209
+ # Linear
210
+ F.linear: get_linear_fun_weight,
211
+ toq.linear: get_qlinear_fun_weight,
212
+ toq.linear_relu: get_qlinear_fun_weight,
213
+ },
214
+ }
215
+
216
+ return op_to_type_to_weight_extraction_fn
217
+
218
+ def extract_weight_from_node(
219
+ node: Node,
220
+ gm: GraphModule,
221
+ op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
222
+ ) -> Optional[NSSingleResultType]:
223
+ res_type = NSSingleResultValuesType.WEIGHT.value
224
+
225
+ # Not all graphmodules have _node_name_to_scope, so only fill it
226
+ # out if it exists.
227
+ fqn = None
228
+ if hasattr(gm, '_node_name_to_scope'):
229
+ fqn = gm._node_name_to_scope[node.name][0] # type: ignore[index]
230
+
231
+ if op_to_type_to_weight_extraction_fn is None:
232
+ op_to_type_to_weight_extraction_fn = get_op_to_type_to_weight_extraction_fn()
233
+
234
+ ref_node_type = get_target_type_str(node, gm)
235
+ # for extracting weights, these are always the same
236
+ prev_node_type = ref_node_type
237
+
238
+ if node.op == 'call_function':
239
+ function_mapping = op_to_type_to_weight_extraction_fn['call_function']
240
+ for target_fn_type, weight_extraction_fn in function_mapping.items():
241
+ if node.target == target_fn_type:
242
+ weight = weight_extraction_fn(node, gm)
243
+ return {
244
+ 'type': res_type,
245
+ 'values': [weight],
246
+ 'prev_node_name': node.name,
247
+ 'prev_node_target_type': prev_node_type,
248
+ 'ref_node_name': node.name,
249
+ 'ref_node_target_type': ref_node_type,
250
+ 'index_within_arg': 0,
251
+ 'index_of_arg': 0,
252
+ 'fqn': fqn,
253
+ }
254
+
255
+ elif node.op == 'call_module':
256
+ # for call_module, we need to look up the modules to do the type check
257
+ assert isinstance(node.target, str)
258
+ mod = getattr_from_fqn(gm, node.target)
259
+ module_mapping = op_to_type_to_weight_extraction_fn['call_module']
260
+ for target_mod_type, weight_extraction_fn in module_mapping.items():
261
+ if type(mod) == target_mod_type:
262
+ weight = weight_extraction_fn(mod)
263
+ return {
264
+ 'type': res_type,
265
+ 'values': [weight],
266
+ 'prev_node_name': node.name,
267
+ 'prev_node_target_type': prev_node_type,
268
+ 'ref_node_name': node.name,
269
+ 'ref_node_target_type': ref_node_type,
270
+ 'index_within_arg': 0,
271
+ 'index_of_arg': 0,
272
+ 'fqn': fqn,
273
+ }
274
+
275
+ return None
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-311.pyc ADDED
Binary file (25.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-311.pyc ADDED
Binary file (22.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/utils.cpython-311.pyc ADDED
Binary file (30.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/graph_signature.cpython-311.pyc ADDED
Binary file (27.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_remove_auto_functionalized_pass.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import operator
8
+ from typing import List
9
+
10
+ import torch
11
+ from torch._higher_order_ops.auto_functionalize import (
12
+ auto_functionalized,
13
+ get_mutable_arg_names,
14
+ )
15
+ from torch.export import ExportedProgram
16
+
17
+
18
+ def unsafe_remove_auto_functionalized_pass(
19
+ ep: ExportedProgram,
20
+ ) -> ExportedProgram:
21
+ """
22
+ This pass removes an instances of the higher order op 'auto_functionalized',
23
+ and modifies the calling EP inplace to have the original mutator op.
24
+ This pass doesn't perform safety checks to make sure that this inplace mutation is safe.
25
+ """
26
+ auto_functionalize_nodes: List[torch.fx.Node] = []
27
+ for module in ep.graph_module.modules():
28
+ if not isinstance(module, torch.fx.GraphModule):
29
+ continue
30
+ for node in ep.graph.nodes:
31
+ if node.op == "call_function" and node.target is auto_functionalized:
32
+ auto_functionalize_nodes.append(node)
33
+
34
+ # Update every use of the HOP
35
+ for node in reversed(auto_functionalize_nodes):
36
+ func = node.args[0]
37
+ original_kwargs = node.kwargs
38
+ assert isinstance(func, torch._ops.OpOverload)
39
+
40
+ with ep.graph.inserting_before(node):
41
+ # This makes the call_function refer to every arg as a kwarg, this is weird but probably fine?
42
+ new_node = ep.graph.call_function(func, kwargs=node.kwargs)
43
+ for k, v in node.meta.items():
44
+ new_node.meta[k] = v
45
+
46
+ # Replace auto_functionalize(func, args) with just func(args)
47
+ node.replace_all_uses_with(new_node)
48
+
49
+ mutable_args_names = get_mutable_arg_names(new_node.target)
50
+ output_specs = ep.graph_signature.output_specs
51
+
52
+ # update the users of the auto_func node (the getitem nodes)
53
+ for user in list(new_node.users.keys()):
54
+ assert user.target == operator.getitem
55
+ # getitem corresponding to a mutated input, just replace all uses with the original input
56
+ if user.args[1] >= len(func._schema.returns):
57
+ assert user.args[1] <= len(func._schema.returns) + len(
58
+ mutable_args_names
59
+ )
60
+
61
+ # If the result of getitem was used in an output node, update the output spec with the correct name
62
+ adusted_index = user.args[1] - len(func._schema.returns)
63
+ original_arg = original_kwargs[mutable_args_names[adusted_index]]
64
+ for spec in output_specs:
65
+ if spec.arg.name == user.name:
66
+ spec.arg.name = original_arg.name # pyre-ignore
67
+ break
68
+
69
+ # This is a little fragile/implementation dependent, but the order of the mutable args is the same as the order
70
+ # of the getitem calls following the HOP.
71
+ user.replace_all_uses_with(
72
+ original_kwargs[mutable_args_names[adusted_index]]
73
+ )
74
+
75
+ if len(func._schema.returns) == 1:
76
+ # If the function has 1 return then it will just directly return the
77
+ # result -- we don't need a getitem. So we can replace all the
78
+ # getitem(auto_functionalized, 0) with just the note itself.
79
+ for user in list(new_node.users.keys()):
80
+ if user.args[1] == 0:
81
+ user.replace_all_uses_with(new_node)
82
+
83
+ # Same case as above, update the output spec if getitem result used in an output node
84
+ for spec in output_specs:
85
+ if spec.arg.name == user.name:
86
+ spec.arg.name = new_node.name
87
+ break
88
+
89
+ new_node.meta["val"] = node.meta["val"][: len(func._schema.returns)]
90
+ ep.graph.erase_node(node)
91
+
92
+ ep.graph.eliminate_dead_code()
93
+ return ep
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ op_add = '+'
2
+ op_sub = '-'
3
+ op_mul = '*'
4
+ op_div = '/'
5
+ op_eq = '='
6
+ op_neq = '!='
7
+ op_imp = '=>'
8
+ op_matching = '⊳'
9
+ op_consistency = '~'
10
+ op_precision = '⊑'
11
+ op_leq = '≤'
12
+ op_lt = '<'
13
+ op_gt = '>'
14
+ op_mod = '%'
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import z3 # type: ignore[import]
3
+ HAS_Z3 = True
4
+ # dynamic type
5
+ dyn = z3.DeclareSort('Dyn')
6
+ dyn_type = z3.Const('dyn', dyn)
7
+
8
+ # dimension
9
+ dim = z3.Datatype('dim')
10
+ dim.declare('dim', ('0', z3.IntSort()), ('1', z3.IntSort()))
11
+ dim = dim.create()
12
+
13
+ # tensors
14
+ tensor_type = z3.Datatype('TensorType')
15
+ tensor_type.declare('Dyn', ('dyn', dyn))
16
+ tensor_type.declare('tensor1', ('0', dim))
17
+ tensor_type.declare('tensor2', ('0', dim), ('1', dim))
18
+ tensor_type.declare('tensor3', ('0', dim), ('1', dim), ('2', dim))
19
+ tensor_type.declare('tensor4', ('0', dim), ('1', dim), ('2', dim), ('3', dim))
20
+ tensor_type = tensor_type.create()
21
+
22
+ # create dimension
23
+ D = dim.dim
24
+
25
+ z3_dyn = tensor_type.Dyn(dyn_type)
26
+
27
+
28
+ except ImportError:
29
+ HAS_Z3 = False
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # mypy: disable-error-code=attr-defined
2
+ from .core import unify, reify # noqa: F403
3
+ from .more import unifiable # noqa: F403
4
+ from .variable import var, isvar, vars, variables, Var # noqa: F403
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-311.pyc ADDED
Binary file (7.12 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .core import dispatch
2
+ from .dispatcher import (Dispatcher, halt_ordering, restart_ordering,
3
+ MDNotImplementedError)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-311.pyc ADDED
Binary file (8.72 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-311.pyc ADDED
Binary file (6.37 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import graph_drawer
2
+ from . import graph_manipulation
3
+ from . import net_min_base
4
+ from . import operator_support
5
+ from . import param_fetch
6
+ from . import reinplace
7
+ from . import shape_prop
8
+ from . import split_module
9
+ from . import split_utils
10
+ from . import splitter_base
11
+ from . import tools_common
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (231 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_manipulation.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, NamedTuple, Optional
2
+
3
+ import torch
4
+ from torch.fx._compatibility import compatibility
5
+ from torch.fx.graph import Graph
6
+ from torch.fx.graph_module import GraphModule
7
+ from torch.fx.node import (
8
+ map_arg,
9
+ Node,
10
+ Target,
11
+ )
12
+ from torch.fx.passes.shape_prop import ShapeProp
13
+
14
+ __all__ = ['replace_target_nodes_with', 'size_bytes', 'get_size_of_all_nodes', 'get_tensor_meta',
15
+ 'get_size_of_node']
16
+
17
+ @compatibility(is_backward_compatible=False)
18
+ def replace_target_nodes_with(
19
+ fx_module: GraphModule,
20
+ old_op: str,
21
+ old_target: Target,
22
+ new_op: str,
23
+ new_target: Target,
24
+ ):
25
+ """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target,
26
+ and updates them to match the new op code and target"""
27
+ new_graph = Graph()
28
+ val_map: Dict[Node, Node] = {}
29
+ for node in fx_module.graph.nodes:
30
+ if node.op == old_op and node.target == old_target:
31
+ args = map_arg(node.args, lambda n: val_map[n])
32
+ kwargs = map_arg(node.kwargs, lambda n: val_map[n])
33
+ assert isinstance(args, tuple)
34
+ assert isinstance(kwargs, dict)
35
+ val_map[node] = new_graph.create_node(
36
+ new_op, new_target, args, kwargs, node.name
37
+ )
38
+ else:
39
+ val_map[node] = new_graph.node_copy(node, lambda n: val_map[n])
40
+ fx_module.graph = new_graph
41
+
42
+
43
+ @compatibility(is_backward_compatible=False)
44
+ class size_bytes(NamedTuple):
45
+ output_size: int
46
+ total_size: int
47
+
48
+
49
+ @compatibility(is_backward_compatible=False)
50
+ def get_size_of_all_nodes(
51
+ fx_module: GraphModule, args: Optional[List[torch.Tensor]] = None
52
+ ) -> None:
53
+ """Given a fx graph module, update each node with its total size (weights + bias + output)
54
+ and its output_size(output). For a non-module node, the total size is the output size.
55
+ return total size"""
56
+ if args is not None:
57
+ # Mark shape and dtype for each node (node.shape and node.dtype)
58
+ ShapeProp(fx_module).propagate(*args)
59
+ # Calculate the total size of the whole fx graph
60
+ total_size_of_graph = 0.0
61
+ for node in fx_module.graph.nodes:
62
+ if node.op == "output":
63
+ break
64
+ node.size_bytes = get_size_of_node(fx_module, node)
65
+ return
66
+
67
+
68
+ @compatibility(is_backward_compatible=False)
69
+ def get_tensor_meta(node: Node) -> Any:
70
+ tensor_meta = node.meta.get("tensor_meta")
71
+
72
+ if not tensor_meta:
73
+ raise RuntimeError(
74
+ f"Node {node} has no tensor metadata associated with it! "
75
+ f"Check that shape propagation has run."
76
+ )
77
+
78
+ return tensor_meta
79
+
80
+
81
+ @compatibility(is_backward_compatible=False)
82
+ def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
83
+ """Given a node with node.dtype and node.shape, return its total size and its output size.
84
+ total_size = weights + bias + output_size
85
+ """
86
+ # Total num of elements
87
+ total_num_of_elems = 0
88
+ # For a module, conside all parameters
89
+ if node.op == "call_module":
90
+ submodule_dict = dict(fx_module.named_modules())
91
+ submodule = submodule_dict[node.target]
92
+ parameters = submodule.named_parameters()
93
+ # Parameters are named tuples
94
+ for name, p in parameters:
95
+ total_num_of_elems += p.numel()
96
+ # Don't forget the output size
97
+ # node.shape is the shape of this node's output
98
+ tensor_meta = get_tensor_meta(node)
99
+ output_elem = tensor_meta.shape.numel()
100
+ total_num_of_elems += output_elem
101
+ # Assume for now if it's quantized then it's qint8 or quint8
102
+ if tensor_meta.is_quantized:
103
+ size_per_elem_bytes = torch._empty_affine_quantized(
104
+ [], dtype=tensor_meta.dtype
105
+ ).element_size()
106
+ else:
107
+ size_per_elem_bytes = torch.tensor([], dtype=tensor_meta.dtype).element_size()
108
+ total_size = size_per_elem_bytes * total_num_of_elems
109
+ output_size = size_per_elem_bytes * output_elem
110
+ return size_bytes(output_size, total_size)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/net_min_base.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.fx
7
+
8
+ from torch.fx._compatibility import compatibility
9
+ from torch.fx.node import map_arg
10
+
11
+ from .shape_prop import ShapeProp
12
+ from .split_utils import split_by_tags
13
+ from .tools_common import (
14
+ CALLABLE_NODE_OPS,
15
+ FxNetAccFusionsFinder,
16
+ Names,
17
+ NodeList,
18
+ NodeSet,
19
+ TensorOrTensors,
20
+ Tensors,
21
+ )
22
+
23
+ __all__ = [
24
+ "FxNetMinimizerBadModuleError",
25
+ "FxNetMinimizerRunFuncError",
26
+ "FxNetMinimizerResultMismatchError",
27
+ ]
28
+
29
+ _LOGGER = logging.getLogger(__name__)
30
+
31
+
32
+ @compatibility(is_backward_compatible=False)
33
+ class FxNetMinimizerBadModuleError(Exception):
34
+ """
35
+ Raised if failed to split out a minimize module
36
+ """
37
+
38
+ pass
39
+
40
+
41
+ @compatibility(is_backward_compatible=False)
42
+ class FxNetMinimizerRunFuncError(Exception):
43
+ """
44
+ Raised if error occurs during run_a or run_b functions
45
+ """
46
+
47
+ pass
48
+
49
+
50
+ @compatibility(is_backward_compatible=False)
51
+ class FxNetMinimizerResultMismatchError(Exception):
52
+ """
53
+ Raised if comparing function thinks the results are mismatching.
54
+ """
55
+
56
+ pass
57
+
58
+
59
+ @dataclass
60
+ class _MinimizerSettingBase:
61
+ """
62
+ Args:
63
+ `accumulate_error`: Instead of using a's input for both converted module to verify
64
+ , use the previous outputs of each converted module as input to accumulate the
65
+ errors.
66
+
67
+ `traverse_method`: "sequential" or "binary" or "accumulate"
68
+ Determine the way of traverse the nodes in FX module.
69
+
70
+ `find_all`: Minimizer will go through the entire model and return all problematic nodes.
71
+
72
+ `return_intermediate`: If true, when using `run_nodes()` function to run the
73
+ model, intermediate results of all the ops will be returned as output.
74
+ """
75
+
76
+ accumulate_error: bool = False
77
+ traverse_method: str = "sequential"
78
+ find_all: bool = False
79
+ return_intermediate: bool = False
80
+
81
+ def __str__(self):
82
+ settings_str = "FX Minimizer Settings:\n"
83
+
84
+ for k, v in vars(self).items():
85
+ settings_str += f"\t{k}: {v}\n"
86
+
87
+ return settings_str
88
+
89
+
90
+ class _MinimizerBase:
91
+ """
92
+ This class is used to automatically find problematic nodes in a model. It takes a FX
93
+ graphmodule and generate some submodules while traverse the graph. Then two functions
94
+ `run_a` and `run_b` will be used to run the same submodule and a function `compare_fn`
95
+ will be used to compare the results.
96
+
97
+ Currently we provides two ways to traverse the graph and generate submodules.
98
+ 1. Sequential traversal: this will traverse the graph node by node and generate
99
+ one submodule with one sigle node.
100
+ 2. Binary searching: this will do a binary search style traversal on the graph.
101
+
102
+ For internal Users, a guide can be found here https://fb.quip.com/HDtuAgiKGfkP.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ module: torch.fx.GraphModule,
108
+ sample_input: Tensors,
109
+ compare_fn: Callable[
110
+ [TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool]
111
+ ],
112
+ settings: _MinimizerSettingBase,
113
+ module_exporter: Optional[
114
+ Callable[
115
+ [List[torch.Tensor], torch.fx.GraphModule, str],
116
+ None
117
+ ]
118
+ ] = None,
119
+ ):
120
+ assert isinstance(module, torch.fx.GraphModule)
121
+
122
+ self.module = module
123
+ self.sample_input = sample_input
124
+ self.compare_fn = compare_fn
125
+ self.module_exporter = module_exporter
126
+ self.settings = settings
127
+
128
+ # Stores outputs of run_a function
129
+ self.a_outputs: Dict[str, Any] = {}
130
+
131
+ # Stores outputs of run_b function
132
+ self.b_outputs: Dict[str, Any] = {}
133
+
134
+ # Stores the results of compare_fn
135
+ self.results: Dict[Any, Any] = {}
136
+
137
+ # Stores the report for the runs
138
+ self.reports: List[List[str]] = []
139
+
140
+ # Current iteration
141
+ self.iteration: int = 0
142
+
143
+ callable_nodes = {
144
+ node for node in self.module.graph.nodes if node.op in CALLABLE_NODE_OPS
145
+ }
146
+ ShapeProp(self.module).propagate(*self.sample_input)
147
+ self.fusions = FxNetAccFusionsFinder(self.module, callable_nodes)()
148
+
149
+ # Check if number of input in sample_input matches the number of placeholders
150
+ placeholders = [
151
+ node.name for node in self.module.graph.nodes if node.op == "placeholder"
152
+ ]
153
+ assert len(placeholders) == len(self.sample_input)
154
+
155
+ # Store sample_input
156
+ for i, name in enumerate(placeholders):
157
+ self.a_outputs[name] = sample_input[i]
158
+ self.b_outputs[name] = sample_input[i]
159
+
160
+ def run_a(self, mod: torch.fx.GraphModule, inputs: Tensors) -> TensorOrTensors:
161
+ """
162
+ Run `mod` with `inputs` and generate output. The output will be compared with
163
+ output of run_b().
164
+ """
165
+ raise RuntimeError("run_a() is not implemented.")
166
+
167
+ def run_b(self, mod: torch.fx.GraphModule, inputs: Tensors) -> TensorOrTensors:
168
+ """
169
+ Run `mod` with `inputs` and generate output. The output will be compared with
170
+ output of run_a().
171
+ """
172
+ raise RuntimeError("run_b() is not implemented.")
173
+
174
+ def _store_outputs(
175
+ self,
176
+ a_result: TensorOrTensors,
177
+ b_result: TensorOrTensors,
178
+ submodule: torch.fx.GraphModule,
179
+ ):
180
+ """
181
+ Store the outputs of self.run_a() and self.run_b() into self.a_outputs and
182
+ self.b_outputs, so that we can use them when execute preceding nodes that
183
+ use those outputs as inputs.
184
+
185
+ Args:
186
+ a_result: Output of self.run_a(). Could be a tensor or tensors.
187
+ b_result: Output of self.run_b(). Could be a tensor or tensors.
188
+ submodule: The module that generates a_result and b_result.
189
+ """
190
+ output_node = next(
191
+ node for node in submodule.graph.nodes if node.op == "output"
192
+ )
193
+
194
+ # Only one output
195
+ if isinstance(output_node.args[0], torch.fx.Node):
196
+ self.a_outputs[output_node.args[0].name] = a_result
197
+ self.b_outputs[output_node.args[0].name] = b_result
198
+ # Multiple outputs
199
+ else:
200
+ for i, arg in enumerate(output_node.args[0]):
201
+ self.a_outputs[arg.name] = a_result[i]
202
+ self.b_outputs[arg.name] = b_result[i]
203
+
204
+ def _get_submod_inputs(
205
+ self, main_module: torch.fx.GraphModule, submod_path: str
206
+ ) -> Tuple[Tensors, Tensors]:
207
+ """
208
+ Try get submodule inputs from stored outputs. If not found then use
209
+ torch_glow.get_submod_inputs to get the inputs.
210
+
211
+ If accumulate_error is False, use a_input for run_a() and run_b()
212
+ otherwise use a_input for run_a and b_input for run_b.
213
+
214
+ Args:
215
+ main_module: Top-levlel fx module.
216
+ submod_path: Path to the submodule we want to run and compare results.
217
+
218
+ Returns:
219
+ a_input: List of tensor(s) that will be used by run_a() as submodule inputs.
220
+ b_input: List of tensor(s) that will be used by run_b() as submodule inputs.
221
+ """
222
+ a_input = []
223
+ b_input = []
224
+ submodule = getattr(main_module, submod_path)
225
+ placeholders = [
226
+ node.name for node in submodule.graph.nodes if node.op == "placeholder"
227
+ ]
228
+
229
+ # If all placeholder can be found in stored outputs, use stored
230
+ # outputs as inputs. Otherwise, use `torch_glow.get_submod_inputs`
231
+ # to get the inputs.
232
+ if set(placeholders) <= self.a_outputs.keys():
233
+ for name in placeholders:
234
+ a_input.append(self.a_outputs[name])
235
+ b_input.append(self.b_outputs[name])
236
+ else:
237
+ if self.settings.accumulate_error:
238
+ print(f"Can't find previous stored outputs named {placeholders}!")
239
+
240
+ def get_inputs(self: torch.nn.Module, inputs: Any):
241
+ nonlocal a_input
242
+ a_input = inputs
243
+
244
+ # Use forward hook to get the inputs to the submodule
245
+ handle = submodule.register_forward_pre_hook(get_inputs)
246
+ main_module(*self.sample_input)
247
+ handle.remove()
248
+
249
+ b_input = a_input
250
+
251
+ if not self.settings.accumulate_error:
252
+ return a_input, a_input
253
+
254
+ return a_input, b_input
255
+
256
+ def _tag_nodes(self, selected_nodes: NodeSet):
257
+ """
258
+ Tag selected nodes with tag "minimize". Nodes with the same tags will
259
+ be split to the same submodule afterwards.
260
+
261
+ Args:
262
+ selected_nodes: Nodes that we want to minimize. We will tag those nodes
263
+ with "minimize", all preceding nodes with "main_0" and all following
264
+ nodes with "main_1".
265
+ """
266
+ for node in self.module.graph.nodes:
267
+ if node.op not in CALLABLE_NODE_OPS:
268
+ continue
269
+
270
+ if node in selected_nodes:
271
+ node.tag = "minimize"
272
+ elif any(
273
+ n.tag in {"minimize", "main_1"}
274
+ for n in node.all_input_nodes
275
+ if n.op in CALLABLE_NODE_OPS
276
+ ):
277
+ node.tag = "main_1"
278
+ else:
279
+ node.tag = "main_0"
280
+
281
+ def _build_submodule(self, nodes: NodeSet) -> Tuple[torch.fx.GraphModule, str]:
282
+ """
283
+ Split self.module so that one submodule consists of `nodes` and only `nodes`.
284
+
285
+ Args:
286
+ nodes: Nodes that we want to include in the minimize submodule.
287
+
288
+ Returns:
289
+ split_module (torch.fx.GraphModule): the module after split.
290
+ submodule_name (str): the name of the submodule that consists of `nodes`.
291
+ """
292
+ # Color provided nodes
293
+ self._tag_nodes(nodes)
294
+
295
+ # Split module based on coloring
296
+ split_module = split_by_tags(self.module, ["main_0", "minimize", "main_1"])
297
+
298
+ # Find submodule containing colored nodes
299
+ submodule_name: str = ""
300
+ for child_name, _ in split_module.named_children():
301
+ # Skip submodules we're not interested in at the moment
302
+ if "minimize" not in child_name:
303
+ continue
304
+
305
+ if submodule_name == "":
306
+ submodule_name = child_name
307
+ else:
308
+ raise FxNetMinimizerBadModuleError(
309
+ f"Expected only one minimize submodule with nodes {nodes}"
310
+ )
311
+
312
+ if submodule_name == "":
313
+ raise FxNetMinimizerBadModuleError(
314
+ f"Minimize submodule was not found with nodes {nodes}"
315
+ )
316
+
317
+ return split_module, submodule_name
318
+
319
+ def _run_and_compare(
320
+ self, split_module: torch.fx.GraphModule, submod_name: str, output_names: Names
321
+ ):
322
+ """
323
+ Run the submodule in `split_module` that has name `submod_name`
324
+ using `self.run_a` and `self.run_b` and compare their results.
325
+
326
+ Args:
327
+ split_module: Main module that contains the minimize submodule.
328
+ submod_name: Name of the minimize submodule.
329
+ output_names: Names of the node we want to output. If None, we
330
+ will use the original output.
331
+ """
332
+ submodule = getattr(split_module, submod_name)
333
+ a_input, b_input = self._get_submod_inputs(split_module, submod_name)
334
+
335
+ if len(self.reports) == 0:
336
+ self.reports.append([])
337
+ self.iteration = 1
338
+
339
+ report = self.reports[self.iteration - 1]
340
+ report.append("Run and compare ...")
341
+
342
+ if output_names:
343
+ output_nodes: NodeList = []
344
+ for node in submodule.graph.nodes:
345
+ if node.op == "output":
346
+ submodule.graph.erase_node(node)
347
+
348
+ if node.name in output_names:
349
+ output_nodes.append(node)
350
+
351
+ submodule.graph.output(
352
+ output_nodes[0] if len(output_nodes) == 1 else tuple(output_nodes)
353
+ )
354
+ submodule.graph.lint()
355
+ submodule.recompile()
356
+
357
+ # Use name of args in output node as key to store comparison result
358
+ for node in submodule.graph.nodes:
359
+ if node.op == "output":
360
+ result_key = map_arg(node.args, lambda x: x.name)
361
+
362
+ try:
363
+ a_result = self.run_a(submodule, a_input)
364
+ b_result = self.run_b(submodule, b_input)
365
+ self._store_outputs(a_result, b_result, submodule)
366
+ except Exception as e:
367
+ report.append(f"Exception raised when running {submod_name}: {e}")
368
+ raise FxNetMinimizerRunFuncError( # noqa: TRY200
369
+ f"Exception raised when running {submod_name}: {e}"
370
+ )
371
+
372
+ # Compare results
373
+ names: Names = output_names
374
+ if output_names is None:
375
+ names = [str(v) for v in result_key] # type: ignore[possibly-undefined]
376
+
377
+ numeric_result, bool_result = self.compare_fn(a_result, b_result, names)
378
+
379
+ self.results[result_key] = numeric_result # type: ignore[possibly-undefined]
380
+ report.append(f"Numerical accuracy = {numeric_result}")
381
+ if not bool_result:
382
+ report.append(f"Result mismatch for {result_key}")
383
+ if self.module_exporter:
384
+ self.module_exporter(
385
+ List[torch.Tensor](a_input), submodule, str(result_key[0]) + "_cpu",
386
+ )
387
+ self.module_exporter(
388
+ List[torch.Tensor](b_input), submodule, str(result_key[0]) + "_acc",
389
+ )
390
+ raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}")
391
+
392
+ def _binary_search_impl(
393
+ self, all_nodes: NodeList, start_idx: int, end_idx: int
394
+ ) -> NodeSet:
395
+ """
396
+ Recursive binary search implementation.
397
+ """
398
+ nodes: NodeList = all_nodes[start_idx:end_idx]
399
+
400
+ report: List[str] = []
401
+ self.reports.append(report)
402
+ self.iteration += 1
403
+ report.append(f"Binary search iteration {self.iteration}.")
404
+ report.append(
405
+ f"From node index {start_idx} to {end_idx-1}. "
406
+ f"Size of the interested node list is {len(nodes)}"
407
+ )
408
+
409
+ cur_nodes: NodeSet = set(nodes)
410
+
411
+ for node in nodes:
412
+ if node in self.fusions:
413
+ cur_nodes.update(self.fusions[node])
414
+
415
+ try:
416
+ split_module, submod_name = self._build_submodule(cur_nodes)
417
+ self._run_and_compare(split_module, submod_name, [])
418
+ except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError):
419
+
420
+ if len(nodes) == 1:
421
+ report.append(
422
+ f"This is the last node in the sub-module. "
423
+ f"Search in the current branch is successful with culprit = {cur_nodes}."
424
+ )
425
+ self.print_report(report)
426
+ return cur_nodes
427
+
428
+ report.append(
429
+ "Proceed to split and lower the halves of the current "
430
+ "sub-module individually."
431
+ )
432
+ self.print_report(report)
433
+
434
+ mid = len(nodes) // 2
435
+ culprits = self._binary_search_impl(all_nodes, start_idx, start_idx + mid)
436
+
437
+ if len(culprits) != 0 and not self.settings.find_all:
438
+ return culprits
439
+
440
+ culprits = self._binary_search_impl(all_nodes, start_idx + mid, end_idx)
441
+
442
+ if len(culprits) == 0:
443
+ report.append(
444
+ f"Further split and lowering found no errors. "
445
+ f"Unable to minimize the submodule with list of nodes: {nodes}"
446
+ )
447
+ self.print_report(report)
448
+
449
+ return culprits
450
+ else:
451
+ report.append("No discrepancy found.")
452
+ self.print_report(report)
453
+ return set()
454
+
455
+ def _binary_traverse(self, nodes: NodeList) -> NodeSet:
456
+ """
457
+ Binary search on `nodes` for culprit.
458
+ """
459
+ return self._binary_search_impl(nodes, 0, len(nodes))
460
+
461
+ def _sequential_traverse(self, nodes: NodeList) -> NodeSet:
462
+ """
463
+ Traverse `nodes` one by one and determine if any of them is a culprit.
464
+ """
465
+ culprits: NodeSet = set()
466
+
467
+ for node in nodes:
468
+ report: List[str] = []
469
+ self.reports.append(report)
470
+ self.iteration += 1
471
+ report.append(f"Sequential traverse iteration {self.iteration}.")
472
+ report.append(f"Visit node: {node.name}")
473
+
474
+ _LOGGER.info("Visit node: %s", node.name)
475
+ cur_nodes: NodeSet = {node}
476
+
477
+ if node in self.fusions:
478
+ cur_nodes = self.fusions[node]
479
+
480
+ try:
481
+ split_module, submod_name = self._build_submodule(cur_nodes)
482
+ self._run_and_compare(split_module, submod_name, [node.name])
483
+ self.print_report(report)
484
+ except (FxNetMinimizerResultMismatchError):
485
+ culprits.add(node)
486
+ report.append(f"Found culprit from numeric error: {node}")
487
+ self.print_report(report)
488
+ if not self.settings.find_all:
489
+ return culprits
490
+ except (FxNetMinimizerRunFuncError):
491
+ culprits.update(cur_nodes)
492
+ report.append(f"Found culprit from run error: {node}")
493
+ self.print_report(report)
494
+ if not self.settings.find_all:
495
+ return culprits
496
+
497
+ return culprits
498
+
499
+ def _defined_traverse(self, nodes: NodeList) -> NodeSet:
500
+ """
501
+ run user defined `nodes` and determine if it is a culprit.
502
+ """
503
+ culprits: NodeSet = set()
504
+
505
+ first_node_name = nodes[0].name
506
+ output_node_name = nodes[-1].name
507
+ report = [f"Defined graph from {first_node_name} to {output_node_name}"]
508
+ cur_nodes: NodeSet = set(nodes)
509
+ try:
510
+ split_module, submod_name = self._build_submodule(cur_nodes)
511
+ self._run_and_compare(split_module, submod_name, [output_node_name])
512
+ self.print_report(report)
513
+ except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
514
+ report.append(f"Found culprit {cur_nodes}")
515
+ self.print_report(report)
516
+ return culprits
517
+
518
+ return culprits
519
+
520
+ def _accumulate_traverse(self, nodes: NodeList) -> NodeSet:
521
+ culprits: NodeSet = set()
522
+ nodes_to_run: NodeSet = set()
523
+
524
+ # find_all is not supported for accumulate traversal because all the
525
+ # ops run on NNPI. So we return after the first op that raises error.
526
+ if self.settings.find_all:
527
+ print("'Find All' mode is not supported in accumulate traversal.")
528
+ return culprits
529
+
530
+ for node in nodes:
531
+ report: List[str] = []
532
+ self.reports.append(report)
533
+ self.iteration += 1
534
+ report.append(f"Accumulate traverse iteration {self.iteration}.")
535
+
536
+ nodes_to_run.add(node)
537
+
538
+ node_name = node.name
539
+ if node_name is not None and isinstance(node_name, tuple):
540
+ node_name = node_name[0]
541
+ assert node_name is not None and isinstance(
542
+ node_name, str
543
+ ), f"minimize: node_name: {node_name}"
544
+
545
+ report.append(f"Add node: {node_name}")
546
+
547
+ try:
548
+ split_module, submod_name = self._build_submodule(nodes_to_run)
549
+ self._run_and_compare(split_module, submod_name, [node_name])
550
+ self.print_report(report)
551
+ except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
552
+ culprits.add(node)
553
+ report.append(f"Found culprit {node}")
554
+ self.print_report(report)
555
+ return culprits
556
+
557
+ return culprits
558
+
559
+ def _skip_traverse_impl(self, all_nodes: NodeList, start_idx: int, end_idx: int) -> NodeSet:
560
+ """
561
+ Skip certain nodes in graph based on settings
562
+ """
563
+ culprits: NodeSet = set()
564
+ nodes: NodeList = all_nodes[start_idx:end_idx]
565
+
566
+ report: List[str] = []
567
+ self.reports.append(report)
568
+ self.iteration += 1
569
+ report.append(f" Nodes block {self.iteration}.")
570
+ report.append(
571
+ f"From node index {start_idx} to {end_idx-1}. "
572
+ f"Size of the interested node list is {len(nodes)}"
573
+ )
574
+
575
+ cur_nodes: NodeSet = set(nodes)
576
+
577
+ for node in nodes:
578
+ if node in self.fusions:
579
+ cur_nodes.update(self.fusions[node])
580
+
581
+ try:
582
+ split_module, submod_name = self._build_submodule(cur_nodes)
583
+ self._run_and_compare(split_module, submod_name, [])
584
+ except (FxNetMinimizerResultMismatchError):
585
+ culprits.update(cur_nodes)
586
+ report.append(f"Found culprit from numeric error: {cur_nodes}")
587
+ self.print_report(report)
588
+ return culprits
589
+ except (FxNetMinimizerRunFuncError):
590
+ culprits.update(cur_nodes)
591
+ report.append(f"Found culprit from run error: {node}")
592
+ self.print_report(report)
593
+ return culprits
594
+ else:
595
+ report.append("No discrepancy found.")
596
+ self.print_report(report)
597
+ return set()
598
+
599
+
600
+ def _skip_traverse(self, all_nodes: NodeList, skip_nodes: List) -> NodeSet:
601
+ """
602
+ Skip certain nodes in graph based on settings
603
+ """
604
+ start_idx = 0
605
+ num_nodes = len(all_nodes)
606
+ idx = 0
607
+ culprits = set()
608
+ while idx < num_nodes:
609
+ node = all_nodes[idx]
610
+ if (node.name in skip_nodes): # skip the node
611
+ if idx > start_idx:
612
+ culprits = self._skip_traverse_impl(all_nodes, start_idx, idx)
613
+ start_idx = idx + 1
614
+ elif idx == num_nodes - 1 and start_idx <= idx: # last node
615
+ culprits = self._skip_traverse_impl(all_nodes, start_idx, idx + 1)
616
+ idx += 1
617
+
618
+ return culprits
619
+
620
+
621
+
622
+ def _collect_nodes(self, start: Optional[str], end: Optional[str]) -> NodeList:
623
+ """
624
+ Collect nodes in the model that between nodes with name of `start` and `end`.
625
+ These two nodes are also included.
626
+ """
627
+ nodes: NodeList = []
628
+ add_node = start is None
629
+
630
+ for node in self.module.graph.nodes:
631
+ if node.op not in CALLABLE_NODE_OPS:
632
+ continue
633
+
634
+ if node.name == start:
635
+ add_node = True
636
+
637
+ if add_node:
638
+ nodes.append(node)
639
+
640
+ if node.name == end:
641
+ break
642
+
643
+ return nodes
644
+
645
+ def run_nodes(self, start: Optional[str] = None, end: Optional[str] = None):
646
+ """
647
+ Run part of the model from `start` node to `end` node. If `start` is None
648
+ then we start from the beginning of the model. If `end` is None then we
649
+ stop at the end of the model.
650
+
651
+ Args:
652
+ start: The name of the node which is the first node of the submodule
653
+ we want to run. If set to None, then we'll start with the first
654
+ node of the model.
655
+ end: The name of the node which is the last node of the submodule we
656
+ want to run. If set to None, we'll end with the last node of the
657
+ model.
658
+ """
659
+ nodes = self._collect_nodes(start, end)
660
+ cur_nodes = set(nodes)
661
+
662
+ for node in nodes:
663
+ if node in self.fusions:
664
+ cur_nodes.update(self.fusions[node])
665
+
666
+ output_names = []
667
+ if self.settings.return_intermediate:
668
+ output_names = [node.name for node in nodes]
669
+
670
+ try:
671
+ split_module, submod_name = self._build_submodule(cur_nodes)
672
+ self._run_and_compare(split_module, submod_name, output_names)
673
+ except (
674
+ FxNetMinimizerRunFuncError,
675
+ FxNetMinimizerResultMismatchError,
676
+ ) as e:
677
+ print(e)
678
+
679
+ def print_report(self, report: List[str]):
680
+ for i in range(len(report)):
681
+ if i > 0:
682
+ print(" . " + report[i])
683
+ else:
684
+ print(report[i])
685
+
686
+ def print_reports(self):
687
+ for report in self.reports:
688
+ self.print_report(report)
689
+
690
+ def minimize(
691
+ self, start: Optional[str] = None, end: Optional[str] = None, skip_nodes: Optional[List] = None,
692
+ ) -> NodeSet:
693
+ """
694
+ Minimizing the model from node with name `start` to node with name `end` base
695
+ on self.settings. Find culprits that causes FxNetMinimizerRunFuncError or
696
+ FxNetMinimizerResultMismatchError errors.
697
+
698
+ Args:
699
+ start: The name of the node where we want to start minimizing. If set
700
+ to None, then we'll start with the first node of the model.
701
+ end: The name of the node where we want to terminate minimizing. If
702
+ set to None, we'll end with the last node of the model.
703
+
704
+ Returns:
705
+ nodes: A list of nodes that causes FxNetMinimizerRunFuncError or
706
+ FxNetMinimizerResultMismatchError errors during minimizing.
707
+ """
708
+
709
+ print(self.settings)
710
+ print(self.module.graph)
711
+
712
+ nodes = self._collect_nodes(start, end)
713
+
714
+ if self.settings.traverse_method == "sequential":
715
+ return self._sequential_traverse(nodes)
716
+
717
+ if self.settings.traverse_method == "binary":
718
+ return self._binary_traverse(nodes)
719
+
720
+ if self.settings.traverse_method == "accumulate":
721
+ return self._accumulate_traverse(nodes)
722
+
723
+ if self.settings.traverse_method == "skip":
724
+ if (skip_nodes is None):
725
+ raise RuntimeError("'skip_nodes' can't be None when 'traverse_method' is 'skip'.")
726
+ return self._skip_traverse(nodes, skip_nodes)
727
+
728
+ if self.settings.traverse_method == "defined":
729
+ return self._defined_traverse(nodes)
730
+
731
+ raise RuntimeError(f"Unknown traverse method {self.settings.traverse_method}!")
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/pass_manager.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ from inspect import unwrap
3
+ from typing import Callable, List, Optional
4
+ import logging
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ __all__ = [
9
+ "PassManager",
10
+ "inplace_wrapper",
11
+ "log_hook",
12
+ "loop_pass",
13
+ "this_before_that_pass_constraint",
14
+ "these_before_those_pass_constraint",
15
+ ]
16
+
17
+ # for callables which modify object inplace and return something other than
18
+ # the object on which they act
19
+ def inplace_wrapper(fn: Callable) -> Callable:
20
+ """
21
+ Convenience wrapper for passes which modify an object inplace. This
22
+ wrapper makes them return the modified object instead.
23
+
24
+ Args:
25
+ fn (Callable[Object, Any])
26
+
27
+ Returns:
28
+ wrapped_fn (Callable[Object, Object])
29
+ """
30
+
31
+ @wraps(fn)
32
+ def wrapped_fn(gm):
33
+ val = fn(gm)
34
+ return gm
35
+
36
+ return wrapped_fn
37
+
38
+ def log_hook(fn: Callable, level=logging.INFO) -> Callable:
39
+ """
40
+ Logs callable output.
41
+
42
+ This is useful for logging output of passes. Note inplace_wrapper replaces
43
+ the pass output with the modified object. If we want to log the original
44
+ output, apply this wrapper before inplace_wrapper.
45
+
46
+
47
+ ```
48
+ def my_pass(d: Dict) -> bool:
49
+ changed = False
50
+ if 'foo' in d:
51
+ d['foo'] = 'bar'
52
+ changed = True
53
+ return changed
54
+
55
+ pm = PassManager(
56
+ passes=[
57
+ inplace_wrapper(log_hook(my_pass))
58
+ ]
59
+ )
60
+ ```
61
+
62
+ Args:
63
+ fn (Callable[Type1, Type2])
64
+ level: logging level (e.g. logging.INFO)
65
+
66
+ Returns:
67
+ wrapped_fn (Callable[Type1, Type2])
68
+ """
69
+ @wraps(fn)
70
+ def wrapped_fn(gm):
71
+ val = fn(gm)
72
+ logger.log(level, "Ran pass %s\t Return value: %s", fn, val)
73
+ return val
74
+
75
+ return wrapped_fn
76
+
77
+
78
+
79
+ def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None):
80
+ """
81
+ Convenience wrapper for passes which need to be applied multiple times.
82
+
83
+ Exactly one of `n_iter`or `predicate` must be specified.
84
+
85
+ Args:
86
+ base_pass (Callable[Object, Object]): pass to be applied in loop
87
+ n_iter (int, optional): number of times to loop pass
88
+ predicate (Callable[Object, bool], optional):
89
+
90
+ """
91
+ assert (n_iter is not None) ^ (
92
+ predicate is not None
93
+ ), "Exactly one of `n_iter`or `predicate` must be specified."
94
+
95
+ @wraps(base_pass)
96
+ def new_pass(source):
97
+ output = source
98
+ if n_iter is not None and n_iter > 0:
99
+ for _ in range(n_iter):
100
+ output = base_pass(output)
101
+ elif predicate is not None:
102
+ while predicate(output):
103
+ output = base_pass(output)
104
+ else:
105
+ raise RuntimeError(
106
+ f"loop_pass must be given positive int n_iter (given "
107
+ f"{n_iter}) xor predicate (given {predicate})"
108
+ )
109
+ return output
110
+
111
+ return new_pass
112
+
113
+
114
+ # Pass Schedule Constraints:
115
+ #
116
+ # Implemented as 'depends on' operators. A constraint is satisfied iff a list
117
+ # has a valid partial ordering according to this comparison operator.
118
+ def _validate_pass_schedule_constraint(
119
+ constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
120
+ ):
121
+ for i, a in enumerate(passes):
122
+ for j, b in enumerate(passes[i + 1 :]):
123
+ if constraint(a, b):
124
+ continue
125
+ raise RuntimeError(
126
+ f"pass schedule constraint violated. Expected {a} before {b}"
127
+ f" but found {a} at index {i} and {b} at index{j} in pass"
128
+ f" list."
129
+ )
130
+
131
+
132
+ def this_before_that_pass_constraint(this: Callable, that: Callable):
133
+ """
134
+ Defines a partial order ('depends on' function) where `this` must occur
135
+ before `that`.
136
+ """
137
+
138
+ def depends_on(a: Callable, b: Callable):
139
+ if a == that and b == this:
140
+ return False
141
+ return True
142
+
143
+ return depends_on
144
+
145
+
146
+ def these_before_those_pass_constraint(these: Callable, those: Callable):
147
+ """
148
+ Defines a partial order ('depends on' function) where `these` must occur
149
+ before `those`. Where the inputs are 'unwrapped' before comparison.
150
+
151
+ For example, the following pass list and constraint list would be invalid.
152
+ ```
153
+ passes = [
154
+ loop_pass(pass_b, 3),
155
+ loop_pass(pass_a, 5),
156
+ ]
157
+
158
+ constraints = [
159
+ these_before_those_pass_constraint(pass_a, pass_b)
160
+ ]
161
+ ```
162
+
163
+ Args:
164
+ these (Callable): pass which should occur first
165
+ those (Callable): pass which should occur later
166
+
167
+ Returns:
168
+ depends_on (Callable[[Object, Object], bool]
169
+ """
170
+
171
+ def depends_on(a: Callable, b: Callable):
172
+ if unwrap(a) == those and unwrap(b) == these:
173
+ return False
174
+ return True
175
+
176
+ return depends_on
177
+
178
+
179
+ class PassManager:
180
+ """
181
+ Construct a PassManager.
182
+
183
+ Collects passes and constraints. This defines the pass schedule, manages
184
+ pass constraints and pass execution.
185
+
186
+ Args:
187
+ passes (Optional[List[Callable]]): list of passes. A pass is a
188
+ callable which modifies an object and returns modified object
189
+ constraint (Optional[List[Callable]]): list of constraints. A
190
+ constraint is a callable which takes two passes (A, B) and returns
191
+ True if A depends on B and False otherwise. See implementation of
192
+ `this_before_that_pass_constraint` for example.
193
+ """
194
+
195
+ passes: List[Callable]
196
+ constraints: List[Callable]
197
+ _validated: bool = False
198
+
199
+ def __init__(
200
+ self,
201
+ passes=None,
202
+ constraints=None,
203
+ ):
204
+ self.passes = passes or []
205
+ self.constraints = constraints or []
206
+
207
+ @classmethod
208
+ def build_from_passlist(cls, passes):
209
+ pm = PassManager(passes)
210
+ # TODO(alexbeloi): add constraint management/validation
211
+ return pm
212
+
213
+ def add_pass(self, _pass: Callable):
214
+ self.passes.append(_pass)
215
+ self._validated = False
216
+
217
+ def add_constraint(self, constraint):
218
+ self.constraints.append(constraint)
219
+ self._validated = False
220
+
221
+ def remove_pass(self, _passes: List[str]):
222
+ if _passes is None:
223
+ return
224
+ passes_left = []
225
+ for ps in self.passes:
226
+ if ps.__name__ not in _passes:
227
+ passes_left.append(ps)
228
+ self.passes = passes_left
229
+ self._validated = False
230
+
231
+ def replace_pass(self, _target, _replacement):
232
+ passes_left = []
233
+ for ps in self.passes:
234
+ if ps.__name__ == _target.__name__:
235
+ passes_left.append(_replacement)
236
+ else:
237
+ passes_left.append(ps)
238
+ self.passes = passes_left
239
+ self._validated = False
240
+
241
+ def validate(self):
242
+ """
243
+ Validates that current pass schedule defined by `self.passes` is valid
244
+ according to all constraints in `self.constraints`
245
+ """
246
+ if self._validated:
247
+ return
248
+ for constraint in self.constraints:
249
+ _validate_pass_schedule_constraint(constraint, self.passes)
250
+ self._validated = True
251
+
252
+ def __call__(self, source):
253
+ self.validate()
254
+ out = source
255
+ for _pass in self.passes:
256
+ out = _pass(out)
257
+ return out
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/split_module.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING
3
+ from collections import OrderedDict
4
+ import logging
5
+
6
+ import torch
7
+ from torch.fx._compatibility import compatibility
8
+ from torch.fx.graph_module import GraphModule
9
+ from torch.fx.node import Node
10
+
11
+ if TYPE_CHECKING:
12
+ import sympy # noqa: F401
13
+
14
+ __all__ = ["Partition", "split_module"]
15
+ _LOGGER = logging.getLogger(__name__)
16
+
17
+ @compatibility(is_backward_compatible=True)
18
+ class Partition:
19
+ def __init__(self, name: str):
20
+ self.name: str = name
21
+ self.submod_name = f"submod_{name}"
22
+ self.node_names: List[str] = []
23
+ self.inputs: Dict[str, None] = {}
24
+ self.outputs: Dict[str, None] = {}
25
+ self.dependencies: Dict[str, None] = {}
26
+ self.dependents: Dict[str, None] = {}
27
+ self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
28
+ self.environment: Dict[Node, Node] = {}
29
+ self.targets: Dict[str, Any] = {}
30
+
31
+ def __repr__(self) -> str:
32
+ return (
33
+ f"name: {self.name},\n"
34
+ f" nodes: {self.node_names},\n"
35
+ f" inputs: {self.inputs},\n"
36
+ f" outputs: {self.outputs},\n"
37
+ f" partitions depended on: {self.dependencies},\n"
38
+ f" partition dependents: {self.dependents}"
39
+ )
40
+
41
+
42
+ # Creates subgraphs out of main graph
43
+ @compatibility(is_backward_compatible=True)
44
+ def split_module(
45
+ m: GraphModule,
46
+ root_m: torch.nn.Module,
47
+ split_callback: Callable[[Node], int],
48
+ qualname_map: Optional[Dict[str, str]] = None,
49
+ keep_original_order: Optional[bool] = False,
50
+ keep_original_node_name: Optional[bool] = False,
51
+ ):
52
+ """
53
+ Creates subgraphs out of main graph
54
+
55
+ Args:
56
+ m (GraphModule): Graph module to split
57
+ root_m (torch.nn.Module): root nn module. Not currently used. Included
58
+ because the root nn module is usually transformed via
59
+ torch.fx._symbolic_trace.symbolic_trace (see example below)
60
+ split_callback (Callable[[Node], int]): Callable function
61
+ that maps a given Node instance to a numeric partition identifier.
62
+ split_module will use this function as the policy for which operations
63
+ appear in which partitions in the output Module.
64
+ qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a
65
+ mapping from new target names in the module after split to old target
66
+ names in the original module.
67
+ keep_original_order: Optional[bool]: keep the original order of the GraphModule
68
+ or use the Topological order of the new constructed GraphModule
69
+
70
+
71
+ Returns:
72
+ GraphModule: the module after split.
73
+
74
+ Example:
75
+
76
+ This is a sample setup:
77
+
78
+ import torch
79
+ from torch.fx.symbolic_trace import symbolic_trace
80
+ from torch.fx.graph_module import GraphModule
81
+ from torch.fx.node import Node
82
+ from torch.fx.passes.split_module import split_module
83
+
84
+ class MyModule(torch.nn.Module):
85
+ def __init__(self):
86
+ super().__init__()
87
+ self.param = torch.nn.Parameter(torch.rand(3, 4))
88
+ self.linear = torch.nn.Linear(4, 5)
89
+
90
+ def forward(self, x, y):
91
+ z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
92
+ w = self.linear(y).clamp(min=0.0, max=1.0)
93
+ return z + w
94
+
95
+ # symbolically trace model
96
+ my_module = MyModule()
97
+ my_module_traced = symbolic_trace(my_module)
98
+
99
+ # random mod partitioning
100
+ partition_counter = 0
101
+ NPARTITIONS = 3
102
+
103
+ def mod_partition(node: Node):
104
+ global partition_counter
105
+ partition = partition_counter % NPARTITIONS
106
+ partition_counter = (partition_counter + 1) % NPARTITIONS
107
+ return partition
108
+
109
+ # split module in module with submodules
110
+ module_with_submodules = split_module(
111
+ my_module_traced, my_module, mod_partition
112
+ )
113
+
114
+ Output looks like this. Original graph is broken into partitions
115
+
116
+ > print(module_with_submodules)
117
+ GraphModule(
118
+ (submod_0): GraphModule(
119
+ (linear): Linear(in_features=4, out_features=5, bias=True)
120
+ )
121
+ (submod_1): GraphModule(
122
+ (linear): Linear(in_features=4, out_features=5, bias=True)
123
+ )
124
+ (submod_2): GraphModule()
125
+ )
126
+
127
+ def forward(self, x, y):
128
+ param = self.param
129
+ submod_0 = self.submod_0(x, param, y); x = param = y = None
130
+ getitem = submod_0[0]
131
+ getitem_1 = submod_0[1]; submod_0 = None
132
+ submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None
133
+ getitem_2 = submod_1[0]
134
+ getitem_3 = submod_1[1]; submod_1 = None
135
+ submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None
136
+ return submod_2
137
+
138
+ Output of split module is the same as output of input traced module.
139
+ This is an example within a test setting:
140
+
141
+ > orig_out = my_module_traced(x, y)
142
+ > submodules_out = module_with_submodules(x, y)
143
+ > self.assertEqual(orig_out, submodules_out)
144
+ True
145
+ """
146
+
147
+ def construct_graph(
148
+ node: Node,
149
+ base_mod_env: Dict[str, Node],
150
+ base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule],
151
+ ):
152
+ if node.op == "placeholder":
153
+ default_value = (
154
+ node.args[0] if len(node.args) > 0 else inspect.Signature.empty
155
+ )
156
+ if keep_original_node_name:
157
+ args = () if default_value is inspect.Signature.empty else (default_value,)
158
+ base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type)
159
+ else:
160
+ base_mod_env[node.name] = base_mod_graph.placeholder(
161
+ node.target, type_expr=node.type, default_value=default_value
162
+ )
163
+ base_mod_env[node.name].meta = node.meta.copy()
164
+ elif node.op == "get_attr":
165
+ base_mod_env[node.name] = base_mod_graph.get_attr(node.target)
166
+ base_mod_env[node.name].meta = node.meta.copy()
167
+ attr_val = m
168
+ for atom in node.target.split("."): # type: ignore[union-attr]
169
+ if not hasattr(attr_val, atom):
170
+ raise AttributeError(f"Node target {node.target} not found!")
171
+ attr_val = getattr(attr_val, atom)
172
+ base_mod_attrs[node.target] = attr_val # type: ignore[index]
173
+ return base_mod_env, base_mod_attrs
174
+
175
+ partitions: Dict[str, Partition] = {}
176
+ orig_nodes: Dict[str, Node] = {}
177
+ symbol_to_node: Dict["sympy.Symbol", Node] = {}
178
+
179
+ def record_cross_partition_use(
180
+ def_node: Node, use_node: Optional[Node]
181
+ ): # noqa: B950
182
+ from torch.fx.experimental.symbolic_shapes import free_symbols
183
+
184
+ defined = getattr(def_node, "_fx_partition", None)
185
+ used = getattr(use_node, "_fx_partition", None)
186
+ if defined != used:
187
+ if defined is not None:
188
+ def_partition = partitions[defined]
189
+ def_partition.outputs.setdefault(def_node.name)
190
+ if used is not None:
191
+ def_partition.dependents.setdefault(used)
192
+
193
+ if used is not None:
194
+ use_partition = partitions[used]
195
+ use_partition.inputs.setdefault(def_node.name)
196
+ if (def_val := def_node.meta.get("example_value")) is not None:
197
+ for s in sorted(free_symbols(def_val), key=str):
198
+ use_partition.inputs.setdefault(symbol_to_node[s].name)
199
+ if defined is not None:
200
+ use_partition.dependencies.setdefault(defined)
201
+
202
+ def instantiate_node_partition_mapping(node):
203
+ partition_name = str(split_callback(node))
204
+
205
+ # add node to partitions
206
+ partition = partitions.get(partition_name)
207
+ if partition is None:
208
+ partitions[partition_name] = partition = Partition(partition_name)
209
+
210
+ partition.node_names.append(node.name)
211
+ node._fx_partition = partition_name
212
+
213
+ # Global State Nodes are nodes which by their global state effects,
214
+ # "taint" all downstream nodes while they are active.
215
+ GLOBAL_STATE_NODES = [
216
+ torch.amp._enter_autocast,
217
+ torch.amp._exit_autocast,
218
+ torch._C._set_grad_enabled
219
+ ]
220
+
221
+ # For grad regions:
222
+ # ------------------------
223
+ # 1. first region: we do nothing
224
+ # 2. subsequent regions: we insert the set_grad at the beginning
225
+ grad_regions: OrderedDict[Node, Set[int]] = OrderedDict()
226
+
227
+ # For autocast regions:
228
+ # ------------------------
229
+ # 1. first region: we will only insert the _exit at the end
230
+ # 2. intermediate regions: we will insert both the
231
+ # _enter at the beginning and _exit at the end
232
+ # 3. last region: we will only insert _enter at the beginning
233
+ # We will do so in the order in which the autocasts were instantiated.
234
+ autocast_regions: OrderedDict[Node, Set[int]] = OrderedDict()
235
+ autocast_exits: Dict[Node, Optional[Node]] = {}
236
+
237
+ active_grad = None
238
+ active_autocasts = set()
239
+
240
+ import sympy # noqa: F811
241
+
242
+ for node in m.graph.nodes:
243
+ if node.op in ["placeholder", "get_attr", "output"]:
244
+ if (
245
+ node.op == "placeholder" and
246
+ (val := node.meta.get("example_value")) is not None and
247
+ isinstance(val, torch.SymInt) and
248
+ isinstance(val.node.expr, sympy.Symbol)
249
+ ):
250
+ symbol_to_node[val.node.expr] = node
251
+ continue
252
+
253
+ instantiate_node_partition_mapping(node)
254
+
255
+ if node.op == "call_function" and node.target in GLOBAL_STATE_NODES:
256
+ if node.target == torch._C._set_grad_enabled:
257
+ assert len(node.args) == 1
258
+ assert isinstance(node.args[0], bool)
259
+ active_grad = node
260
+ grad_regions[active_grad] = set({split_callback(node)})
261
+ elif node.target == torch.amp._enter_autocast:
262
+ # Should all be python constants
263
+ assert all(not isinstance(arg, Node) for arg in node.args)
264
+ active_autocasts.add(node)
265
+ autocast_regions[node] = set({split_callback(node)})
266
+ autocast_exits[node] = None
267
+ elif node.target == torch.amp._exit_autocast:
268
+ assert len(node.args) == 1
269
+ autocast_regions[node.args[0]].add(split_callback(node))
270
+ active_autocasts.remove(node.args[0])
271
+ autocast_exits[node.args[0]] = node
272
+
273
+ if active_grad is not None:
274
+ grad_regions[active_grad].add(split_callback(node))
275
+
276
+ for a in active_autocasts:
277
+ autocast_regions[a].add(split_callback(node))
278
+
279
+ assert all(v is not None for v in autocast_exits.values()), "autocast must exit"
280
+
281
+ autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()}
282
+ grad_regions = {k: sorted(v) for k, v in grad_regions.items()}
283
+
284
+ if _LOGGER.isEnabledFor(logging.DEBUG):
285
+ _LOGGER.debug("autocast_regions: %s", autocast_regions)
286
+ _LOGGER.debug("grad_regions: %s", grad_regions)
287
+
288
+ assert_monotonically_increasing = bool(autocast_regions) or bool(grad_regions)
289
+
290
+ # split nodes into partitions
291
+ highest_partition = -1
292
+ for node in m.graph.nodes:
293
+ orig_nodes[node.name] = node
294
+
295
+ # TODO currently placeholders/parameters aren't put into random partitions,
296
+ # rather they're added to the graphs where they are used down below
297
+ if node.op in ["placeholder", "get_attr"]:
298
+ continue
299
+ if node.op == "output":
300
+ torch.fx.graph.map_arg(
301
+ node.args[0], lambda n: record_cross_partition_use(n, None)
302
+ )
303
+ continue
304
+
305
+ if assert_monotonically_increasing:
306
+ pid = split_callback(node)
307
+ assert highest_partition <= pid, \
308
+ ("autocast or set_grad_enabled require monotonically increasing partitions:"
309
+ f"highest: {highest_partition}, this node's: {pid}")
310
+ highest_partition = pid
311
+
312
+ # do not capture cross-partition dependencies for global state nodes as they will be
313
+ # self-contained - their setup and unwind will be isolated to each partition submodule.
314
+ if node.target not in GLOBAL_STATE_NODES:
315
+ torch.fx.graph.map_arg(
316
+ node.args, lambda def_node: record_cross_partition_use(def_node, node)
317
+ )
318
+ torch.fx.graph.map_arg(
319
+ node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)
320
+ ) # noqa: B950
321
+
322
+ original_partition_order = list(partitions.keys())
323
+ # find partitions with no dependencies
324
+ root_partitions: List[str] = []
325
+ for partition_name, partition in partitions.items():
326
+ if not len(partition.dependencies):
327
+ root_partitions.append(partition_name)
328
+
329
+ # check partitions for circular dependencies and create topological partition ordering
330
+ sorted_partitions: List[str] = []
331
+ while root_partitions:
332
+ root_partition = root_partitions.pop()
333
+ sorted_partitions.append(root_partition)
334
+ for dependent in partitions[root_partition].dependents:
335
+ partitions[dependent].dependencies.pop(root_partition)
336
+ if not partitions[dependent].dependencies:
337
+ root_partitions.append(dependent)
338
+ if len(sorted_partitions) != len(partitions):
339
+ raise RuntimeError("cycle exists between partitions!")
340
+
341
+ # Enter prelude
342
+ for regions_mapping in [autocast_regions, grad_regions]:
343
+ for node, regions in regions_mapping.items():
344
+ assert len(regions) > 0
345
+ partitions[str(regions[0])].environment[node] = node
346
+ for r in regions[1:]:
347
+ partition = partitions[str(r)]
348
+ new_node = partition.graph.create_node(
349
+ op=node.op,
350
+ target=node.target,
351
+ args=tuple(arg for arg in node.args),
352
+ kwargs={},
353
+ type_expr=node.type,
354
+ )
355
+ new_node.meta = node.meta.copy() # is it really a good idea to copy this?
356
+ partition.environment[node] = new_node
357
+
358
+ # add placeholders to partition inputs
359
+ for partition_name in sorted_partitions:
360
+ partition = partitions[partition_name]
361
+ for inp in partition.inputs:
362
+ placeholder = partition.graph.placeholder(
363
+ inp,
364
+ type_expr=orig_nodes[inp].type,
365
+ )
366
+ placeholder.meta = orig_nodes[inp].meta.copy()
367
+ partition.environment[orig_nodes[inp]] = placeholder
368
+
369
+ # Transform nodes and collect targets for partition's submodule
370
+ for node in m.graph.nodes:
371
+ if hasattr(node, "_fx_partition"):
372
+ partition = partitions[node._fx_partition]
373
+
374
+ # swap out old graph nodes in kw/args with references to new nodes in this submodule
375
+ environment = partition.environment
376
+ gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
377
+ gathered_kwargs = torch.fx.graph.map_arg(
378
+ node.kwargs, lambda n: environment[n]
379
+ )
380
+
381
+ if node.op not in ["call_module", "get_attr"]:
382
+ target = node.target
383
+ else:
384
+ target_atoms = node.target.split(".")
385
+ target_attr = m
386
+ for atom in target_atoms:
387
+ if not hasattr(target_attr, atom):
388
+ raise AttributeError(f"Operator target {node.target} not found!")
389
+ target_attr = getattr(target_attr, atom)
390
+ # target = target_atoms[-1]
391
+ target = "_".join(target_atoms)
392
+ partition.targets[target] = target_attr
393
+ # Fill in the passed-in mapping from new qualname to old qualname
394
+ if qualname_map is not None:
395
+ # When creating the split module later, the submodules will have
396
+ # path prefix matching the corresponding partition's submod_name
397
+ qualname = f"{partition.submod_name}.{target}"
398
+ qualname_map[qualname] = node.target
399
+
400
+ assert isinstance(gathered_args, tuple)
401
+ assert isinstance(gathered_kwargs, dict)
402
+ name = node.name if keep_original_node_name else None
403
+ new_node = partition.graph.create_node(
404
+ op=node.op,
405
+ target=target,
406
+ args=gathered_args,
407
+ kwargs=gathered_kwargs,
408
+ type_expr=node.type,
409
+ name=name,
410
+ )
411
+ new_node.meta = node.meta.copy()
412
+ partition.environment[node] = new_node
413
+
414
+ # Exit epilogue
415
+ for regions_mapping in [autocast_regions]:
416
+ for node in reversed(regions_mapping):
417
+ regions = regions_mapping[node]
418
+ assert len(regions) > 0
419
+ for r in regions[:-1]:
420
+ partition = partitions[str(r)]
421
+ exit_node = autocast_exits[node]
422
+ assert exit_node is not None, "Missing exit node"
423
+ new_node = partition.graph.create_node(
424
+ op=exit_node.op,
425
+ target=exit_node.target,
426
+ args=(partition.environment[node],),
427
+ kwargs={},
428
+ type_expr=exit_node.type,
429
+ )
430
+ new_node.meta = exit_node.meta.copy() # is it really a good idea to copy this?
431
+
432
+ # original module environment dict mapping node names to nodes
433
+ orig_mod_env: Dict[str, Node] = {}
434
+ # Set up values to construct base module
435
+ base_mod_env: Dict[str, Node] = {}
436
+ base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
437
+ base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
438
+ if not keep_original_order:
439
+ for node in m.graph.nodes:
440
+ base_mod_env, base_mod_attrs = construct_graph(
441
+ node, base_mod_env, base_mod_attrs
442
+ )
443
+
444
+ else:
445
+ # Go through the graph to construct the mapping dict
446
+ for node in m.graph.nodes:
447
+ orig_mod_env[node.name] = node
448
+
449
+ # Do some things iterating over the partitions in topological order again:
450
+ # 1) Finish off submodule Graphs by setting corresponding outputs
451
+ # 2) Construct GraphModules for each submodule
452
+ # 3) Construct the base graph by emitting calls to those submodules in
453
+ # topological order or original order specified by keep_original_order
454
+
455
+ construct_order_partitions = (
456
+ sorted_partitions if not keep_original_order else original_partition_order
457
+ )
458
+
459
+ already_constructed_attr_nodes = set()
460
+ for partition_name in construct_order_partitions:
461
+ partition = partitions[partition_name]
462
+
463
+ # Set correct output values
464
+ output_vals = tuple(
465
+ partition.environment[orig_nodes[name]] for name in partition.outputs
466
+ )
467
+
468
+ # skip output node generation if there are no output values
469
+ num_output_vals = len(output_vals)
470
+ if num_output_vals == 1:
471
+ partition.graph.output(output_vals[0])
472
+ elif num_output_vals > 1:
473
+ partition.graph.output(output_vals)
474
+
475
+ if keep_original_order:
476
+ # first get the attr nodes required by this partition
477
+ orig_mod_attr_nodes: List[Node] = [
478
+ orig_mod_env[key] for key in partition.inputs
479
+ ]
480
+ # Construct GraphModule for this partition
481
+ for node in orig_mod_attr_nodes: # type: ignore[attr-defined]
482
+ if node in already_constructed_attr_nodes:
483
+ continue
484
+ base_mod_env, base_mod_attrs = construct_graph(
485
+ node, base_mod_env, base_mod_attrs
486
+ )
487
+ already_constructed_attr_nodes.add(node)
488
+
489
+ base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule(
490
+ partition.targets, partition.graph
491
+ ) # noqa: B950
492
+
493
+ # Emit call in base graph to this submodule
494
+ output_val = base_mod_graph.call_module(
495
+ partition.submod_name,
496
+ tuple(base_mod_env[name] for name in partition.inputs),
497
+ )
498
+
499
+ num_outputs = len(partition.outputs)
500
+ if num_outputs > 1:
501
+ # Unpack multiple return values from submodule
502
+ output_val_proxy = torch.fx.proxy.Proxy(output_val)
503
+ for i, output_name in enumerate(partition.outputs):
504
+ base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
505
+ elif num_outputs == 1:
506
+ base_mod_env[next(iter(partition.outputs))] = output_val
507
+
508
+ for node in m.graph.nodes:
509
+ if node.op == "output":
510
+ base_mod_graph.output(
511
+ torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
512
+ ) # noqa: B950
513
+
514
+ return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/splitter_base.py ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ from collections import defaultdict
4
+ from dataclasses import dataclass
5
+ from typing import NamedTuple, Sequence, Iterable, Any, List, Dict, Optional, Tuple
6
+ import logging
7
+
8
+ import torch
9
+ from torch.fx.passes.graph_manipulation import get_size_of_node
10
+ from torch.fx.node import map_arg
11
+ from torch.fx._compatibility import compatibility
12
+
13
+ from .operator_support import (
14
+ get_node_target,
15
+ OperatorSupportBase,
16
+ )
17
+ from .graph_drawer import FxGraphDrawer
18
+ from .shape_prop import ShapeProp
19
+ from .split_utils import split_by_tags
20
+ from .tools_common import (
21
+ FxNetAccFusionsFinder,
22
+ CALLABLE_NODE_OPS,
23
+ Tensors,
24
+ NodeList,
25
+ NodeSet,
26
+ is_node_output_tensor,
27
+ )
28
+
29
+
30
+ __all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules']
31
+ _LOGGER = logging.getLogger(__name__)
32
+
33
+ DEFAULT_MIN_ACC_MODULE_SIZE = 1
34
+ DEFAULT_SKIP_FUSION = False
35
+ DEFAULT_ALLOW_NON_TENSOR = False
36
+
37
+ class _SplitterSettingBase:
38
+ def __init__(
39
+ self,
40
+ min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE,
41
+ skip_fusion=DEFAULT_SKIP_FUSION,
42
+ allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR
43
+ ):
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument(
46
+ "--min-acc-module-size",
47
+ "--min_acc_module_size",
48
+ required=False,
49
+ type=int,
50
+ help="Minimum size limit of an accelerator subgraph.",
51
+ )
52
+ parser.add_argument(
53
+ "--skip-fusion",
54
+ "--skip_fusion",
55
+ default=False,
56
+ action="store_true",
57
+ help="If true then no fusion groups. Fusion group is used to "
58
+ "enforce no non-tensor data flow between submodules. If we don't "
59
+ "have this constrain, setting this to false is recommended as it "
60
+ "can reduce overhead.",
61
+ )
62
+ parser.add_argument(
63
+ "--allow-non-tensor",
64
+ "--allow_non_tensor",
65
+ default=False,
66
+ action="store_true",
67
+ help="For some backends non-tensor data flow between cpu and them "
68
+ "are not allowed. Therefore, if a node supported by accelerator but "
69
+ "it has non-tensor inputs or outputs to a cpu node we would want to "
70
+ "consider it as a cpu node during splitting. However, for some backends "
71
+ "we might not care about non-tensor data flow and we can set this option "
72
+ "to true to disable the functionality that prevent non-tensor data flow.",
73
+ )
74
+ args, unknown = parser.parse_known_args()
75
+
76
+ self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size
77
+ self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion
78
+ self.allow_non_tensor: bool = args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor
79
+
80
+
81
+ @compatibility(is_backward_compatible=False)
82
+ class FxNetAccNodesFinder:
83
+ """
84
+ Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor
85
+ input/output to cpu nodes to prevent non-tensor data flow between backends and cpu.
86
+
87
+ I.e. if we have a chain:
88
+
89
+ ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1
90
+
91
+ where every ACC node produces non-tensor output, then they all should be treated as CPU nodes.
92
+
93
+ This behavior can be turned off by passing allow_non_tensor=True.
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ module: torch.fx.GraphModule,
99
+ operator_support: OperatorSupportBase,
100
+ allow_non_tensor: bool,
101
+ ):
102
+ self.module = module
103
+ self.operator_support = operator_support
104
+ self.allow_non_tensor = allow_non_tensor
105
+
106
+ def reduce_acc_nodes_non_tensor_input_helper(
107
+ self, cpu_worklist: NodeList
108
+ ):
109
+ """
110
+ Transitively excludes nodes from ACC supported set.
111
+ For every node in the worklist:
112
+ - removes its downstream ACC nodes from ACC supported set,
113
+ - if any downstream ACC node produces non-tensor output,
114
+ then it gets added into the worklist.
115
+ """
116
+ while cpu_worklist:
117
+ node = cpu_worklist.pop(0)
118
+
119
+ for user in node.users:
120
+ if user in self.acc_nodes:
121
+ self.acc_nodes.remove(user)
122
+ if not is_node_output_tensor(user):
123
+ cpu_worklist.append(user)
124
+
125
+ def reduce_acc_nodes_non_tensor_input(self):
126
+ """
127
+ Excludes nodes from ACC supported set that have direct
128
+ upstream CPU nodes that produce non-tensor outputs.
129
+ """
130
+ non_tensor_cpu_nodes: NodeList = []
131
+
132
+ for node in self.module.graph.nodes:
133
+ if node.op not in CALLABLE_NODE_OPS:
134
+ continue
135
+ if node in self.acc_nodes:
136
+ continue
137
+ if is_node_output_tensor(node):
138
+ continue
139
+ non_tensor_cpu_nodes.append(node)
140
+
141
+ self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes)
142
+
143
+ def reduce_acc_nodes_non_tensor_output(self):
144
+ """
145
+ Excludes nodes from ACC supported set that produce non-tensor
146
+ outputs and have downstream CPU nodes.
147
+ """
148
+ while True:
149
+ new_cpu_nodes: NodeList = []
150
+
151
+ for acc_node in self.acc_nodes:
152
+ if is_node_output_tensor(acc_node):
153
+ continue
154
+ for user in acc_node.users:
155
+ if user not in self.acc_nodes:
156
+ new_cpu_nodes.append(acc_node)
157
+ break
158
+
159
+ if not new_cpu_nodes:
160
+ break
161
+
162
+ for new_cpu_node in new_cpu_nodes:
163
+ self.acc_nodes.remove(new_cpu_node)
164
+
165
+ self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes)
166
+
167
+ def __call__(self) -> NodeSet:
168
+ submodules = dict(self.module.named_modules())
169
+ self.acc_nodes = {
170
+ n
171
+ for n in self.module.graph.nodes
172
+ if n.op in CALLABLE_NODE_OPS
173
+ and self.operator_support.is_node_supported(submodules, n)
174
+ }
175
+
176
+ if not self.allow_non_tensor:
177
+ self.reduce_acc_nodes_non_tensor_input()
178
+ self.reduce_acc_nodes_non_tensor_output()
179
+
180
+ return self.acc_nodes
181
+
182
+ @compatibility(is_backward_compatible=False)
183
+ class FxNetSplitterInternalError(Exception):
184
+ pass
185
+
186
+ @compatibility(is_backward_compatible=False)
187
+ @dataclass
188
+ class Subgraph:
189
+ is_acc: bool
190
+ nodes: NodeList
191
+
192
+
193
+ @compatibility(is_backward_compatible=False)
194
+ class SplitResult(NamedTuple):
195
+ """
196
+ Stores the results of the splitter.
197
+
198
+ Attributes:
199
+ split_module: root module after splitting.
200
+ submodule_inputs: a dict that maps submodule name to its inputs.
201
+ non_acc_submodule_prefix: the prefix for non acc submodules. For
202
+ acc submodule the prefix is alwasy "_run_on_acc_".
203
+ """
204
+
205
+ split_module: torch.fx.GraphModule
206
+ submodule_inputs: Dict[str, Any]
207
+ non_acc_submodule_prefix: str
208
+
209
+
210
+ @compatibility(is_backward_compatible=False)
211
+ def generate_inputs_for_submodules(
212
+ model: torch.nn.Module,
213
+ inputs: Sequence[Any],
214
+ target_submodules: Iterable[str],
215
+ deepcopy: bool = False,
216
+ ) -> Dict[str, Any]:
217
+ """
218
+ Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
219
+ function doesn't work.
220
+
221
+ Args:
222
+ model: root model.
223
+ inputs: inputs to the root model.
224
+ target_submodules: submodules that we want to generate inputs for.
225
+
226
+ Returns:
227
+ A dict that maps from submodule name to its inputs.
228
+ """
229
+
230
+ handles = []
231
+ results = {}
232
+ submodule_to_names = {mod: name for name, mod in model.named_modules()}
233
+
234
+ def pre_forward(module, module_inputs):
235
+ results[submodule_to_names[module]] = copy.deepcopy(module_inputs) if deepcopy else module_inputs
236
+
237
+ for name, mod in model.named_modules():
238
+ if name in target_submodules:
239
+ handles.append(mod.register_forward_pre_hook(pre_forward))
240
+
241
+ def clean_up_handles():
242
+ for h in handles:
243
+ h.remove()
244
+
245
+ try:
246
+ with torch.no_grad():
247
+ model(*inputs)
248
+ except Exception as e:
249
+ clean_up_handles()
250
+ raise e
251
+
252
+ clean_up_handles()
253
+ return results
254
+
255
+
256
+ class _SplitterBase:
257
+ """
258
+ Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator.
259
+ Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible.
260
+ Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator.
261
+
262
+ Given the following graph:
263
+ ==> b ==>
264
+ // \\
265
+ a d
266
+ \\ //
267
+ ==> c ==>
268
+
269
+ class SimpleModule(torch.nn.Module):
270
+ def forward(self, a):
271
+ b = torch.sin(a)
272
+ c = torch.cos(a)
273
+ d = b + c
274
+ return d
275
+
276
+ and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator,
277
+ we will get the following split result:
278
+
279
+ main:
280
+ def forward(self, a):
281
+ run_on_acc_0_0 = self._run_on_acc_0_0(a)
282
+ getitem = run_on_acc_0_0[0]
283
+ getitem_1 = run_on_acc_0_0[1]
284
+ run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1)
285
+ return run_on_cpu_1_1
286
+
287
+ _run_on_acc_0_0:
288
+ def forward(self, a):
289
+ sin_1 = torch.sin(a)
290
+ cos_1 = torch.cos(a)
291
+ return (sin_1, cos_1)
292
+
293
+ _run_on_cpu_1_1:
294
+ def forward(self, sin_1, cos_1):
295
+ add_1 = sin_1 + cos_1
296
+ return add_1
297
+ """
298
+
299
+ # PCIe bandwidth for the backend, default to 100 GB/s
300
+ PCIe_BW = 100 * 2 ** 30
301
+
302
+ def __init__(
303
+ self,
304
+ module: torch.fx.GraphModule,
305
+ sample_input: Sequence[Any],
306
+ operator_support: OperatorSupportBase,
307
+ settings: _SplitterSettingBase,
308
+ non_acc_submodule_name: str = "_run_on_cpu_",
309
+ ):
310
+ """
311
+ Preprocesses graph before splitting:
312
+ - finds nodes supported by ACC,
313
+ - finds fusion groups for ACC nodes having non-tensor IO,
314
+ - builds a graph of direct dependencies,
315
+ - builds a map of fused nodes to their fusions.
316
+ As a result we get self.acc_nodes, self.deps and self.fusions.
317
+ """
318
+ assert isinstance(module, torch.fx.GraphModule)
319
+
320
+ self.module = module
321
+ ShapeProp(self.module).propagate(*sample_input)
322
+
323
+ self.settings = settings
324
+ self.operator_support = operator_support
325
+ self.sample_input = sample_input
326
+ self.acc_nodes = FxNetAccNodesFinder(self.module, self.operator_support, self.settings.allow_non_tensor)()
327
+
328
+ if self.settings.skip_fusion:
329
+ self.fusions = {}
330
+ else:
331
+ self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)()
332
+
333
+ # Modify deps to add more deps for fused nodes
334
+ self.deps = self.find_deps()
335
+ self.update_deps_for_fusions()
336
+
337
+ self.non_acc_submodule_name = non_acc_submodule_name
338
+ self._node_submodule_map: Dict[str, str] = {}
339
+
340
+ # ===============================================================
341
+ # Helpers for ctor and initial state
342
+ # ===============================================================
343
+
344
+ def get_node_submodule_map(self) -> Dict[str, str]:
345
+ """ Returns a map from node name to submodule name, e.g.
346
+ node: main_module_impl_impl_over_arch_unary_multiple_embedding
347
+ _pooling_embedding_pooling_sparse_entity_equivalence_key
348
+ _proxy_embedding_bag
349
+ maps to submodule name of: _run_on_acc_1
350
+ """
351
+ return self._node_submodule_map
352
+
353
+ def find_deps(self) -> Dict[torch.fx.Node, NodeSet]:
354
+ """
355
+ Builds a graph of node dependencies. Leaf nodes don't have any
356
+ dependencies and the "output" node doesn't have nodes depending on it.
357
+
358
+ Resulting graph has only direct dependencies, i.e. there are no
359
+ transitive dependencies.
360
+ """
361
+ deps: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
362
+ for node in self.module.graph.nodes:
363
+ if node.op not in CALLABLE_NODE_OPS:
364
+ continue
365
+
366
+ for user in node.users:
367
+ if user.op != "output":
368
+ deps[user].add(node)
369
+ return deps
370
+
371
+ def update_deps_for_fusions(self):
372
+ """
373
+ Updates graph of dependencies so that:
374
+ - nodes from the same fusion depend on the same set of outer nodes,
375
+ - outer nodes depending on a fusion depend on all nodes in that fusion.
376
+ """
377
+ for node in self.fusions:
378
+ fusion = self.fusions[node]
379
+ for fused_neighbor in fusion:
380
+ self.deps[node].update(self.deps[fused_neighbor] - fusion)
381
+
382
+ for user in fused_neighbor.users:
383
+ if user not in fusion:
384
+ self.deps[user].add(node)
385
+
386
+ # ===============================================================
387
+ # Helpers for preview
388
+ # ===============================================================
389
+
390
+ def _lower_model_to_backend(
391
+ self, mod: torch.fx.GraphModule, inputs: Tensors
392
+ ) -> torch.nn.Module:
393
+ """
394
+ Lower the model to a backend.
395
+ """
396
+
397
+ return mod
398
+
399
+ def _find_culprit(
400
+ self, mod: torch.fx.GraphModule, inputs: Tensors
401
+ ) -> str:
402
+ """
403
+ When an error occurs during lowering or running the lowered mod, we use this
404
+ function to find culprits in the `mod` that causes the error.
405
+ """
406
+
407
+ return "Unable to find a culprit because _find_culprit() function is not implemented."
408
+
409
+ def _draw_graph_based_on_node_support(
410
+ self, mod: torch.fx.GraphModule, supported_nodes: NodeList
411
+ ):
412
+ color_map = {
413
+ "default": "AliceBlue",
414
+ "supported": "chartreuse1",
415
+ "unsupported": "crimson",
416
+ }
417
+
418
+ class CustomDrawer(FxGraphDrawer):
419
+ def _get_node_style(self, node):
420
+ template = super()._get_node_style(node)
421
+ if node in supported_nodes:
422
+ template["fillcolor"] = color_map["supported"]
423
+ elif node.op in CALLABLE_NODE_OPS:
424
+ template["fillcolor"] = color_map["unsupported"]
425
+ else:
426
+ template["fillcolor"] = color_map["default"]
427
+
428
+ return template
429
+
430
+ drawer = CustomDrawer(mod, "node_support", ignore_getattr=True)
431
+ dot_graph = drawer.get_main_dot_graph()
432
+ dot_graph.write_raw("node_support.dot")
433
+
434
+ def node_support_preview(self, dump_graph: bool = False):
435
+ submodules = dict(self.module.named_modules())
436
+
437
+ supported_nodes: NodeList = []
438
+ supported_node_types = defaultdict(set)
439
+ unsupported_node_types = defaultdict(set)
440
+
441
+ def get_dtype(arg):
442
+ tensor_meta = arg.meta.get("tensor_meta")
443
+ return getattr(tensor_meta, "dtype", None)
444
+
445
+ for node in self.module.graph.nodes:
446
+ if node.op not in CALLABLE_NODE_OPS:
447
+ continue
448
+
449
+ target = get_node_target(submodules, node)
450
+
451
+ # Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None.
452
+ arg_dtypes = [
453
+ get_dtype(arg) if isinstance(arg, torch.fx.Node) else None
454
+ for arg in node.args
455
+ ]
456
+
457
+ # Find last non-None element. If all elements are None, return max_len.
458
+ last_index = len(arg_dtypes) - next(
459
+ (
460
+ i
461
+ for i, dtype in enumerate(reversed(arg_dtypes))
462
+ if dtype is not None
463
+ ),
464
+ len(arg_dtypes),
465
+ )
466
+
467
+ # Strip None elements at the end.
468
+ arg_dtypes_tuple = tuple(arg_dtypes[:last_index])
469
+ kwarg_dtypes_tuple = tuple(
470
+ (k, get_dtype(arg))
471
+ for k, arg in node.kwargs.items()
472
+ if isinstance(arg, torch.fx.Node)
473
+ )
474
+
475
+ if self.operator_support.is_node_supported(submodules, node):
476
+ supported_nodes.append(node)
477
+ supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
478
+ else:
479
+ unsupported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
480
+
481
+ if dump_graph:
482
+ self._draw_graph_based_on_node_support(self.module, supported_nodes)
483
+
484
+ reports = "\nSupported node types in the model:\n"
485
+ for t, dtypes in supported_node_types.items():
486
+ for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
487
+ reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
488
+
489
+ reports += "\nUnsupported node types in the model:\n"
490
+ for t, dtypes in unsupported_node_types.items():
491
+ for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
492
+ reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
493
+
494
+ print(reports)
495
+
496
+ # Return reports for testing purpose
497
+ return reports
498
+
499
+ def split_preview(self, dump_graph: bool = False):
500
+ reports = ""
501
+ subgraphs = self.put_nodes_into_subgraphs()
502
+ acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
503
+ cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
504
+ reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
505
+ reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
506
+
507
+ subgraphs = self.remove_small_acc_subgraphs(subgraphs)
508
+ acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
509
+ cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
510
+ reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
511
+ reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
512
+
513
+ for i, subgraph in enumerate(subgraphs):
514
+ reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"{self.non_acc_submodule_name}{i}: "
515
+ reports += f"{len(subgraph.nodes)} node(s)\n"
516
+
517
+ self.tag(subgraphs)
518
+ split_mod = self.split(remove_tag=True)
519
+ split_mod.eval()
520
+
521
+ if dump_graph:
522
+ drawer = FxGraphDrawer(
523
+ split_mod, "preview", ignore_getattr=True
524
+ )
525
+ dot_graphs = drawer.get_all_dot_graphs()
526
+ for name, dot_graph in dot_graphs.items():
527
+ dot_graph.write_raw(f"{name}.dot")
528
+
529
+ max_qps: float = self.PCIe_BW
530
+ bottleneck_module = ""
531
+
532
+ for node in split_mod.graph.nodes:
533
+ if node.op == "call_module" and "acc" in node.target:
534
+ reports += f"\nProcessing acc submodule {node.target}\n"
535
+
536
+ submod = getattr(split_mod, node.target)
537
+
538
+ def get_submod_inputs(main_mod, submod, example_inputs):
539
+ sub_inputs = None
540
+
541
+ def get_inputs(self, inputs):
542
+ nonlocal sub_inputs
543
+ sub_inputs = inputs
544
+
545
+ handle = submod.register_forward_pre_hook(get_inputs)
546
+ main_mod(*example_inputs)
547
+ handle.remove()
548
+ return sub_inputs
549
+
550
+ submod_inputs = get_submod_inputs(
551
+ split_mod, submod, self.sample_input
552
+ )
553
+ ShapeProp(submod).propagate(*submod_inputs)
554
+
555
+ total_input_bytes = 0
556
+ total_output_bytes = 0
557
+
558
+ reports += "Checking inputs...\n"
559
+ for n in submod.graph.nodes:
560
+ if n.op == "placeholder":
561
+ if not is_node_output_tensor(n):
562
+ reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n"
563
+ else:
564
+ total_input_bytes += get_size_of_node(submod, n)[0]
565
+ if n.op == "output":
566
+ output_node = n
567
+
568
+ reports += "Checking outputs...\n"
569
+
570
+ def get_bytes(node: torch.fx.Node):
571
+ nonlocal total_output_bytes
572
+ nonlocal reports
573
+ if not is_node_output_tensor(node):
574
+ reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n"
575
+ else:
576
+ total_output_bytes += get_size_of_node(submod, node)[0]
577
+
578
+ map_arg(output_node.args, get_bytes) # type: ignore[possibly-undefined]
579
+ qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes)
580
+ reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes},"
581
+ reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n"
582
+
583
+ if qps < max_qps:
584
+ max_qps = qps
585
+ bottleneck_module = node.target
586
+
587
+ try:
588
+ lowered_submod = self._lower_model_to_backend(submod, submod_inputs)
589
+ except RuntimeError:
590
+ reports += "Run into an error during lowering!\n"
591
+ reports += self._find_culprit(submod, submod_inputs)
592
+ continue
593
+
594
+ try:
595
+ lowered_submod(*submod_inputs)
596
+ except RuntimeError:
597
+ reports += "Run into an error during inference!\n"
598
+ reports += self._find_culprit(submod, submod_inputs)
599
+ else:
600
+ reports += "Lowering and running succeed!\n"
601
+
602
+ reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps},"
603
+ reports += f" bottleneck is submodule {bottleneck_module}."
604
+ print(reports)
605
+
606
+ # return the reports for testing purposes
607
+ return reports
608
+
609
+ # ===============================================================
610
+ # Helpers for extend_acc_subgraph() method
611
+ # ===============================================================
612
+
613
+ def find_reverse_deps(
614
+ self, tag_id: Optional[int] = None
615
+ ) -> Dict[torch.fx.Node, NodeSet]:
616
+ """
617
+ Builds reversed topological node dependencies, if tag_id is specified,
618
+ we ignore nodes that are in later subgraph i.e. nodes have greater tag_id.
619
+ """
620
+ result: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
621
+
622
+ for node in self.module.graph.nodes:
623
+ if node.op not in CALLABLE_NODE_OPS:
624
+ continue
625
+
626
+ for user in node.users:
627
+ if user.op not in CALLABLE_NODE_OPS:
628
+ continue
629
+
630
+ if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id):
631
+ result[node].add(user)
632
+
633
+ return result
634
+
635
+ def update_reverse_deps_for_fusions(
636
+ self, deps: Dict[torch.fx.Node, NodeSet]
637
+ ):
638
+ processed_node = set()
639
+
640
+ for node, fusion in self.fusions.items():
641
+ if node in processed_node:
642
+ continue
643
+
644
+ new_dep = set()
645
+
646
+ # Create a new dependency set which include all the
647
+ # dependencies of the nodes in the fusion group
648
+ for n in fusion:
649
+ new_dep.update(deps[n])
650
+
651
+ # Exclude nodes in the fusion
652
+ new_dep.difference_update(fusion)
653
+
654
+ # Update dependency
655
+ for n in fusion:
656
+ deps[n] = new_dep
657
+
658
+ for arg in n.all_input_nodes:
659
+ if arg not in fusion:
660
+ deps[arg].update(fusion)
661
+
662
+ processed_node.add(n)
663
+
664
+ def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet:
665
+ """
666
+ Finds parent nodes of the `tag` subgraph.
667
+
668
+ Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph
669
+ and is not a placeholder, we consider it as the parent node of the subgraph.
670
+ """
671
+ parent_nodes = set()
672
+
673
+ for node in self.module.graph.nodes:
674
+ if node.op in CALLABLE_NODE_OPS and node.tag == tag:
675
+ for arg in node.all_input_nodes:
676
+ if arg.op in CALLABLE_NODE_OPS and arg.tag != tag:
677
+ parent_nodes.add(arg)
678
+
679
+ return parent_nodes
680
+
681
+ def extend_acc_subgraph(self, tag: str):
682
+ """
683
+ Extend the acc subgraph with `tag` going the reversed topological direction.
684
+ """
685
+ # Dict that maps node to its users and ignore users that
686
+ # are in the subgraph that has greater tag
687
+ deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1]))
688
+ self.update_reverse_deps_for_fusions(deps)
689
+
690
+ # Parent nodes of the subgraph
691
+ parent_nodes = self.find_parent_nodes_of_subgraph(tag)
692
+
693
+ visited_nodes: NodeSet = set()
694
+
695
+ while parent_nodes:
696
+ node = None
697
+
698
+ # Find a acc node that depends on visited nodes only
699
+ for n in parent_nodes:
700
+ if deps[n] <= visited_nodes and n in self.acc_nodes:
701
+ node = n
702
+ break
703
+
704
+ if node is None:
705
+ break
706
+
707
+ # Put the node into `tag` subgraph
708
+ node.tag = tag # type: ignore[attr-defined]
709
+ parent_nodes.remove(node)
710
+ visited_nodes.add(node)
711
+
712
+ # If node is in a fusion group, add all fusion buddies to parent nodes
713
+ if node in self.fusions:
714
+ for fusion_node in self.fusions[node]:
715
+ if fusion_node not in visited_nodes:
716
+ parent_nodes.add(fusion_node)
717
+
718
+ # Add inputs of the node to parent nodes
719
+ for arg in node.all_input_nodes:
720
+ if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes:
721
+ parent_nodes.add(arg)
722
+
723
+ # ===============================================================
724
+ # Helpers for split() method
725
+ # ===============================================================
726
+
727
+ def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
728
+ """
729
+ Finds nodes that consume module inputs or get_attr nodes.
730
+ """
731
+ starter_cpu_nodes: NodeSet = set()
732
+ starter_acc_nodes: NodeSet = set()
733
+ for node in self.module.graph.nodes:
734
+ if node.op not in {"placeholder", "get_attr"}:
735
+ continue
736
+ for user in node.users:
737
+ if user in self.acc_nodes:
738
+ starter_acc_nodes.add(user)
739
+ else:
740
+ starter_cpu_nodes.add(user)
741
+ return starter_cpu_nodes, starter_acc_nodes
742
+
743
+ def put_nodes_into_subgraphs(self) -> List[Subgraph]:
744
+ # We start graph traversal from leaf nodes
745
+ current_cpu_nodes, current_acc_nodes = self.starter_nodes()
746
+ visited_nodes: NodeSet = set()
747
+
748
+ # Determine which subgraph to start from based on which subgraph has
749
+ # 0-dep node
750
+ acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes)
751
+
752
+ current_subgraph_nodes: NodeList = []
753
+
754
+ # Result accumulator
755
+ subgraphs: List[Subgraph] = []
756
+ while current_cpu_nodes or current_acc_nodes:
757
+ # Find the first node that should belong to the current subgraph and has all dependencies resolved
758
+ current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes
759
+ node = next(
760
+ (n for n in current_nodes if self.deps[n] <= visited_nodes),
761
+ None,
762
+ )
763
+
764
+ # If nothing was found, then it's time to flip the mode and start a new subgraph
765
+ if node is None:
766
+ if not current_subgraph_nodes:
767
+ raise FxNetSplitterInternalError("Subgraph can't be empty")
768
+
769
+ subgraphs.append(
770
+ Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
771
+ )
772
+ acc_subgraph = not acc_subgraph
773
+ current_subgraph_nodes = []
774
+ continue
775
+
776
+ current_nodes.remove(node)
777
+ visited_nodes.add(node)
778
+ current_subgraph_nodes.append(node)
779
+
780
+ # Add fusion buddies
781
+ if node in self.fusions:
782
+ if node in self.acc_nodes:
783
+ current_acc_nodes.update(self.fusions[node] - visited_nodes)
784
+ else:
785
+ current_cpu_nodes.update(self.fusions[node] - visited_nodes)
786
+
787
+ # Put depending nodes into the queue
788
+ for user in node.users:
789
+ if user.op not in CALLABLE_NODE_OPS:
790
+ continue
791
+
792
+ # Add downstream nodes
793
+ if user in self.acc_nodes:
794
+ current_acc_nodes.add(user)
795
+ else:
796
+ current_cpu_nodes.add(user)
797
+
798
+ # Check if the last subgraph was not created
799
+ if current_subgraph_nodes:
800
+ subgraphs.append(
801
+ Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
802
+ )
803
+
804
+ if not subgraphs:
805
+ raise FxNetSplitterInternalError("Couldn't create subgraphs")
806
+
807
+ return subgraphs
808
+
809
+ def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
810
+ """
811
+ This pass finds ACC submodules with less than specified size and merges
812
+ them with adjacent CPU submodules.
813
+ """
814
+ result: List[Subgraph] = []
815
+ for subgraph in subgraphs:
816
+ if subgraph.is_acc:
817
+ if len(subgraph.nodes) >= self.settings.min_acc_module_size:
818
+ result.append(subgraph)
819
+ else:
820
+ print(
821
+ "Eliminating acc subgraph because it's smaller than the threshold: "
822
+ f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
823
+ )
824
+ if result:
825
+ result[-1].nodes.extend(subgraph.nodes)
826
+ else:
827
+ subgraph.is_acc = False
828
+ result.append(subgraph)
829
+ else:
830
+ if result and not result[-1].is_acc:
831
+ result[-1].nodes.extend(subgraph.nodes)
832
+ else:
833
+ result.append(subgraph)
834
+ return result
835
+
836
+ def tag(self, subgraphs: List[Subgraph]):
837
+ self.tags: List[str] = []
838
+ for subgraph in subgraphs:
839
+ tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}"
840
+ self.tags.append(tag)
841
+ for node in subgraph.nodes:
842
+ if hasattr(node, "tag"):
843
+ raise FxNetSplitterInternalError(f"Node {node} was already tagged")
844
+
845
+ node.tag = tag # type: ignore[attr-defined]
846
+ self._node_submodule_map[node.name] = tag
847
+
848
+ def split(self, remove_tag: bool = False) -> torch.fx.GraphModule:
849
+ split_module = split_by_tags(self.module, self.tags)
850
+ if remove_tag:
851
+ for node in self.module.graph.nodes:
852
+ if hasattr(node, "tag"):
853
+ del node.tag
854
+ return split_module
855
+
856
+ def __call__(self) -> torch.fx.GraphModule:
857
+ subgraphs = self.put_nodes_into_subgraphs()
858
+ subgraphs = self.remove_small_acc_subgraphs(subgraphs)
859
+ acc_subgraphs_count = len([s for s in subgraphs if s.is_acc])
860
+ non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count
861
+ print(f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs")
862
+ self.tag(subgraphs)
863
+ return self.split()
864
+
865
+ def generate_split_results(self) -> SplitResult:
866
+ split_module = self()
867
+ submodule_names = []
868
+ for name, mod in split_module.named_children():
869
+ submodule_names.append(name)
870
+ submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names)
871
+ return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tools_common.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Union, Dict, Any, Set, Mapping, Optional
2
+ import collections
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ import torch.fx
7
+ from torch.fx.node import _get_qualified_name
8
+ from torch.fx._compatibility import compatibility
9
+
10
+ __all__ = ['get_acc_ops_name', 'get_node_target', 'is_node_output_tensor', 'FxNetAccFusionsFinder', 'legalize_graph']
11
+
12
+ Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]]
13
+ TensorOrTensors = Union[torch.Tensor, Tensors]
14
+ NodeList = List[torch.fx.Node]
15
+ NodeSet = Set[torch.fx.Node]
16
+ Names = List[str]
17
+ CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"}
18
+
19
+
20
+ @compatibility(is_backward_compatible=False)
21
+ def get_acc_ops_name(k):
22
+ if isinstance(k, str):
23
+ return k
24
+ elif k.__module__ and "acc_ops" in k.__module__:
25
+ return f"acc_ops.{k.__name__}"
26
+ else:
27
+ module = k.__module__.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module
28
+ return f"{module if module else ''}.{k.__name__}"
29
+
30
+
31
+ @compatibility(is_backward_compatible=False)
32
+ def get_node_target(submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> str:
33
+ """
34
+ Given a `node` returns its target typename.
35
+
36
+ For "call_method" node, return node.target which is the name of that method being called.
37
+ This could potential lead to conflict but should be okay because normally it's on a tensor.
38
+
39
+ For "call_function" node, return typename of node.target.
40
+
41
+ For "call_module" node, return typename of the module that node.target point to.
42
+
43
+ If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by
44
+ "torch". e.g. _VariableFunctionsClass.relu would become torch.relu.
45
+ """
46
+
47
+ assert node.op in CALLABLE_NODE_OPS, (
48
+ "Expect op types of " + ", ".join(CALLABLE_NODE_OPS) + f", but found {node.op}"
49
+ )
50
+
51
+ if node.op == "call_module":
52
+ assert isinstance(node.target, str)
53
+ submod = submodules[node.target]
54
+ submod_type = getattr(submod, "_base_class_origin", type(submod))
55
+ return get_acc_ops_name(submod_type)
56
+ elif node.op == "call_function":
57
+ target: Any = node.target
58
+ return (
59
+ f"acc_ops.{target.__name__}"
60
+ if target.__module__ is not None and "acc_ops" in target.__module__
61
+ else _get_qualified_name(target)
62
+ )
63
+ else:
64
+ assert isinstance(node.target, str)
65
+ return node.target
66
+
67
+ @compatibility(is_backward_compatible=False)
68
+ def is_node_output_tensor(node: torch.fx.Node) -> bool:
69
+ """Checks if the node output produces a Tensor or not.
70
+
71
+ NOTE: This requires to run `ShapeProp` on the containing fx graph before
72
+ calling this function. This is because it works by checking the `type`
73
+ metadata on the node. This metadata is produced by the `ShapeProp`.
74
+ """
75
+ type_ = node.meta.get("type", None)
76
+ return type_ is not None and issubclass(type_, torch.Tensor)
77
+
78
+ @compatibility(is_backward_compatible=False)
79
+ class FxNetAccFusionsFinder:
80
+ """
81
+ Finds groups of connected ACC nodes that pass non-tensor data between each other.
82
+ Such groups are called fusion groups.
83
+ """
84
+
85
+ def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet):
86
+ self.module = module
87
+ self.nodes = list(module.graph.nodes)
88
+ self.acc_nodes = acc_nodes
89
+
90
+ @dataclass
91
+ class FusionGroup:
92
+ # The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model.
93
+ top_node_idx: int
94
+
95
+ # Nodes in this fusion group.
96
+ nodes: NodeSet
97
+
98
+ # Inputs to this fusion group.
99
+ inputs: NodeSet
100
+
101
+ # Nodes that in the fusion group that haven't been processed yet.
102
+ nodes_need_process: NodeSet
103
+
104
+ def add_node(self, node):
105
+ """
106
+ Add a node to fusion group.
107
+ """
108
+ if node in self.nodes:
109
+ return
110
+
111
+ self.nodes_need_process.add(node)
112
+ self.nodes.add(node)
113
+ self.inputs.discard(node)
114
+ self.inputs.update(
115
+ {
116
+ n
117
+ for n in node.all_input_nodes
118
+ if n.op in CALLABLE_NODE_OPS and n not in self.nodes
119
+ }
120
+ )
121
+
122
+ def recursive_add_node(
123
+ self,
124
+ fusion_group: "FxNetAccFusionsFinder.FusionGroup",
125
+ inputs: Union[NodeSet, NodeList],
126
+ visited: Optional[NodeSet] = None,
127
+ ):
128
+ """
129
+ Start from inputs and going reverse topological order. If any upstream node
130
+ is in the fusion group, add all the nodes in this path to fusion group.
131
+ """
132
+ for arg in inputs:
133
+ # skip the node if already seen
134
+ if visited is not None:
135
+ if arg in visited:
136
+ continue
137
+ visited.add(arg)
138
+
139
+ # Skip placeholder and get_attr because they won't be in the fusion group.
140
+ if arg.op not in CALLABLE_NODE_OPS:
141
+ continue
142
+
143
+ # If the node has smaller idx, it's already an upstream node of the fusion
144
+ # group. We don't need to check it anymore.
145
+ if self.nodes.index(arg) < fusion_group.top_node_idx:
146
+ continue
147
+
148
+ # If the node is in the fusion group, return True.
149
+ if arg in fusion_group.nodes:
150
+ return True
151
+
152
+ # Check the upstream nodes of the node, if any of them is in the fusion group
153
+ # we'll add this node to fusion group and return True.
154
+ if self.recursive_add_node(fusion_group, arg.all_input_nodes, visited):
155
+ fusion_group.add_node(arg)
156
+ return True
157
+
158
+ return False
159
+
160
+ def __call__(self) -> Dict[torch.fx.Node, NodeSet]:
161
+ result: Dict[torch.fx.Node, NodeSet] = {}
162
+ acc_nodes = list(self.acc_nodes)
163
+
164
+ for node in acc_nodes:
165
+ if node in result:
166
+ continue
167
+ if node.op not in CALLABLE_NODE_OPS:
168
+ continue
169
+ if "tensor_meta" in node.meta:
170
+ continue
171
+ if node not in self.acc_nodes:
172
+ continue
173
+
174
+ fusion_group: FxNetAccFusionsFinder.FusionGroup = self.FusionGroup(
175
+ top_node_idx=self.nodes.index(node),
176
+ nodes={node},
177
+ inputs=set(node.all_input_nodes),
178
+ nodes_need_process={node},
179
+ )
180
+ while fusion_group.nodes_need_process:
181
+ node = fusion_group.nodes_need_process.pop()
182
+ self.recursive_add_node(
183
+ fusion_group,
184
+ fusion_group.inputs,
185
+ visited=set(),
186
+ )
187
+
188
+ # Optionally add downstream nodes
189
+ if "tensor_meta" not in node.meta:
190
+ for user in node.users:
191
+ if user.op not in CALLABLE_NODE_OPS:
192
+ continue
193
+ if user in fusion_group.nodes:
194
+ continue
195
+
196
+ fusion_group.add_node(user)
197
+ self.recursive_add_node(
198
+ fusion_group,
199
+ fusion_group.inputs,
200
+ visited=set(),
201
+ )
202
+
203
+ # Add some upstream nodes
204
+ for arg in node.all_input_nodes:
205
+ if arg.op not in CALLABLE_NODE_OPS:
206
+ continue
207
+ if "tensor_meta" in arg.meta:
208
+ continue
209
+ if arg in fusion_group.nodes:
210
+ continue
211
+
212
+ fusion_group.add_node(arg)
213
+ fusion_group.top_node_idx = min(
214
+ fusion_group.top_node_idx, self.nodes.index(arg)
215
+ )
216
+ self.recursive_add_node(
217
+ fusion_group,
218
+ fusion_group.inputs,
219
+ visited=set(),
220
+ )
221
+
222
+ if not (set(fusion_group.nodes) <= self.acc_nodes):
223
+ self.acc_nodes -= fusion_group.nodes
224
+ else:
225
+ for n in fusion_group.nodes:
226
+ result[n] = fusion_group.nodes
227
+
228
+ return result
229
+
230
+
231
+ @compatibility(is_backward_compatible=False)
232
+ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
233
+ """
234
+ Replace the graph of the given GraphModule with one that contains the same nodes as the
235
+ original, but in topologically sorted order.
236
+
237
+ This is used by the merge_matmul transformation below, which disturbs the topologically sorted
238
+ order of its input GraphModule, so that this order is restored before further transformation.
239
+
240
+ Arguments:
241
+ gm: The graph module to topologically sort. It is modified in-place.
242
+
243
+ Returns:
244
+ The graph module in-place sorted
245
+ """
246
+ indeg = dict.fromkeys(gm.graph.nodes, 0)
247
+ new_graph = torch.fx.Graph()
248
+ # Track how many unfulfilled dependencies each node has
249
+ for node in gm.graph.nodes:
250
+ for user in node.users:
251
+ indeg[user] += 1
252
+ queue: collections.deque = collections.deque()
253
+ # Add all nodes with no dependencies to the queue
254
+ for node in gm.graph.nodes:
255
+ if indeg[node] == 0:
256
+ queue.append(node)
257
+ env: Dict[torch.fx.Node, torch.fx.Node] = {}
258
+ # Pop nodes from the queue, and add nodes that have had all their
259
+ # dependencies fulfilled
260
+ while len(queue) > 0:
261
+ cur = queue.popleft()
262
+ env[cur] = new_graph.node_copy(cur, lambda x: env[x])
263
+ for user in cur.users:
264
+ indeg[user] -= 1
265
+ if indeg[user] == 0:
266
+ queue.append(user)
267
+ # If the new graph's size is not as large as the old one, then there must be
268
+ # a cycle (i.e. some node's dependencies were not satisfied.)
269
+ if len(new_graph.nodes) < len(gm.graph.nodes):
270
+ raise RuntimeError(f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}")
271
+ new_graph._codegen = gm.graph._codegen
272
+ gm.graph = new_graph
273
+ return gm
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/cpp.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Functionality for Python <-> C++ frontend inter-op."""
2
+
3
+ from torch import nn
4
+
5
+
6
+ class OrderedDictWrapper:
7
+ """A wrapper around a C++ OrderedDict.
8
+
9
+ It dynamically evaluates the OrderedDict getter on a bound C++ module, such
10
+ that new changes on the C++ side are picked up. Otherwise accessing e.g.
11
+ ``cpp_module._parameters`` just once would get a frozen copy of the parameters
12
+ at the time of access. ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__``
13
+ so using properties does not work.
14
+ """
15
+
16
+ def __init__(self, cpp_module, attr):
17
+ self.cpp_module = cpp_module
18
+ self.attr = attr
19
+
20
+ @property
21
+ def cpp_dict(self):
22
+ return getattr(self.cpp_module, self.attr)
23
+
24
+ # Magic methods cannot be assigned dynamically and bypass ``getattr``, so we
25
+ # must manually override them.
26
+
27
+ def items(self):
28
+ return self.cpp_dict.items()
29
+
30
+ def keys(self):
31
+ return self.cpp_dict.keys()
32
+
33
+ def values(self):
34
+ return self.cpp_dict.values()
35
+
36
+ def __iter__(self):
37
+ return self.cpp_dict.__iter__()
38
+
39
+ def __len__(self):
40
+ return self.cpp_dict.__len__()
41
+
42
+ def __contains__(self, key):
43
+ return self.cpp_dict.__contains__(key)
44
+
45
+ def __getitem__(self, key):
46
+ return self.cpp_dict.__getitem__(key)
47
+
48
+
49
+ class ModuleWrapper(nn.Module):
50
+ """A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access."""
51
+
52
+ def __init__(self, cpp_module):
53
+ # Assign before the super class constructor so ``self.training`` can be
54
+ # assigned to in the super class constructor.
55
+ self.cpp_module = cpp_module
56
+ super().__init__()
57
+ self._parameters = OrderedDictWrapper(cpp_module, "_parameters") # type: ignore[assignment]
58
+ self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers") # type: ignore[assignment]
59
+ self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules") # type: ignore[assignment]
60
+ for attr in dir(cpp_module):
61
+ # Skip magic methods and the three attributes above.
62
+ if not attr.startswith("_"):
63
+ setattr(self, attr, getattr(self.cpp_module, attr))
64
+
65
+ def _apply(self, fn, recurse=True):
66
+ for param in self.parameters():
67
+ # Tensors stored in modules are graph leaves, and we don't
68
+ # want to create copy nodes, so we have to unpack the data.
69
+ param.data = fn(param.data)
70
+ if param._grad is not None:
71
+ param._grad.data = fn(param._grad.data)
72
+
73
+ for buf in self.buffers():
74
+ buf.data = fn(buf.data)
75
+
76
+ return self
77
+
78
+ # nn.Module defines training as a boolean
79
+ @property # type: ignore[override]
80
+ def training(self):
81
+ return self.cpp_module.training
82
+
83
+ @training.setter
84
+ def training(self, mode):
85
+ self.cpp_module.train(mode)
86
+
87
+ def __repr__(self):
88
+ return self.cpp_module.__repr__()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/functional.py ADDED
The diff for this file is too large to render. See raw diff
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-311.pyc ADDED
Binary file (713 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/linear_relu.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Intrinsic QAT Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and
5
+ is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
8
+ while adding an import statement here.
9
+ """
10
+
11
+ __all__ = [
12
+ 'LinearReLU',
13
+ ]
14
+
15
+ from torch.ao.nn.intrinsic.qat import LinearReLU
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (435 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modules import * # noqa: F403
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .linear_relu import LinearReLU
2
+
3
+ __all__ = [
4
+ 'LinearReLU',
5
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-311.pyc ADDED
Binary file (366 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/bn_relu.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from torch.ao.nn.intrinsic.quantized import BNReLU2d
2
+ from torch.ao.nn.intrinsic.quantized import BNReLU3d
3
+
4
+ __all__ = [
5
+ 'BNReLU2d',
6
+ 'BNReLU3d',
7
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/linear_relu.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from torch.ao.nn.intrinsic.quantized import LinearReLU
2
+
3
+ __all__ = [
4
+ 'LinearReLU',
5
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/activation.cpython-311.pyc ADDED
Binary file (73.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-311.pyc ADDED
Binary file (16.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-311.pyc ADDED
Binary file (2.56 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/distance.cpython-311.pyc ADDED
Binary file (5.04 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-311.pyc ADDED
Binary file (24.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/linear.cpython-311.pyc ADDED
Binary file (14.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/normalization.cpython-311.pyc ADDED
Binary file (15.6 kB). View file