Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 +3 -0
- .venv/lib/python3.11/site-packages/torch/_export/__init__.py +317 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/__init__.py +5 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/case.py +174 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__init__.py +61 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/unsupported_operator.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/assume_constant_result.py +20 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/autograd_function.py +23 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/class_method.py +22 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_class_method.py +44 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nested_function.py +41 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py +59 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_closed_over_variable.py +22 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_operands.py +36 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_predicate.py +25 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_size_example.py +25 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_value_example.py +28 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/decorator.py +23 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/dictionary.py +17 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_assert.py +18 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py +15 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py +19 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_map.py +19 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_round.py +21 -0
.gitattributes
CHANGED
|
@@ -125,3 +125,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 125 |
.venv/lib/python3.11/site-packages/opencv_python_headless.libs/libopenblas-r0-f650aae0.3.3.so filter=lfs diff=lfs merge=lfs -text
|
| 126 |
.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 filter=lfs diff=lfs merge=lfs -text
|
| 127 |
.venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 125 |
.venv/lib/python3.11/site-packages/opencv_python_headless.libs/libopenblas-r0-f650aae0.3.3.so filter=lfs diff=lfs merge=lfs -text
|
| 126 |
.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 filter=lfs diff=lfs merge=lfs -text
|
| 127 |
.venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 128 |
+
.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:242b9dba953ae2e4878d66032624135a9118a1616ca24588ed586d4bcc475c69
|
| 3 |
+
size 108421928
|
.venv/lib/python3.11/site-packages/torch/_export/__init__.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import copy
|
| 3 |
+
import dataclasses
|
| 4 |
+
import functools
|
| 5 |
+
import io
|
| 6 |
+
import json
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import re
|
| 10 |
+
import sys
|
| 11 |
+
import types
|
| 12 |
+
import warnings
|
| 13 |
+
import weakref
|
| 14 |
+
import zipfile
|
| 15 |
+
from collections import OrderedDict
|
| 16 |
+
from contextlib import contextmanager
|
| 17 |
+
from functools import lru_cache
|
| 18 |
+
|
| 19 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 20 |
+
from unittest.mock import patch
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.fx
|
| 24 |
+
import torch.utils._pytree as pytree
|
| 25 |
+
|
| 26 |
+
from torch._dispatch.python import enable_python_dispatcher
|
| 27 |
+
from torch._utils_internal import log_export_usage
|
| 28 |
+
from torch.export._tree_utils import reorder_kwargs
|
| 29 |
+
from torch.export.graph_signature import (
|
| 30 |
+
ArgumentSpec,
|
| 31 |
+
ConstantArgument,
|
| 32 |
+
ExportGraphSignature,
|
| 33 |
+
InputKind,
|
| 34 |
+
InputSpec,
|
| 35 |
+
OutputKind,
|
| 36 |
+
OutputSpec,
|
| 37 |
+
SymIntArgument,
|
| 38 |
+
TensorArgument,
|
| 39 |
+
)
|
| 40 |
+
from torch.fx import traceback as fx_traceback
|
| 41 |
+
from torch.fx._compatibility import compatibility
|
| 42 |
+
from torch.fx.experimental.proxy_tensor import make_fx
|
| 43 |
+
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
| 44 |
+
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
| 45 |
+
|
| 46 |
+
from .wrappers import _wrap_submodules
|
| 47 |
+
|
| 48 |
+
log = logging.getLogger(__name__)
|
| 49 |
+
|
| 50 |
+
@dataclasses.dataclass
|
| 51 |
+
class ExportDynamoConfig:
|
| 52 |
+
"""
|
| 53 |
+
Manage Export-specific configurations of Dynamo.
|
| 54 |
+
"""
|
| 55 |
+
allow_rnn: bool = True
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# We only want to print this once to avoid flooding logs in workflows where capture_pre_autograd_graph
|
| 59 |
+
# is called multiple times.
|
| 60 |
+
@lru_cache
|
| 61 |
+
def capture_pre_autograd_graph_warning():
|
| 62 |
+
from torch._inductor import config
|
| 63 |
+
|
| 64 |
+
log.warning("+============================+")
|
| 65 |
+
log.warning("| !!! WARNING !!! |")
|
| 66 |
+
log.warning("+============================+")
|
| 67 |
+
log.warning("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.")
|
| 68 |
+
log.warning("Please switch to use torch.export.export_for_training instead.")
|
| 69 |
+
if config.is_fbcode():
|
| 70 |
+
log.warning("Unless the unittest is in the blocklist, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.") # noqa: B950
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@compatibility(is_backward_compatible=False)
|
| 74 |
+
def capture_pre_autograd_graph(
|
| 75 |
+
f: torch.nn.Module,
|
| 76 |
+
args: Tuple[Any],
|
| 77 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
| 78 |
+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
| 79 |
+
) -> torch.nn.Module:
|
| 80 |
+
"""
|
| 81 |
+
A helper function that is intended to trace a module before any pre-autograd
|
| 82 |
+
decomposition is run. The produced module will be "non-functional" and
|
| 83 |
+
composed of aten operators. Later this API will be deleted in favor of more general
|
| 84 |
+
torch.export API.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
f: nn.Module to be traced
|
| 88 |
+
|
| 89 |
+
args: example positional inputs.
|
| 90 |
+
|
| 91 |
+
kwargs: optional example keyword inputs.
|
| 92 |
+
|
| 93 |
+
dynamic_shapes: Should either be:
|
| 94 |
+
1) a dict from argument names of ``f`` to their dynamic shape specifications,
|
| 95 |
+
2) a tuple that specifies dynamic shape specifications for each input in original order.
|
| 96 |
+
If you are specifying dynamism on keyword args, you will need to pass them in the order that
|
| 97 |
+
is defined in the original function signature.
|
| 98 |
+
|
| 99 |
+
The dynamic shape of a tensor argument can be specified as either
|
| 100 |
+
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
|
| 101 |
+
not required to include static dimension indices in this dict, but when they are,
|
| 102 |
+
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
|
| 103 |
+
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
|
| 104 |
+
are denoted by None. Arguments that are dicts or tuples / lists of tensors are
|
| 105 |
+
recursively specified by using mappings or sequences of contained specifications.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
An nn.Module containing the traced method.
|
| 109 |
+
|
| 110 |
+
"""
|
| 111 |
+
from torch.export._trace import _extract_fake_inputs, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps
|
| 112 |
+
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
|
| 113 |
+
from torch._export.non_strict_utils import make_constraints
|
| 114 |
+
from torch._subclasses.functional_tensor import FunctionalTensor
|
| 115 |
+
from torch.export._unlift import _create_stateful_graph_module
|
| 116 |
+
from torch.export.dynamic_shapes import _combine_args
|
| 117 |
+
|
| 118 |
+
capture_pre_autograd_graph_warning()
|
| 119 |
+
|
| 120 |
+
if sys.platform == "win32":
|
| 121 |
+
raise RuntimeError("capture_pre_autograd_graph not yet supported on Windows")
|
| 122 |
+
|
| 123 |
+
assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance."
|
| 124 |
+
|
| 125 |
+
if kwargs is None:
|
| 126 |
+
kwargs = {}
|
| 127 |
+
|
| 128 |
+
if capture_pre_autograd_graph_using_training_ir():
|
| 129 |
+
@lru_cache
|
| 130 |
+
def print_export_warning():
|
| 131 |
+
log.warning("Using torch.export.export_for_training(...,strict=True)")
|
| 132 |
+
print_export_warning()
|
| 133 |
+
module = torch.export.export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module()
|
| 134 |
+
else:
|
| 135 |
+
log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
|
| 136 |
+
|
| 137 |
+
# Do not decompose dropout for exported models, because in eval mode the dropout
|
| 138 |
+
# op disappears from the graph, which makes it difficult to switch to train mode.
|
| 139 |
+
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
|
| 140 |
+
decomp_table = {
|
| 141 |
+
op: op.decompose
|
| 142 |
+
for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
|
| 143 |
+
if op != torch.ops.aten.dropout.default
|
| 144 |
+
}
|
| 145 |
+
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():
|
| 146 |
+
m = torch._dynamo.export(
|
| 147 |
+
f,
|
| 148 |
+
dynamic_shapes=dynamic_shapes,
|
| 149 |
+
assume_static_by_default=True,
|
| 150 |
+
tracing_mode="symbolic",
|
| 151 |
+
decomposition_table=decomp_table,
|
| 152 |
+
pre_dispatch=True,
|
| 153 |
+
aten_graph=True,
|
| 154 |
+
_log_export_usage=False,
|
| 155 |
+
)(
|
| 156 |
+
*args,
|
| 157 |
+
**kwargs,
|
| 158 |
+
)[0]
|
| 159 |
+
|
| 160 |
+
_, _, fake_mode = _extract_fake_inputs(m, args, kwargs)
|
| 161 |
+
|
| 162 |
+
m.meta["inline_constraints"] = {
|
| 163 |
+
k: v
|
| 164 |
+
for k, v in fake_mode.shape_env.var_to_range.items()
|
| 165 |
+
if re.match(r"^[if]\d+$", str(k))
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
if isinstance(f, torch.nn.Module):
|
| 169 |
+
from torch.export._trace import _restore_state_dict
|
| 170 |
+
_restore_state_dict(f, m)
|
| 171 |
+
|
| 172 |
+
flat_args, _ = pytree.tree_flatten((args, kwargs or {}))
|
| 173 |
+
combined_args = _combine_args(f, args, kwargs)
|
| 174 |
+
range_constraints = make_constraints(
|
| 175 |
+
fake_mode,
|
| 176 |
+
m,
|
| 177 |
+
combined_args,
|
| 178 |
+
dynamic_shapes,
|
| 179 |
+
0,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
module = _create_stateful_graph_module(
|
| 183 |
+
m,
|
| 184 |
+
range_constraints=range_constraints,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
error_message = \
|
| 188 |
+
"""
|
| 189 |
+
Calling train() or eval() is not supported for exported models.
|
| 190 |
+
Alternatively, you may override these methods to do custom user behavior as follows:
|
| 191 |
+
|
| 192 |
+
def _my_train(self, mode: bool = True):
|
| 193 |
+
...
|
| 194 |
+
|
| 195 |
+
def _my_eval(self):
|
| 196 |
+
...
|
| 197 |
+
|
| 198 |
+
model.train = types.MethodType(_my_train, model)
|
| 199 |
+
model.eval = types.MethodType(_my_eval, model)
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def _train(self, mode: bool = True):
|
| 203 |
+
raise NotImplementedError(error_message)
|
| 204 |
+
|
| 205 |
+
def _eval(self, mode: bool = True):
|
| 206 |
+
raise NotImplementedError(error_message)
|
| 207 |
+
|
| 208 |
+
module.train = types.MethodType(_train, module) # type: ignore[method-assign]
|
| 209 |
+
module.eval = types.MethodType(_eval, module) # type: ignore[method-assign]
|
| 210 |
+
|
| 211 |
+
# Remove Proxy because they cannot be deepcopied or pickled.
|
| 212 |
+
if hasattr(module, "_buffers"):
|
| 213 |
+
torch._export.utils.remove_proxy_from_state_dict(
|
| 214 |
+
module._buffers, in_place=True
|
| 215 |
+
)
|
| 216 |
+
return module
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def aot_compile(
|
| 220 |
+
f: Callable,
|
| 221 |
+
args: Tuple[Any],
|
| 222 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
| 223 |
+
*,
|
| 224 |
+
dynamic_shapes: Optional[Dict[str, Any]] = None,
|
| 225 |
+
options: Optional[Dict[str, Any]] = None,
|
| 226 |
+
remove_runtime_assertions: bool = False,
|
| 227 |
+
disable_constraint_solver: bool = False,
|
| 228 |
+
same_signature: bool = True,
|
| 229 |
+
) -> str:
|
| 230 |
+
"""
|
| 231 |
+
Note: this function is not stable yet
|
| 232 |
+
|
| 233 |
+
Traces either an nn.Module's forward function or just a callable with PyTorch
|
| 234 |
+
operations inside, generates executable cpp code from the program, and returns
|
| 235 |
+
the path to the generated shared library
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
f: the `nn.Module` or callable to trace.
|
| 239 |
+
|
| 240 |
+
args: example positional inputs.
|
| 241 |
+
|
| 242 |
+
kwargs: optional example keyword inputs.
|
| 243 |
+
|
| 244 |
+
dynamic_shapes: Should either be:
|
| 245 |
+
1) a dict from argument names of ``f`` to their dynamic shape specifications,
|
| 246 |
+
2) a tuple that specifies dynamic shape specifications for each input in original order.
|
| 247 |
+
If you are specifying dynamism on keyword args, you will need to pass them in the order that
|
| 248 |
+
is defined in the original function signature.
|
| 249 |
+
|
| 250 |
+
The dynamic shape of a tensor argument can be specified as either
|
| 251 |
+
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
|
| 252 |
+
not required to include static dimension indices in this dict, but when they are,
|
| 253 |
+
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
|
| 254 |
+
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
|
| 255 |
+
are denoted by None. Arguments that are dicts or tuples / lists of tensors are
|
| 256 |
+
recursively specified by using mappings or sequences of contained specifications.
|
| 257 |
+
|
| 258 |
+
options: A dictionary of options to control inductor
|
| 259 |
+
|
| 260 |
+
disable_constraint_solver: Whether the dim constraint solver must be disabled.
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
Path to the generated shared library
|
| 264 |
+
"""
|
| 265 |
+
from torch.export._trace import _export_to_torch_ir
|
| 266 |
+
from torch._inductor.decomposition import select_decomp_table
|
| 267 |
+
from torch._inductor import config
|
| 268 |
+
|
| 269 |
+
if config.is_predispatch:
|
| 270 |
+
gm = torch.export._trace._export(f, args, kwargs, dynamic_shapes, pre_dispatch=True).module()
|
| 271 |
+
else:
|
| 272 |
+
# We want to export to Torch IR here to utilize the pre_grad passes in
|
| 273 |
+
# inductor, which run on Torch IR.
|
| 274 |
+
gm = _export_to_torch_ir(
|
| 275 |
+
f,
|
| 276 |
+
args,
|
| 277 |
+
kwargs,
|
| 278 |
+
dynamic_shapes,
|
| 279 |
+
disable_constraint_solver=disable_constraint_solver,
|
| 280 |
+
same_signature=same_signature,
|
| 281 |
+
# Disabling this flag, because instead we can rely on the mapping
|
| 282 |
+
# dynamo_flat_name_to_original_fqn which is coming from Dynamo.
|
| 283 |
+
restore_fqn=False,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
with torch.no_grad():
|
| 287 |
+
so_path = torch._inductor.aot_compile(gm, args, kwargs, options=options) # type: ignore[arg-type]
|
| 288 |
+
|
| 289 |
+
return so_path
|
| 290 |
+
|
| 291 |
+
def aot_load(so_path: str, device: str) -> Callable:
|
| 292 |
+
"""
|
| 293 |
+
Loads a shared library generated by aot_compile and returns a callable
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
so_path: Path to the shared library
|
| 297 |
+
|
| 298 |
+
Returns:
|
| 299 |
+
A callable
|
| 300 |
+
"""
|
| 301 |
+
if device == "cpu":
|
| 302 |
+
runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg]
|
| 303 |
+
elif device == "cuda" or device.startswith("cuda:"):
|
| 304 |
+
runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg]
|
| 305 |
+
else:
|
| 306 |
+
raise RuntimeError("Unsupported device " + device)
|
| 307 |
+
|
| 308 |
+
def optimized(*args, **kwargs):
|
| 309 |
+
call_spec = runner.get_call_spec() # type: ignore[attr-defined]
|
| 310 |
+
in_spec = pytree.treespec_loads(call_spec[0])
|
| 311 |
+
out_spec = pytree.treespec_loads(call_spec[1])
|
| 312 |
+
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
|
| 313 |
+
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
|
| 314 |
+
flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined]
|
| 315 |
+
return pytree.tree_unflatten(flat_outputs, out_spec)
|
| 316 |
+
|
| 317 |
+
return optimized
|
.venv/lib/python3.11/site-packages/torch/_export/db/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
.venv/lib/python3.11/site-packages/torch/_export/db/case.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import inspect
|
| 3 |
+
import re
|
| 4 |
+
import string
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
| 8 |
+
from types import ModuleType
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
_TAGS: Dict[str, Dict[str, Any]] = {
|
| 13 |
+
"torch": {
|
| 14 |
+
"cond": {},
|
| 15 |
+
"dynamic-shape": {},
|
| 16 |
+
"escape-hatch": {},
|
| 17 |
+
"map": {},
|
| 18 |
+
"dynamic-value": {},
|
| 19 |
+
"operator": {},
|
| 20 |
+
"mutation": {},
|
| 21 |
+
},
|
| 22 |
+
"python": {
|
| 23 |
+
"assert": {},
|
| 24 |
+
"builtin": {},
|
| 25 |
+
"closure": {},
|
| 26 |
+
"context-manager": {},
|
| 27 |
+
"control-flow": {},
|
| 28 |
+
"data-structure": {},
|
| 29 |
+
"standard-library": {},
|
| 30 |
+
"object-model": {},
|
| 31 |
+
},
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class SupportLevel(Enum):
|
| 36 |
+
"""
|
| 37 |
+
Indicates at what stage the feature
|
| 38 |
+
used in the example is handled in export.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
SUPPORTED = 1
|
| 42 |
+
NOT_SUPPORTED_YET = 0
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
ArgsType = Tuple[Any, ...]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def check_inputs_type(args, kwargs):
|
| 49 |
+
if not isinstance(args, tuple):
|
| 50 |
+
raise ValueError(
|
| 51 |
+
f"Expecting args type to be a tuple, got: {type(args)}"
|
| 52 |
+
)
|
| 53 |
+
if not isinstance(kwargs, dict):
|
| 54 |
+
raise ValueError(
|
| 55 |
+
f"Expecting kwargs type to be a dict, got: {type(kwargs)}"
|
| 56 |
+
)
|
| 57 |
+
for key in kwargs:
|
| 58 |
+
if not isinstance(key, str):
|
| 59 |
+
raise ValueError(
|
| 60 |
+
f"Expecting kwargs keys to be a string, got: {type(key)}"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def _validate_tag(tag: str):
|
| 64 |
+
parts = tag.split(".")
|
| 65 |
+
t = _TAGS
|
| 66 |
+
for part in parts:
|
| 67 |
+
assert set(part) <= set(
|
| 68 |
+
string.ascii_lowercase + "-"
|
| 69 |
+
), f"Tag contains invalid characters: {part}"
|
| 70 |
+
if part in t:
|
| 71 |
+
t = t[part]
|
| 72 |
+
else:
|
| 73 |
+
raise ValueError(f"Tag {tag} is not found in registered tags.")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclass(frozen=True)
|
| 77 |
+
class ExportCase:
|
| 78 |
+
example_args: ArgsType
|
| 79 |
+
description: str # A description of the use case.
|
| 80 |
+
model: torch.nn.Module
|
| 81 |
+
name: str
|
| 82 |
+
example_kwargs: Dict[str, Any] = field(default_factory=dict)
|
| 83 |
+
extra_args: Optional[ArgsType] = None # For testing graph generalization.
|
| 84 |
+
# Tags associated with the use case. (e.g dynamic-shape, escape-hatch)
|
| 85 |
+
tags: Set[str] = field(default_factory=set)
|
| 86 |
+
support_level: SupportLevel = SupportLevel.SUPPORTED
|
| 87 |
+
dynamic_shapes: Optional[Dict[str, Any]] = None
|
| 88 |
+
|
| 89 |
+
def __post_init__(self):
|
| 90 |
+
check_inputs_type(self.example_args, self.example_kwargs)
|
| 91 |
+
if self.extra_args is not None:
|
| 92 |
+
check_inputs_type(self.extra_args, {})
|
| 93 |
+
|
| 94 |
+
for tag in self.tags:
|
| 95 |
+
_validate_tag(tag)
|
| 96 |
+
|
| 97 |
+
if not isinstance(self.description, str) or len(self.description) == 0:
|
| 98 |
+
raise ValueError(f'Invalid description: "{self.description}"')
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
_EXAMPLE_CASES: Dict[str, ExportCase] = {}
|
| 102 |
+
_MODULES: Set[ModuleType] = set()
|
| 103 |
+
_EXAMPLE_CONFLICT_CASES: Dict[str, List[ExportCase]] = {}
|
| 104 |
+
_EXAMPLE_REWRITE_CASES: Dict[str, List[ExportCase]] = {}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def register_db_case(case: ExportCase) -> None:
|
| 108 |
+
"""
|
| 109 |
+
Registers a user provided ExportCase into example bank.
|
| 110 |
+
"""
|
| 111 |
+
if case.name in _EXAMPLE_CASES:
|
| 112 |
+
if case.name not in _EXAMPLE_CONFLICT_CASES:
|
| 113 |
+
_EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]]
|
| 114 |
+
_EXAMPLE_CONFLICT_CASES[case.name].append(case)
|
| 115 |
+
return
|
| 116 |
+
|
| 117 |
+
_EXAMPLE_CASES[case.name] = case
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def to_snake_case(name):
|
| 121 |
+
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
| 122 |
+
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _make_export_case(m, name, configs):
|
| 126 |
+
if not isinstance(m, torch.nn.Module):
|
| 127 |
+
raise TypeError("Export case class should be a torch.nn.Module.")
|
| 128 |
+
|
| 129 |
+
if "description" not in configs:
|
| 130 |
+
# Fallback to docstring if description is missing.
|
| 131 |
+
assert (
|
| 132 |
+
m.__doc__ is not None
|
| 133 |
+
), f"Could not find description or docstring for export case: {m}"
|
| 134 |
+
configs = {**configs, "description": m.__doc__}
|
| 135 |
+
return ExportCase(**{**configs, "model": m, "name": name})
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def export_case(**kwargs):
|
| 139 |
+
"""
|
| 140 |
+
Decorator for registering a user provided case into example bank.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
def wrapper(m):
|
| 144 |
+
configs = kwargs
|
| 145 |
+
module = inspect.getmodule(m)
|
| 146 |
+
if module in _MODULES:
|
| 147 |
+
raise RuntimeError("export_case should only be used once per example file.")
|
| 148 |
+
|
| 149 |
+
assert module is not None
|
| 150 |
+
_MODULES.add(module)
|
| 151 |
+
module_name = module.__name__.split(".")[-1]
|
| 152 |
+
case = _make_export_case(m, module_name, configs)
|
| 153 |
+
register_db_case(case)
|
| 154 |
+
return case
|
| 155 |
+
|
| 156 |
+
return wrapper
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def export_rewrite_case(**kwargs):
|
| 160 |
+
def wrapper(m):
|
| 161 |
+
configs = kwargs
|
| 162 |
+
|
| 163 |
+
parent = configs.pop("parent")
|
| 164 |
+
assert isinstance(parent, ExportCase)
|
| 165 |
+
key = parent.name
|
| 166 |
+
if key not in _EXAMPLE_REWRITE_CASES:
|
| 167 |
+
_EXAMPLE_REWRITE_CASES[key] = []
|
| 168 |
+
|
| 169 |
+
configs["example_args"] = parent.example_args
|
| 170 |
+
case = _make_export_case(m, to_snake_case(m.__name__), configs)
|
| 171 |
+
_EXAMPLE_REWRITE_CASES[key].append(case)
|
| 172 |
+
return case
|
| 173 |
+
|
| 174 |
+
return wrapper
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__init__.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import dataclasses
|
| 3 |
+
import glob
|
| 4 |
+
import inspect
|
| 5 |
+
from os.path import basename, dirname, isfile, join
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch._export.db.case import (
|
| 9 |
+
_EXAMPLE_CASES,
|
| 10 |
+
_EXAMPLE_CONFLICT_CASES,
|
| 11 |
+
_EXAMPLE_REWRITE_CASES,
|
| 12 |
+
SupportLevel,
|
| 13 |
+
export_case,
|
| 14 |
+
ExportCase,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _collect_examples():
|
| 19 |
+
case_names = glob.glob(join(dirname(__file__), "*.py"))
|
| 20 |
+
case_names = [
|
| 21 |
+
basename(f)[:-3] for f in case_names if isfile(f) and not f.endswith("__init__.py")
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
case_fields = {f.name for f in dataclasses.fields(ExportCase)}
|
| 25 |
+
for case_name in case_names:
|
| 26 |
+
case = __import__(case_name, globals(), locals(), [], 1)
|
| 27 |
+
variables = [name for name in dir(case) if name in case_fields]
|
| 28 |
+
export_case(**{v: getattr(case, v) for v in variables})(case.model)
|
| 29 |
+
|
| 30 |
+
_collect_examples()
|
| 31 |
+
|
| 32 |
+
def all_examples():
|
| 33 |
+
return _EXAMPLE_CASES
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if len(_EXAMPLE_CONFLICT_CASES) > 0:
|
| 37 |
+
|
| 38 |
+
def get_name(case):
|
| 39 |
+
model = case.model
|
| 40 |
+
if isinstance(model, torch.nn.Module):
|
| 41 |
+
model = type(model)
|
| 42 |
+
return model.__name__
|
| 43 |
+
|
| 44 |
+
msg = "Error on conflict export case name.\n"
|
| 45 |
+
for case_name, cases in _EXAMPLE_CONFLICT_CASES.items():
|
| 46 |
+
msg += f"Case name {case_name} is associated with multiple cases:\n "
|
| 47 |
+
msg += f"[{','.join(map(get_name, cases))}]\n"
|
| 48 |
+
|
| 49 |
+
raise RuntimeError(msg)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def filter_examples_by_support_level(support_level: SupportLevel):
|
| 53 |
+
return {
|
| 54 |
+
key: val
|
| 55 |
+
for key, val in all_examples().items()
|
| 56 |
+
if val.support_level == support_level
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_rewrite_cases(case):
|
| 61 |
+
return _EXAMPLE_REWRITE_CASES.get(case.name, [])
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-311.pyc
ADDED
|
Binary file (1.48 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-311.pyc
ADDED
|
Binary file (1.72 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-311.pyc
ADDED
|
Binary file (1.79 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-311.pyc
ADDED
|
Binary file (2.96 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-311.pyc
ADDED
|
Binary file (2.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-311.pyc
ADDED
|
Binary file (2.79 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-311.pyc
ADDED
|
Binary file (1.54 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-311.pyc
ADDED
|
Binary file (1.92 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-311.pyc
ADDED
|
Binary file (1.76 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-311.pyc
ADDED
|
Binary file (1.43 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-311.pyc
ADDED
|
Binary file (1.42 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-311.pyc
ADDED
|
Binary file (1.06 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-311.pyc
ADDED
|
Binary file (1.11 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-311.pyc
ADDED
|
Binary file (1.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-311.pyc
ADDED
|
Binary file (1.29 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-311.pyc
ADDED
|
Binary file (1.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-311.pyc
ADDED
|
Binary file (1.52 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-311.pyc
ADDED
|
Binary file (1.22 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-311.pyc
ADDED
|
Binary file (1.36 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-311.pyc
ADDED
|
Binary file (1.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-311.pyc
ADDED
|
Binary file (1.38 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-311.pyc
ADDED
|
Binary file (1.12 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-311.pyc
ADDED
|
Binary file (1.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-311.pyc
ADDED
|
Binary file (1.04 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/unsupported_operator.cpython-311.pyc
ADDED
|
Binary file (1.21 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-311.pyc
ADDED
|
Binary file (1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/assume_constant_result.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
import torch._dynamo as torchdynamo
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class AssumeConstantResult(torch.nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
Applying `assume_constant_result` decorator to burn make non-tracable code as constant.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
@torchdynamo.assume_constant_result
|
| 12 |
+
def get_item(self, y):
|
| 13 |
+
return y.int().item()
|
| 14 |
+
|
| 15 |
+
def forward(self, x, y):
|
| 16 |
+
return x[: self.get_item(y)]
|
| 17 |
+
|
| 18 |
+
example_args = (torch.randn(3, 2), torch.tensor(4))
|
| 19 |
+
tags = {"torch.escape-hatch"}
|
| 20 |
+
model = AssumeConstantResult()
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/autograd_function.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class MyAutogradFunction(torch.autograd.Function):
|
| 5 |
+
@staticmethod
|
| 6 |
+
def forward(ctx, x):
|
| 7 |
+
return x.clone()
|
| 8 |
+
|
| 9 |
+
@staticmethod
|
| 10 |
+
def backward(ctx, grad_output):
|
| 11 |
+
return grad_output + 1
|
| 12 |
+
|
| 13 |
+
class AutogradFunction(torch.nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
TorchDynamo does not keep track of backward() on autograd functions. We recommend to
|
| 16 |
+
use `allow_in_graph` to mitigate this problem.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
return MyAutogradFunction.apply(x)
|
| 21 |
+
|
| 22 |
+
example_args = (torch.randn(3, 2),)
|
| 23 |
+
model = AutogradFunction()
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/class_method.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class ClassMethod(torch.nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
Class methods are inlined during tracing.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
@classmethod
|
| 10 |
+
def method(cls, x):
|
| 11 |
+
return x + 1
|
| 12 |
+
|
| 13 |
+
def __init__(self) -> None:
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.linear = torch.nn.Linear(4, 2)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
x = self.linear(x)
|
| 19 |
+
return self.method(x) * self.__class__.method(x) * type(self).method(x)
|
| 20 |
+
|
| 21 |
+
example_args = (torch.randn(3, 4),)
|
| 22 |
+
model = ClassMethod()
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_class_method.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from functorch.experimental.control_flow import cond
|
| 5 |
+
|
| 6 |
+
class MySubModule(torch.nn.Module):
|
| 7 |
+
def foo(self, x):
|
| 8 |
+
return x.cos()
|
| 9 |
+
|
| 10 |
+
def forward(self, x):
|
| 11 |
+
return self.foo(x)
|
| 12 |
+
|
| 13 |
+
class CondBranchClassMethod(torch.nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
|
| 16 |
+
- both branches must take the same args, which must also match the branch args passed to cond.
|
| 17 |
+
- both branches must return a single tensor
|
| 18 |
+
- returned tensor must have the same tensor metadata, e.g. shape and dtype
|
| 19 |
+
- branch function can be free function, nested function, lambda, class methods
|
| 20 |
+
- branch function can not have closure variables
|
| 21 |
+
- no inplace mutations on inputs or global variables
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
This example demonstrates using class method in cond().
|
| 25 |
+
|
| 26 |
+
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self) -> None:
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.subm = MySubModule()
|
| 32 |
+
|
| 33 |
+
def bar(self, x):
|
| 34 |
+
return x.sin()
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
|
| 38 |
+
|
| 39 |
+
example_args = (torch.randn(3),)
|
| 40 |
+
tags = {
|
| 41 |
+
"torch.cond",
|
| 42 |
+
"torch.dynamic-shape",
|
| 43 |
+
}
|
| 44 |
+
model = CondBranchClassMethod()
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nested_function.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from functorch.experimental.control_flow import cond
|
| 5 |
+
|
| 6 |
+
class CondBranchNestedFunction(torch.nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
|
| 9 |
+
- both branches must take the same args, which must also match the branch args passed to cond.
|
| 10 |
+
- both branches must return a single tensor
|
| 11 |
+
- returned tensor must have the same tensor metadata, e.g. shape and dtype
|
| 12 |
+
- branch function can be free function, nested function, lambda, class methods
|
| 13 |
+
- branch function can not have closure variables
|
| 14 |
+
- no inplace mutations on inputs or global variables
|
| 15 |
+
|
| 16 |
+
This example demonstrates using nested function in cond().
|
| 17 |
+
|
| 18 |
+
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
def true_fn(x):
|
| 23 |
+
def inner_true_fn(y):
|
| 24 |
+
return x + y
|
| 25 |
+
|
| 26 |
+
return inner_true_fn(x)
|
| 27 |
+
|
| 28 |
+
def false_fn(x):
|
| 29 |
+
def inner_false_fn(y):
|
| 30 |
+
return x - y
|
| 31 |
+
|
| 32 |
+
return inner_false_fn(x)
|
| 33 |
+
|
| 34 |
+
return cond(x.shape[0] < 10, true_fn, false_fn, [x])
|
| 35 |
+
|
| 36 |
+
example_args = (torch.randn(3),)
|
| 37 |
+
tags = {
|
| 38 |
+
"torch.cond",
|
| 39 |
+
"torch.dynamic-shape",
|
| 40 |
+
}
|
| 41 |
+
model = CondBranchNestedFunction()
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from functorch.experimental.control_flow import cond
|
| 5 |
+
|
| 6 |
+
class CondBranchNonlocalVariables(torch.nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
|
| 9 |
+
- both branches must take the same args, which must also match the branch args passed to cond.
|
| 10 |
+
- both branches must return a single tensor
|
| 11 |
+
- returned tensor must have the same tensor metadata, e.g. shape and dtype
|
| 12 |
+
- branch function can be free function, nested function, lambda, class methods
|
| 13 |
+
- branch function can not have closure variables
|
| 14 |
+
- no inplace mutations on inputs or global variables
|
| 15 |
+
|
| 16 |
+
This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions.
|
| 17 |
+
|
| 18 |
+
The code below will not work because capturing closure variables is not supported.
|
| 19 |
+
```
|
| 20 |
+
my_tensor_var = x + 100
|
| 21 |
+
my_primitive_var = 3.14
|
| 22 |
+
|
| 23 |
+
def true_fn(y):
|
| 24 |
+
nonlocal my_tensor_var, my_primitive_var
|
| 25 |
+
return y + my_tensor_var + my_primitive_var
|
| 26 |
+
|
| 27 |
+
def false_fn(y):
|
| 28 |
+
nonlocal my_tensor_var, my_primitive_var
|
| 29 |
+
return y - my_tensor_var - my_primitive_var
|
| 30 |
+
|
| 31 |
+
return cond(x.shape[0] > 5, true_fn, false_fn, [x])
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
my_tensor_var = x + 100
|
| 39 |
+
my_primitive_var = 3.14
|
| 40 |
+
|
| 41 |
+
def true_fn(x, y, z):
|
| 42 |
+
return x + y + z
|
| 43 |
+
|
| 44 |
+
def false_fn(x, y, z):
|
| 45 |
+
return x - y - z
|
| 46 |
+
|
| 47 |
+
return cond(
|
| 48 |
+
x.shape[0] > 5,
|
| 49 |
+
true_fn,
|
| 50 |
+
false_fn,
|
| 51 |
+
[x, my_tensor_var, torch.tensor(my_primitive_var)],
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
example_args = (torch.randn(6),)
|
| 55 |
+
tags = {
|
| 56 |
+
"torch.cond",
|
| 57 |
+
"torch.dynamic-shape",
|
| 58 |
+
}
|
| 59 |
+
model = CondBranchNonlocalVariables()
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_closed_over_variable.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from functorch.experimental.control_flow import cond
|
| 5 |
+
|
| 6 |
+
class CondClosedOverVariable(torch.nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
torch.cond() supports branches closed over arbitrary variables.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def forward(self, pred, x):
|
| 12 |
+
def true_fn(val):
|
| 13 |
+
return x * 2
|
| 14 |
+
|
| 15 |
+
def false_fn(val):
|
| 16 |
+
return x - 2
|
| 17 |
+
|
| 18 |
+
return cond(pred, true_fn, false_fn, [x + 1])
|
| 19 |
+
|
| 20 |
+
example_args = (torch.tensor(True), torch.randn(3, 2))
|
| 21 |
+
tags = {"torch.cond", "python.closure"}
|
| 22 |
+
model = CondClosedOverVariable()
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_operands.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from torch.export import Dim
|
| 5 |
+
from functorch.experimental.control_flow import cond
|
| 6 |
+
|
| 7 |
+
x = torch.randn(3, 2)
|
| 8 |
+
y = torch.randn(2)
|
| 9 |
+
dim0_x = Dim("dim0_x")
|
| 10 |
+
|
| 11 |
+
class CondOperands(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
The operands passed to cond() must be:
|
| 14 |
+
- a list of tensors
|
| 15 |
+
- match arguments of `true_fn` and `false_fn`
|
| 16 |
+
|
| 17 |
+
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def forward(self, x, y):
|
| 21 |
+
def true_fn(x, y):
|
| 22 |
+
return x + y
|
| 23 |
+
|
| 24 |
+
def false_fn(x, y):
|
| 25 |
+
return x - y
|
| 26 |
+
|
| 27 |
+
return cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
|
| 28 |
+
|
| 29 |
+
example_args = (x, y)
|
| 30 |
+
tags = {
|
| 31 |
+
"torch.cond",
|
| 32 |
+
"torch.dynamic-shape",
|
| 33 |
+
}
|
| 34 |
+
extra_inputs = (torch.randn(2, 2), torch.randn(2))
|
| 35 |
+
dynamic_shapes = {"x": {0: dim0_x}, "y": None}
|
| 36 |
+
model = CondOperands()
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_predicate.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from functorch.experimental.control_flow import cond
|
| 5 |
+
|
| 6 |
+
class CondPredicate(torch.nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
The conditional statement (aka predicate) passed to cond() must be one of the following:
|
| 9 |
+
- torch.Tensor with a single element
|
| 10 |
+
- boolean expression
|
| 11 |
+
|
| 12 |
+
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
pred = x.dim() > 2 and x.shape[2] > 10
|
| 17 |
+
|
| 18 |
+
return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
|
| 19 |
+
|
| 20 |
+
example_args = (torch.randn(6, 4, 3),)
|
| 21 |
+
tags = {
|
| 22 |
+
"torch.cond",
|
| 23 |
+
"torch.dynamic-shape",
|
| 24 |
+
}
|
| 25 |
+
model = CondPredicate()
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_size_example.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ConstrainAsSizeExample(torch.nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
If the value is not known at tracing time, you can provide hint so that we
|
| 8 |
+
can trace further. Please look at torch._check and torch._check_is_size APIs.
|
| 9 |
+
torch._check_is_size is used for values that NEED to be used for constructing
|
| 10 |
+
tensor.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
a = x.item()
|
| 15 |
+
torch._check_is_size(a)
|
| 16 |
+
torch._check(a <= 5)
|
| 17 |
+
return torch.zeros((a, 5))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
example_args = (torch.tensor(4),)
|
| 21 |
+
tags = {
|
| 22 |
+
"torch.dynamic-value",
|
| 23 |
+
"torch.escape-hatch",
|
| 24 |
+
}
|
| 25 |
+
model = ConstrainAsSizeExample()
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_value_example.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ConstrainAsValueExample(torch.nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
If the value is not known at tracing time, you can provide hint so that we
|
| 8 |
+
can trace further. Please look at torch._check and torch._check_is_size APIs.
|
| 9 |
+
torch._check is used for values that don't need to be used for constructing
|
| 10 |
+
tensor.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def forward(self, x, y):
|
| 14 |
+
a = x.item()
|
| 15 |
+
torch._check(a >= 0)
|
| 16 |
+
torch._check(a <= 5)
|
| 17 |
+
|
| 18 |
+
if a < 6:
|
| 19 |
+
return y.sin()
|
| 20 |
+
return y.cos()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
example_args = (torch.tensor(4), torch.randn(5, 5))
|
| 24 |
+
tags = {
|
| 25 |
+
"torch.dynamic-value",
|
| 26 |
+
"torch.escape-hatch",
|
| 27 |
+
}
|
| 28 |
+
model = ConstrainAsValueExample()
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/decorator.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
def test_decorator(func):
|
| 7 |
+
@functools.wraps(func)
|
| 8 |
+
def wrapper(*args, **kwargs):
|
| 9 |
+
return func(*args, **kwargs) + 1
|
| 10 |
+
|
| 11 |
+
return wrapper
|
| 12 |
+
|
| 13 |
+
class Decorator(torch.nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Decorators calls are inlined into the exported function during tracing.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
@test_decorator
|
| 19 |
+
def forward(self, x, y):
|
| 20 |
+
return x + y
|
| 21 |
+
|
| 22 |
+
example_args = (torch.randn(3, 2), torch.randn(3, 2))
|
| 23 |
+
model = Decorator()
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/dictionary.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class Dictionary(torch.nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
Dictionary structures are inlined and flattened along tracing.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
def forward(self, x, y):
|
| 10 |
+
elements = {}
|
| 11 |
+
elements["x2"] = x * x
|
| 12 |
+
y = y * elements["x2"]
|
| 13 |
+
return {"y": y}
|
| 14 |
+
|
| 15 |
+
example_args = (torch.randn(3, 2), torch.tensor(4))
|
| 16 |
+
tags = {"python.data-structure"}
|
| 17 |
+
model = Dictionary()
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_assert.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class DynamicShapeAssert(torch.nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
A basic usage of python assertion.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
def forward(self, x):
|
| 10 |
+
# assertion with error message
|
| 11 |
+
assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2"
|
| 12 |
+
# assertion without error message
|
| 13 |
+
assert x.shape[0] > 1
|
| 14 |
+
return x
|
| 15 |
+
|
| 16 |
+
example_args = (torch.randn(3, 2),)
|
| 17 |
+
tags = {"python.assert"}
|
| 18 |
+
model = DynamicShapeAssert()
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class DynamicShapeConstructor(torch.nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
Tensor constructors should be captured with dynamic shape inputs rather
|
| 7 |
+
than being baked in with static shape.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
def forward(self, x):
|
| 11 |
+
return torch.zeros(x.shape[0] * 2)
|
| 12 |
+
|
| 13 |
+
example_args = (torch.randn(3, 2),)
|
| 14 |
+
tags = {"torch.dynamic-shape"}
|
| 15 |
+
model = DynamicShapeConstructor()
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class DynamicShapeIfGuard(torch.nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
`if` statement with backed dynamic shape predicate will be specialized into
|
| 7 |
+
one particular branch and generate a guard. However, export will fail if the
|
| 8 |
+
the dimension is marked as dynamic shape from higher level API.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def forward(self, x):
|
| 12 |
+
if x.shape[0] == 3:
|
| 13 |
+
return x.cos()
|
| 14 |
+
|
| 15 |
+
return x.sin()
|
| 16 |
+
|
| 17 |
+
example_args = (torch.randn(3, 2, 2),)
|
| 18 |
+
tags = {"torch.dynamic-shape", "python.control-flow"}
|
| 19 |
+
model = DynamicShapeIfGuard()
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_map.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from functorch.experimental.control_flow import map
|
| 5 |
+
|
| 6 |
+
class DynamicShapeMap(torch.nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
functorch map() maps a function over the first tensor dimension.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def forward(self, xs, y):
|
| 12 |
+
def body(x, y):
|
| 13 |
+
return x + y
|
| 14 |
+
|
| 15 |
+
return map(body, xs, y)
|
| 16 |
+
|
| 17 |
+
example_args = (torch.randn(3, 2), torch.randn(2))
|
| 18 |
+
tags = {"torch.dynamic-shape", "torch.map"}
|
| 19 |
+
model = DynamicShapeMap()
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_round.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from torch._export.db.case import SupportLevel
|
| 5 |
+
from torch.export import Dim
|
| 6 |
+
|
| 7 |
+
class DynamicShapeRound(torch.nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Calling round on dynamic shapes is not supported.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
return x[: round(x.shape[0] / 2)]
|
| 14 |
+
|
| 15 |
+
x = torch.randn(3, 2)
|
| 16 |
+
dim0_x = Dim("dim0_x")
|
| 17 |
+
example_args = (x,)
|
| 18 |
+
tags = {"torch.dynamic-shape", "python.builtin"}
|
| 19 |
+
support_level = SupportLevel.NOT_SUPPORTED_YET
|
| 20 |
+
dynamic_shapes = {"x": {0: dim0_x}}
|
| 21 |
+
model = DynamicShapeRound()
|