koichi12 commited on
Commit
65e568a
·
verified ·
1 Parent(s): ec4fbbc

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/error.cpython-311.pyc +0 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-311.pyc +0 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__init__.py +5 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/__init__.cpython-311.pyc +0 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/gen_example.cpython-311.pyc +0 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/logging.cpython-311.pyc +0 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/case.py +188 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-311.pyc +0 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-311.pyc +0 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-311.pyc +0 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-311.pyc +0 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-311.pyc +0 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-311.pyc +0 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-311.pyc +0 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-311.pyc +0 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-311.pyc +0 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-311.pyc +0 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-311.pyc +0 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-311.pyc +0 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-311.pyc +0 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-311.pyc +0 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-311.pyc +0 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-311.pyc +0 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-311.pyc +0 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-311.pyc +0 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-311.pyc +0 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/autograd_function.py +26 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_class_method.py +46 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py +63 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/decorator.py +26 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_assert.py +22 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py +19 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_map.py +23 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_round.py +24 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/fn_with_kwargs.py +32 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/nested_function.py +27 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/null_context_manager.py +26 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/optional_input.py +19 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/torch_sym_min.py +17 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/user_input_mutation.py +18 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/logging.py +2 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__init__.py +0 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-311.pyc +0 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-311.pyc +0 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/node_metadata.py +32 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__init__.py +1 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-311.pyc +0 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-311.pyc +0 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_sym_size_ops_pass.cpython-311.pyc +0 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-311.pyc +0 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/error.cpython-311.pyc ADDED
Binary file (2.81 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-311.pyc ADDED
Binary file (11.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (217 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/gen_example.cpython-311.pyc ADDED
Binary file (1.51 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/logging.cpython-311.pyc ADDED
Binary file (383 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/case.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import re
3
+ import string
4
+ from dataclasses import dataclass, field
5
+ from enum import Enum
6
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
7
+ from types import ModuleType
8
+
9
+ import torch
10
+
11
+ _TAGS: Dict[str, Dict[str, Any]] = {
12
+ "torch": {
13
+ "cond": {},
14
+ "dynamic-shape": {},
15
+ "escape-hatch": {},
16
+ "map": {},
17
+ "dynamic-value": {},
18
+ "operator": {},
19
+ "mutation": {},
20
+ },
21
+ "python": {
22
+ "assert": {},
23
+ "builtin": {},
24
+ "closure": {},
25
+ "context-manager": {},
26
+ "control-flow": {},
27
+ "data-structure": {},
28
+ "standard-library": {},
29
+ "object-model": {},
30
+ },
31
+ }
32
+
33
+
34
+ class SupportLevel(Enum):
35
+ """
36
+ Indicates at what stage the feature
37
+ used in the example is handled in export.
38
+ """
39
+
40
+ SUPPORTED = 1
41
+ NOT_SUPPORTED_YET = 0
42
+
43
+
44
+ class ExportArgs:
45
+ __slots__ = ("args", "kwargs")
46
+
47
+ def __init__(self, *args, **kwargs):
48
+ self.args = args
49
+ self.kwargs = kwargs
50
+
51
+
52
+ InputsType = Union[Tuple[Any, ...], ExportArgs]
53
+
54
+
55
+ def check_inputs_type(x):
56
+ if not isinstance(x, (ExportArgs, tuple)):
57
+ raise ValueError(
58
+ f"Expecting inputs type to be either a tuple, or ExportArgs, got: {type(x)}"
59
+ )
60
+
61
+
62
+ def _validate_tag(tag: str):
63
+ parts = tag.split(".")
64
+ t = _TAGS
65
+ for part in parts:
66
+ assert set(part) <= set(
67
+ string.ascii_lowercase + "-"
68
+ ), f"Tag contains invalid characters: {part}"
69
+ if part in t:
70
+ t = t[part]
71
+ else:
72
+ raise ValueError(f"Tag {tag} is not found in registered tags.")
73
+
74
+
75
+ @dataclass(frozen=True)
76
+ class ExportCase:
77
+ example_inputs: InputsType
78
+ description: str # A description of the use case.
79
+ model: torch.nn.Module
80
+ name: str
81
+ extra_inputs: Optional[InputsType] = None # For testing graph generalization.
82
+ # Tags associated with the use case. (e.g dynamic-shape, escape-hatch)
83
+ tags: Set[str] = field(default_factory=set)
84
+ support_level: SupportLevel = SupportLevel.SUPPORTED
85
+ dynamic_shapes: Optional[Dict[str, Any]] = None
86
+
87
+ def __post_init__(self):
88
+ check_inputs_type(self.example_inputs)
89
+ if self.extra_inputs is not None:
90
+ check_inputs_type(self.extra_inputs)
91
+
92
+ for tag in self.tags:
93
+ _validate_tag(tag)
94
+
95
+ if not isinstance(self.description, str) or len(self.description) == 0:
96
+ raise ValueError(f'Invalid description: "{self.description}"')
97
+
98
+
99
+ _EXAMPLE_CASES: Dict[str, ExportCase] = {}
100
+ _MODULES: Set[ModuleType] = set()
101
+ _EXAMPLE_CONFLICT_CASES: Dict[str, List[ExportCase]] = {}
102
+ _EXAMPLE_REWRITE_CASES: Dict[str, List[ExportCase]] = {}
103
+
104
+
105
+ def register_db_case(case: ExportCase) -> None:
106
+ """
107
+ Registers a user provided ExportCase into example bank.
108
+ """
109
+ if case.name in _EXAMPLE_CASES:
110
+ if case.name not in _EXAMPLE_CONFLICT_CASES:
111
+ _EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]]
112
+ _EXAMPLE_CONFLICT_CASES[case.name].append(case)
113
+ return
114
+
115
+ _EXAMPLE_CASES[case.name] = case
116
+
117
+
118
+ def to_snake_case(name):
119
+ name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
120
+ return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
121
+
122
+
123
+ def _make_export_case(m, name, configs):
124
+ if not issubclass(m, torch.nn.Module):
125
+ raise TypeError("Export case class should be a torch.nn.Module.")
126
+ m = m()
127
+
128
+ if "description" not in configs:
129
+ # Fallback to docstring if description is missing.
130
+ assert (
131
+ m.__doc__ is not None
132
+ ), f"Could not find description or docstring for export case: {m}"
133
+ configs = {**configs, "description": m.__doc__}
134
+ return ExportCase(**{**configs, "model": m, "name": name})
135
+
136
+
137
+ def export_case(**kwargs):
138
+ """
139
+ Decorator for registering a user provided case into example bank.
140
+ """
141
+
142
+ def wrapper(m):
143
+ configs = kwargs
144
+ module = inspect.getmodule(m)
145
+ if module in _MODULES:
146
+ raise RuntimeError("export_case should only be used once per example file.")
147
+
148
+ assert module is not None
149
+ _MODULES.add(module)
150
+ normalized_name = to_snake_case(m.__name__)
151
+ module_name = module.__name__.split(".")[-1]
152
+ if module_name != normalized_name:
153
+ raise RuntimeError(
154
+ f'Module name "{module.__name__}" is inconsistent with exported program '
155
+ + f'name "{m.__name__}". Please rename the module to "{normalized_name}".'
156
+ )
157
+
158
+ case = _make_export_case(m, module_name, configs)
159
+ register_db_case(case)
160
+ return case
161
+
162
+ return wrapper
163
+
164
+
165
+ def export_rewrite_case(**kwargs):
166
+ def wrapper(m):
167
+ configs = kwargs
168
+
169
+ parent = configs.pop("parent")
170
+ assert isinstance(parent, ExportCase)
171
+ key = parent.name
172
+ if key not in _EXAMPLE_REWRITE_CASES:
173
+ _EXAMPLE_REWRITE_CASES[key] = []
174
+
175
+ configs["example_inputs"] = parent.example_inputs
176
+ case = _make_export_case(m, to_snake_case(m.__name__), configs)
177
+ _EXAMPLE_REWRITE_CASES[key].append(case)
178
+ return case
179
+
180
+ return wrapper
181
+
182
+
183
+ def normalize_inputs(x: InputsType) -> ExportArgs:
184
+ if isinstance(x, tuple):
185
+ return ExportArgs(*x)
186
+
187
+ assert isinstance(x, ExportArgs)
188
+ return x
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-311.pyc ADDED
Binary file (1.85 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-311.pyc ADDED
Binary file (3.03 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-311.pyc ADDED
Binary file (3.23 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-311.pyc ADDED
Binary file (1.67 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-311.pyc ADDED
Binary file (2.35 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-311.pyc ADDED
Binary file (2.19 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-311.pyc ADDED
Binary file (1.54 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-311.pyc ADDED
Binary file (1.41 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-311.pyc ADDED
Binary file (1.69 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-311.pyc ADDED
Binary file (1.72 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-311.pyc ADDED
Binary file (1.54 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-311.pyc ADDED
Binary file (1.63 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-311.pyc ADDED
Binary file (2.04 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-311.pyc ADDED
Binary file (2.05 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-311.pyc ADDED
Binary file (1.21 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-311.pyc ADDED
Binary file (1.62 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-311.pyc ADDED
Binary file (1.62 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-311.pyc ADDED
Binary file (1.81 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-311.pyc ADDED
Binary file (1.52 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/autograd_function.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+
5
+
6
+ class MyAutogradFunction(torch.autograd.Function):
7
+ @staticmethod
8
+ def forward(ctx, x):
9
+ return x.clone()
10
+
11
+ @staticmethod
12
+ def backward(ctx, grad_output):
13
+ return grad_output + 1
14
+
15
+
16
+ @export_case(
17
+ example_inputs=(torch.randn(3, 2),),
18
+ )
19
+ class AutogradFunction(torch.nn.Module):
20
+ """
21
+ TorchDynamo does not keep track of backward() on autograd functions. We recommend to
22
+ use `allow_in_graph` to mitigate this problem.
23
+ """
24
+
25
+ def forward(self, x):
26
+ return MyAutogradFunction.apply(x)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_class_method.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+ from functorch.experimental.control_flow import cond
5
+
6
+
7
+ class MySubModule(torch.nn.Module):
8
+ def foo(self, x):
9
+ return x.cos()
10
+
11
+ def forward(self, x):
12
+ return self.foo(x)
13
+
14
+
15
+ @export_case(
16
+ example_inputs=(torch.ones(3),),
17
+ tags={
18
+ "torch.cond",
19
+ "torch.dynamic-shape",
20
+ },
21
+ )
22
+ class CondBranchClassMethod(torch.nn.Module):
23
+ """
24
+ The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
25
+ - both branches must take the same args, which must also match the branch args passed to cond.
26
+ - both branches must return a single tensor
27
+ - returned tensor must have the same tensor metadata, e.g. shape and dtype
28
+ - branch function can be free function, nested function, lambda, class methods
29
+ - branch function can not have closure variables
30
+ - no inplace mutations on inputs or global variables
31
+
32
+
33
+ This example demonstrates using class method in cond().
34
+
35
+ NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
36
+ """
37
+
38
+ def __init__(self):
39
+ super().__init__()
40
+ self.subm = MySubModule()
41
+
42
+ def bar(self, x):
43
+ return x.sin()
44
+
45
+ def forward(self, x):
46
+ return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+ from functorch.experimental.control_flow import cond
5
+
6
+
7
+ @export_case(
8
+ example_inputs=(torch.ones(6),),
9
+ tags={
10
+ "torch.cond",
11
+ "torch.dynamic-shape",
12
+ },
13
+ )
14
+ class CondBranchNonlocalVariables(torch.nn.Module):
15
+ """
16
+ The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
17
+ - both branches must take the same args, which must also match the branch args passed to cond.
18
+ - both branches must return a single tensor
19
+ - returned tensor must have the same tensor metadata, e.g. shape and dtype
20
+ - branch function can be free function, nested function, lambda, class methods
21
+ - branch function can not have closure variables
22
+ - no inplace mutations on inputs or global variables
23
+
24
+ This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions.
25
+
26
+ The code below will not work because capturing closure variables is not supported.
27
+ ```
28
+ my_tensor_var = x + 100
29
+ my_primitive_var = 3.14
30
+
31
+ def true_fn(y):
32
+ nonlocal my_tensor_var, my_primitive_var
33
+ return y + my_tensor_var + my_primitive_var
34
+
35
+ def false_fn(y):
36
+ nonlocal my_tensor_var, my_primitive_var
37
+ return y - my_tensor_var - my_primitive_var
38
+
39
+ return cond(x.shape[0] > 5, true_fn, false_fn, [x])
40
+ ```
41
+
42
+ NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
43
+ """
44
+
45
+ def __init__(self):
46
+ super().__init__()
47
+
48
+ def forward(self, x):
49
+ my_tensor_var = x + 100
50
+ my_primitive_var = 3.14
51
+
52
+ def true_fn(x, y, z):
53
+ return x + y + z
54
+
55
+ def false_fn(x, y, z):
56
+ return x - y - z
57
+
58
+ return cond(
59
+ x.shape[0] > 5,
60
+ true_fn,
61
+ false_fn,
62
+ [x, my_tensor_var, torch.tensor(my_primitive_var)],
63
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/decorator.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch
4
+
5
+ from torch._export.db.case import export_case
6
+
7
+
8
+ def test_decorator(func):
9
+ @functools.wraps(func)
10
+ def wrapper(*args, **kwargs):
11
+ return func(*args, **kwargs) + 1
12
+
13
+ return wrapper
14
+
15
+
16
+ @export_case(
17
+ example_inputs=(torch.ones(3, 2), torch.ones(3, 2)),
18
+ )
19
+ class Decorator(torch.nn.Module):
20
+ """
21
+ Decorators calls are inlined into the exported function during tracing.
22
+ """
23
+
24
+ @test_decorator
25
+ def forward(self, x, y):
26
+ return x + y
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_assert.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+
5
+
6
+ @export_case(
7
+ example_inputs=(torch.ones(3, 2),),
8
+ tags={"python.assert"},
9
+ )
10
+ class DynamicShapeAssert(torch.nn.Module):
11
+ """
12
+ A basic usage of python assertion.
13
+ """
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def forward(self, x):
18
+ # assertion with error message
19
+ assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2"
20
+ # assertion without error message
21
+ assert x.shape[0] > 1
22
+ return x
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+
5
+
6
+ @export_case(
7
+ example_inputs=(torch.ones(3, 2),),
8
+ tags={"torch.dynamic-shape"},
9
+ )
10
+ class DynamicShapeConstructor(torch.nn.Module):
11
+ """
12
+ Tensor constructors should be captured with dynamic shape inputs rather
13
+ than being baked in with static shape.
14
+ """
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ def forward(self, x):
19
+ return torch.ones(x.shape[0] * 2)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_map.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+ from functorch.experimental.control_flow import map
5
+
6
+
7
+ @export_case(
8
+ example_inputs=(torch.ones(3, 2), torch.ones(2)),
9
+ tags={"torch.dynamic-shape", "torch.map"},
10
+ )
11
+ class DynamicShapeMap(torch.nn.Module):
12
+ """
13
+ functorch map() maps a function over the first tensor dimension.
14
+ """
15
+
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ def forward(self, xs, y):
20
+ def body(x, y):
21
+ return x + y
22
+
23
+ return map(body, xs, y)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_round.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case, SupportLevel
4
+ from torch.export import Dim
5
+
6
+ x = torch.ones(3, 2)
7
+ dim0_x = Dim("dim0_x")
8
+
9
+ @export_case(
10
+ example_inputs=(x,),
11
+ tags={"torch.dynamic-shape", "python.builtin"},
12
+ support_level=SupportLevel.NOT_SUPPORTED_YET,
13
+ dynamic_shapes={"x": {0: dim0_x}},
14
+ )
15
+ class DynamicShapeRound(torch.nn.Module):
16
+ """
17
+ Calling round on dynamic shapes is not supported.
18
+ """
19
+
20
+ def __init__(self):
21
+ super().__init__()
22
+
23
+ def forward(self, x):
24
+ return x[: round(x.shape[0] / 2)]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/fn_with_kwargs.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case, ExportArgs, SupportLevel
4
+
5
+
6
+ @export_case(
7
+ example_inputs=ExportArgs(
8
+ torch.randn(4),
9
+ (torch.randn(4), torch.randn(4)),
10
+ *[torch.randn(4), torch.randn(4)],
11
+ mykw0=torch.randn(4),
12
+ input0=torch.randn(4), input1=torch.randn(4)
13
+ ),
14
+ tags={"python.data-structure"},
15
+ support_level=SupportLevel.SUPPORTED,
16
+ )
17
+ class FnWithKwargs(torch.nn.Module):
18
+ """
19
+ Keyword arguments are not supported at the moment.
20
+ """
21
+ def __init__(self):
22
+ super().__init__()
23
+
24
+ def forward(self, pos0, tuple0, *myargs, mykw0, **mykwargs):
25
+ out = pos0
26
+ for arg in tuple0:
27
+ out = out * arg
28
+ for arg in myargs:
29
+ out = out * arg
30
+ out = out * mykw0
31
+ out = out * mykwargs["input0"] * mykwargs["input1"]
32
+ return out
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/nested_function.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case
4
+
5
+
6
+ @export_case(
7
+ example_inputs=(torch.ones(3, 2), torch.ones(2)),
8
+ tags={"python.closure"},
9
+ )
10
+ class NestedFunction(torch.nn.Module):
11
+ """
12
+ Nested functions are traced through. Side effects on global captures
13
+ are not supported though.
14
+ """
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ def forward(self, a, b):
19
+ x = a + b
20
+ z = a - b
21
+
22
+ def closure(y):
23
+ nonlocal x
24
+ x += 1
25
+ return x * y + z
26
+
27
+ return closure(x)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/null_context_manager.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+
3
+ import torch
4
+
5
+ from torch._export.db.case import export_case
6
+
7
+
8
+ @export_case(
9
+ example_inputs=(torch.ones(3, 2),),
10
+ tags={"python.context-manager"},
11
+ )
12
+ class NullContextManager(torch.nn.Module):
13
+ """
14
+ Null context manager in Python will be traced out.
15
+ """
16
+
17
+ def __init__(self):
18
+ super().__init__()
19
+
20
+ def forward(self, x):
21
+ """
22
+ Null context manager in Python will be traced out.
23
+ """
24
+ ctx = contextlib.nullcontext()
25
+ with ctx:
26
+ return x.sin() + x.cos()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/optional_input.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case, SupportLevel
4
+
5
+
6
+ @export_case(
7
+ example_inputs=(torch.randn(2, 3),),
8
+ tags={"python.object-model"},
9
+ support_level=SupportLevel.NOT_SUPPORTED_YET,
10
+ )
11
+ class OptionalInput(torch.nn.Module):
12
+ """
13
+ Tracing through optional input is not supported yet
14
+ """
15
+
16
+ def forward(self, x, y=torch.ones(2, 3)):
17
+ if y is not None:
18
+ return x + y
19
+ return x
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/torch_sym_min.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case, SupportLevel
4
+
5
+
6
+ @export_case(
7
+ example_inputs=(torch.ones(3, 2),),
8
+ tags={"torch.operator"},
9
+ support_level=SupportLevel.NOT_SUPPORTED_YET,
10
+ )
11
+ class TorchSymMin(torch.nn.Module):
12
+ """
13
+ torch.sym_min operator is not supported in export.
14
+ """
15
+
16
+ def forward(self, x):
17
+ return x.sum() + torch.sym_min(x.size(0), 100)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/user_input_mutation.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch._export.db.case import export_case, SupportLevel
4
+
5
+
6
+ @export_case(
7
+ example_inputs=(torch.ones(3, 2),),
8
+ tags={"torch.mutation"},
9
+ support_level=SupportLevel.SUPPORTED,
10
+ )
11
+ class UserInputMutation(torch.nn.Module):
12
+ """
13
+ Directly mutate user input in forward
14
+ """
15
+
16
+ def forward(self, x):
17
+ x.mul_(2)
18
+ return x.cos()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/logging.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ def exportdb_error_message(case_name: str):
2
+ return ""
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (225 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-311.pyc ADDED
Binary file (2.12 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/node_metadata.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Set
2
+
3
+
4
+ NodeMetadataValue = Any
5
+
6
+
7
+ PROTECTED_KEYS: Set[str] = {
8
+ "val",
9
+ "stack_trace",
10
+ "nn_module_stack",
11
+ "debug_handle",
12
+ "tensor_meta",
13
+ }
14
+
15
+
16
+ class NodeMetadata:
17
+ def __init__(self, data: Dict[str, Any]) -> None:
18
+ self.data: Dict[str, Any] = data.copy()
19
+
20
+ def __getitem__(self, key: str) -> NodeMetadataValue:
21
+ return self.data[key]
22
+
23
+ def __setitem__(self, key: str, value: NodeMetadataValue) -> NodeMetadataValue:
24
+ if key in PROTECTED_KEYS:
25
+ raise RuntimeError(f"Could not override node key: {key}")
26
+ self.data[key] = value
27
+
28
+ def __contains__(self, key: str) -> bool:
29
+ return key in self.data
30
+
31
+ def copy(self) -> "NodeMetadata":
32
+ return NodeMetadata(self.data.copy())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .replace_view_ops_with_view_copy_ops_pass import ReplaceViewOpsWithViewCopyOpsPass
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-311.pyc ADDED
Binary file (4.39 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-311.pyc ADDED
Binary file (5.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_sym_size_ops_pass.cpython-311.pyc ADDED
Binary file (1.45 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-311.pyc ADDED
Binary file (4.31 kB). View file