Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__init__.py +406 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__init__.py +52 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/torch_sym_min.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/assume_constant_result.py +24 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/class_method.py +24 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nested_function.py +44 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_closed_over_variable.py +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_operands.py +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_predicate.py +29 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_size_example.py +27 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_value_example.py +30 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dictionary.py +21 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py +21 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py +20 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_view.py +22 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_contains.py +21 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_unpack.py +27 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/model_attr_mutation.py +25 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/pytree_flatten.py +20 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/scalar_output.py +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/specialized_attribute.py +29 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_for_loop.py +22 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_if.py +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/tensor_setattr.py +17 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/type_reflection_method.py +41 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/remove_runtime_assertions.py +26 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py +141 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/utils.py +401 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/wrappers.py +114 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/config.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/device_context.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-311.pyc +0 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__init__.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import dataclasses
|
| 3 |
+
import functools
|
| 4 |
+
import io
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
import sys
|
| 9 |
+
import types
|
| 10 |
+
import warnings
|
| 11 |
+
import weakref
|
| 12 |
+
import zipfile
|
| 13 |
+
from collections import OrderedDict
|
| 14 |
+
from contextlib import contextmanager
|
| 15 |
+
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 17 |
+
from unittest.mock import patch
|
| 18 |
+
|
| 19 |
+
import sympy
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch._dynamo
|
| 23 |
+
import torch.fx
|
| 24 |
+
import torch.utils._pytree as pytree
|
| 25 |
+
|
| 26 |
+
from torch._decomp import core_aten_decompositions, get_decompositions
|
| 27 |
+
from torch._dispatch.python import enable_python_dispatcher
|
| 28 |
+
from torch._dynamo.exc import UserError, UserErrorType
|
| 29 |
+
from torch._dynamo.source import ConstantSource
|
| 30 |
+
from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass
|
| 31 |
+
from torch._functorch.aot_autograd import aot_export_module, GraphSignature
|
| 32 |
+
from torch._functorch.eager_transforms import functionalize
|
| 33 |
+
from torch._guards import detect_fake_mode
|
| 34 |
+
from torch._inductor import config
|
| 35 |
+
from torch._ops import OpOverload
|
| 36 |
+
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
| 37 |
+
from torch._subclasses.functional_tensor import FunctionalTensor
|
| 38 |
+
from torch._utils_internal import log_export_usage
|
| 39 |
+
from torch.export._tree_utils import reorder_kwargs
|
| 40 |
+
from torch.export._unlift import _create_stateful_graph_module
|
| 41 |
+
from torch.export.dynamic_shapes import (
|
| 42 |
+
_process_constraints,
|
| 43 |
+
_process_dynamic_shapes,
|
| 44 |
+
Constraint,
|
| 45 |
+
dims,
|
| 46 |
+
dynamic_dim,
|
| 47 |
+
)
|
| 48 |
+
from torch.export.exported_program import (
|
| 49 |
+
_disable_prexisiting_fake_mode,
|
| 50 |
+
ExportedProgram,
|
| 51 |
+
ModuleCallEntry,
|
| 52 |
+
ModuleCallSignature,
|
| 53 |
+
)
|
| 54 |
+
from torch.export.graph_signature import (
|
| 55 |
+
_sig_to_specs,
|
| 56 |
+
ArgumentSpec,
|
| 57 |
+
ConstantArgument,
|
| 58 |
+
ExportGraphSignature,
|
| 59 |
+
InputKind,
|
| 60 |
+
InputSpec,
|
| 61 |
+
OutputKind,
|
| 62 |
+
OutputSpec,
|
| 63 |
+
SymIntArgument,
|
| 64 |
+
TensorArgument,
|
| 65 |
+
)
|
| 66 |
+
from torch.fx import traceback as fx_traceback
|
| 67 |
+
from torch.fx._compatibility import compatibility
|
| 68 |
+
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
|
| 69 |
+
from torch.fx.experimental.symbolic_shapes import (
|
| 70 |
+
ConstraintViolationError,
|
| 71 |
+
GuardOnDataDependentSymNode,
|
| 72 |
+
ShapeEnv,
|
| 73 |
+
StrictMinMaxConstraint,
|
| 74 |
+
)
|
| 75 |
+
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
| 76 |
+
from torch.utils._sympy.value_ranges import ValueRangeError, ValueRanges
|
| 77 |
+
|
| 78 |
+
from .passes.add_runtime_assertions_for_constraints_pass import (
|
| 79 |
+
_AddRuntimeAssertionsForInlineConstraintsPass,
|
| 80 |
+
)
|
| 81 |
+
from .wrappers import _wrap_submodules
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@dataclasses.dataclass
|
| 85 |
+
class ExportDynamoConfig:
|
| 86 |
+
"""
|
| 87 |
+
Manage Export-specific configurations of Dynamo.
|
| 88 |
+
"""
|
| 89 |
+
allow_rnn: bool = True
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@compatibility(is_backward_compatible=False)
|
| 93 |
+
def capture_pre_autograd_graph(
|
| 94 |
+
f: torch.nn.Module,
|
| 95 |
+
args: Tuple[Any],
|
| 96 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
| 97 |
+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
| 98 |
+
) -> torch.nn.Module:
|
| 99 |
+
"""
|
| 100 |
+
A helper function that is intended to trace a module before any pre-autograd
|
| 101 |
+
decomposition is run. The produced module will be "non-functional" and
|
| 102 |
+
composed of aten operators. Later this API will be deleted in favor of more general
|
| 103 |
+
torch.export API.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
f: nn.Module to be traced
|
| 107 |
+
|
| 108 |
+
args: example positional inputs.
|
| 109 |
+
|
| 110 |
+
kwargs: optional example keyword inputs.
|
| 111 |
+
|
| 112 |
+
dynamic_shapes: Should either be:
|
| 113 |
+
1) a dict from argument names of ``f`` to their dynamic shape specifications,
|
| 114 |
+
2) a tuple that specifies dynamic shape specifications for each input in original order.
|
| 115 |
+
If you are specifying dynamism on keyword args, you will need to pass them in the order that
|
| 116 |
+
is defined in the original function signature.
|
| 117 |
+
|
| 118 |
+
The dynamic shape of a tensor argument can be specified as either
|
| 119 |
+
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
|
| 120 |
+
not required to include static dimension indices in this dict, but when they are,
|
| 121 |
+
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
|
| 122 |
+
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
|
| 123 |
+
are denoted by None. Arguments that are dicts or tuples / lists of tensors are
|
| 124 |
+
recursively specified by using mappings or sequences of contained specifications.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
An nn.Module containing the traced method.
|
| 128 |
+
|
| 129 |
+
"""
|
| 130 |
+
from torch.export._trace import _convert_input_to_fake, DEFAULT_EXPORT_DYNAMO_CONFIG
|
| 131 |
+
from torch.export.dynamic_shapes import _process_dynamic_shapes
|
| 132 |
+
|
| 133 |
+
log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
|
| 134 |
+
|
| 135 |
+
assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance."
|
| 136 |
+
|
| 137 |
+
if kwargs is None:
|
| 138 |
+
kwargs = {}
|
| 139 |
+
|
| 140 |
+
constraints = _process_dynamic_shapes(f, args, kwargs, dynamic_shapes)
|
| 141 |
+
|
| 142 |
+
# Do not decompose dropout for exported models, because in eval mode the dropout
|
| 143 |
+
# op disappears from the graph, which makes it difficult to switch to train mode.
|
| 144 |
+
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
|
| 145 |
+
decomp_table = {
|
| 146 |
+
op: op.decompose
|
| 147 |
+
for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
|
| 148 |
+
if op != torch.ops.aten.dropout.default
|
| 149 |
+
}
|
| 150 |
+
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
|
| 151 |
+
m = torch._dynamo.export(
|
| 152 |
+
f,
|
| 153 |
+
constraints=constraints,
|
| 154 |
+
assume_static_by_default=True,
|
| 155 |
+
tracing_mode="symbolic",
|
| 156 |
+
decomposition_table=decomp_table,
|
| 157 |
+
pre_dispatch=True,
|
| 158 |
+
aten_graph=True,
|
| 159 |
+
_log_export_usage=False,
|
| 160 |
+
)(
|
| 161 |
+
*args,
|
| 162 |
+
**kwargs,
|
| 163 |
+
)[0]
|
| 164 |
+
|
| 165 |
+
_, _, _, fake_mode = _convert_input_to_fake(m, args, kwargs)
|
| 166 |
+
|
| 167 |
+
m.meta["inline_constraints"] = {
|
| 168 |
+
k: v
|
| 169 |
+
for k, v in fake_mode.shape_env.var_to_range.items()
|
| 170 |
+
if re.match(r"^[if]\d+$", str(k))
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
if isinstance(f, torch.nn.Module):
|
| 174 |
+
from torch.export._trace import _restore_state_dict
|
| 175 |
+
_restore_state_dict(f, m)
|
| 176 |
+
|
| 177 |
+
flat_args, _ = pytree.tree_flatten((args, kwargs or {}))
|
| 178 |
+
range_constraints = _process_constraints(fake_mode, m, 0, flat_args)
|
| 179 |
+
|
| 180 |
+
module = _create_stateful_graph_module(
|
| 181 |
+
m,
|
| 182 |
+
range_constraints=range_constraints,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
error_message = \
|
| 186 |
+
"""
|
| 187 |
+
Calling train() or eval() is not supported for exported models.
|
| 188 |
+
Alternatively, you may override these methods to do custom user behavior as follows:
|
| 189 |
+
|
| 190 |
+
def _my_train(self, mode: bool = True):
|
| 191 |
+
...
|
| 192 |
+
|
| 193 |
+
def _my_eval(self):
|
| 194 |
+
...
|
| 195 |
+
|
| 196 |
+
model.train = types.MethodType(_my_train, model)
|
| 197 |
+
model.eval = types.MethodType(_my_eval, model)
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
def _train(self, mode: bool = True):
|
| 201 |
+
raise NotImplementedError(error_message)
|
| 202 |
+
|
| 203 |
+
def _eval(self, mode: bool = True):
|
| 204 |
+
raise NotImplementedError(error_message)
|
| 205 |
+
|
| 206 |
+
module.train = types.MethodType(_train, module) # type: ignore[method-assign]
|
| 207 |
+
module.eval = types.MethodType(_eval, module) # type: ignore[method-assign]
|
| 208 |
+
return module
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def save(
|
| 212 |
+
ep: ExportedProgram,
|
| 213 |
+
f: Union[str, os.PathLike, io.BytesIO],
|
| 214 |
+
*,
|
| 215 |
+
extra_files: Optional[Dict[str, Any]] = None,
|
| 216 |
+
opset_version: Optional[Dict[str, int]] = None,
|
| 217 |
+
) -> None:
|
| 218 |
+
if not isinstance(ep, ExportedProgram):
|
| 219 |
+
raise TypeError(f"save() expects an ExportedProgram but got {type(ep)}")
|
| 220 |
+
|
| 221 |
+
from .serde.serialize import serialize, SerializedArtifact
|
| 222 |
+
from .serde.schema import SCHEMA_VERSION
|
| 223 |
+
artifact: SerializedArtifact = serialize(ep, opset_version)
|
| 224 |
+
|
| 225 |
+
if isinstance(f, (str, os.PathLike)):
|
| 226 |
+
f = os.fspath(f)
|
| 227 |
+
|
| 228 |
+
with zipfile.ZipFile(f, 'w') as zipf:
|
| 229 |
+
# Save every field the SerializedArtifact to a file
|
| 230 |
+
assert isinstance(artifact.exported_program, bytes)
|
| 231 |
+
zipf.writestr("serialized_exported_program.json", artifact.exported_program)
|
| 232 |
+
zipf.writestr("serialized_state_dict.pt", artifact.state_dict)
|
| 233 |
+
zipf.writestr("serialized_constants.pt", artifact.constants)
|
| 234 |
+
|
| 235 |
+
zipf.writestr('version', ".".join(map(str, SCHEMA_VERSION)))
|
| 236 |
+
|
| 237 |
+
# Add extra files if provided
|
| 238 |
+
if extra_files:
|
| 239 |
+
for extra_file_name, content in extra_files.items():
|
| 240 |
+
encoded_content = content.encode('utf-8')
|
| 241 |
+
zipf.writestr(f"extra_files/{extra_file_name}", encoded_content)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def load(
|
| 245 |
+
f: Union[str, os.PathLike, io.BytesIO],
|
| 246 |
+
*,
|
| 247 |
+
extra_files: Optional[Dict[str, Any]] = None,
|
| 248 |
+
expected_opset_version: Optional[Dict[str, int]] = None,
|
| 249 |
+
) -> ExportedProgram:
|
| 250 |
+
if isinstance(f, (str, os.PathLike)):
|
| 251 |
+
f = os.fspath(f)
|
| 252 |
+
|
| 253 |
+
extra_files = extra_files or {}
|
| 254 |
+
|
| 255 |
+
with zipfile.ZipFile(f, 'r') as zipf:
|
| 256 |
+
# Check the version
|
| 257 |
+
version = zipf.read('version').decode().split('.')
|
| 258 |
+
from .serde.schema import SCHEMA_VERSION
|
| 259 |
+
|
| 260 |
+
assert len(version) == len(SCHEMA_VERSION)
|
| 261 |
+
if version[0] != str(SCHEMA_VERSION[0]):
|
| 262 |
+
raise RuntimeError(
|
| 263 |
+
f"Serialized version {version} does not match our current "
|
| 264 |
+
f"schema version {SCHEMA_VERSION}."
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
from .serde.serialize import deserialize, SerializedArtifact
|
| 268 |
+
|
| 269 |
+
# Load serialized_ep and serialized_state_dict from the zip file
|
| 270 |
+
|
| 271 |
+
serialized_exported_program: Optional[bytes] = None
|
| 272 |
+
serialized_state_dict: Optional[bytes] = None
|
| 273 |
+
serialized_constants: Optional[bytes] = None
|
| 274 |
+
|
| 275 |
+
for file_info in zipf.infolist():
|
| 276 |
+
file_content = zipf.read(file_info.filename)
|
| 277 |
+
|
| 278 |
+
if file_info.filename == "serialized_exported_program.json":
|
| 279 |
+
serialized_exported_program = file_content
|
| 280 |
+
elif file_info.filename == "serialized_state_dict.json":
|
| 281 |
+
warnings.warn("This version of file is deprecated")
|
| 282 |
+
serialized_state_dict = file_content
|
| 283 |
+
elif file_info.filename == "serialized_constants.json":
|
| 284 |
+
warnings.warn("This version of file is deprecated")
|
| 285 |
+
serialized_constants = file_content
|
| 286 |
+
elif file_info.filename == "serialized_state_dict.pt":
|
| 287 |
+
serialized_state_dict = file_content
|
| 288 |
+
elif file_info.filename == "serialized_constants.pt":
|
| 289 |
+
serialized_constants = file_content
|
| 290 |
+
elif file_info.filename.startswith("extra_files"):
|
| 291 |
+
filename = file_info.filename.split("/", 1)[1]
|
| 292 |
+
extra_files[filename] = file_content.decode('utf-8')
|
| 293 |
+
|
| 294 |
+
assert serialized_exported_program is not None
|
| 295 |
+
assert serialized_state_dict is not None
|
| 296 |
+
assert serialized_constants is not None
|
| 297 |
+
artifact: SerializedArtifact = SerializedArtifact(
|
| 298 |
+
serialized_exported_program,
|
| 299 |
+
serialized_state_dict,
|
| 300 |
+
serialized_constants,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# Deserialize ExportedProgram
|
| 304 |
+
ep = deserialize(artifact, expected_opset_version)
|
| 305 |
+
|
| 306 |
+
return ep
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def aot_compile(
|
| 310 |
+
f: Callable,
|
| 311 |
+
args: Tuple[Any],
|
| 312 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
| 313 |
+
*,
|
| 314 |
+
dynamic_shapes: Optional[Dict[str, Any]] = None,
|
| 315 |
+
options: Optional[Dict[str, Any]] = None,
|
| 316 |
+
remove_runtime_assertions: bool = False,
|
| 317 |
+
disable_constraint_solver: bool = False,
|
| 318 |
+
) -> str:
|
| 319 |
+
"""
|
| 320 |
+
Note: this function is not stable yet
|
| 321 |
+
|
| 322 |
+
Traces either an nn.Module's forward function or just a callable with PyTorch
|
| 323 |
+
operations inside, generates executable cpp code from the program, and returns
|
| 324 |
+
the path to the generated shared library
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
f: the `nn.Module` or callable to trace.
|
| 328 |
+
|
| 329 |
+
args: example positional inputs.
|
| 330 |
+
|
| 331 |
+
kwargs: optional example keyword inputs.
|
| 332 |
+
|
| 333 |
+
dynamic_shapes: Should either be:
|
| 334 |
+
1) a dict from argument names of ``f`` to their dynamic shape specifications,
|
| 335 |
+
2) a tuple that specifies dynamic shape specifications for each input in original order.
|
| 336 |
+
If you are specifying dynamism on keyword args, you will need to pass them in the order that
|
| 337 |
+
is defined in the original function signature.
|
| 338 |
+
|
| 339 |
+
The dynamic shape of a tensor argument can be specified as either
|
| 340 |
+
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
|
| 341 |
+
not required to include static dimension indices in this dict, but when they are,
|
| 342 |
+
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
|
| 343 |
+
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
|
| 344 |
+
are denoted by None. Arguments that are dicts or tuples / lists of tensors are
|
| 345 |
+
recursively specified by using mappings or sequences of contained specifications.
|
| 346 |
+
|
| 347 |
+
options: A dictionary of options to control inductor
|
| 348 |
+
|
| 349 |
+
disable_constraint_solver: Whether the dim constraint solver must be disabled.
|
| 350 |
+
|
| 351 |
+
Returns:
|
| 352 |
+
Path to the generated shared library
|
| 353 |
+
"""
|
| 354 |
+
from torch.export._trace import _export_to_torch_ir
|
| 355 |
+
from torch._inductor.decomposition import select_decomp_table
|
| 356 |
+
|
| 357 |
+
constraints = _process_dynamic_shapes(f, args, kwargs, dynamic_shapes)
|
| 358 |
+
|
| 359 |
+
if config.is_predispatch:
|
| 360 |
+
gm = torch.export._trace._export(f, args, kwargs, constraints, pre_dispatch=True).module()
|
| 361 |
+
else:
|
| 362 |
+
# We want to export to Torch IR here to utilize the pre_grad passes in
|
| 363 |
+
# inductor, which run on Torch IR.
|
| 364 |
+
gm = _export_to_torch_ir(
|
| 365 |
+
f,
|
| 366 |
+
args,
|
| 367 |
+
kwargs,
|
| 368 |
+
constraints,
|
| 369 |
+
disable_constraint_solver=disable_constraint_solver,
|
| 370 |
+
# Disabling this flag, because instead we can rely on the mapping
|
| 371 |
+
# dynamo_flat_name_to_original_fqn which is coming from Dynamo.
|
| 372 |
+
restore_fqn=False,
|
| 373 |
+
)
|
| 374 |
+
flat_example_inputs = pytree.arg_tree_leaves(*args, **(kwargs or {}))
|
| 375 |
+
|
| 376 |
+
with torch.no_grad():
|
| 377 |
+
so_path = torch._inductor.aot_compile(gm, flat_example_inputs, options) # type: ignore[arg-type]
|
| 378 |
+
|
| 379 |
+
return so_path
|
| 380 |
+
|
| 381 |
+
def aot_load(so_path: str, device: str) -> Callable:
|
| 382 |
+
"""
|
| 383 |
+
Loads a shared library generated by aot_compile and returns a callable
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
so_path: Path to the shared library
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
A callable
|
| 390 |
+
"""
|
| 391 |
+
if device == "cpu":
|
| 392 |
+
runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg]
|
| 393 |
+
elif device == "cuda" or device.startswith("cuda:"):
|
| 394 |
+
runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg]
|
| 395 |
+
else:
|
| 396 |
+
raise RuntimeError("Unsupported device " + device)
|
| 397 |
+
|
| 398 |
+
def optimized(*args, **kwargs):
|
| 399 |
+
call_spec = runner.get_call_spec() # type: ignore[attr-defined]
|
| 400 |
+
in_spec = pytree.treespec_loads(call_spec[0])
|
| 401 |
+
out_spec = pytree.treespec_loads(call_spec[1])
|
| 402 |
+
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
|
| 403 |
+
flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined]
|
| 404 |
+
return pytree.tree_unflatten(flat_outputs, out_spec)
|
| 405 |
+
|
| 406 |
+
return optimized
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__init__.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import importlib
|
| 3 |
+
from os.path import basename, dirname, isfile, join
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch._export.db.case import (
|
| 7 |
+
_EXAMPLE_CASES,
|
| 8 |
+
_EXAMPLE_CONFLICT_CASES,
|
| 9 |
+
_EXAMPLE_REWRITE_CASES,
|
| 10 |
+
SupportLevel,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
modules = glob.glob(join(dirname(__file__), "*.py"))
|
| 15 |
+
__all__ = [
|
| 16 |
+
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py")
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
# Import all module in the current directory.
|
| 20 |
+
from . import * # noqa: F403
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def all_examples():
|
| 24 |
+
return _EXAMPLE_CASES
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if len(_EXAMPLE_CONFLICT_CASES) > 0:
|
| 28 |
+
|
| 29 |
+
def get_name(case):
|
| 30 |
+
model = case.model
|
| 31 |
+
if isinstance(model, torch.nn.Module):
|
| 32 |
+
model = type(model)
|
| 33 |
+
return model.__name__
|
| 34 |
+
|
| 35 |
+
msg = "Error on conflict export case name.\n"
|
| 36 |
+
for case_name, cases in _EXAMPLE_CONFLICT_CASES.items():
|
| 37 |
+
msg += f"Case name {case_name} is associated with multiple cases:\n "
|
| 38 |
+
msg += f"[{','.join(map(get_name, cases))}]\n"
|
| 39 |
+
|
| 40 |
+
raise RuntimeError(msg)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def filter_examples_by_support_level(support_level: SupportLevel):
|
| 44 |
+
return {
|
| 45 |
+
key: val
|
| 46 |
+
for key, val in all_examples().items()
|
| 47 |
+
if val.support_level == support_level
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_rewrite_cases(case):
|
| 52 |
+
return _EXAMPLE_REWRITE_CASES.get(case.name, [])
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.96 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-311.pyc
ADDED
|
Binary file (1.91 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-311.pyc
ADDED
|
Binary file (1.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-311.pyc
ADDED
|
Binary file (3.06 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-311.pyc
ADDED
|
Binary file (1.81 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-311.pyc
ADDED
|
Binary file (1.96 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-311.pyc
ADDED
|
Binary file (1.56 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-311.pyc
ADDED
|
Binary file (1.48 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-311.pyc
ADDED
|
Binary file (1.53 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-311.pyc
ADDED
|
Binary file (1.64 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-311.pyc
ADDED
|
Binary file (1.78 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-311.pyc
ADDED
|
Binary file (1.73 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-311.pyc
ADDED
|
Binary file (1.81 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-311.pyc
ADDED
|
Binary file (1.24 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/torch_sym_min.cpython-311.pyc
ADDED
|
Binary file (1.28 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-311.pyc
ADDED
|
Binary file (2.86 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-311.pyc
ADDED
|
Binary file (1.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/assume_constant_result.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch._dynamo as torchdynamo
|
| 3 |
+
|
| 4 |
+
from torch._export.db.case import export_case
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@export_case(
|
| 8 |
+
example_inputs=(torch.ones(3, 2), torch.tensor(4)),
|
| 9 |
+
tags={"torch.escape-hatch"},
|
| 10 |
+
)
|
| 11 |
+
class AssumeConstantResult(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Applying `assume_constant_result` decorator to burn make non-tracable code as constant.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
super().__init__()
|
| 18 |
+
|
| 19 |
+
@torchdynamo.assume_constant_result
|
| 20 |
+
def get_item(self, y):
|
| 21 |
+
return y.int().item()
|
| 22 |
+
|
| 23 |
+
def forward(self, x, y):
|
| 24 |
+
return x[: self.get_item(y)]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/class_method.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@export_case(
|
| 7 |
+
example_inputs=(torch.ones(3, 4),),
|
| 8 |
+
)
|
| 9 |
+
class ClassMethod(torch.nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
Class methods are inlined during tracing.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
@classmethod
|
| 15 |
+
def method(cls, x):
|
| 16 |
+
return x + 1
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.linear = torch.nn.Linear(4, 2)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
x = self.linear(x)
|
| 24 |
+
return self.method(x) * self.__class__.method(x) * type(self).method(x)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nested_function.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
from functorch.experimental.control_flow import cond
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@export_case(
|
| 8 |
+
example_inputs=(torch.ones(3),),
|
| 9 |
+
tags={
|
| 10 |
+
"torch.cond",
|
| 11 |
+
"torch.dynamic-shape",
|
| 12 |
+
},
|
| 13 |
+
)
|
| 14 |
+
class CondBranchNestedFunction(torch.nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
|
| 17 |
+
- both branches must take the same args, which must also match the branch args passed to cond.
|
| 18 |
+
- both branches must return a single tensor
|
| 19 |
+
- returned tensor must have the same tensor metadata, e.g. shape and dtype
|
| 20 |
+
- branch function can be free function, nested function, lambda, class methods
|
| 21 |
+
- branch function can not have closure variables
|
| 22 |
+
- no inplace mutations on inputs or global variables
|
| 23 |
+
|
| 24 |
+
This example demonstrates using nested function in cond().
|
| 25 |
+
|
| 26 |
+
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
|
| 27 |
+
"""
|
| 28 |
+
def __init__(self):
|
| 29 |
+
super().__init__()
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
def true_fn(x):
|
| 33 |
+
def inner_true_fn(y):
|
| 34 |
+
return x + y
|
| 35 |
+
|
| 36 |
+
return inner_true_fn(x)
|
| 37 |
+
|
| 38 |
+
def false_fn(x):
|
| 39 |
+
def inner_false_fn(y):
|
| 40 |
+
return x - y
|
| 41 |
+
|
| 42 |
+
return inner_false_fn(x)
|
| 43 |
+
|
| 44 |
+
return cond(x.shape[0] < 10, true_fn, false_fn, [x])
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_closed_over_variable.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
from functorch.experimental.control_flow import cond
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@export_case(
|
| 8 |
+
example_inputs=(torch.tensor(True), torch.ones(3, 2)),
|
| 9 |
+
tags={"torch.cond", "python.closure"},
|
| 10 |
+
)
|
| 11 |
+
class CondClosedOverVariable(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
torch.cond() supports branches closed over arbitrary variables.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def forward(self, pred, x):
|
| 17 |
+
def true_fn(val):
|
| 18 |
+
return x * 2
|
| 19 |
+
|
| 20 |
+
def false_fn(val):
|
| 21 |
+
return x - 2
|
| 22 |
+
|
| 23 |
+
return cond(pred, true_fn, false_fn, [x + 1])
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_operands.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
from torch.export import Dim
|
| 5 |
+
from functorch.experimental.control_flow import cond
|
| 6 |
+
|
| 7 |
+
x = torch.randn(3, 2)
|
| 8 |
+
y = torch.ones(2)
|
| 9 |
+
dim0_x = Dim("dim0_x")
|
| 10 |
+
|
| 11 |
+
@export_case(
|
| 12 |
+
example_inputs=(x, y),
|
| 13 |
+
tags={
|
| 14 |
+
"torch.cond",
|
| 15 |
+
"torch.dynamic-shape",
|
| 16 |
+
},
|
| 17 |
+
extra_inputs=(torch.randn(2, 2), torch.ones(2)),
|
| 18 |
+
dynamic_shapes={"x": {0: dim0_x}, "y": None},
|
| 19 |
+
)
|
| 20 |
+
class CondOperands(torch.nn.Module):
|
| 21 |
+
"""
|
| 22 |
+
The operands passed to cond() must be:
|
| 23 |
+
- a list of tensors
|
| 24 |
+
- match arguments of `true_fn` and `false_fn`
|
| 25 |
+
|
| 26 |
+
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self):
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
def forward(self, x, y):
|
| 33 |
+
def true_fn(x, y):
|
| 34 |
+
return x + y
|
| 35 |
+
|
| 36 |
+
def false_fn(x, y):
|
| 37 |
+
return x - y
|
| 38 |
+
|
| 39 |
+
return cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_predicate.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
from functorch.experimental.control_flow import cond
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@export_case(
|
| 8 |
+
example_inputs=(torch.ones(6, 4, 3),),
|
| 9 |
+
tags={
|
| 10 |
+
"torch.cond",
|
| 11 |
+
"torch.dynamic-shape",
|
| 12 |
+
},
|
| 13 |
+
)
|
| 14 |
+
class CondPredicate(torch.nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
The conditional statement (aka predicate) passed to cond() must be one of the following:
|
| 17 |
+
- torch.Tensor with a single element
|
| 18 |
+
- boolean expression
|
| 19 |
+
|
| 20 |
+
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
pred = x.dim() > 2 and x.shape[2] > 10
|
| 28 |
+
|
| 29 |
+
return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_size_example.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@export_case(
|
| 7 |
+
example_inputs=(torch.tensor(4),),
|
| 8 |
+
tags={
|
| 9 |
+
"torch.dynamic-value",
|
| 10 |
+
"torch.escape-hatch",
|
| 11 |
+
},
|
| 12 |
+
)
|
| 13 |
+
class ConstrainAsSizeExample(torch.nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
If the value is not known at tracing time, you can provide hint so that we
|
| 16 |
+
can trace further. Please look at constrain_as_value and constrain_as_size APIs
|
| 17 |
+
constrain_as_size is used for values that NEED to be used for constructing
|
| 18 |
+
tensor.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
a = x.item()
|
| 26 |
+
torch._constrain_as_size(a, min=0, max=5)
|
| 27 |
+
return torch.ones((a, 5))
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_value_example.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@export_case(
|
| 7 |
+
example_inputs=(torch.tensor(4), torch.randn(5, 5)),
|
| 8 |
+
tags={
|
| 9 |
+
"torch.dynamic-value",
|
| 10 |
+
"torch.escape-hatch",
|
| 11 |
+
},
|
| 12 |
+
)
|
| 13 |
+
class ConstrainAsValueExample(torch.nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
If the value is not known at tracing time, you can provide hint so that we
|
| 16 |
+
can trace further. Please look at constrain_as_value and constrain_as_size APIs.
|
| 17 |
+
constrain_as_value is used for values that don't need to be used for constructing
|
| 18 |
+
tensor.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
def forward(self, x, y):
|
| 25 |
+
a = x.item()
|
| 26 |
+
torch._constrain_as_value(a, min=0, max=5)
|
| 27 |
+
|
| 28 |
+
if a < 6:
|
| 29 |
+
return y.sin()
|
| 30 |
+
return y.cos()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dictionary.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@export_case(
|
| 7 |
+
example_inputs=(torch.ones(3, 2), torch.tensor(4)),
|
| 8 |
+
tags={"python.data-structure"},
|
| 9 |
+
)
|
| 10 |
+
class Dictionary(torch.nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Dictionary structures are inlined and flattened along tracing.
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
def forward(self, x, y):
|
| 18 |
+
elements = {}
|
| 19 |
+
elements["x2"] = x * x
|
| 20 |
+
y = y * elements["x2"]
|
| 21 |
+
return {"y": y}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@export_case(
|
| 7 |
+
example_inputs=(torch.ones(3, 2, 2),),
|
| 8 |
+
tags={"torch.dynamic-shape", "python.control-flow"},
|
| 9 |
+
)
|
| 10 |
+
class DynamicShapeIfGuard(torch.nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
`if` statement with backed dynamic shape predicate will be specialized into
|
| 13 |
+
one particular branch and generate a guard. However, export will fail if the
|
| 14 |
+
the dimension is marked as dynamic shape from higher level API.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
if x.shape[0] == 3:
|
| 19 |
+
return x.cos()
|
| 20 |
+
|
| 21 |
+
return x.sin()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@export_case(
|
| 7 |
+
example_inputs=(torch.ones(3, 2),),
|
| 8 |
+
tags={"torch.dynamic-shape"},
|
| 9 |
+
)
|
| 10 |
+
class DynamicShapeSlicing(torch.nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Slices with dynamic shape arguments should be captured into the graph
|
| 13 |
+
rather than being baked in.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
super().__init__()
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_view.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@export_case(
|
| 7 |
+
example_inputs=(torch.ones(10, 10),),
|
| 8 |
+
tags={"torch.dynamic-shape"},
|
| 9 |
+
)
|
| 10 |
+
class DynamicShapeView(torch.nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Dynamic shapes should be propagated to view arguments instead of being
|
| 13 |
+
baked into the exported graph.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
super().__init__()
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
new_x_shape = x.size()[:-1] + (2, 5)
|
| 21 |
+
x = x.view(*new_x_shape)
|
| 22 |
+
return x.permute(0, 2, 1)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_contains.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@export_case(
|
| 7 |
+
example_inputs=(torch.ones(3, 2),),
|
| 8 |
+
tags={"torch.dynamic-shape", "python.data-structure", "python.assert"},
|
| 9 |
+
)
|
| 10 |
+
class ListContains(torch.nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
List containment relation can be checked on a dynamic shape or constants.
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
assert x.size(-1) in [6, 2]
|
| 19 |
+
assert x.size(0) not in [4, 5, 6]
|
| 20 |
+
assert "monkey" not in ["cow", "pig"]
|
| 21 |
+
return x + x
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_unpack.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from torch._export.db.case import export_case
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@export_case(
|
| 9 |
+
example_inputs=([torch.ones(3, 2), torch.tensor(4), torch.tensor(5)],),
|
| 10 |
+
tags={"python.control-flow", "python.data-structure"},
|
| 11 |
+
)
|
| 12 |
+
class ListUnpack(torch.nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Lists are treated as static construct, therefore unpacking should be
|
| 15 |
+
erased after tracing.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
def forward(self, args: List[torch.Tensor]):
|
| 22 |
+
"""
|
| 23 |
+
Lists are treated as static construct, therefore unpacking should be
|
| 24 |
+
erased after tracing.
|
| 25 |
+
"""
|
| 26 |
+
x, *y = args
|
| 27 |
+
return x + y[0]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/model_attr_mutation.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case, SupportLevel
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@export_case(
|
| 7 |
+
example_inputs=(torch.ones(3, 2),),
|
| 8 |
+
tags={"python.object-model"},
|
| 9 |
+
support_level=SupportLevel.NOT_SUPPORTED_YET,
|
| 10 |
+
)
|
| 11 |
+
class ModelAttrMutation(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Attribute mutation is not supported.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.attr_list = [torch.ones(3, 2), torch.ones(3, 2)]
|
| 19 |
+
|
| 20 |
+
def recreate_list(self):
|
| 21 |
+
return [torch.zeros(3, 2), torch.zeros(3, 2)]
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
self.attr_list = self.recreate_list()
|
| 25 |
+
return x.sum() + self.attr_list[0].sum()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/pytree_flatten.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case, SupportLevel
|
| 4 |
+
from torch.utils import _pytree as pytree
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@export_case(
|
| 8 |
+
example_inputs=({1: torch.randn(3, 2), 2: torch.randn(3, 2)},),
|
| 9 |
+
support_level=SupportLevel.SUPPORTED,
|
| 10 |
+
)
|
| 11 |
+
class PytreeFlatten(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Pytree from PyTorch can be captured by TorchDynamo.
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
y, spec = pytree.tree_flatten(x)
|
| 20 |
+
return y[0] + 1
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/scalar_output.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
from torch.export import Dim
|
| 5 |
+
|
| 6 |
+
x = torch.ones(3, 2)
|
| 7 |
+
dim1_x = Dim("dim1_x")
|
| 8 |
+
|
| 9 |
+
@export_case(
|
| 10 |
+
example_inputs=(x,),
|
| 11 |
+
tags={"torch.dynamic-shape"},
|
| 12 |
+
dynamic_shapes={"x": {1: dim1_x}},
|
| 13 |
+
)
|
| 14 |
+
class ScalarOutput(torch.nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
Returning scalar values from the graph is supported, in addition to Tensor
|
| 17 |
+
outputs. Symbolic shapes are captured and rank is specialized.
|
| 18 |
+
"""
|
| 19 |
+
def __init__(self):
|
| 20 |
+
super().__init__()
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
return x.shape[1] + 1
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/specialized_attribute.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from torch._export.db.case import export_case
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Animal(Enum):
|
| 9 |
+
COW = "moo"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@export_case(
|
| 13 |
+
example_inputs=(torch.ones(3, 2),),
|
| 14 |
+
)
|
| 15 |
+
class SpecializedAttribute(torch.nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
Model attributes are specialized.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.a = "moo"
|
| 23 |
+
self.b = 4
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
if self.a == Animal.COW.value:
|
| 27 |
+
return x * x + self.b
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError("bad")
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_for_loop.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@export_case(
|
| 7 |
+
example_inputs=(torch.ones(3, 2),),
|
| 8 |
+
tags={"python.control-flow"},
|
| 9 |
+
)
|
| 10 |
+
class StaticForLoop(torch.nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
A for loop with constant number of iterations should be unrolled in the exported graph.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
ret = []
|
| 20 |
+
for i in range(10): # constant
|
| 21 |
+
ret.append(i + x)
|
| 22 |
+
return ret
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_if.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@export_case(
|
| 7 |
+
example_inputs=(torch.ones(3, 2, 2),),
|
| 8 |
+
tags={"python.control-flow"},
|
| 9 |
+
)
|
| 10 |
+
class StaticIf(torch.nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
`if` statement with static predicate value should be traced through with the
|
| 13 |
+
taken branch.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
super().__init__()
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
if len(x.shape) == 3:
|
| 21 |
+
return x + torch.ones(1, 1, 1)
|
| 22 |
+
|
| 23 |
+
return x
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/tensor_setattr.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case, SupportLevel
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@export_case(
|
| 7 |
+
example_inputs=(torch.randn(3, 2), "attr"),
|
| 8 |
+
tags={"python.builtin"},
|
| 9 |
+
support_level=SupportLevel.SUPPORTED,
|
| 10 |
+
)
|
| 11 |
+
class TensorSetattr(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
setattr() call onto tensors is not supported.
|
| 14 |
+
"""
|
| 15 |
+
def forward(self, x, attr):
|
| 16 |
+
setattr(x, attr, torch.randn(3, 2))
|
| 17 |
+
return x + 4
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/type_reflection_method.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case, SupportLevel, export_rewrite_case
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class A:
|
| 7 |
+
@classmethod
|
| 8 |
+
def func(cls, x):
|
| 9 |
+
return 1 + x
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@export_case(
|
| 13 |
+
example_inputs=(torch.ones(3, 4),),
|
| 14 |
+
tags={"python.builtin"},
|
| 15 |
+
support_level=SupportLevel.SUPPORTED,
|
| 16 |
+
)
|
| 17 |
+
class TypeReflectionMethod(torch.nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
type() calls on custom objects followed by attribute accesses are not allowed
|
| 20 |
+
due to its overly dynamic nature.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
a = A()
|
| 28 |
+
return type(a).func(x)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@export_rewrite_case(parent=TypeReflectionMethod)
|
| 32 |
+
class TypeReflectionMethodRewrite(torch.nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Custom object class methods will be inlined.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self):
|
| 38 |
+
super().__init__()
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
return A.func(x)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/remove_runtime_assertions.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class _RemoveRuntimeAssertionsPass(PassBase):
|
| 6 |
+
"""
|
| 7 |
+
Remove runtime assertions inserted by the
|
| 8 |
+
_AddRuntimeAssertionsForInlineConstraintsPass.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def call(self, graph_module) -> PassResult:
|
| 12 |
+
modified = False
|
| 13 |
+
for module in graph_module.modules():
|
| 14 |
+
if not isinstance(module, torch.fx.GraphModule):
|
| 15 |
+
continue
|
| 16 |
+
for node in module.graph.nodes:
|
| 17 |
+
if node.target == torch.ops.aten._assert_async.msg:
|
| 18 |
+
assert_async_node = node
|
| 19 |
+
if len(assert_async_node.users) > 0:
|
| 20 |
+
continue
|
| 21 |
+
module.graph.erase_node(assert_async_node)
|
| 22 |
+
# the upstream scalar_tensor <- {le, ge} <- sym_size
|
| 23 |
+
# linear chain of nodes of nodes is removed by the
|
| 24 |
+
# downstream dead code elimination
|
| 25 |
+
modified = True
|
| 26 |
+
return PassResult(graph_module, modified)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch._higher_order_ops.wrap import wrap_with_set_grad_enabled
|
| 3 |
+
|
| 4 |
+
from ..utils import (
|
| 5 |
+
node_inline_,
|
| 6 |
+
node_replace_,
|
| 7 |
+
nodes_filter,
|
| 8 |
+
nodes_first,
|
| 9 |
+
nodes_map,
|
| 10 |
+
sequential_split,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _is_set_grad_enabled_node(node: torch.fx.Node):
|
| 15 |
+
return (
|
| 16 |
+
node
|
| 17 |
+
and node.op == "call_function"
|
| 18 |
+
and node.target == torch._C._set_grad_enabled
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _is_set_grad_enabled_sub_mod(node: torch.fx.Node, omit_if_same_with_ambient=False):
|
| 23 |
+
if node.op == "call_module":
|
| 24 |
+
assert isinstance(node.target, str)
|
| 25 |
+
subgm = getattr(node.graph.owning_module, node.target)
|
| 26 |
+
first_non_ph = nodes_first(
|
| 27 |
+
subgm.graph.nodes, lambda node: node.op != "placeholder"
|
| 28 |
+
)
|
| 29 |
+
if (
|
| 30 |
+
first_non_ph
|
| 31 |
+
and first_non_ph.op == "call_function"
|
| 32 |
+
and first_non_ph.target == torch._C._set_grad_enabled
|
| 33 |
+
):
|
| 34 |
+
return (
|
| 35 |
+
first_non_ph.args[0] != torch.is_grad_enabled()
|
| 36 |
+
if omit_if_same_with_ambient
|
| 37 |
+
else True
|
| 38 |
+
)
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _replace_with_hop(node: torch.fx.Node):
|
| 43 |
+
assert node.op == "call_module"
|
| 44 |
+
graph: torch.fx.Graph = node.graph
|
| 45 |
+
gm: torch.fx.GraphModule = graph.owning_module
|
| 46 |
+
assert isinstance(node.target, str)
|
| 47 |
+
sub_gm = getattr(gm, node.target)
|
| 48 |
+
sub_graph = sub_gm.graph
|
| 49 |
+
set_grad_nodes = nodes_filter(sub_graph.nodes, _is_set_grad_enabled_node)
|
| 50 |
+
if len(set_grad_nodes) > 0:
|
| 51 |
+
assert len(set_grad_nodes) == 1
|
| 52 |
+
set_grad_node = set_grad_nodes[0]
|
| 53 |
+
enable_grad_val = set_grad_node.args[0]
|
| 54 |
+
with graph.inserting_before(node):
|
| 55 |
+
get_attr_node = graph.get_attr(node.target)
|
| 56 |
+
output_node = next(iter(reversed(sub_gm.graph.nodes)), None)
|
| 57 |
+
if output_node is not None:
|
| 58 |
+
assert len(output_node.args) == 1
|
| 59 |
+
output_args = output_node.args[0]
|
| 60 |
+
if isinstance(output_args, (tuple, list)):
|
| 61 |
+
call_func_node = graph.call_function(
|
| 62 |
+
wrap_with_set_grad_enabled,
|
| 63 |
+
(enable_grad_val, get_attr_node, *node.args),
|
| 64 |
+
{},
|
| 65 |
+
)
|
| 66 |
+
# Create the metadata
|
| 67 |
+
call_func_node.meta["val"] = tuple(
|
| 68 |
+
arg.meta["val"] for arg in output_args
|
| 69 |
+
)
|
| 70 |
+
node_replace_(node, call_func_node, delete_old=True)
|
| 71 |
+
|
| 72 |
+
# Rename the name of getitem nodes to the actual name of its contents
|
| 73 |
+
# for passing verifier and better readability, also propagate metadata
|
| 74 |
+
for get_item_node in call_func_node.users.keys():
|
| 75 |
+
idx: int = get_item_node.args[1]
|
| 76 |
+
output_node = output_args[idx]
|
| 77 |
+
get_item_node._rename(output_node.name)
|
| 78 |
+
get_item_node.meta = output_node.meta
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
+
elif isinstance(output_args, torch.fx.Node):
|
| 82 |
+
call_func_node = graph.create_node(
|
| 83 |
+
"call_function",
|
| 84 |
+
wrap_with_set_grad_enabled,
|
| 85 |
+
(enable_grad_val, get_attr_node, *node.args),
|
| 86 |
+
{},
|
| 87 |
+
output_args.name,
|
| 88 |
+
)
|
| 89 |
+
call_func_node.meta = output_args.meta
|
| 90 |
+
node_replace_(node, call_func_node, delete_old=True)
|
| 91 |
+
else:
|
| 92 |
+
raise NotImplementedError(
|
| 93 |
+
f"repalce_set_grad_with_hop_pass doesnt' support output type {type(output_args)}"
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
raise NotImplementedError(
|
| 97 |
+
"Cannot replace a call_module with a hop if it has no output. This module will gets DCEed."
|
| 98 |
+
)
|
| 99 |
+
sub_graph.erase_node(set_grad_node)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _remove_set_grad_and_inline(node: torch.fx.Node):
|
| 103 |
+
assert node.op == "call_module"
|
| 104 |
+
graph: torch.fx.Graph = node.graph
|
| 105 |
+
gm: torch.fx.GraphModule = graph.owning_module
|
| 106 |
+
assert isinstance(node.target, str)
|
| 107 |
+
sub_gm = getattr(gm, node.target)
|
| 108 |
+
sub_graph = sub_gm.graph
|
| 109 |
+
nodes_map(
|
| 110 |
+
sub_graph.nodes,
|
| 111 |
+
lambda n: sub_graph.erase_node(n) if _is_set_grad_enabled_node(n) else n,
|
| 112 |
+
)
|
| 113 |
+
node_inline_(node)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule):
|
| 117 |
+
# If there is no set_grad_enabled node, return the original graph module
|
| 118 |
+
need_replacing = False
|
| 119 |
+
for node in gm.graph.nodes:
|
| 120 |
+
if _is_set_grad_enabled_node(node):
|
| 121 |
+
need_replacing = True
|
| 122 |
+
|
| 123 |
+
if not need_replacing:
|
| 124 |
+
return gm
|
| 125 |
+
|
| 126 |
+
new_gm = sequential_split(gm, _is_set_grad_enabled_node)
|
| 127 |
+
|
| 128 |
+
def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
|
| 129 |
+
if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True):
|
| 130 |
+
_replace_with_hop(node)
|
| 131 |
+
else:
|
| 132 |
+
_remove_set_grad_and_inline(node)
|
| 133 |
+
|
| 134 |
+
nodes_map(
|
| 135 |
+
list(new_gm.graph.nodes),
|
| 136 |
+
lambda node: _maybe_inline_or_replace_with_hop(node)
|
| 137 |
+
if node.op == "call_module"
|
| 138 |
+
else node,
|
| 139 |
+
)
|
| 140 |
+
new_gm.graph.lint()
|
| 141 |
+
return new_gm
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/utils.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import math
|
| 3 |
+
import operator
|
| 4 |
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch._subclasses.fake_tensor import FakeTensor
|
| 8 |
+
|
| 9 |
+
from torch.export import ExportedProgram
|
| 10 |
+
from torch.utils._pytree import (
|
| 11 |
+
_register_pytree_node,
|
| 12 |
+
Context,
|
| 13 |
+
FlattenFunc,
|
| 14 |
+
FromDumpableContextFn,
|
| 15 |
+
KeyPath,
|
| 16 |
+
keystr,
|
| 17 |
+
MappingKey,
|
| 18 |
+
SequenceKey,
|
| 19 |
+
ToDumpableContextFn,
|
| 20 |
+
UnflattenFunc,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _check_input_constraints_for_graph(
|
| 25 |
+
input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints
|
| 26 |
+
):
|
| 27 |
+
def get_keystr(key_path: KeyPath) -> str:
|
| 28 |
+
"""For a given index into the flat_args, return a human readable string
|
| 29 |
+
describing how to access it, e.g. "*args["foo"][0].bar"
|
| 30 |
+
"""
|
| 31 |
+
# Prefix the keypath with "*args" or "**kwargs" to make it clearer where
|
| 32 |
+
# the arguments come from. Ultimately we ought to serialize the
|
| 33 |
+
# original arg names for the best error message here.
|
| 34 |
+
args_kwargs_key_path = key_path[0]
|
| 35 |
+
assert isinstance(args_kwargs_key_path, SequenceKey)
|
| 36 |
+
if args_kwargs_key_path.idx == 0:
|
| 37 |
+
return f"*args{keystr(key_path[1:])}"
|
| 38 |
+
else:
|
| 39 |
+
kwarg_key = key_path[1]
|
| 40 |
+
assert isinstance(kwarg_key, MappingKey)
|
| 41 |
+
name = str(kwarg_key)[1:-1] # get rid of the enclosed []
|
| 42 |
+
return f"{name}{keystr(key_path[2:])}"
|
| 43 |
+
|
| 44 |
+
import sympy
|
| 45 |
+
|
| 46 |
+
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
|
| 47 |
+
_convert_range_to_int,
|
| 48 |
+
)
|
| 49 |
+
from torch.utils._sympy.solve import try_solve
|
| 50 |
+
|
| 51 |
+
if len(flat_args_with_path) != len(input_placeholders):
|
| 52 |
+
raise RuntimeError(
|
| 53 |
+
"Unexpected number of inputs "
|
| 54 |
+
f"(expected {len(input_placeholders)}, got {len(flat_args_with_path)})"
|
| 55 |
+
)
|
| 56 |
+
# NOTE: export already guarantees that the same symbol is used in metadata
|
| 57 |
+
# for all InputDims related by equality constraints, so we can just unify
|
| 58 |
+
# symbols with given input dimension values to check equality constraints.
|
| 59 |
+
unification_map: "Dict[sympy.Symbol, Any]" = {}
|
| 60 |
+
for (key_path, arg), node in zip(flat_args_with_path, input_placeholders):
|
| 61 |
+
node_val = node.meta.get("val")
|
| 62 |
+
if isinstance(node_val, FakeTensor):
|
| 63 |
+
if not isinstance(arg, torch.Tensor):
|
| 64 |
+
raise RuntimeError(
|
| 65 |
+
f"Expected input at {get_keystr(key_path)} to be a tensor, but got {type(arg)}",
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
if len(node_val.shape) != len(arg.shape):
|
| 69 |
+
raise RuntimeError(
|
| 70 |
+
f"Unexpected number of dimensions in input at {get_keystr(key_path)}.shape "
|
| 71 |
+
f"(expected {node_val.shape}, got {arg.shape})"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
for j, (arg_dim, node_dim) in enumerate(zip(arg.shape, node_val.shape)):
|
| 75 |
+
# TODO(avik): Assert the following property in the IR verifier:
|
| 76 |
+
# node_dim is either an int or a SymInt containing an int or a unary sympy.Expr
|
| 77 |
+
if (
|
| 78 |
+
isinstance(node_dim, torch.SymInt)
|
| 79 |
+
and len(node_dim.node.expr.free_symbols) == 1
|
| 80 |
+
):
|
| 81 |
+
symbol = next(iter(node_dim.node.expr.free_symbols))
|
| 82 |
+
if symbol in unification_map:
|
| 83 |
+
existing_dim = node_dim.node.expr.subs(unification_map)
|
| 84 |
+
if arg_dim != existing_dim:
|
| 85 |
+
raise RuntimeError(
|
| 86 |
+
f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to "
|
| 87 |
+
f"{existing_dim}, but got {arg_dim}",
|
| 88 |
+
)
|
| 89 |
+
else:
|
| 90 |
+
if (
|
| 91 |
+
isinstance(arg_dim, torch.SymInt)
|
| 92 |
+
and not arg_dim.node.expr.is_number
|
| 93 |
+
):
|
| 94 |
+
# This can happen when, say, arg is a fake tensor.
|
| 95 |
+
# We do not run checks on symbolic shapes of fake inputs as
|
| 96 |
+
# such checks can affect the shape env.
|
| 97 |
+
pass
|
| 98 |
+
else:
|
| 99 |
+
solution = try_solve(
|
| 100 |
+
sympy.Eq(node_dim.node.expr, arg_dim), symbol
|
| 101 |
+
)
|
| 102 |
+
if solution is None:
|
| 103 |
+
raise RuntimeError( # noqa: TRY200
|
| 104 |
+
f"Expected input {node.name}.shape[{j}] = {arg_dim} to be "
|
| 105 |
+
f"of the form {node_dim.node.expr}, where {symbol} is an integer"
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
unification_map[symbol] = int(solution[1])
|
| 109 |
+
|
| 110 |
+
if node_dim.node.expr in range_constraints:
|
| 111 |
+
min_val, max_val = _convert_range_to_int(
|
| 112 |
+
range_constraints[node_dim.node.expr]
|
| 113 |
+
)
|
| 114 |
+
# NOTE: we allow dimensions to be 0/1 at runtime
|
| 115 |
+
if min_val > 2:
|
| 116 |
+
if arg_dim < min_val:
|
| 117 |
+
raise RuntimeError(
|
| 118 |
+
f"Expected input at {get_keystr(key_path)}.shape[{j}] to be >= "
|
| 119 |
+
f"{min_val}, but got {arg_dim}",
|
| 120 |
+
)
|
| 121 |
+
if max_val < math.inf:
|
| 122 |
+
if arg_dim > max_val:
|
| 123 |
+
raise RuntimeError(
|
| 124 |
+
f"Expected input at {get_keystr(key_path)}.shape[{j}] to be <= "
|
| 125 |
+
f"{max_val}, but got {arg_dim}",
|
| 126 |
+
)
|
| 127 |
+
else:
|
| 128 |
+
if arg_dim != node_dim:
|
| 129 |
+
raise RuntimeError(
|
| 130 |
+
f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to "
|
| 131 |
+
f"{node_dim}, but got {arg_dim}",
|
| 132 |
+
)
|
| 133 |
+
elif isinstance(node_val, (int, float, str)):
|
| 134 |
+
if type(arg) != type(node_val) or arg != node_val:
|
| 135 |
+
raise RuntimeError(
|
| 136 |
+
f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}",
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def register_dataclass_as_pytree_node(
|
| 141 |
+
cls: Type[Any],
|
| 142 |
+
flatten_fn: Optional[FlattenFunc] = None,
|
| 143 |
+
unflatten_fn: Optional[UnflattenFunc] = None,
|
| 144 |
+
*,
|
| 145 |
+
serialized_type_name: Optional[str] = None,
|
| 146 |
+
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
| 147 |
+
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
| 148 |
+
return_none_fields: bool = False,
|
| 149 |
+
) -> None:
|
| 150 |
+
assert dataclasses.is_dataclass(
|
| 151 |
+
cls
|
| 152 |
+
), f"Only dataclasses can be registered with this function: {cls}"
|
| 153 |
+
|
| 154 |
+
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
|
| 155 |
+
flattened = []
|
| 156 |
+
flat_names = []
|
| 157 |
+
none_names = []
|
| 158 |
+
for f in dataclasses.fields(obj):
|
| 159 |
+
name, val = f.name, getattr(obj, f.name)
|
| 160 |
+
if val is not None or return_none_fields:
|
| 161 |
+
flattened.append(val)
|
| 162 |
+
flat_names.append(name)
|
| 163 |
+
else:
|
| 164 |
+
none_names.append(name)
|
| 165 |
+
return flattened, [flat_names, none_names]
|
| 166 |
+
|
| 167 |
+
def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any:
|
| 168 |
+
flat_names, none_names = context
|
| 169 |
+
return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
|
| 170 |
+
|
| 171 |
+
flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn
|
| 172 |
+
unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn
|
| 173 |
+
|
| 174 |
+
if (to_dumpable_context is None) ^ (from_dumpable_context is None):
|
| 175 |
+
raise ValueError(
|
| 176 |
+
f"Both to_dumpable_context and from_dumpable_context for {cls} must "
|
| 177 |
+
"be None or registered."
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
_register_pytree_node(
|
| 181 |
+
cls,
|
| 182 |
+
flatten_fn,
|
| 183 |
+
unflatten_fn,
|
| 184 |
+
serialized_type_name=serialized_type_name,
|
| 185 |
+
to_dumpable_context=to_dumpable_context,
|
| 186 |
+
from_dumpable_context=from_dumpable_context,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def is_param(program: ExportedProgram, node: torch.fx.Node) -> bool:
|
| 191 |
+
"""
|
| 192 |
+
Checks if the given node is a parameter within the exported program
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
return node.name in program.graph_signature.inputs_to_parameters
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def get_param(
|
| 199 |
+
program: ExportedProgram,
|
| 200 |
+
node: torch.fx.Node,
|
| 201 |
+
) -> Optional[torch.nn.Parameter]:
|
| 202 |
+
"""
|
| 203 |
+
Returns the parameter associated with the given node in the exported program.
|
| 204 |
+
Returns None if the node is not a parameter within the exported program
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
if is_param(program, node):
|
| 208 |
+
parameter_name = program.graph_signature.inputs_to_parameters[node.name]
|
| 209 |
+
return program.state_dict[parameter_name]
|
| 210 |
+
|
| 211 |
+
return None
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def is_buffer(program: ExportedProgram, node: torch.fx.Node) -> bool:
|
| 215 |
+
"""
|
| 216 |
+
Checks if the given node is a buffer within the exported program
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
return node.name in program.graph_signature.inputs_to_buffers
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def get_buffer(
|
| 223 |
+
program: ExportedProgram,
|
| 224 |
+
node: torch.fx.Node,
|
| 225 |
+
) -> Optional[torch.Tensor]:
|
| 226 |
+
"""
|
| 227 |
+
Returns the buffer associated with the given node in the exported program.
|
| 228 |
+
Returns None if the node is not a buffer within the exported program
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
if is_buffer(program, node):
|
| 232 |
+
buffer_name = program.graph_signature.inputs_to_buffers[node.name]
|
| 233 |
+
if buffer_name in program.graph_signature.non_persistent_buffers:
|
| 234 |
+
return program.constants[buffer_name]
|
| 235 |
+
else:
|
| 236 |
+
return program.state_dict[buffer_name]
|
| 237 |
+
|
| 238 |
+
return None
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def is_lifted_tensor_constant(
|
| 242 |
+
program: ExportedProgram,
|
| 243 |
+
node: torch.fx.Node,
|
| 244 |
+
) -> bool:
|
| 245 |
+
"""
|
| 246 |
+
Checks if the given node is a lifted tensor constant within the exported program
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
return node.name in program.graph_signature.inputs_to_lifted_tensor_constants
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def get_lifted_tensor_constant(
|
| 253 |
+
program: ExportedProgram,
|
| 254 |
+
node: torch.fx.Node,
|
| 255 |
+
) -> Optional[torch.Tensor]:
|
| 256 |
+
"""
|
| 257 |
+
Returns the lifted tensor constant associated with the given node in the exported program.
|
| 258 |
+
Returns None if the node is not a lifted tensor constant within the exported program
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
if is_lifted_tensor_constant(program, node):
|
| 262 |
+
lifted_tensor_name = program.graph_signature.inputs_to_lifted_tensor_constants[
|
| 263 |
+
node.name
|
| 264 |
+
]
|
| 265 |
+
return program.constants[lifted_tensor_name]
|
| 266 |
+
|
| 267 |
+
return None
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def sequential_split(gm: torch.fx.GraphModule, node_call_back) -> torch.fx.GraphModule:
|
| 271 |
+
"""
|
| 272 |
+
Splits the graph module into multiple submodules based on the node_call_back.
|
| 273 |
+
The node_call_back should return True if the node is a delimiter. Delimiter will be
|
| 274 |
+
the first node in the next submodule.
|
| 275 |
+
"""
|
| 276 |
+
from torch.fx.passes.split_module import split_module
|
| 277 |
+
|
| 278 |
+
split_map = {}
|
| 279 |
+
split_id = 0
|
| 280 |
+
for node in gm.graph.nodes:
|
| 281 |
+
if node_call_back(node):
|
| 282 |
+
split_id += 1
|
| 283 |
+
split_map[node] = split_id
|
| 284 |
+
|
| 285 |
+
new_gm = split_module(
|
| 286 |
+
gm,
|
| 287 |
+
gm,
|
| 288 |
+
lambda node: split_map[node],
|
| 289 |
+
keep_original_order=True,
|
| 290 |
+
keep_original_node_name=True,
|
| 291 |
+
)
|
| 292 |
+
# Keep the codegen from original graph module to preserve e.g. pytree info.
|
| 293 |
+
new_gm.graph._codegen = gm.graph._codegen
|
| 294 |
+
new_gm.recompile()
|
| 295 |
+
return new_gm
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def nodes_filter(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]:
|
| 299 |
+
"""Returns the nodes that match the node_call_back as a list."""
|
| 300 |
+
return [node for node in nodes if node_call_back(node)]
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def nodes_first(
|
| 304 |
+
nodes: List[torch.fx.Node], node_call_back=None
|
| 305 |
+
) -> Optional[torch.fx.Node]:
|
| 306 |
+
"""
|
| 307 |
+
Returns the first node that matches the node_call_back. If no node matches, returns None.
|
| 308 |
+
When node_call_back is None, returns the first node in the node list.
|
| 309 |
+
"""
|
| 310 |
+
ret = nodes_filter(nodes, node_call_back if node_call_back else lambda node: True)
|
| 311 |
+
if len(ret) > 0:
|
| 312 |
+
return ret[0]
|
| 313 |
+
return None
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def nodes_count(nodes: List[torch.fx.Node], node_call_back) -> int:
|
| 317 |
+
"""Returns the number of nodes that match the node_call_back."""
|
| 318 |
+
return len(nodes_filter(nodes, node_call_back))
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def nodes_map(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]:
|
| 322 |
+
"""
|
| 323 |
+
Sequentially visit the nodes list and invoke node_call_back on each element.
|
| 324 |
+
Returns the nodes list after the node_call_back is invoked on each element.
|
| 325 |
+
"""
|
| 326 |
+
for node in nodes:
|
| 327 |
+
node_call_back(node)
|
| 328 |
+
return nodes
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def node_replace_(
|
| 332 |
+
old_node: torch.fx.Node, new_node: torch.fx.Node, delete_old: bool = False
|
| 333 |
+
) -> None:
|
| 334 |
+
"""
|
| 335 |
+
Replace all uses of old_node with new_node.
|
| 336 |
+
"""
|
| 337 |
+
old_node.replace_all_uses_with(new_node)
|
| 338 |
+
if delete_old:
|
| 339 |
+
old_node.users.clear()
|
| 340 |
+
old_node.graph.erase_node(old_node)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def node_inline_(call_mod_node: torch.fx.Node) -> None:
|
| 344 |
+
"""
|
| 345 |
+
Inline the submodule of the given node into the parent module.
|
| 346 |
+
Note: we only support the case where submodule takes tensors inputs.
|
| 347 |
+
"""
|
| 348 |
+
assert call_mod_node.op == "call_module"
|
| 349 |
+
gm = call_mod_node.graph.owning_module
|
| 350 |
+
|
| 351 |
+
assert isinstance(call_mod_node.target, str)
|
| 352 |
+
sub_gm = getattr(gm, call_mod_node.target)
|
| 353 |
+
|
| 354 |
+
phs = (node for node in sub_gm.graph.nodes if node.op == "placeholder")
|
| 355 |
+
body = (
|
| 356 |
+
node for node in sub_gm.graph.nodes if node.op not in ("placeholder", "output")
|
| 357 |
+
)
|
| 358 |
+
output = [node for node in sub_gm.graph.nodes if node.op == "output"]
|
| 359 |
+
|
| 360 |
+
for ph, arg in zip(phs, call_mod_node.args):
|
| 361 |
+
assert isinstance(arg, torch.fx.Node)
|
| 362 |
+
node_replace_(ph, arg, delete_old=True)
|
| 363 |
+
|
| 364 |
+
with gm.graph.inserting_before(call_mod_node):
|
| 365 |
+
for node in body:
|
| 366 |
+
new_node = gm.graph.node_copy(node)
|
| 367 |
+
node_replace_(node, new_node, delete_old=True)
|
| 368 |
+
|
| 369 |
+
if len(output) > 0:
|
| 370 |
+
assert len(output) == 1 and len(output[0].args) == 1
|
| 371 |
+
new_output = output[0].args[0]
|
| 372 |
+
|
| 373 |
+
if isinstance(new_output, torch.fx.Node):
|
| 374 |
+
node_replace_(call_mod_node, new_output, delete_old=True)
|
| 375 |
+
elif isinstance(new_output, (list, tuple)):
|
| 376 |
+
# Inline the get_item calls for the output node.
|
| 377 |
+
get_item_users = nodes_filter(
|
| 378 |
+
list(call_mod_node.users.keys()),
|
| 379 |
+
lambda node: node.op == "call_function"
|
| 380 |
+
and node.target == operator.getitem,
|
| 381 |
+
)
|
| 382 |
+
# get_item_node.args[1] is the idx referring to new_output[idx]
|
| 383 |
+
nodes_map(
|
| 384 |
+
get_item_users,
|
| 385 |
+
lambda get_item_node: node_replace_(
|
| 386 |
+
get_item_node,
|
| 387 |
+
new_output[get_item_node.args[1]],
|
| 388 |
+
delete_old=True,
|
| 389 |
+
),
|
| 390 |
+
)
|
| 391 |
+
call_mod_node.graph.erase_node(call_mod_node)
|
| 392 |
+
else:
|
| 393 |
+
raise NotImplementedError(
|
| 394 |
+
f"Unsupported output type {type(new_output)}. Expect it to be a Node or a list/tuple of Nodes."
|
| 395 |
+
)
|
| 396 |
+
else:
|
| 397 |
+
call_mod_node.graph.erase_node(call_mod_node)
|
| 398 |
+
|
| 399 |
+
gm.delete_all_unused_submodules()
|
| 400 |
+
gm.recompile()
|
| 401 |
+
return gm
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/wrappers.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import contextmanager
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch._custom_ops
|
| 5 |
+
from torch._C import DispatchKey
|
| 6 |
+
from torch._higher_order_ops.strict_mode import strict_mode
|
| 7 |
+
from torch._higher_order_ops.utils import autograd_not_implemented
|
| 8 |
+
from torch._ops import HigherOrderOperator
|
| 9 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 10 |
+
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
| 11 |
+
from torch.utils import _pytree as pytree
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_export_tracepoint = HigherOrderOperator("_export_tracepoint")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@_export_tracepoint.py_impl(ProxyTorchDispatchMode)
|
| 18 |
+
def export_tracepoint_dispatch_mode(mode, *args, **kwargs):
|
| 19 |
+
if not mode.enable_tracing:
|
| 20 |
+
return _export_tracepoint(*args, **kwargs)
|
| 21 |
+
p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs))
|
| 22 |
+
proxy = mode.tracer.create_proxy(
|
| 23 |
+
"call_function", _export_tracepoint, p_args, p_kwargs
|
| 24 |
+
)
|
| 25 |
+
return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@_export_tracepoint.py_impl(FakeTensorMode)
|
| 29 |
+
def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs):
|
| 30 |
+
with mode:
|
| 31 |
+
return args
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@_export_tracepoint.py_functionalize_impl
|
| 35 |
+
def export_tracepoint_functional(ctx, *args, **kwargs):
|
| 36 |
+
unwrapped_args = ctx.unwrap_tensors(args)
|
| 37 |
+
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
| 38 |
+
|
| 39 |
+
with ctx.redispatch_to_next():
|
| 40 |
+
out = _export_tracepoint(*unwrapped_args, **unwrapped_kwargs)
|
| 41 |
+
return ctx.wrap_tensors(out)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
_export_tracepoint.py_impl(DispatchKey.Autograd)(
|
| 45 |
+
autograd_not_implemented(_export_tracepoint, deferred_error=True)
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@_export_tracepoint.py_impl(DispatchKey.CPU)
|
| 50 |
+
def export_tracepoint_cpu(*args, **kwargs):
|
| 51 |
+
return args
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _wrap_submodule(mod, path, module_call_specs):
|
| 55 |
+
assert isinstance(mod, torch.nn.Module)
|
| 56 |
+
assert path != ""
|
| 57 |
+
submodule = mod
|
| 58 |
+
for name in path.split("."):
|
| 59 |
+
if not hasattr(submodule, name):
|
| 60 |
+
raise RuntimeError(f"Couldn't find submodule at path {path}")
|
| 61 |
+
submodule = getattr(submodule, name)
|
| 62 |
+
|
| 63 |
+
def update_module_call_signatures(path, in_spec, out_spec):
|
| 64 |
+
if path in module_call_specs:
|
| 65 |
+
assert module_call_specs[path]["in_spec"] == in_spec
|
| 66 |
+
assert module_call_specs[path]["out_spec"] == out_spec
|
| 67 |
+
module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec}
|
| 68 |
+
|
| 69 |
+
def check_flattened(flat_args):
|
| 70 |
+
for a in flat_args:
|
| 71 |
+
if not (isinstance(a, (torch.Tensor, str, int, float, bool)) or a is None):
|
| 72 |
+
raise AssertionError(
|
| 73 |
+
f"Only Tensors or scalars are supported as pytree flattened inputs, got: {a}"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def pre_hook(module, args, kwargs):
|
| 77 |
+
flat_args, in_spec = pytree.tree_flatten((args, kwargs))
|
| 78 |
+
check_flattened(flat_args)
|
| 79 |
+
flat_args = _export_tracepoint(*flat_args, kind="module_call_inputs", path=path)
|
| 80 |
+
args, kwargs = pytree.tree_unflatten(flat_args, in_spec)
|
| 81 |
+
return args, kwargs
|
| 82 |
+
|
| 83 |
+
def post_hook(module, args, kwargs, res):
|
| 84 |
+
_, in_spec = pytree.tree_flatten((args, kwargs))
|
| 85 |
+
flat_res, out_spec = pytree.tree_flatten(res)
|
| 86 |
+
check_flattened(flat_res)
|
| 87 |
+
flat_res = _export_tracepoint(*flat_res, kind="module_call_outputs", path=path)
|
| 88 |
+
update_module_call_signatures(path, in_spec, out_spec)
|
| 89 |
+
return pytree.tree_unflatten(flat_res, out_spec)
|
| 90 |
+
|
| 91 |
+
pre_handle = submodule.register_forward_pre_hook(pre_hook, with_kwargs=True)
|
| 92 |
+
post_handle = submodule.register_forward_hook(post_hook, with_kwargs=True)
|
| 93 |
+
return pre_handle, post_handle
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@contextmanager
|
| 97 |
+
def _wrap_submodules(f, preserve_signature, module_call_signatures):
|
| 98 |
+
handles = []
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
for path in preserve_signature:
|
| 102 |
+
handles.extend(_wrap_submodule(f, path, module_call_signatures))
|
| 103 |
+
yield
|
| 104 |
+
finally:
|
| 105 |
+
for handle in handles:
|
| 106 |
+
handle.remove()
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _mark_strict_experimental(cls):
|
| 110 |
+
def call(self, *args):
|
| 111 |
+
return strict_mode(self, args)
|
| 112 |
+
|
| 113 |
+
cls.__call__ = call
|
| 114 |
+
return cls
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (1.18 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/device_context.cpython-311.pyc
ADDED
|
Binary file (1.72 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-311.pyc
ADDED
|
Binary file (1.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (358 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-311.pyc
ADDED
|
Binary file (54.8 kB). View file
|
|
|