koichi12 commited on
Commit
f681997
·
verified ·
1 Parent(s): 5dbc224

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 +3 -0
  3. .venv/lib/python3.11/site-packages/torch/_export/__init__.py +317 -0
  4. .venv/lib/python3.11/site-packages/torch/_export/db/__init__.py +5 -0
  5. .venv/lib/python3.11/site-packages/torch/_export/db/case.py +174 -0
  6. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__init__.py +61 -0
  7. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/unsupported_operator.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/torch/_export/db/examples/assume_constant_result.py +20 -0
  34. .venv/lib/python3.11/site-packages/torch/_export/db/examples/autograd_function.py +23 -0
  35. .venv/lib/python3.11/site-packages/torch/_export/db/examples/class_method.py +22 -0
  36. .venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_class_method.py +44 -0
  37. .venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nested_function.py +41 -0
  38. .venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py +59 -0
  39. .venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_closed_over_variable.py +22 -0
  40. .venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_operands.py +36 -0
  41. .venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_predicate.py +25 -0
  42. .venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_size_example.py +25 -0
  43. .venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_value_example.py +28 -0
  44. .venv/lib/python3.11/site-packages/torch/_export/db/examples/decorator.py +23 -0
  45. .venv/lib/python3.11/site-packages/torch/_export/db/examples/dictionary.py +17 -0
  46. .venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_assert.py +18 -0
  47. .venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py +15 -0
  48. .venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py +19 -0
  49. .venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_map.py +19 -0
  50. .venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_round.py +21 -0
.gitattributes CHANGED
@@ -125,3 +125,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
125
  .venv/lib/python3.11/site-packages/opencv_python_headless.libs/libopenblas-r0-f650aae0.3.3.so filter=lfs diff=lfs merge=lfs -text
126
  .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 filter=lfs diff=lfs merge=lfs -text
127
  .venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so filter=lfs diff=lfs merge=lfs -text
 
 
125
  .venv/lib/python3.11/site-packages/opencv_python_headless.libs/libopenblas-r0-f650aae0.3.3.so filter=lfs diff=lfs merge=lfs -text
126
  .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 filter=lfs diff=lfs merge=lfs -text
127
  .venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so filter=lfs diff=lfs merge=lfs -text
