koichi12 commited on
Commit
ec4fbbc
·
verified ·
1 Parent(s): 9aa23eb

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__init__.py +406 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__init__.py +52 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-311.pyc +0 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-311.pyc +0 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-311.pyc +0 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-311.pyc +0 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-311.pyc +0 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-311.pyc +0 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-311.pyc +0 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-311.pyc +0 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-311.pyc +0 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-311.pyc +0 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-311.pyc +0 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-311.pyc +0 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-311.pyc +0 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-311.pyc +0 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/torch_sym_min.cpython-311.pyc +0 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-311.pyc +0 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-311.pyc +0 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/assume_constant_result.py +24 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/class_method.py +24 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nested_function.py +44 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_closed_over_variable.py +23 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_operands.py +39 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_predicate.py +29 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_size_example.py +27 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_value_example.py +30 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dictionary.py +21 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py +21 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py +20 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_view.py +22 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_contains.py +21 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_unpack.py +27 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/model_attr_mutation.py +25 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/pytree_flatten.py +20 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/scalar_output.py +23 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/specialized_attribute.py +29 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_for_loop.py +22 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_if.py +23 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/tensor_setattr.py +17 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/type_reflection_method.py +41 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/remove_runtime_assertions.py +26 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py +141 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/utils.py +401 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/wrappers.py +114 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/config.cpython-311.pyc +0 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/device_context.cpython-311.pyc +0 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-311.pyc +0 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/__init__.cpython-311.pyc +0 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-311.pyc +0 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__init__.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import dataclasses
3
+ import functools
4
+ import io
5
+ import json
6
+ import os
7
+ import re
8
+ import sys
9
+ import types
10
+ import warnings
11
+ import weakref
12
+ import zipfile
13
+ from collections import OrderedDict
14
+ from contextlib import contextmanager
15
+
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+ from unittest.mock import patch
18
+
19
+ import sympy
20
+
21
+ import torch
22
+ import torch._dynamo
23
+ import torch.fx
24
+ import torch.utils._pytree as pytree
25
+
26
+ from torch._decomp import core_aten_decompositions, get_decompositions
27
+ from torch._dispatch.python import enable_python_dispatcher
28
+ from torch._dynamo.exc import UserError, UserErrorType
29
+ from torch._dynamo.source import ConstantSource
30
+ from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass
31
+ from torch._functorch.aot_autograd import aot_export_module, GraphSignature
32
+ from torch._functorch.eager_transforms import functionalize
33
+ from torch._guards import detect_fake_mode
34
+ from torch._inductor import config
35
+ from torch._ops import OpOverload
36
+ from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
37
+ from torch._subclasses.functional_tensor import FunctionalTensor
38
+ from torch._utils_internal import log_export_usage
39
+ from torch.export._tree_utils import reorder_kwargs
40
+ from torch.export._unlift import _create_stateful_graph_module
41
+ from torch.export.dynamic_shapes import (
42
+ _process_constraints,
43
+ _process_dynamic_shapes,
44
+ Constraint,
45
+ dims,
46
+ dynamic_dim,
47
+ )
48
+ from torch.export.exported_program import (
49
+ _disable_prexisiting_fake_mode,
50
+ ExportedProgram,
51
+ ModuleCallEntry,
52
+ ModuleCallSignature,
53
+ )
54
+ from torch.export.graph_signature import (
55
+ _sig_to_specs,
56
+ ArgumentSpec,
57
+ ConstantArgument,
58
+ ExportGraphSignature,
59
+ InputKind,
60
+ InputSpec,
61
+ OutputKind,
62
+ OutputSpec,
63
+ SymIntArgument,
64
+ TensorArgument,
65
+ )
66
+ from torch.fx import traceback as fx_traceback
67
+ from torch.fx._compatibility import compatibility
68
+ from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
69
+ from torch.fx.experimental.symbolic_shapes import (
70
+ ConstraintViolationError,
71
+ GuardOnDataDependentSymNode,
72
+ ShapeEnv,
73
+ StrictMinMaxConstraint,
74
+ )
75
+ from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
76
+ from torch.utils._sympy.value_ranges import ValueRangeError, ValueRanges
77
+
78
+ from .passes.add_runtime_assertions_for_constraints_pass import (
79
+ _AddRuntimeAssertionsForInlineConstraintsPass,
80
+ )
81
+ from .wrappers import _wrap_submodules
82
+
83
+
84
+ @dataclasses.dataclass
85
+ class ExportDynamoConfig:
86
+ """
87
+ Manage Export-specific configurations of Dynamo.
88
+ """
89
+ allow_rnn: bool = True
90
+
91
+
92
+ @compatibility(is_backward_compatible=False)
93
+ def capture_pre_autograd_graph(
94
+ f: torch.nn.Module,
95
+ args: Tuple[Any],
96
+ kwargs: Optional[Dict[str, Any]] = None,
97
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
98
+ ) -> torch.nn.Module:
99
+ """
100
+ A helper function that is intended to trace a module before any pre-autograd
101
+ decomposition is run. The produced module will be "non-functional" and
102
+ composed of aten operators. Later this API will be deleted in favor of more general
103
+ torch.export API.
104
+
105
+ Args:
106
+ f: nn.Module to be traced
107
+
108
+ args: example positional inputs.
109
+
110
+ kwargs: optional example keyword inputs.
111
+
112
+ dynamic_shapes: Should either be:
113
+ 1) a dict from argument names of ``f`` to their dynamic shape specifications,
114
+ 2) a tuple that specifies dynamic shape specifications for each input in original order.
115
+ If you are specifying dynamism on keyword args, you will need to pass them in the order that
116
+ is defined in the original function signature.
117
+
118
+ The dynamic shape of a tensor argument can be specified as either
119
+ (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
120
+ not required to include static dimension indices in this dict, but when they are,
121
+ they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
122
+ where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
123
+ are denoted by None. Arguments that are dicts or tuples / lists of tensors are
124
+ recursively specified by using mappings or sequences of contained specifications.
125
+
126
+ Returns:
127
+ An nn.Module containing the traced method.
128
+
129
+ """
130
+ from torch.export._trace import _convert_input_to_fake, DEFAULT_EXPORT_DYNAMO_CONFIG
131
+ from torch.export.dynamic_shapes import _process_dynamic_shapes
132
+
133
+ log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
134
+
135
+ assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance."
136
+
137
+ if kwargs is None:
138
+ kwargs = {}
139
+
140
+ constraints = _process_dynamic_shapes(f, args, kwargs, dynamic_shapes)
141
+
142
+ # Do not decompose dropout for exported models, because in eval mode the dropout
143
+ # op disappears from the graph, which makes it difficult to switch to train mode.
144
+ # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
145
+ decomp_table = {
146
+ op: op.decompose
147
+ for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
148
+ if op != torch.ops.aten.dropout.default
149
+ }
150
+ with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
151
+ m = torch._dynamo.export(
152
+ f,
153
+ constraints=constraints,
154
+ assume_static_by_default=True,
155
+ tracing_mode="symbolic",
156
+ decomposition_table=decomp_table,
157
+ pre_dispatch=True,
158
+ aten_graph=True,
159
+ _log_export_usage=False,
160
+ )(
161
+ *args,
162
+ **kwargs,
163
+ )[0]
164
+
165
+ _, _, _, fake_mode = _convert_input_to_fake(m, args, kwargs)
166
+
167
+ m.meta["inline_constraints"] = {
168
+ k: v
169
+ for k, v in fake_mode.shape_env.var_to_range.items()
170
+ if re.match(r"^[if]\d+$", str(k))
171
+ }
172
+
173
+ if isinstance(f, torch.nn.Module):
174
+ from torch.export._trace import _restore_state_dict
175
+ _restore_state_dict(f, m)
176
+
177
+ flat_args, _ = pytree.tree_flatten((args, kwargs or {}))
178
+ range_constraints = _process_constraints(fake_mode, m, 0, flat_args)
179
+
180
+ module = _create_stateful_graph_module(
181
+ m,
182
+ range_constraints=range_constraints,
183
+ )
184
+
185
+ error_message = \
186
+ """
187
+ Calling train() or eval() is not supported for exported models.
188
+ Alternatively, you may override these methods to do custom user behavior as follows:
189
+
190
+ def _my_train(self, mode: bool = True):
191
+ ...
192
+
193
+ def _my_eval(self):
194
+ ...
195
+
196
+ model.train = types.MethodType(_my_train, model)
197
+ model.eval = types.MethodType(_my_eval, model)
198
+ """
199
+
200
+ def _train(self, mode: bool = True):
201
+ raise NotImplementedError(error_message)
202
+
203
+ def _eval(self, mode: bool = True):
204
+ raise NotImplementedError(error_message)
205
+
206
+ module.train = types.MethodType(_train, module) # type: ignore[method-assign]
207
+ module.eval = types.MethodType(_eval, module) # type: ignore[method-assign]
208
+ return module
209
+
210
+
211
+ def save(
212
+ ep: ExportedProgram,
213
+ f: Union[str, os.PathLike, io.BytesIO],
214
+ *,
215
+ extra_files: Optional[Dict[str, Any]] = None,
216
+ opset_version: Optional[Dict[str, int]] = None,
217
+ ) -> None:
218
+ if not isinstance(ep, ExportedProgram):
219
+ raise TypeError(f"save() expects an ExportedProgram but got {type(ep)}")
220
+
221
+ from .serde.serialize import serialize, SerializedArtifact
222
+ from .serde.schema import SCHEMA_VERSION
223
+ artifact: SerializedArtifact = serialize(ep, opset_version)
224
+
225
+ if isinstance(f, (str, os.PathLike)):
226
+ f = os.fspath(f)
227
+
228
+ with zipfile.ZipFile(f, 'w') as zipf:
229
+ # Save every field the SerializedArtifact to a file
230
+ assert isinstance(artifact.exported_program, bytes)
231
+ zipf.writestr("serialized_exported_program.json", artifact.exported_program)
232
+ zipf.writestr("serialized_state_dict.pt", artifact.state_dict)
233
+ zipf.writestr("serialized_constants.pt", artifact.constants)
234
+
235
+ zipf.writestr('version', ".".join(map(str, SCHEMA_VERSION)))
236
+
237
+ # Add extra files if provided
238
+ if extra_files:
239
+ for extra_file_name, content in extra_files.items():
240
+ encoded_content = content.encode('utf-8')
241
+ zipf.writestr(f"extra_files/{extra_file_name}", encoded_content)
242
+
243
+
244
+ def load(
245
+ f: Union[str, os.PathLike, io.BytesIO],
246
+ *,
247
+ extra_files: Optional[Dict[str, Any]] = None,
248
+ expected_opset_version: Optional[Dict[str, int]] = None,
249
+ ) -> ExportedProgram:
250
+ if isinstance(f, (str, os.PathLike)):
251
+ f = os.fspath(f)
252
+
253
+ extra_files = extra_files or {}
254
+
255
+ with zipfile.ZipFile(f, 'r') as zipf:
256
+ # Check the version
257
+ version = zipf.read('version').decode().split('.')
258
+ from .serde.schema import SCHEMA_VERSION
259
+
260
+ assert len(version) == len(SCHEMA_VERSION)
261
+ if version[0] != str(SCHEMA_VERSION[0]):
262
+ raise RuntimeError(
263
+ f"Serialized version {version} does not match our current "
264
+ f"schema version {SCHEMA_VERSION}."
265
+ )
266
+
267
+ from .serde.serialize import deserialize, SerializedArtifact
268
+
269
+ # Load serialized_ep and serialized_state_dict from the zip file
270
+
271
+ serialized_exported_program: Optional[bytes] = None
272
+ serialized_state_dict: Optional[bytes] = None
273
+ serialized_constants: Optional[bytes] = None
274
+
275
+ for file_info in zipf.infolist():
276
+ file_content = zipf.read(file_info.filename)
277
+
278
+ if file_info.filename == "serialized_exported_program.json":
279
+ serialized_exported_program = file_content
280
+ elif file_info.filename == "serialized_state_dict.json":
281
+ warnings.warn("This version of file is deprecated")
282
+ serialized_state_dict = file_content
283
+ elif file_info.filename == "serialized_constants.json":
284
+ warnings.warn("This version of file is deprecated")
285
+ serialized_constants = file_content
286
+ elif file_info.filename == "serialized_state_dict.pt":
287
+ serialized_state_dict = file_content
288
+ elif file_info.filename == "serialized_constants.pt":
289
+ serialized_constants = file_content
290
+ elif file_info.filename.startswith("extra_files"):
291
+ filename = file_info.filename.split("/", 1)[1]
292
+ extra_files[filename] = file_content.decode('utf-8')
293
+
294
+ assert serialized_exported_program is not None
295
+ assert serialized_state_dict is not None
296
+ assert serialized_constants is not None
297
+ artifact: SerializedArtifact = SerializedArtifact(
298
+ serialized_exported_program,
299
+ serialized_state_dict,
300
+ serialized_constants,
301
+ )
302
+
303
+ # Deserialize ExportedProgram
304
+ ep = deserialize(artifact, expected_opset_version)
305
+
306
+ return ep
307
+
308
+
309
+ def aot_compile(
310
+ f: Callable,
311
+ args: Tuple[Any],
312
+ kwargs: Optional[Dict[str, Any]] = None,
313
+ *,
314
+ dynamic_shapes: Optional[Dict[str, Any]] = None,
315
+ options: Optional[Dict[str, Any]] = None,
316
+ remove_runtime_assertions: bool = False,
317
+ disable_constraint_solver: bool = False,
318
+ ) -> str:
319
+ """
320
+ Note: this function is not stable yet
321
+
322
+ Traces either an nn.Module's forward function or just a callable with PyTorch
323
+ operations inside, generates executable cpp code from the program, and returns
324
+ the path to the generated shared library
325
+
326
+ Args:
327
+ f: the `nn.Module` or callable to trace.
328
+
329
+ args: example positional inputs.
330
+
331
+ kwargs: optional example keyword inputs.
332
+
333
+ dynamic_shapes: Should either be:
334
+ 1) a dict from argument names of ``f`` to their dynamic shape specifications,
335
+ 2) a tuple that specifies dynamic shape specifications for each input in original order.
336
+ If you are specifying dynamism on keyword args, you will need to pass them in the order that
337
+ is defined in the original function signature.
338
+
339
+ The dynamic shape of a tensor argument can be specified as either
340
+ (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
341
+ not required to include static dimension indices in this dict, but when they are,
342
+ they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
343
+ where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
344
+ are denoted by None. Arguments that are dicts or tuples / lists of tensors are
345
+ recursively specified by using mappings or sequences of contained specifications.
346
+
347
+ options: A dictionary of options to control inductor
348
+
349
+ disable_constraint_solver: Whether the dim constraint solver must be disabled.
350
+
351
+ Returns:
352
+ Path to the generated shared library
353
+ """
354
+ from torch.export._trace import _export_to_torch_ir
355
+ from torch._inductor.decomposition import select_decomp_table
356
+
357
+ constraints = _process_dynamic_shapes(f, args, kwargs, dynamic_shapes)
358
+
359
+ if config.is_predispatch:
360
+ gm = torch.export._trace._export(f, args, kwargs, constraints, pre_dispatch=True).module()
361
+ else:
362
+ # We want to export to Torch IR here to utilize the pre_grad passes in
363
+ # inductor, which run on Torch IR.
364
+ gm = _export_to_torch_ir(
365
+ f,
366
+ args,
367
+ kwargs,
368
+ constraints,
369
+ disable_constraint_solver=disable_constraint_solver,
370
+ # Disabling this flag, because instead we can rely on the mapping
371
+ # dynamo_flat_name_to_original_fqn which is coming from Dynamo.
372
+ restore_fqn=False,
373
+ )
374
+ flat_example_inputs = pytree.arg_tree_leaves(*args, **(kwargs or {}))
375
+
376
+ with torch.no_grad():
377
+ so_path = torch._inductor.aot_compile(gm, flat_example_inputs, options) # type: ignore[arg-type]
378
+
379
+ return so_path
380
+
381
+ def aot_load(so_path: str, device: str) -> Callable:
382
+ """
383
+ Loads a shared library generated by aot_compile and returns a callable
384
+
385
+ Args:
386
+ so_path: Path to the shared library
387
+
388
+ Returns:
389
+ A callable
390
+ """
391
+ if device == "cpu":
392
+ runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg]
393
+ elif device == "cuda" or device.startswith("cuda:"):
394
+ runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg]
395
+ else:
396
+ raise RuntimeError("Unsupported device " + device)
397
+
398
+ def optimized(*args, **kwargs):
399
+ call_spec = runner.get_call_spec() # type: ignore[attr-defined]
400
+ in_spec = pytree.treespec_loads(call_spec[0])
401
+ out_spec = pytree.treespec_loads(call_spec[1])
402
+ flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
403
+ flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined]
404
+ return pytree.tree_unflatten(flat_outputs, out_spec)
405
+
406
+ return optimized
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import importlib
3
+ from os.path import basename, dirname, isfile, join
4
+
5
+ import torch
6
+ from torch._export.db.case import (
7
+ _EXAMPLE_CASES,
8
+ _EXAMPLE_CONFLICT_CASES,
9
+ _EXAMPLE_REWRITE_CASES,
10
+ SupportLevel,
11
+ )
12
+
13
+
14
+ modules = glob.glob(join(dirname(__file__), "*.py"))
15
+ __all__ = [
16
+ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py")
17
+ ]
18
+
19
+ # Import all module in the current directory.
20
+ from . import * # noqa: F403
21
+
22
+
23
+ def all_examples():
24
+ return _EXAMPLE_CASES
25
+
26
+
27
+ if len(_EXAMPLE_CONFLICT_CASES) > 0:
28
+
29
+ def get_name(case):
30
+ model = case.model
31
+ if isinstance(model, torch.nn.Module):
32
+ model = type(model)
33
+ return model.__name__
34
+
35
+ msg = "Error on conflict export case name.\n"
36
+ for case_name, cases in _EXAMPLE_CONFLICT_CASES.items():
37
+ msg += f"Case name {case_name} is associated with multiple cases:\n "
38
+ msg += f"[{','.join(map(get_name, cases))}]\n"
39
+
40
+ raise RuntimeError(msg)
41
+
42
+
43
+ def filter_examples_by_support_level(support_level: SupportLevel):
44
+ return {
45
+ key: val
46
+ for key, val in all_examples().items()
47
+ if val.support_level == support_level
48
+ }
49
+
50
+
51
+ def get_rewrite_cases(case):
52
+ return _EXAMPLE_REWRITE_CASES.get(case.name, [])
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.96 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-311.pyc ADDED
Binary file (1.91 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-311.pyc ADDED
Binary file (1.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-311.pyc ADDED
Binary file (3.06 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-311.pyc ADDED
Binary file (1.81 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-311.pyc ADDED
Binary file (1.96 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-311.pyc ADDED
Binary file (1.56 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-311.pyc ADDED
Binary file (1.48 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-311.pyc ADDED
Binary file (1.53 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-311.pyc ADDED
Binary file (1.64 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-311.pyc ADDED
Binary file (1.78 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-311.pyc ADDED
Binary file (1.73 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-311.pyc ADDED
Binary file (1.81 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-311.pyc ADDED
Binary file (1.24 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/torch_sym_min.cpython-311.pyc ADDED
Binary file (1.28 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-311.pyc ADDED
Binary file (2.86 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-311.pyc ADDED
Binary file (1.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/assume_constant_result.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch._dynamo as torchdynamo
3
+
4
+ from torch._export.db.case import export_case
5
+
6
+
7
+ @export_case(
8
+ example_inputs=(torch.ones(3, 2), torch.tensor(4)),
9
+ tags={"torch.escape-hatch"},
10
+ )
11
+ class AssumeConstantResult(torch.nn.Module):
12
+ """
13
+ Applying `assume_constant_result` decorator to burn make non-tracable code as constant.
14
+ """
15
+
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ @torchdynamo.assume_constant_result
20
+ def get_item(self, y):
21
+ return y.int().item()
22
+
23
+ def forward(self, x, y):
24
+ return x[: self.get_item(y)]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/class_method.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+
5
+
6
+ @export_case(
7
+ example_inputs=(torch.ones(3, 4),),
8
+ )
9
+ class ClassMethod(torch.nn.Module):
10
+ """
11
+ Class methods are inlined during tracing.
12
+ """
13
+
14
+ @classmethod
15
+ def method(cls, x):
16
+ return x + 1
17
+
18
+ def __init__(self):
19
+ super().__init__()
20
+ self.linear = torch.nn.Linear(4, 2)
21
+
22
+ def forward(self, x):
23
+ x = self.linear(x)
24
+ return self.method(x) * self.__class__.method(x) * type(self).method(x)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nested_function.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+ from functorch.experimental.control_flow import cond
5
+
6
+
7
+ @export_case(
8
+ example_inputs=(torch.ones(3),),
9
+ tags={
10
+ "torch.cond",
11
+ "torch.dynamic-shape",
12
+ },
13
+ )
14
+ class CondBranchNestedFunction(torch.nn.Module):
15
+ """
16
+ The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
17
+ - both branches must take the same args, which must also match the branch args passed to cond.
18
+ - both branches must return a single tensor
19
+ - returned tensor must have the same tensor metadata, e.g. shape and dtype
20
+ - branch function can be free function, nested function, lambda, class methods
21
+ - branch function can not have closure variables
22
+ - no inplace mutations on inputs or global variables
23
+
24
+ This example demonstrates using nested function in cond().
25
+
26
+ NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
27
+ """
28
+ def __init__(self):
29
+ super().__init__()
30
+
31
+ def forward(self, x):
32
+ def true_fn(x):
33
+ def inner_true_fn(y):
34
+ return x + y
35
+
36
+ return inner_true_fn(x)
37
+
38
+ def false_fn(x):
39
+ def inner_false_fn(y):
40
+ return x - y
41
+
42
+ return inner_false_fn(x)
43
+
44
+ return cond(x.shape[0] < 10, true_fn, false_fn, [x])
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_closed_over_variable.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+ from functorch.experimental.control_flow import cond
5
+
6
+
7
+ @export_case(
8
+ example_inputs=(torch.tensor(True), torch.ones(3, 2)),
9
+ tags={"torch.cond", "python.closure"},
10
+ )
11
+ class CondClosedOverVariable(torch.nn.Module):
12
+ """
13
+ torch.cond() supports branches closed over arbitrary variables.
14
+ """
15
+
16
+ def forward(self, pred, x):
17
+ def true_fn(val):
18
+ return x * 2
19
+
20
+ def false_fn(val):
21
+ return x - 2
22
+
23
+ return cond(pred, true_fn, false_fn, [x + 1])
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_operands.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+ from torch.export import Dim
5
+ from functorch.experimental.control_flow import cond
6
+
7
+ x = torch.randn(3, 2)
8
+ y = torch.ones(2)
9
+ dim0_x = Dim("dim0_x")
10
+
11
+ @export_case(
12
+ example_inputs=(x, y),
13
+ tags={
14
+ "torch.cond",
15
+ "torch.dynamic-shape",
16
+ },
17
+ extra_inputs=(torch.randn(2, 2), torch.ones(2)),
18
+ dynamic_shapes={"x": {0: dim0_x}, "y": None},
19
+ )
20
+ class CondOperands(torch.nn.Module):
21
+ """
22
+ The operands passed to cond() must be:
23
+ - a list of tensors
24
+ - match arguments of `true_fn` and `false_fn`
25
+
26
+ NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
27
+ """
28
+
29
+ def __init__(self):
30
+ super().__init__()
31
+
32
+ def forward(self, x, y):
33
+ def true_fn(x, y):
34
+ return x + y
35
+
36
+ def false_fn(x, y):
37
+ return x - y
38
+
39
+ return cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_predicate.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+ from functorch.experimental.control_flow import cond
5
+
6
+
7
+ @export_case(
8
+ example_inputs=(torch.ones(6, 4, 3),),
9
+ tags={
10
+ "torch.cond",
11
+ "torch.dynamic-shape",
12
+ },
13
+ )
14
+ class CondPredicate(torch.nn.Module):
15
+ """
16
+ The conditional statement (aka predicate) passed to cond() must be one of the following:
17
+ - torch.Tensor with a single element
18
+ - boolean expression
19
+
20
+ NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
21
+ """
22
+
23
+ def __init__(self):
24
+ super().__init__()
25
+
26
+ def forward(self, x):
27
+ pred = x.dim() > 2 and x.shape[2] > 10
28
+
29
+ return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_size_example.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+
5
+
6
+ @export_case(
7
+ example_inputs=(torch.tensor(4),),
8
+ tags={
9
+ "torch.dynamic-value",
10
+ "torch.escape-hatch",
11
+ },
12
+ )
13
+ class ConstrainAsSizeExample(torch.nn.Module):
14
+ """
15
+ If the value is not known at tracing time, you can provide hint so that we
16
+ can trace further. Please look at constrain_as_value and constrain_as_size APIs
17
+ constrain_as_size is used for values that NEED to be used for constructing
18
+ tensor.
19
+ """
20
+
21
+ def __init__(self):
22
+ super().__init__()
23
+
24
+ def forward(self, x):
25
+ a = x.item()
26
+ torch._constrain_as_size(a, min=0, max=5)
27
+ return torch.ones((a, 5))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_value_example.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+
5
+
6
+ @export_case(
7
+ example_inputs=(torch.tensor(4), torch.randn(5, 5)),
8
+ tags={
9
+ "torch.dynamic-value",
10
+ "torch.escape-hatch",
11
+ },
12
+ )
13
+ class ConstrainAsValueExample(torch.nn.Module):
14
+ """
15
+ If the value is not known at tracing time, you can provide hint so that we
16
+ can trace further. Please look at constrain_as_value and constrain_as_size APIs.
17
+ constrain_as_value is used for values that don't need to be used for constructing
18
+ tensor.
19
+ """
20
+
21
+ def __init__(self):
22
+ super().__init__()
23
+
24
+ def forward(self, x, y):
25
+ a = x.item()
26
+ torch._constrain_as_value(a, min=0, max=5)
27
+
28
+ if a < 6:
29
+ return y.sin()
30
+ return y.cos()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dictionary.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+
5
+
6
+ @export_case(
7
+ example_inputs=(torch.ones(3, 2), torch.tensor(4)),
8
+ tags={"python.data-structure"},
9
+ )
10
+ class Dictionary(torch.nn.Module):
11
+ """
12
+ Dictionary structures are inlined and flattened along tracing.
13
+ """
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def forward(self, x, y):
18
+ elements = {}
19
+ elements["x2"] = x * x
20
+ y = y * elements["x2"]
21
+ return {"y": y}
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+
5
+
6
+ @export_case(
7
+ example_inputs=(torch.ones(3, 2, 2),),
8
+ tags={"torch.dynamic-shape", "python.control-flow"},
9
+ )
10
+ class DynamicShapeIfGuard(torch.nn.Module):
11
+ """
12
+ `if` statement with backed dynamic shape predicate will be specialized into
13
+ one particular branch and generate a guard. However, export will fail if the
14
+ the dimension is marked as dynamic shape from higher level API.
15
+ """
16
+
17
+ def forward(self, x):
18
+ if x.shape[0] == 3:
19
+ return x.cos()
20
+
21
+ return x.sin()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+
5
+
6
+ @export_case(
7
+ example_inputs=(torch.ones(3, 2),),
8
+ tags={"torch.dynamic-shape"},
9
+ )
10
+ class DynamicShapeSlicing(torch.nn.Module):
11
+ """
12
+ Slices with dynamic shape arguments should be captured into the graph
13
+ rather than being baked in.
14
+ """
15
+
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ def forward(self, x):
20
+ return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_view.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+
5
+
6
+ @export_case(
7
+ example_inputs=(torch.ones(10, 10),),
8
+ tags={"torch.dynamic-shape"},
9
+ )
10
+ class DynamicShapeView(torch.nn.Module):
11
+ """
12
+ Dynamic shapes should be propagated to view arguments instead of being
13
+ baked into the exported graph.
14
+ """
15
+
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ def forward(self, x):
20
+ new_x_shape = x.size()[:-1] + (2, 5)
21
+ x = x.view(*new_x_shape)
22
+ return x.permute(0, 2, 1)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_contains.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+
5
+
6
+ @export_case(
7
+ example_inputs=(torch.ones(3, 2),),
8
+ tags={"torch.dynamic-shape", "python.data-structure", "python.assert"},
9
+ )
10
+ class ListContains(torch.nn.Module):
11
+ """
12
+ List containment relation can be checked on a dynamic shape or constants.
13
+ """
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def forward(self, x):
18
+ assert x.size(-1) in [6, 2]
19
+ assert x.size(0) not in [4, 5, 6]
20
+ assert "monkey" not in ["cow", "pig"]
21
+ return x + x
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_unpack.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+
5
+ from torch._export.db.case import export_case
6
+
7
+
8
+ @export_case(
9
+ example_inputs=([torch.ones(3, 2), torch.tensor(4), torch.tensor(5)],),
10
+ tags={"python.control-flow", "python.data-structure"},
11
+ )
12
+ class ListUnpack(torch.nn.Module):
13
+ """
14
+ Lists are treated as static construct, therefore unpacking should be
15
+ erased after tracing.
16
+ """
17
+
18
+ def __init__(self):
19
+ super().__init__()
20
+
21
+ def forward(self, args: List[torch.Tensor]):
22
+ """
23
+ Lists are treated as static construct, therefore unpacking should be
24
+ erased after tracing.
25
+ """
26
+ x, *y = args
27
+ return x + y[0]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/model_attr_mutation.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case, SupportLevel
4
+
5
+
6
+ @export_case(
7
+ example_inputs=(torch.ones(3, 2),),
8
+ tags={"python.object-model"},
9
+ support_level=SupportLevel.NOT_SUPPORTED_YET,
10
+ )
11
+ class ModelAttrMutation(torch.nn.Module):
12
+ """
13
+ Attribute mutation is not supported.
14
+ """
15
+
16
+ def __init__(self):
17
+ super().__init__()
18
+ self.attr_list = [torch.ones(3, 2), torch.ones(3, 2)]
19
+
20
+ def recreate_list(self):
21
+ return [torch.zeros(3, 2), torch.zeros(3, 2)]
22
+
23
+ def forward(self, x):
24
+ self.attr_list = self.recreate_list()
25
+ return x.sum() + self.attr_list[0].sum()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/pytree_flatten.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case, SupportLevel
4
+ from torch.utils import _pytree as pytree
5
+
6
+
7
+ @export_case(
8
+ example_inputs=({1: torch.randn(3, 2), 2: torch.randn(3, 2)},),
9
+ support_level=SupportLevel.SUPPORTED,
10
+ )
11
+ class PytreeFlatten(torch.nn.Module):
12
+ """
13
+ Pytree from PyTorch can be captured by TorchDynamo.
14
+ """
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ def forward(self, x):
19
+ y, spec = pytree.tree_flatten(x)
20
+ return y[0] + 1
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/scalar_output.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+ from torch.export import Dim
5
+
6
+ x = torch.ones(3, 2)
7
+ dim1_x = Dim("dim1_x")
8
+
9
+ @export_case(
10
+ example_inputs=(x,),
11
+ tags={"torch.dynamic-shape"},
12
+ dynamic_shapes={"x": {1: dim1_x}},
13
+ )
14
+ class ScalarOutput(torch.nn.Module):
15
+ """
16
+ Returning scalar values from the graph is supported, in addition to Tensor
17
+ outputs. Symbolic shapes are captured and rank is specialized.
18
+ """
19
+ def __init__(self):
20
+ super().__init__()
21
+
22
+ def forward(self, x):
23
+ return x.shape[1] + 1
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/specialized_attribute.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import torch
4
+
5
+ from torch._export.db.case import export_case
6
+
7
+
8
+ class Animal(Enum):
9
+ COW = "moo"
10
+
11
+
12
+ @export_case(
13
+ example_inputs=(torch.ones(3, 2),),
14
+ )
15
+ class SpecializedAttribute(torch.nn.Module):
16
+ """
17
+ Model attributes are specialized.
18
+ """
19
+
20
+ def __init__(self):
21
+ super().__init__()
22
+ self.a = "moo"
23
+ self.b = 4
24
+
25
+ def forward(self, x):
26
+ if self.a == Animal.COW.value:
27
+ return x * x + self.b
28
+ else:
29
+ raise ValueError("bad")
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_for_loop.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+
5
+
6
+ @export_case(
7
+ example_inputs=(torch.ones(3, 2),),
8
+ tags={"python.control-flow"},
9
+ )
10
+ class StaticForLoop(torch.nn.Module):
11
+ """
12
+ A for loop with constant number of iterations should be unrolled in the exported graph.
13
+ """
14
+
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ def forward(self, x):
19
+ ret = []
20
+ for i in range(10): # constant
21
+ ret.append(i + x)
22
+ return ret
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_if.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+
5
+
6
+ @export_case(
7
+ example_inputs=(torch.ones(3, 2, 2),),
8
+ tags={"python.control-flow"},
9
+ )
10
+ class StaticIf(torch.nn.Module):
11
+ """
12
+ `if` statement with static predicate value should be traced through with the
13
+ taken branch.
14
+ """
15
+
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ def forward(self, x):
20
+ if len(x.shape) == 3:
21
+ return x + torch.ones(1, 1, 1)
22
+
23
+ return x
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/tensor_setattr.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case, SupportLevel
4
+
5
+
6
+ @export_case(
7
+ example_inputs=(torch.randn(3, 2), "attr"),
8
+ tags={"python.builtin"},
9
+ support_level=SupportLevel.SUPPORTED,
10
+ )
11
+ class TensorSetattr(torch.nn.Module):
12
+ """
13
+ setattr() call onto tensors is not supported.
14
+ """
15
+ def forward(self, x, attr):
16
+ setattr(x, attr, torch.randn(3, 2))
17
+ return x + 4
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/type_reflection_method.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case, SupportLevel, export_rewrite_case
4
+
5
+
6
+ class A:
7
+ @classmethod
8
+ def func(cls, x):
9
+ return 1 + x
10
+
11
+
12
+ @export_case(
13
+ example_inputs=(torch.ones(3, 4),),
14
+ tags={"python.builtin"},
15
+ support_level=SupportLevel.SUPPORTED,
16
+ )
17
+ class TypeReflectionMethod(torch.nn.Module):
18
+ """
19
+ type() calls on custom objects followed by attribute accesses are not allowed
20
+ due to its overly dynamic nature.
21
+ """
22
+
23
+ def __init__(self):
24
+ super().__init__()
25
+
26
+ def forward(self, x):
27
+ a = A()
28
+ return type(a).func(x)
29
+
30
+
31
+ @export_rewrite_case(parent=TypeReflectionMethod)
32
+ class TypeReflectionMethodRewrite(torch.nn.Module):
33
+ """
34
+ Custom object class methods will be inlined.
35
+ """
36
+
37
+ def __init__(self):
38
+ super().__init__()
39
+
40
+ def forward(self, x):
41
+ return A.func(x)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/remove_runtime_assertions.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.fx.passes.infra.pass_base import PassBase, PassResult
3
+
4
+
5
+ class _RemoveRuntimeAssertionsPass(PassBase):
6
+ """
7
+ Remove runtime assertions inserted by the
8
+ _AddRuntimeAssertionsForInlineConstraintsPass.
9
+ """
10
+
11
+ def call(self, graph_module) -> PassResult:
12
+ modified = False
13
+ for module in graph_module.modules():
14
+ if not isinstance(module, torch.fx.GraphModule):
15
+ continue
16
+ for node in module.graph.nodes:
17
+ if node.target == torch.ops.aten._assert_async.msg:
18
+ assert_async_node = node
19
+ if len(assert_async_node.users) > 0:
20
+ continue
21
+ module.graph.erase_node(assert_async_node)
22
+ # the upstream scalar_tensor <- {le, ge} <- sym_size
23
+ # linear chain of nodes of nodes is removed by the
24
+ # downstream dead code elimination
25
+ modified = True
26
+ return PassResult(graph_module, modified)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch._higher_order_ops.wrap import wrap_with_set_grad_enabled
3
+
4
+ from ..utils import (
5
+ node_inline_,
6
+ node_replace_,
7
+ nodes_filter,
8
+ nodes_first,
9
+ nodes_map,
10
+ sequential_split,
11
+ )
12
+
13
+
14
+ def _is_set_grad_enabled_node(node: torch.fx.Node):
15
+ return (
16
+ node
17
+ and node.op == "call_function"
18
+ and node.target == torch._C._set_grad_enabled
19
+ )
20
+
21
+
22
+ def _is_set_grad_enabled_sub_mod(node: torch.fx.Node, omit_if_same_with_ambient=False):
23
+ if node.op == "call_module":
24
+ assert isinstance(node.target, str)
25
+ subgm = getattr(node.graph.owning_module, node.target)
26
+ first_non_ph = nodes_first(
27
+ subgm.graph.nodes, lambda node: node.op != "placeholder"
28
+ )
29
+ if (
30
+ first_non_ph
31
+ and first_non_ph.op == "call_function"
32
+ and first_non_ph.target == torch._C._set_grad_enabled
33
+ ):
34
+ return (
35
+ first_non_ph.args[0] != torch.is_grad_enabled()
36
+ if omit_if_same_with_ambient
37
+ else True
38
+ )
39
+ return False
40
+
41
+
42
+ def _replace_with_hop(node: torch.fx.Node):
43
+ assert node.op == "call_module"
44
+ graph: torch.fx.Graph = node.graph
45
+ gm: torch.fx.GraphModule = graph.owning_module
46
+ assert isinstance(node.target, str)
47
+ sub_gm = getattr(gm, node.target)
48
+ sub_graph = sub_gm.graph
49
+ set_grad_nodes = nodes_filter(sub_graph.nodes, _is_set_grad_enabled_node)
50
+ if len(set_grad_nodes) > 0:
51
+ assert len(set_grad_nodes) == 1
52
+ set_grad_node = set_grad_nodes[0]
53
+ enable_grad_val = set_grad_node.args[0]
54
+ with graph.inserting_before(node):
55
+ get_attr_node = graph.get_attr(node.target)
56
+ output_node = next(iter(reversed(sub_gm.graph.nodes)), None)
57
+ if output_node is not None:
58
+ assert len(output_node.args) == 1
59
+ output_args = output_node.args[0]
60
+ if isinstance(output_args, (tuple, list)):
61
+ call_func_node = graph.call_function(
62
+ wrap_with_set_grad_enabled,
63
+ (enable_grad_val, get_attr_node, *node.args),
64
+ {},
65
+ )
66
+ # Create the metadata
67
+ call_func_node.meta["val"] = tuple(
68
+ arg.meta["val"] for arg in output_args
69
+ )
70
+ node_replace_(node, call_func_node, delete_old=True)
71
+
72
+ # Rename the name of getitem nodes to the actual name of its contents
73
+ # for passing verifier and better readability, also propagate metadata
74
+ for get_item_node in call_func_node.users.keys():
75
+ idx: int = get_item_node.args[1]
76
+ output_node = output_args[idx]
77
+ get_item_node._rename(output_node.name)
78
+ get_item_node.meta = output_node.meta
79
+ pass
80
+
81
+ elif isinstance(output_args, torch.fx.Node):
82
+ call_func_node = graph.create_node(
83
+ "call_function",
84
+ wrap_with_set_grad_enabled,
85
+ (enable_grad_val, get_attr_node, *node.args),
86
+ {},
87
+ output_args.name,
88
+ )
89
+ call_func_node.meta = output_args.meta
90
+ node_replace_(node, call_func_node, delete_old=True)
91
+ else:
92
+ raise NotImplementedError(
93
+ f"repalce_set_grad_with_hop_pass doesnt' support output type {type(output_args)}"
94
+ )
95
+ else:
96
+ raise NotImplementedError(
97
+ "Cannot replace a call_module with a hop if it has no output. This module will gets DCEed."
98
+ )
99
+ sub_graph.erase_node(set_grad_node)
100
+
101
+
102
+ def _remove_set_grad_and_inline(node: torch.fx.Node):
103
+ assert node.op == "call_module"
104
+ graph: torch.fx.Graph = node.graph
105
+ gm: torch.fx.GraphModule = graph.owning_module
106
+ assert isinstance(node.target, str)
107
+ sub_gm = getattr(gm, node.target)
108
+ sub_graph = sub_gm.graph
109
+ nodes_map(
110
+ sub_graph.nodes,
111
+ lambda n: sub_graph.erase_node(n) if _is_set_grad_enabled_node(n) else n,
112
+ )
113
+ node_inline_(node)
114
+
115
+
116
+ def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule):
117
+ # If there is no set_grad_enabled node, return the original graph module
118
+ need_replacing = False
119
+ for node in gm.graph.nodes:
120
+ if _is_set_grad_enabled_node(node):
121
+ need_replacing = True
122
+
123
+ if not need_replacing:
124
+ return gm
125
+
126
+ new_gm = sequential_split(gm, _is_set_grad_enabled_node)
127
+
128
+ def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
129
+ if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True):
130
+ _replace_with_hop(node)
131
+ else:
132
+ _remove_set_grad_and_inline(node)
133
+
134
+ nodes_map(
135
+ list(new_gm.graph.nodes),
136
+ lambda node: _maybe_inline_or_replace_with_hop(node)
137
+ if node.op == "call_module"
138
+ else node,
139
+ )
140
+ new_gm.graph.lint()
141
+ return new_gm
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/utils.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import math
3
+ import operator
4
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
5
+
6
+ import torch
7
+ from torch._subclasses.fake_tensor import FakeTensor
8
+
9
+ from torch.export import ExportedProgram
10
+ from torch.utils._pytree import (
11
+ _register_pytree_node,
12
+ Context,
13
+ FlattenFunc,
14
+ FromDumpableContextFn,
15
+ KeyPath,
16
+ keystr,
17
+ MappingKey,
18
+ SequenceKey,
19
+ ToDumpableContextFn,
20
+ UnflattenFunc,
21
+ )
22
+
23
+
24
+ def _check_input_constraints_for_graph(
25
+ input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints
26
+ ):
27
+ def get_keystr(key_path: KeyPath) -> str:
28
+ """For a given index into the flat_args, return a human readable string
29
+ describing how to access it, e.g. "*args["foo"][0].bar"
30
+ """
31
+ # Prefix the keypath with "*args" or "**kwargs" to make it clearer where
32
+ # the arguments come from. Ultimately we ought to serialize the
33
+ # original arg names for the best error message here.
34
+ args_kwargs_key_path = key_path[0]
35
+ assert isinstance(args_kwargs_key_path, SequenceKey)
36
+ if args_kwargs_key_path.idx == 0:
37
+ return f"*args{keystr(key_path[1:])}"
38
+ else:
39
+ kwarg_key = key_path[1]
40
+ assert isinstance(kwarg_key, MappingKey)
41
+ name = str(kwarg_key)[1:-1] # get rid of the enclosed []
42
+ return f"{name}{keystr(key_path[2:])}"
43
+
44
+ import sympy
45
+
46
+ from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
47
+ _convert_range_to_int,
48
+ )
49
+ from torch.utils._sympy.solve import try_solve
50
+
51
+ if len(flat_args_with_path) != len(input_placeholders):
52
+ raise RuntimeError(
53
+ "Unexpected number of inputs "
54
+ f"(expected {len(input_placeholders)}, got {len(flat_args_with_path)})"
55
+ )
56
+ # NOTE: export already guarantees that the same symbol is used in metadata
57
+ # for all InputDims related by equality constraints, so we can just unify
58
+ # symbols with given input dimension values to check equality constraints.
59
+ unification_map: "Dict[sympy.Symbol, Any]" = {}
60
+ for (key_path, arg), node in zip(flat_args_with_path, input_placeholders):
61
+ node_val = node.meta.get("val")
62
+ if isinstance(node_val, FakeTensor):
63
+ if not isinstance(arg, torch.Tensor):
64
+ raise RuntimeError(
65
+ f"Expected input at {get_keystr(key_path)} to be a tensor, but got {type(arg)}",
66
+ )
67
+
68
+ if len(node_val.shape) != len(arg.shape):
69
+ raise RuntimeError(
70
+ f"Unexpected number of dimensions in input at {get_keystr(key_path)}.shape "
71
+ f"(expected {node_val.shape}, got {arg.shape})"
72
+ )
73
+
74
+ for j, (arg_dim, node_dim) in enumerate(zip(arg.shape, node_val.shape)):
75
+ # TODO(avik): Assert the following property in the IR verifier:
76
+ # node_dim is either an int or a SymInt containing an int or a unary sympy.Expr
77
+ if (
78
+ isinstance(node_dim, torch.SymInt)
79
+ and len(node_dim.node.expr.free_symbols) == 1
80
+ ):
81
+ symbol = next(iter(node_dim.node.expr.free_symbols))
82
+ if symbol in unification_map:
83
+ existing_dim = node_dim.node.expr.subs(unification_map)
84
+ if arg_dim != existing_dim:
85
+ raise RuntimeError(
86
+ f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to "
87
+ f"{existing_dim}, but got {arg_dim}",
88
+ )
89
+ else:
90
+ if (
91
+ isinstance(arg_dim, torch.SymInt)
92
+ and not arg_dim.node.expr.is_number
93
+ ):
94
+ # This can happen when, say, arg is a fake tensor.
95
+ # We do not run checks on symbolic shapes of fake inputs as
96
+ # such checks can affect the shape env.
97
+ pass
98
+ else:
99
+ solution = try_solve(
100
+ sympy.Eq(node_dim.node.expr, arg_dim), symbol
101
+ )
102
+ if solution is None:
103
+ raise RuntimeError( # noqa: TRY200
104
+ f"Expected input {node.name}.shape[{j}] = {arg_dim} to be "
105
+ f"of the form {node_dim.node.expr}, where {symbol} is an integer"
106
+ )
107
+ else:
108
+ unification_map[symbol] = int(solution[1])
109
+
110
+ if node_dim.node.expr in range_constraints:
111
+ min_val, max_val = _convert_range_to_int(
112
+ range_constraints[node_dim.node.expr]
113
+ )
114
+ # NOTE: we allow dimensions to be 0/1 at runtime
115
+ if min_val > 2:
116
+ if arg_dim < min_val:
117
+ raise RuntimeError(
118
+ f"Expected input at {get_keystr(key_path)}.shape[{j}] to be >= "
119
+ f"{min_val}, but got {arg_dim}",
120
+ )
121
+ if max_val < math.inf:
122
+ if arg_dim > max_val:
123
+ raise RuntimeError(
124
+ f"Expected input at {get_keystr(key_path)}.shape[{j}] to be <= "
125
+ f"{max_val}, but got {arg_dim}",
126
+ )
127
+ else:
128
+ if arg_dim != node_dim:
129
+ raise RuntimeError(
130
+ f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to "
131
+ f"{node_dim}, but got {arg_dim}",
132
+ )
133
+ elif isinstance(node_val, (int, float, str)):
134
+ if type(arg) != type(node_val) or arg != node_val:
135
+ raise RuntimeError(
136
+ f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}",
137
+ )
138
+
139
+
140
+ def register_dataclass_as_pytree_node(
141
+ cls: Type[Any],
142
+ flatten_fn: Optional[FlattenFunc] = None,
143
+ unflatten_fn: Optional[UnflattenFunc] = None,
144
+ *,
145
+ serialized_type_name: Optional[str] = None,
146
+ to_dumpable_context: Optional[ToDumpableContextFn] = None,
147
+ from_dumpable_context: Optional[FromDumpableContextFn] = None,
148
+ return_none_fields: bool = False,
149
+ ) -> None:
150
+ assert dataclasses.is_dataclass(
151
+ cls
152
+ ), f"Only dataclasses can be registered with this function: {cls}"
153
+
154
+ def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
155
+ flattened = []
156
+ flat_names = []
157
+ none_names = []
158
+ for f in dataclasses.fields(obj):
159
+ name, val = f.name, getattr(obj, f.name)
160
+ if val is not None or return_none_fields:
161
+ flattened.append(val)
162
+ flat_names.append(name)
163
+ else:
164
+ none_names.append(name)
165
+ return flattened, [flat_names, none_names]
166
+
167
+ def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any:
168
+ flat_names, none_names = context
169
+ return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
170
+
171
+ flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn
172
+ unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn
173
+
174
+ if (to_dumpable_context is None) ^ (from_dumpable_context is None):
175
+ raise ValueError(
176
+ f"Both to_dumpable_context and from_dumpable_context for {cls} must "
177
+ "be None or registered."
178
+ )
179
+
180
+ _register_pytree_node(
181
+ cls,
182
+ flatten_fn,
183
+ unflatten_fn,
184
+ serialized_type_name=serialized_type_name,
185
+ to_dumpable_context=to_dumpable_context,
186
+ from_dumpable_context=from_dumpable_context,
187
+ )
188
+
189
+
190
+ def is_param(program: ExportedProgram, node: torch.fx.Node) -> bool:
191
+ """
192
+ Checks if the given node is a parameter within the exported program
193
+ """
194
+
195
+ return node.name in program.graph_signature.inputs_to_parameters
196
+
197
+
198
+ def get_param(
199
+ program: ExportedProgram,
200
+ node: torch.fx.Node,
201
+ ) -> Optional[torch.nn.Parameter]:
202
+ """
203
+ Returns the parameter associated with the given node in the exported program.
204
+ Returns None if the node is not a parameter within the exported program
205
+ """
206
+
207
+ if is_param(program, node):
208
+ parameter_name = program.graph_signature.inputs_to_parameters[node.name]
209
+ return program.state_dict[parameter_name]
210
+
211
+ return None
212
+
213
+
214
+ def is_buffer(program: ExportedProgram, node: torch.fx.Node) -> bool:
215
+ """
216
+ Checks if the given node is a buffer within the exported program
217
+ """
218
+
219
+ return node.name in program.graph_signature.inputs_to_buffers
220
+
221
+
222
+ def get_buffer(
223
+ program: ExportedProgram,
224
+ node: torch.fx.Node,
225
+ ) -> Optional[torch.Tensor]:
226
+ """
227
+ Returns the buffer associated with the given node in the exported program.
228
+ Returns None if the node is not a buffer within the exported program
229
+ """
230
+
231
+ if is_buffer(program, node):
232
+ buffer_name = program.graph_signature.inputs_to_buffers[node.name]
233
+ if buffer_name in program.graph_signature.non_persistent_buffers:
234
+ return program.constants[buffer_name]
235
+ else:
236
+ return program.state_dict[buffer_name]
237
+
238
+ return None
239
+
240
+
241
+ def is_lifted_tensor_constant(
242
+ program: ExportedProgram,
243
+ node: torch.fx.Node,
244
+ ) -> bool:
245
+ """
246
+ Checks if the given node is a lifted tensor constant within the exported program
247
+ """
248
+
249
+ return node.name in program.graph_signature.inputs_to_lifted_tensor_constants
250
+
251
+
252
+ def get_lifted_tensor_constant(
253
+ program: ExportedProgram,
254
+ node: torch.fx.Node,
255
+ ) -> Optional[torch.Tensor]:
256
+ """
257
+ Returns the lifted tensor constant associated with the given node in the exported program.
258
+ Returns None if the node is not a lifted tensor constant within the exported program
259
+ """
260
+
261
+ if is_lifted_tensor_constant(program, node):
262
+ lifted_tensor_name = program.graph_signature.inputs_to_lifted_tensor_constants[
263
+ node.name
264
+ ]
265
+ return program.constants[lifted_tensor_name]
266
+
267
+ return None
268
+
269
+
270
+ def sequential_split(gm: torch.fx.GraphModule, node_call_back) -> torch.fx.GraphModule:
271
+ """
272
+ Splits the graph module into multiple submodules based on the node_call_back.
273
+ The node_call_back should return True if the node is a delimiter. Delimiter will be
274
+ the first node in the next submodule.
275
+ """
276
+ from torch.fx.passes.split_module import split_module
277
+
278
+ split_map = {}
279
+ split_id = 0
280
+ for node in gm.graph.nodes:
281
+ if node_call_back(node):
282
+ split_id += 1
283
+ split_map[node] = split_id
284
+
285
+ new_gm = split_module(
286
+ gm,
287
+ gm,
288
+ lambda node: split_map[node],
289
+ keep_original_order=True,
290
+ keep_original_node_name=True,
291
+ )
292
+ # Keep the codegen from original graph module to preserve e.g. pytree info.
293
+ new_gm.graph._codegen = gm.graph._codegen
294
+ new_gm.recompile()
295
+ return new_gm
296
+
297
+
298
+ def nodes_filter(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]:
299
+ """Returns the nodes that match the node_call_back as a list."""
300
+ return [node for node in nodes if node_call_back(node)]
301
+
302
+
303
+ def nodes_first(
304
+ nodes: List[torch.fx.Node], node_call_back=None
305
+ ) -> Optional[torch.fx.Node]:
306
+ """
307
+ Returns the first node that matches the node_call_back. If no node matches, returns None.
308
+ When node_call_back is None, returns the first node in the node list.
309
+ """
310
+ ret = nodes_filter(nodes, node_call_back if node_call_back else lambda node: True)
311
+ if len(ret) > 0:
312
+ return ret[0]
313
+ return None
314
+
315
+
316
+ def nodes_count(nodes: List[torch.fx.Node], node_call_back) -> int:
317
+ """Returns the number of nodes that match the node_call_back."""
318
+ return len(nodes_filter(nodes, node_call_back))
319
+
320
+
321
+ def nodes_map(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]:
322
+ """
323
+ Sequentially visit the nodes list and invoke node_call_back on each element.
324
+ Returns the nodes list after the node_call_back is invoked on each element.
325
+ """
326
+ for node in nodes:
327
+ node_call_back(node)
328
+ return nodes
329
+
330
+
331
+ def node_replace_(
332
+ old_node: torch.fx.Node, new_node: torch.fx.Node, delete_old: bool = False
333
+ ) -> None:
334
+ """
335
+ Replace all uses of old_node with new_node.
336
+ """
337
+ old_node.replace_all_uses_with(new_node)
338
+ if delete_old:
339
+ old_node.users.clear()
340
+ old_node.graph.erase_node(old_node)
341
+
342
+
343
+ def node_inline_(call_mod_node: torch.fx.Node) -> None:
344
+ """
345
+ Inline the submodule of the given node into the parent module.
346
+ Note: we only support the case where submodule takes tensors inputs.
347
+ """
348
+ assert call_mod_node.op == "call_module"
349
+ gm = call_mod_node.graph.owning_module
350
+
351
+ assert isinstance(call_mod_node.target, str)
352
+ sub_gm = getattr(gm, call_mod_node.target)
353
+
354
+ phs = (node for node in sub_gm.graph.nodes if node.op == "placeholder")
355
+ body = (
356
+ node for node in sub_gm.graph.nodes if node.op not in ("placeholder", "output")
357
+ )
358
+ output = [node for node in sub_gm.graph.nodes if node.op == "output"]
359
+
360
+ for ph, arg in zip(phs, call_mod_node.args):
361
+ assert isinstance(arg, torch.fx.Node)
362
+ node_replace_(ph, arg, delete_old=True)
363
+
364
+ with gm.graph.inserting_before(call_mod_node):
365
+ for node in body:
366
+ new_node = gm.graph.node_copy(node)
367
+ node_replace_(node, new_node, delete_old=True)
368
+
369
+ if len(output) > 0:
370
+ assert len(output) == 1 and len(output[0].args) == 1
371
+ new_output = output[0].args[0]
372
+
373
+ if isinstance(new_output, torch.fx.Node):
374
+ node_replace_(call_mod_node, new_output, delete_old=True)
375
+ elif isinstance(new_output, (list, tuple)):
376
+ # Inline the get_item calls for the output node.
377
+ get_item_users = nodes_filter(
378
+ list(call_mod_node.users.keys()),
379
+ lambda node: node.op == "call_function"
380
+ and node.target == operator.getitem,
381
+ )
382
+ # get_item_node.args[1] is the idx referring to new_output[idx]
383
+ nodes_map(
384
+ get_item_users,
385
+ lambda get_item_node: node_replace_(
386
+ get_item_node,
387
+ new_output[get_item_node.args[1]],
388
+ delete_old=True,
389
+ ),
390
+ )
391
+ call_mod_node.graph.erase_node(call_mod_node)
392
+ else:
393
+ raise NotImplementedError(
394
+ f"Unsupported output type {type(new_output)}. Expect it to be a Node or a list/tuple of Nodes."
395
+ )
396
+ else:
397
+ call_mod_node.graph.erase_node(call_mod_node)
398
+
399
+ gm.delete_all_unused_submodules()
400
+ gm.recompile()
401
+ return gm
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/wrappers.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+
3
+ import torch
4
+ import torch._custom_ops
5
+ from torch._C import DispatchKey
6
+ from torch._higher_order_ops.strict_mode import strict_mode
7
+ from torch._higher_order_ops.utils import autograd_not_implemented
8
+ from torch._ops import HigherOrderOperator
9
+ from torch._subclasses.fake_tensor import FakeTensorMode
10
+ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
11
+ from torch.utils import _pytree as pytree
12
+
13
+
14
+ _export_tracepoint = HigherOrderOperator("_export_tracepoint")
15
+
16
+
17
+ @_export_tracepoint.py_impl(ProxyTorchDispatchMode)
18
+ def export_tracepoint_dispatch_mode(mode, *args, **kwargs):
19
+ if not mode.enable_tracing:
20
+ return _export_tracepoint(*args, **kwargs)
21
+ p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs))
22
+ proxy = mode.tracer.create_proxy(
23
+ "call_function", _export_tracepoint, p_args, p_kwargs
24
+ )
25
+ return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer)
26
+
27
+
28
+ @_export_tracepoint.py_impl(FakeTensorMode)
29
+ def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs):
30
+ with mode:
31
+ return args
32
+
33
+
34
+ @_export_tracepoint.py_functionalize_impl
35
+ def export_tracepoint_functional(ctx, *args, **kwargs):
36
+ unwrapped_args = ctx.unwrap_tensors(args)
37
+ unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
38
+
39
+ with ctx.redispatch_to_next():
40
+ out = _export_tracepoint(*unwrapped_args, **unwrapped_kwargs)
41
+ return ctx.wrap_tensors(out)
42
+
43
+
44
+ _export_tracepoint.py_impl(DispatchKey.Autograd)(
45
+ autograd_not_implemented(_export_tracepoint, deferred_error=True)
46
+ )
47
+
48
+
49
+ @_export_tracepoint.py_impl(DispatchKey.CPU)
50
+ def export_tracepoint_cpu(*args, **kwargs):
51
+ return args
52
+
53
+
54
+ def _wrap_submodule(mod, path, module_call_specs):
55
+ assert isinstance(mod, torch.nn.Module)
56
+ assert path != ""
57
+ submodule = mod
58
+ for name in path.split("."):
59
+ if not hasattr(submodule, name):
60
+ raise RuntimeError(f"Couldn't find submodule at path {path}")
61
+ submodule = getattr(submodule, name)
62
+
63
+ def update_module_call_signatures(path, in_spec, out_spec):
64
+ if path in module_call_specs:
65
+ assert module_call_specs[path]["in_spec"] == in_spec
66
+ assert module_call_specs[path]["out_spec"] == out_spec
67
+ module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec}
68
+
69
+ def check_flattened(flat_args):
70
+ for a in flat_args:
71
+ if not (isinstance(a, (torch.Tensor, str, int, float, bool)) or a is None):
72
+ raise AssertionError(
73
+ f"Only Tensors or scalars are supported as pytree flattened inputs, got: {a}"
74
+ )
75
+
76
+ def pre_hook(module, args, kwargs):
77
+ flat_args, in_spec = pytree.tree_flatten((args, kwargs))
78
+ check_flattened(flat_args)
79
+ flat_args = _export_tracepoint(*flat_args, kind="module_call_inputs", path=path)
80
+ args, kwargs = pytree.tree_unflatten(flat_args, in_spec)
81
+ return args, kwargs
82
+
83
+ def post_hook(module, args, kwargs, res):
84
+ _, in_spec = pytree.tree_flatten((args, kwargs))
85
+ flat_res, out_spec = pytree.tree_flatten(res)
86
+ check_flattened(flat_res)
87
+ flat_res = _export_tracepoint(*flat_res, kind="module_call_outputs", path=path)
88
+ update_module_call_signatures(path, in_spec, out_spec)
89
+ return pytree.tree_unflatten(flat_res, out_spec)
90
+
91
+ pre_handle = submodule.register_forward_pre_hook(pre_hook, with_kwargs=True)
92
+ post_handle = submodule.register_forward_hook(post_hook, with_kwargs=True)
93
+ return pre_handle, post_handle
94
+
95
+
96
+ @contextmanager
97
+ def _wrap_submodules(f, preserve_signature, module_call_signatures):
98
+ handles = []
99
+
100
+ try:
101
+ for path in preserve_signature:
102
+ handles.extend(_wrap_submodule(f, path, module_call_signatures))
103
+ yield
104
+ finally:
105
+ for handle in handles:
106
+ handle.remove()
107
+
108
+
109
+ def _mark_strict_experimental(cls):
110
+ def call(self, *args):
111
+ return strict_mode(self, args)
112
+
113
+ cls.__call__ = call
114
+ return cls
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/config.cpython-311.pyc ADDED
Binary file (1.18 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/device_context.cpython-311.pyc ADDED
Binary file (1.72 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-311.pyc ADDED
Binary file (1.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (358 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-311.pyc ADDED
Binary file (54.8 kB). View file