128
+ .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:242b9dba953ae2e4878d66032624135a9118a1616ca24588ed586d4bcc475c69
3
+ size 108421928
.venv/lib/python3.11/site-packages/torch/_export/__init__.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import copy
3
+ import dataclasses
4
+ import functools
5
+ import io
6
+ import json
7
+ import logging
8
+ import os
9
+ import re
10
+ import sys
11
+ import types
12
+ import warnings
13
+ import weakref
14
+ import zipfile
15
+ from collections import OrderedDict
16
+ from contextlib import contextmanager
17
+ from functools import lru_cache
18
+
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+ from unittest.mock import patch
21
+
22
+ import torch
23
+ import torch.fx
24
+ import torch.utils._pytree as pytree
25
+
26
+ from torch._dispatch.python import enable_python_dispatcher
27
+ from torch._utils_internal import log_export_usage
28
+ from torch.export._tree_utils import reorder_kwargs
29
+ from torch.export.graph_signature import (
30
+ ArgumentSpec,
31
+ ConstantArgument,
32
+ ExportGraphSignature,
33
+ InputKind,
34
+ InputSpec,
35
+ OutputKind,
36
+ OutputSpec,
37
+ SymIntArgument,
38
+ TensorArgument,
39
+ )
40
+ from torch.fx import traceback as fx_traceback
41
+ from torch.fx._compatibility import compatibility
42
+ from torch.fx.experimental.proxy_tensor import make_fx
43
+ from torch._subclasses.fake_tensor import unset_fake_temporarily
44
+ from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
45
+
46
+ from .wrappers import _wrap_submodules
47
+
48
+ log = logging.getLogger(__name__)
49
+
50
+ @dataclasses.dataclass
51
+ class ExportDynamoConfig:
52
+ """
53
+ Manage Export-specific configurations of Dynamo.
54
+ """
55
+ allow_rnn: bool = True
56
+
57
+
58
+ # We only want to print this once to avoid flooding logs in workflows where capture_pre_autograd_graph
59
+ # is called multiple times.
60
+ @lru_cache
61
+ def capture_pre_autograd_graph_warning():
62
+ from torch._inductor import config
63
+
64
+ log.warning("+============================+")
65
+ log.warning("| !!! WARNING !!! |")
66
+ log.warning("+============================+")
67
+ log.warning("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.")
68
+ log.warning("Please switch to use torch.export.export_for_training instead.")
69
+ if config.is_fbcode():
70
+ log.warning("Unless the unittest is in the blocklist, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.") # noqa: B950
71
+
72
+
73
+ @compatibility(is_backward_compatible=False)
74
+ def capture_pre_autograd_graph(
75
+ f: torch.nn.Module,
76
+ args: Tuple[Any],
77
+ kwargs: Optional[Dict[str, Any]] = None,
78
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
79
+ ) -> torch.nn.Module:
80
+ """
81
+ A helper function that is intended to trace a module before any pre-autograd
82
+ decomposition is run. The produced module will be "non-functional" and
83
+ composed of aten operators. Later this API will be deleted in favor of more general
84
+ torch.export API.
85
+
86
+ Args:
87
+ f: nn.Module to be traced
88
+
89
+ args: example positional inputs.
90
+
91
+ kwargs: optional example keyword inputs.
92
+
93
+ dynamic_shapes: Should either be:
94
+ 1) a dict from argument names of ``f`` to their dynamic shape specifications,
95
+ 2) a tuple that specifies dynamic shape specifications for each input in original order.
96
+ If you are specifying dynamism on keyword args, you will need to pass them in the order that
97
+ is defined in the original function signature.
98
+
99
+ The dynamic shape of a tensor argument can be specified as either
100
+ (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
101
+ not required to include static dimension indices in this dict, but when they are,
102
+ they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
103
+ where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
104
+ are denoted by None. Arguments that are dicts or tuples / lists of tensors are
105
+ recursively specified by using mappings or sequences of contained specifications.
106
+
107
+ Returns:
108
+ An nn.Module containing the traced method.
109
+
110
+ """
111
+ from torch.export._trace import _extract_fake_inputs, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps
112
+ from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
113
+ from torch._export.non_strict_utils import make_constraints
114
+ from torch._subclasses.functional_tensor import FunctionalTensor
115
+ from torch.export._unlift import _create_stateful_graph_module
116
+ from torch.export.dynamic_shapes import _combine_args
117
+
118
+ capture_pre_autograd_graph_warning()
119
+
120
+ if sys.platform == "win32":
121
+ raise RuntimeError("capture_pre_autograd_graph not yet supported on Windows")
122
+
123
+ assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance."
124
+
125
+ if kwargs is None:
126
+ kwargs = {}
127
+
128
+ if capture_pre_autograd_graph_using_training_ir():
129
+ @lru_cache
130
+ def print_export_warning():
131
+ log.warning("Using torch.export.export_for_training(...,strict=True)")
132
+ print_export_warning()
133
+ module = torch.export.export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module()
134
+ else:
135
+ log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
136
+
137
+ # Do not decompose dropout for exported models, because in eval mode the dropout
138
+ # op disappears from the graph, which makes it difficult to switch to train mode.
139
+ # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
140
+ decomp_table = {
141
+ op: op.decompose
142
+ for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
143
+ if op != torch.ops.aten.dropout.default
144
+ }
145
+ with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():
146
+ m = torch._dynamo.export(
147
+ f,
148
+ dynamic_shapes=dynamic_shapes,
149
+ assume_static_by_default=True,
150
+ tracing_mode="symbolic",
151
+ decomposition_table=decomp_table,
152
+ pre_dispatch=True,
153
+ aten_graph=True,
154
+ _log_export_usage=False,
155
+ )(
156
+ *args,
157
+ **kwargs,
158
+ )[0]
159
+
160
+ _, _, fake_mode = _extract_fake_inputs(m, args, kwargs)
161
+
162
+ m.meta["inline_constraints"] = {
163
+ k: v
164
+ for k, v in fake_mode.shape_env.var_to_range.items()
165
+ if re.match(r"^[if]\d+$", str(k))
166
+ }
167
+
168
+ if isinstance(f, torch.nn.Module):
169
+ from torch.export._trace import _restore_state_dict
170
+ _restore_state_dict(f, m)
171
+
172
+ flat_args, _ = pytree.tree_flatten((args, kwargs or {}))
173
+ combined_args = _combine_args(f, args, kwargs)
174
+ range_constraints = make_constraints(
175
+ fake_mode,
176
+ m,
177
+ combined_args,
178
+ dynamic_shapes,
179
+ 0,
180
+ )
181
+
182
+ module = _create_stateful_graph_module(
183
+ m,
184
+ range_constraints=range_constraints,
185
+ )
186
+
187
+ error_message = \
188
+ """
189
+ Calling train() or eval() is not supported for exported models.
190
+ Alternatively, you may override these methods to do custom user behavior as follows:
191
+
192
+ def _my_train(self, mode: bool = True):
193
+ ...
194
+
195
+ def _my_eval(self):
196
+ ...
197
+
198
+ model.train = types.MethodType(_my_train, model)
199
+ model.eval = types.MethodType(_my_eval, model)
200
+ """
201
+
202
+ def _train(self, mode: bool = True):
203
+ raise NotImplementedError(error_message)
204
+
205
+ def _eval(self, mode: bool = True):
206
+ raise NotImplementedError(error_message)
207
+
208
+ module.train = types.MethodType(_train, module) # type: ignore[method-assign]
209
+ module.eval = types.MethodType(_eval, module) # type: ignore[method-assign]
210
+
211
+ # Remove Proxy because they cannot be deepcopied or pickled.
212
+ if hasattr(module, "_buffers"):
213
+ torch._export.utils.remove_proxy_from_state_dict(
214
+ module._buffers, in_place=True
215
+ )
216
+ return module
217
+
218
+
219
+ def aot_compile(
220
+ f: Callable,
221
+ args: Tuple[Any],
222
+ kwargs: Optional[Dict[str, Any]] = None,
223
+ *,
224
+ dynamic_shapes: Optional[Dict[str, Any]] = None,
225
+ options: Optional[Dict[str, Any]] = None,
226
+ remove_runtime_assertions: bool = False,
227
+ disable_constraint_solver: bool = False,
228
+ same_signature: bool = True,
229
+ ) -> str:
230
+ """
231
+ Note: this function is not stable yet
232
+
233
+ Traces either an nn.Module's forward function or just a callable with PyTorch
234
+ operations inside, generates executable cpp code from the program, and returns
235
+ the path to the generated shared library
236
+
237
+ Args:
238
+ f: the `nn.Module` or callable to trace.
239
+
240
+ args: example positional inputs.
241
+
242
+ kwargs: optional example keyword inputs.
243
+
244
+ dynamic_shapes: Should either be:
245
+ 1) a dict from argument names of ``f`` to their dynamic shape specifications,
246
+ 2) a tuple that specifies dynamic shape specifications for each input in original order.
247
+ If you are specifying dynamism on keyword args, you will need to pass them in the order that
248
+ is defined in the original function signature.
249
+
250
+ The dynamic shape of a tensor argument can be specified as either
251
+ (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
252
+ not required to include static dimension indices in this dict, but when they are,
253
+ they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
254
+ where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
255
+ are denoted by None. Arguments that are dicts or tuples / lists of tensors are
256
+ recursively specified by using mappings or sequences of contained specifications.
257
+
258
+ options: A dictionary of options to control inductor
259
+
260
+ disable_constraint_solver: Whether the dim constraint solver must be disabled.
261
+
262
+ Returns:
263
+ Path to the generated shared library
264
+ """
265
+ from torch.export._trace import _export_to_torch_ir
266
+ from torch._inductor.decomposition import select_decomp_table
267
+ from torch._inductor import config
268
+
269
+ if config.is_predispatch:
270
+ gm = torch.export._trace._export(f, args, kwargs, dynamic_shapes, pre_dispatch=True).module()
271
+ else:
272
+ # We want to export to Torch IR here to utilize the pre_grad passes in
273
+ # inductor, which run on Torch IR.
274
+ gm = _export_to_torch_ir(
275
+ f,
276
+ args,
277
+ kwargs,
278
+ dynamic_shapes,
279
+ disable_constraint_solver=disable_constraint_solver,
280
+ same_signature=same_signature,
281
+ # Disabling this flag, because instead we can rely on the mapping
282
+ # dynamo_flat_name_to_original_fqn which is coming from Dynamo.
283
+ restore_fqn=False,
284
+ )
285
+
286
+ with torch.no_grad():
287
+ so_path = torch._inductor.aot_compile(gm, args, kwargs, options=options) # type: ignore[arg-type]
288
+
289
+ return so_path
290
+
291
+ def aot_load(so_path: str, device: str) -> Callable:
292
+ """
293
+ Loads a shared library generated by aot_compile and returns a callable
294
+
295
+ Args:
296
+ so_path: Path to the shared library
297
+
298
+ Returns:
299
+ A callable
300
+ """
301
+ if device == "cpu":
302
+ runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg]
303
+ elif device == "cuda" or device.startswith("cuda:"):
304
+ runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg]
305
+ else:
306
+ raise RuntimeError("Unsupported device " + device)
307
+
308
+ def optimized(*args, **kwargs):
309
+ call_spec = runner.get_call_spec() # type: ignore[attr-defined]
310
+ in_spec = pytree.treespec_loads(call_spec[0])
311
+ out_spec = pytree.treespec_loads(call_spec[1])
312
+ flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
313
+ flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
314
+ flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined]
315
+ return pytree.tree_unflatten(flat_outputs, out_spec)
316
+
317
+ return optimized
.venv/lib/python3.11/site-packages/torch/_export/db/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
.venv/lib/python3.11/site-packages/torch/_export/db/case.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import inspect
3
+ import re
4
+ import string
5
+ from dataclasses import dataclass, field
6
+ from enum import Enum
7
+ from typing import Any, Dict, List, Optional, Set, Tuple
8
+ from types import ModuleType
9
+
10
+ import torch
11
+
12
+ _TAGS: Dict[str, Dict[str, Any]] = {
13
+ "torch": {
14
+ "cond": {},
15
+ "dynamic-shape": {},
16
+ "escape-hatch": {},
17
+ "map": {},
18
+ "dynamic-value": {},
19
+ "operator": {},
20
+ "mutation": {},
21
+ },
22
+ "python": {
23
+ "assert": {},
24
+ "builtin": {},
25
+ "closure": {},
26
+ "context-manager": {},
27
+ "control-flow": {},
28
+ "data-structure": {},
29
+ "standard-library": {},
30
+ "object-model": {},
31
+ },
32
+ }
33
+
34
+
35
+ class SupportLevel(Enum):
36
+ """
37
+ Indicates at what stage the feature
38
+ used in the example is handled in export.
39
+ """
40
+
41
+ SUPPORTED = 1
42
+ NOT_SUPPORTED_YET = 0
43
+
44
+
45
+ ArgsType = Tuple[Any, ...]
46
+
47
+
48
+ def check_inputs_type(args, kwargs):
49
+ if not isinstance(args, tuple):
50
+ raise ValueError(
51
+ f"Expecting args type to be a tuple, got: {type(args)}"
52
+ )
53
+ if not isinstance(kwargs, dict):
54
+ raise ValueError(
55
+ f"Expecting kwargs type to be a dict, got: {type(kwargs)}"
56
+ )
57
+ for key in kwargs:
58
+ if not isinstance(key, str):
59
+ raise ValueError(
60
+ f"Expecting kwargs keys to be a string, got: {type(key)}"
61
+ )
62
+
63
+ def _validate_tag(tag: str):
64
+ parts = tag.split(".")
65
+ t = _TAGS
66
+ for part in parts:
67
+ assert set(part) <= set(
68
+ string.ascii_lowercase + "-"
69
+ ), f"Tag contains invalid characters: {part}"
70
+ if part in t:
71
+ t = t[part]
72
+ else:
73
+ raise ValueError(f"Tag {tag} is not found in registered tags.")
74
+
75
+
76
+ @dataclass(frozen=True)
77
+ class ExportCase:
78
+ example_args: ArgsType
79
+ description: str # A description of the use case.
80
+ model: torch.nn.Module
81
+ name: str
82
+ example_kwargs: Dict[str, Any] = field(default_factory=dict)
83
+ extra_args: Optional[ArgsType] = None # For testing graph generalization.
84
+ # Tags associated with the use case. (e.g dynamic-shape, escape-hatch)
85
+ tags: Set[str] = field(default_factory=set)
86
+ support_level: SupportLevel = SupportLevel.SUPPORTED
87
+ dynamic_shapes: Optional[Dict[str, Any]] = None
88
+
89
+ def __post_init__(self):
90
+ check_inputs_type(self.example_args, self.example_kwargs)
91
+ if self.extra_args is not None:
92
+ check_inputs_type(self.extra_args, {})
93
+
94
+ for tag in self.tags:
95
+ _validate_tag(tag)
96
+
97
+ if not isinstance(self.description, str) or len(self.description) == 0:
98
+ raise ValueError(f'Invalid description: "{self.description}"')
99
+
100
+
101
+ _EXAMPLE_CASES: Dict[str, ExportCase] = {}
102
+ _MODULES: Set[ModuleType] = set()
103
+ _EXAMPLE_CONFLICT_CASES: Dict[str, List[ExportCase]] = {}
104
+ _EXAMPLE_REWRITE_CASES: Dict[str, List[ExportCase]] = {}
105
+
106
+
107
+ def register_db_case(case: ExportCase) -> None:
108
+ """
109
+ Registers a user provided ExportCase into example bank.
110
+ """
111
+ if case.name in _EXAMPLE_CASES:
112
+ if case.name not in _EXAMPLE_CONFLICT_CASES:
113
+ _EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]]
114
+ _EXAMPLE_CONFLICT_CASES[case.name].append(case)
115
+ return
116
+
117
+ _EXAMPLE_CASES[case.name] = case
118
+
119
+
120
+ def to_snake_case(name):
121
+ name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
122
+ return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
123
+
124
+
125
+ def _make_export_case(m, name, configs):
126
+ if not isinstance(m, torch.nn.Module):
127
+ raise TypeError("Export case class should be a torch.nn.Module.")
128
+
129
+ if "description" not in configs:
130
+ # Fallback to docstring if description is missing.
131
+ assert (
132
+ m.__doc__ is not None
133
+ ), f"Could not find description or docstring for export case: {m}"
134
+ configs = {**configs, "description": m.__doc__}
135
+ return ExportCase(**{**configs, "model": m, "name": name})
136
+
137
+
138
+ def export_case(**kwargs):
139
+ """
140
+ Decorator for registering a user provided case into example bank.
141
+ """
142
+
143
+ def wrapper(m):
144
+ configs = kwargs
145
+ module = inspect.getmodule(m)
146
+ if module in _MODULES:
147
+ raise RuntimeError("export_case should only be used once per example file.")
148
+
149
+ assert module is not None
150
+ _MODULES.add(module)
151
+ module_name = module.__name__.split(".")[-1]
152
+ case = _make_export_case(m, module_name, configs)
153
+ register_db_case(case)
154
+ return case
155
+
156
+ return wrapper
157
+
158
+
159
+ def export_rewrite_case(**kwargs):
160
+ def wrapper(m):
161
+ configs = kwargs
162
+
163
+ parent = configs.pop("parent")
164
+ assert isinstance(parent, ExportCase)
165
+ key = parent.name
166
+ if key not in _EXAMPLE_REWRITE_CASES:
167
+ _EXAMPLE_REWRITE_CASES[key] = []
168
+
169
+ configs["example_args"] = parent.example_args
170
+ case = _make_export_case(m, to_snake_case(m.__name__), configs)
171
+ _EXAMPLE_REWRITE_CASES[key].append(case)
172
+ return case
173
+
174
+ return wrapper
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__init__.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import dataclasses
3
+ import glob
4
+ import inspect
5
+ from os.path import basename, dirname, isfile, join
6
+
7
+ import torch
8
+ from torch._export.db.case import (
9
+ _EXAMPLE_CASES,
10
+ _EXAMPLE_CONFLICT_CASES,
11
+ _EXAMPLE_REWRITE_CASES,
12
+ SupportLevel,
13
+ export_case,
14
+ ExportCase,
15
+ )
16
+
17
+
18
+ def _collect_examples():
19
+ case_names = glob.glob(join(dirname(__file__), "*.py"))
20
+ case_names = [
21
+ basename(f)[:-3] for f in case_names if isfile(f) and not f.endswith("__init__.py")
22
+ ]
23
+
24
+ case_fields = {f.name for f in dataclasses.fields(ExportCase)}
25
+ for case_name in case_names:
26
+ case = __import__(case_name, globals(), locals(), [], 1)
27
+ variables = [name for name in dir(case) if name in case_fields]
28
+ export_case(**{v: getattr(case, v) for v in variables})(case.model)
29
+
30
+ _collect_examples()
31
+
32
+ def all_examples():
33
+ return _EXAMPLE_CASES
34
+
35
+
36
+ if len(_EXAMPLE_CONFLICT_CASES) > 0:
37
+
38
+ def get_name(case):
39
+ model = case.model
40
+ if isinstance(model, torch.nn.Module):
41
+ model = type(model)
42
+ return model.__name__
43
+
44
+ msg = "Error on conflict export case name.\n"
45
+ for case_name, cases in _EXAMPLE_CONFLICT_CASES.items():
46
+ msg += f"Case name {case_name} is associated with multiple cases:\n "
47
+ msg += f"[{','.join(map(get_name, cases))}]\n"
48
+
49
+ raise RuntimeError(msg)
50
+
51
+
52
+ def filter_examples_by_support_level(support_level: SupportLevel):
53
+ return {
54
+ key: val
55
+ for key, val in all_examples().items()
56
+ if val.support_level == support_level
57
+ }
58
+
59
+
60
+ def get_rewrite_cases(case):
61
+ return _EXAMPLE_REWRITE_CASES.get(case.name, [])
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-311.pyc ADDED
Binary file (1.48 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-311.pyc ADDED
Binary file (1.72 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-311.pyc ADDED
Binary file (1.79 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-311.pyc ADDED
Binary file (2.96 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-311.pyc ADDED
Binary file (2.6 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-311.pyc ADDED
Binary file (2.79 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-311.pyc ADDED
Binary file (1.54 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-311.pyc ADDED
Binary file (1.92 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-311.pyc ADDED
Binary file (1.76 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-311.pyc ADDED
Binary file (1.43 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-311.pyc ADDED
Binary file (1.42 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-311.pyc ADDED
Binary file (1.06 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-311.pyc ADDED
Binary file (1.11 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-311.pyc ADDED
Binary file (1.1 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-311.pyc ADDED
Binary file (1.29 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-311.pyc ADDED
Binary file (1.2 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-311.pyc ADDED
Binary file (1.52 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-311.pyc ADDED
Binary file (1.22 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-311.pyc ADDED
Binary file (1.36 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-311.pyc ADDED
Binary file (1.3 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-311.pyc ADDED
Binary file (1.38 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-311.pyc ADDED
Binary file (1.12 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-311.pyc ADDED
Binary file (1.7 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-311.pyc ADDED
Binary file (1.04 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/unsupported_operator.cpython-311.pyc ADDED
Binary file (1.21 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-311.pyc ADDED
Binary file (1 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/assume_constant_result.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ import torch._dynamo as torchdynamo
4
+
5
+
6
+ class AssumeConstantResult(torch.nn.Module):
7
+ """
8
+ Applying `assume_constant_result` decorator to burn make non-tracable code as constant.
9
+ """
10
+
11
+ @torchdynamo.assume_constant_result
12
+ def get_item(self, y):
13
+ return y.int().item()
14
+
15
+ def forward(self, x, y):
16
+ return x[: self.get_item(y)]
17
+
18
+ example_args = (torch.randn(3, 2), torch.tensor(4))
19
+ tags = {"torch.escape-hatch"}
20
+ model = AssumeConstantResult()
.venv/lib/python3.11/site-packages/torch/_export/db/examples/autograd_function.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+
4
+ class MyAutogradFunction(torch.autograd.Function):
5
+ @staticmethod
6
+ def forward(ctx, x):
7
+ return x.clone()
8
+
9
+ @staticmethod
10
+ def backward(ctx, grad_output):
11
+ return grad_output + 1
12
+
13
+ class AutogradFunction(torch.nn.Module):
14
+ """
15
+ TorchDynamo does not keep track of backward() on autograd functions. We recommend to
16
+ use `allow_in_graph` to mitigate this problem.
17
+ """
18
+
19
+ def forward(self, x):
20
+ return MyAutogradFunction.apply(x)
21
+
22
+ example_args = (torch.randn(3, 2),)
23
+ model = AutogradFunction()
.venv/lib/python3.11/site-packages/torch/_export/db/examples/class_method.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+
4
+ class ClassMethod(torch.nn.Module):
5
+ """
6
+ Class methods are inlined during tracing.
7
+ """
8
+
9
+ @classmethod
10
+ def method(cls, x):
11
+ return x + 1
12
+
13
+ def __init__(self) -> None:
14
+ super().__init__()
15
+ self.linear = torch.nn.Linear(4, 2)
16
+
17
+ def forward(self, x):
18
+ x = self.linear(x)
19
+ return self.method(x) * self.__class__.method(x) * type(self).method(x)
20
+
21
+ example_args = (torch.randn(3, 4),)
22
+ model = ClassMethod()
.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_class_method.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+
4
+ from functorch.experimental.control_flow import cond
5
+
6
+ class MySubModule(torch.nn.Module):
7
+ def foo(self, x):
8
+ return x.cos()
9
+
10
+ def forward(self, x):
11
+ return self.foo(x)
12
+
13
+ class CondBranchClassMethod(torch.nn.Module):
14
+ """
15
+ The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
16
+ - both branches must take the same args, which must also match the branch args passed to cond.
17
+ - both branches must return a single tensor
18
+ - returned tensor must have the same tensor metadata, e.g. shape and dtype
19
+ - branch function can be free function, nested function, lambda, class methods
20
+ - branch function can not have closure variables
21
+ - no inplace mutations on inputs or global variables
22
+
23
+
24
+ This example demonstrates using class method in cond().
25
+
26
+ NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
27
+ """
28
+
29
+ def __init__(self) -> None:
30
+ super().__init__()
31
+ self.subm = MySubModule()
32
+
33
+ def bar(self, x):
34
+ return x.sin()
35
+
36
+ def forward(self, x):
37
+ return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
38
+
39
+ example_args = (torch.randn(3),)
40
+ tags = {
41
+ "torch.cond",
42
+ "torch.dynamic-shape",
43
+ }
44
+ model = CondBranchClassMethod()
.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nested_function.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+
4
+ from functorch.experimental.control_flow import cond
5
+
6
+ class CondBranchNestedFunction(torch.nn.Module):
7
+ """
8
+ The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
9
+ - both branches must take the same args, which must also match the branch args passed to cond.
10
+ - both branches must return a single tensor
11
+ - returned tensor must have the same tensor metadata, e.g. shape and dtype
12
+ - branch function can be free function, nested function, lambda, class methods
13
+ - branch function can not have closure variables
14
+ - no inplace mutations on inputs or global variables
15
+
16
+ This example demonstrates using nested function in cond().
17
+
18
+ NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
19
+ """
20
+
21
+ def forward(self, x):
22
+ def true_fn(x):
23
+ def inner_true_fn(y):
24
+ return x + y
25
+
26
+ return inner_true_fn(x)
27
+
28
+ def false_fn(x):
29
+ def inner_false_fn(y):
30
+ return x - y
31
+
32
+ return inner_false_fn(x)
33
+
34
+ return cond(x.shape[0] < 10, true_fn, false_fn, [x])
35
+
36
+ example_args = (torch.randn(3),)
37
+ tags = {
38
+ "torch.cond",
39
+ "torch.dynamic-shape",
40
+ }
41
+ model = CondBranchNestedFunction()
.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+
4
+ from functorch.experimental.control_flow import cond
5
+
6
+ class CondBranchNonlocalVariables(torch.nn.Module):
7
+ """
8
+ The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
9
+ - both branches must take the same args, which must also match the branch args passed to cond.
10
+ - both branches must return a single tensor
11
+ - returned tensor must have the same tensor metadata, e.g. shape and dtype
12
+ - branch function can be free function, nested function, lambda, class methods
13
+ - branch function can not have closure variables
14
+ - no inplace mutations on inputs or global variables
15
+
16
+ This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions.
17
+
18
+ The code below will not work because capturing closure variables is not supported.
19
+ ```
20
+ my_tensor_var = x + 100
21
+ my_primitive_var = 3.14
22
+
23
+ def true_fn(y):
24
+ nonlocal my_tensor_var, my_primitive_var
25
+ return y + my_tensor_var + my_primitive_var
26
+
27
+ def false_fn(y):
28
+ nonlocal my_tensor_var, my_primitive_var
29
+ return y - my_tensor_var - my_primitive_var
30
+
31
+ return cond(x.shape[0] > 5, true_fn, false_fn, [x])
32
+ ```
33
+
34
+ NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
35
+ """
36
+
37
+ def forward(self, x):
38
+ my_tensor_var = x + 100
39
+ my_primitive_var = 3.14
40
+
41
+ def true_fn(x, y, z):
42
+ return x + y + z
43
+
44
+ def false_fn(x, y, z):
45
+ return x - y - z
46
+
47
+ return cond(
48
+ x.shape[0] > 5,
49
+ true_fn,
50
+ false_fn,
51
+ [x, my_tensor_var, torch.tensor(my_primitive_var)],
52
+ )
53
+
54
+ example_args = (torch.randn(6),)
55
+ tags = {
56
+ "torch.cond",
57
+ "torch.dynamic-shape",
58
+ }
59
+ model = CondBranchNonlocalVariables()
.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_closed_over_variable.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+
4
+ from functorch.experimental.control_flow import cond
5
+
6
+ class CondClosedOverVariable(torch.nn.Module):
7
+ """
8
+ torch.cond() supports branches closed over arbitrary variables.
9
+ """
10
+
11
+ def forward(self, pred, x):
12
+ def true_fn(val):
13
+ return x * 2
14
+
15
+ def false_fn(val):
16
+ return x - 2
17
+
18
+ return cond(pred, true_fn, false_fn, [x + 1])
19
+
20
+ example_args = (torch.tensor(True), torch.randn(3, 2))
21
+ tags = {"torch.cond", "python.closure"}
22
+ model = CondClosedOverVariable()
.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_operands.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+
4
+ from torch.export import Dim
5
+ from functorch.experimental.control_flow import cond
6
+
7
+ x = torch.randn(3, 2)
8
+ y = torch.randn(2)
9
+ dim0_x = Dim("dim0_x")
10
+
11
+ class CondOperands(torch.nn.Module):
12
+ """
13
+ The operands passed to cond() must be:
14
+ - a list of tensors
15
+ - match arguments of `true_fn` and `false_fn`
16
+
17
+ NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
18
+ """
19
+
20
+ def forward(self, x, y):
21
+ def true_fn(x, y):
22
+ return x + y
23
+
24
+ def false_fn(x, y):
25
+ return x - y
26
+
27
+ return cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
28
+
29
+ example_args = (x, y)
30
+ tags = {
31
+ "torch.cond",
32
+ "torch.dynamic-shape",
33
+ }
34
+ extra_inputs = (torch.randn(2, 2), torch.randn(2))
35
+ dynamic_shapes = {"x": {0: dim0_x}, "y": None}
36
+ model = CondOperands()
.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_predicate.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+
4
+ from functorch.experimental.control_flow import cond
5
+
6
+ class CondPredicate(torch.nn.Module):
7
+ """
8
+ The conditional statement (aka predicate) passed to cond() must be one of the following:
9
+ - torch.Tensor with a single element
10
+ - boolean expression
11
+
12
+ NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
13
+ """
14
+
15
+ def forward(self, x):
16
+ pred = x.dim() > 2 and x.shape[2] > 10
17
+
18
+ return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
19
+
20
+ example_args = (torch.randn(6, 4, 3),)
21
+ tags = {
22
+ "torch.cond",
23
+ "torch.dynamic-shape",
24
+ }
25
+ model = CondPredicate()
.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_size_example.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+
4
+
5
+ class ConstrainAsSizeExample(torch.nn.Module):
6
+ """
7
+ If the value is not known at tracing time, you can provide hint so that we
8
+ can trace further. Please look at torch._check and torch._check_is_size APIs.
9
+ torch._check_is_size is used for values that NEED to be used for constructing
10
+ tensor.
11
+ """
12
+
13
+ def forward(self, x):
14
+ a = x.item()
15
+ torch._check_is_size(a)
16
+ torch._check(a <= 5)
17
+ return torch.zeros((a, 5))
18
+
19
+
20
+ example_args = (torch.tensor(4),)
21
+ tags = {
22
+ "torch.dynamic-value",
23
+ "torch.escape-hatch",
24
+ }
25
+ model = ConstrainAsSizeExample()
.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_value_example.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+
4
+
5
+ class ConstrainAsValueExample(torch.nn.Module):
6
+ """
7
+ If the value is not known at tracing time, you can provide hint so that we
8
+ can trace further. Please look at torch._check and torch._check_is_size APIs.
9
+ torch._check is used for values that don't need to be used for constructing
10
+ tensor.
11
+ """
12
+
13
+ def forward(self, x, y):
14
+ a = x.item()
15
+ torch._check(a >= 0)
16
+ torch._check(a <= 5)
17
+
18
+ if a < 6:
19
+ return y.sin()
20
+ return y.cos()
21
+
22
+
23
+ example_args = (torch.tensor(4), torch.randn(5, 5))
24
+ tags = {
25
+ "torch.dynamic-value",
26
+ "torch.escape-hatch",
27
+ }
28
+ model = ConstrainAsValueExample()
.venv/lib/python3.11/site-packages/torch/_export/db/examples/decorator.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+
4
+ import torch
5
+
6
+ def test_decorator(func):
7
+ @functools.wraps(func)
8
+ def wrapper(*args, **kwargs):
9
+ return func(*args, **kwargs) + 1
10
+
11
+ return wrapper
12
+
13
+ class Decorator(torch.nn.Module):
14
+ """
15
+ Decorators calls are inlined into the exported function during tracing.
16
+ """
17
+
18
+ @test_decorator
19
+ def forward(self, x, y):
20
+ return x + y
21
+
22
+ example_args = (torch.randn(3, 2), torch.randn(3, 2))
23
+ model = Decorator()
.venv/lib/python3.11/site-packages/torch/_export/db/examples/dictionary.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+
4
+ class Dictionary(torch.nn.Module):
5
+ """
6
+ Dictionary structures are inlined and flattened along tracing.
7
+ """
8
+
9
+ def forward(self, x, y):
10
+ elements = {}
11
+ elements["x2"] = x * x
12
+ y = y * elements["x2"]
13
+ return {"y": y}
14
+
15
+ example_args = (torch.randn(3, 2), torch.tensor(4))
16
+ tags = {"python.data-structure"}
17
+ model = Dictionary()
.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_assert.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+
4
+ class DynamicShapeAssert(torch.nn.Module):
5
+ """
6
+ A basic usage of python assertion.
7
+ """
8
+
9
+ def forward(self, x):
10
+ # assertion with error message
11
+ assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2"
12
+ # assertion without error message
13
+ assert x.shape[0] > 1
14
+ return x
15
+
16
+ example_args = (torch.randn(3, 2),)
17
+ tags = {"python.assert"}
18
+ model = DynamicShapeAssert()
.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+
4
+ class DynamicShapeConstructor(torch.nn.Module):
5
+ """
6
+ Tensor constructors should be captured with dynamic shape inputs rather
7
+ than being baked in with static shape.
8
+ """
9
+
10
+ def forward(self, x):
11
+ return torch.zeros(x.shape[0] * 2)
12
+
13
+ example_args = (torch.randn(3, 2),)
14
+ tags = {"torch.dynamic-shape"}
15
+ model = DynamicShapeConstructor()
.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+
4
+ class DynamicShapeIfGuard(torch.nn.Module):
5
+ """
6
+ `if` statement with backed dynamic shape predicate will be specialized into
7
+ one particular branch and generate a guard. However, export will fail if the
8
+ the dimension is marked as dynamic shape from higher level API.
9
+ """
10
+
11
+ def forward(self, x):
12
+ if x.shape[0] == 3:
13
+ return x.cos()
14
+
15
+ return x.sin()
16
+
17
+ example_args = (torch.randn(3, 2, 2),)
18
+ tags = {"torch.dynamic-shape", "python.control-flow"}
19
+ model = DynamicShapeIfGuard()
.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_map.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+
4
+ from functorch.experimental.control_flow import map
5
+
6
+ class DynamicShapeMap(torch.nn.Module):
7
+ """
8
+ functorch map() maps a function over the first tensor dimension.
9
+ """
10
+
11
+ def forward(self, xs, y):
12
+ def body(x, y):
13
+ return x + y
14
+
15
+ return map(body, xs, y)
16
+
17
+ example_args = (torch.randn(3, 2), torch.randn(2))
18
+ tags = {"torch.dynamic-shape", "torch.map"}
19
+ model = DynamicShapeMap()
.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_round.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+
4
+ from torch._export.db.case import SupportLevel
5
+ from torch.export import Dim
6
+
7
+ class DynamicShapeRound(torch.nn.Module):
8
+ """
9
+ Calling round on dynamic shapes is not supported.
10
+ """
11
+
12
+ def forward(self, x):
13
+ return x[: round(x.shape[0] / 2)]
14
+
15
+ x = torch.randn(3, 2)
16
+ dim0_x = Dim("dim0_x")
17
+ example_args = (x,)
18
+ tags = {"torch.dynamic-shape", "python.builtin"}
19
+ support_level = SupportLevel.NOT_SUPPORTED_YET
20
+ dynamic_shapes = {"x": {0: dim0_x}}
21
+ model = DynamicShapeRound()