Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/error.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__init__.py +5 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/gen_example.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/logging.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/case.py +188 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/autograd_function.py +26 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_class_method.py +46 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py +63 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/decorator.py +26 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_assert.py +22 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py +19 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_map.py +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_round.py +24 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/fn_with_kwargs.py +32 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/nested_function.py +27 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/null_context_manager.py +26 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/optional_input.py +19 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/torch_sym_min.py +17 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/user_input_mutation.py +18 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/logging.py +2 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/node_metadata.py +32 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__init__.py +1 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_sym_size_ops_pass.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-311.pyc +0 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/error.cpython-311.pyc
ADDED
|
Binary file (2.81 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-311.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (217 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/gen_example.cpython-311.pyc
ADDED
|
Binary file (1.51 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/logging.cpython-311.pyc
ADDED
|
Binary file (383 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/case.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import re
|
| 3 |
+
import string
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from enum import Enum
|
| 6 |
+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
| 7 |
+
from types import ModuleType
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
_TAGS: Dict[str, Dict[str, Any]] = {
|
| 12 |
+
"torch": {
|
| 13 |
+
"cond": {},
|
| 14 |
+
"dynamic-shape": {},
|
| 15 |
+
"escape-hatch": {},
|
| 16 |
+
"map": {},
|
| 17 |
+
"dynamic-value": {},
|
| 18 |
+
"operator": {},
|
| 19 |
+
"mutation": {},
|
| 20 |
+
},
|
| 21 |
+
"python": {
|
| 22 |
+
"assert": {},
|
| 23 |
+
"builtin": {},
|
| 24 |
+
"closure": {},
|
| 25 |
+
"context-manager": {},
|
| 26 |
+
"control-flow": {},
|
| 27 |
+
"data-structure": {},
|
| 28 |
+
"standard-library": {},
|
| 29 |
+
"object-model": {},
|
| 30 |
+
},
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class SupportLevel(Enum):
|
| 35 |
+
"""
|
| 36 |
+
Indicates at what stage the feature
|
| 37 |
+
used in the example is handled in export.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
SUPPORTED = 1
|
| 41 |
+
NOT_SUPPORTED_YET = 0
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ExportArgs:
|
| 45 |
+
__slots__ = ("args", "kwargs")
|
| 46 |
+
|
| 47 |
+
def __init__(self, *args, **kwargs):
|
| 48 |
+
self.args = args
|
| 49 |
+
self.kwargs = kwargs
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
InputsType = Union[Tuple[Any, ...], ExportArgs]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def check_inputs_type(x):
|
| 56 |
+
if not isinstance(x, (ExportArgs, tuple)):
|
| 57 |
+
raise ValueError(
|
| 58 |
+
f"Expecting inputs type to be either a tuple, or ExportArgs, got: {type(x)}"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _validate_tag(tag: str):
|
| 63 |
+
parts = tag.split(".")
|
| 64 |
+
t = _TAGS
|
| 65 |
+
for part in parts:
|
| 66 |
+
assert set(part) <= set(
|
| 67 |
+
string.ascii_lowercase + "-"
|
| 68 |
+
), f"Tag contains invalid characters: {part}"
|
| 69 |
+
if part in t:
|
| 70 |
+
t = t[part]
|
| 71 |
+
else:
|
| 72 |
+
raise ValueError(f"Tag {tag} is not found in registered tags.")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass(frozen=True)
|
| 76 |
+
class ExportCase:
|
| 77 |
+
example_inputs: InputsType
|
| 78 |
+
description: str # A description of the use case.
|
| 79 |
+
model: torch.nn.Module
|
| 80 |
+
name: str
|
| 81 |
+
extra_inputs: Optional[InputsType] = None # For testing graph generalization.
|
| 82 |
+
# Tags associated with the use case. (e.g dynamic-shape, escape-hatch)
|
| 83 |
+
tags: Set[str] = field(default_factory=set)
|
| 84 |
+
support_level: SupportLevel = SupportLevel.SUPPORTED
|
| 85 |
+
dynamic_shapes: Optional[Dict[str, Any]] = None
|
| 86 |
+
|
| 87 |
+
def __post_init__(self):
|
| 88 |
+
check_inputs_type(self.example_inputs)
|
| 89 |
+
if self.extra_inputs is not None:
|
| 90 |
+
check_inputs_type(self.extra_inputs)
|
| 91 |
+
|
| 92 |
+
for tag in self.tags:
|
| 93 |
+
_validate_tag(tag)
|
| 94 |
+
|
| 95 |
+
if not isinstance(self.description, str) or len(self.description) == 0:
|
| 96 |
+
raise ValueError(f'Invalid description: "{self.description}"')
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
_EXAMPLE_CASES: Dict[str, ExportCase] = {}
|
| 100 |
+
_MODULES: Set[ModuleType] = set()
|
| 101 |
+
_EXAMPLE_CONFLICT_CASES: Dict[str, List[ExportCase]] = {}
|
| 102 |
+
_EXAMPLE_REWRITE_CASES: Dict[str, List[ExportCase]] = {}
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def register_db_case(case: ExportCase) -> None:
|
| 106 |
+
"""
|
| 107 |
+
Registers a user provided ExportCase into example bank.
|
| 108 |
+
"""
|
| 109 |
+
if case.name in _EXAMPLE_CASES:
|
| 110 |
+
if case.name not in _EXAMPLE_CONFLICT_CASES:
|
| 111 |
+
_EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]]
|
| 112 |
+
_EXAMPLE_CONFLICT_CASES[case.name].append(case)
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
_EXAMPLE_CASES[case.name] = case
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def to_snake_case(name):
|
| 119 |
+
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
| 120 |
+
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _make_export_case(m, name, configs):
|
| 124 |
+
if not issubclass(m, torch.nn.Module):
|
| 125 |
+
raise TypeError("Export case class should be a torch.nn.Module.")
|
| 126 |
+
m = m()
|
| 127 |
+
|
| 128 |
+
if "description" not in configs:
|
| 129 |
+
# Fallback to docstring if description is missing.
|
| 130 |
+
assert (
|
| 131 |
+
m.__doc__ is not None
|
| 132 |
+
), f"Could not find description or docstring for export case: {m}"
|
| 133 |
+
configs = {**configs, "description": m.__doc__}
|
| 134 |
+
return ExportCase(**{**configs, "model": m, "name": name})
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def export_case(**kwargs):
|
| 138 |
+
"""
|
| 139 |
+
Decorator for registering a user provided case into example bank.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
def wrapper(m):
|
| 143 |
+
configs = kwargs
|
| 144 |
+
module = inspect.getmodule(m)
|
| 145 |
+
if module in _MODULES:
|
| 146 |
+
raise RuntimeError("export_case should only be used once per example file.")
|
| 147 |
+
|
| 148 |
+
assert module is not None
|
| 149 |
+
_MODULES.add(module)
|
| 150 |
+
normalized_name = to_snake_case(m.__name__)
|
| 151 |
+
module_name = module.__name__.split(".")[-1]
|
| 152 |
+
if module_name != normalized_name:
|
| 153 |
+
raise RuntimeError(
|
| 154 |
+
f'Module name "{module.__name__}" is inconsistent with exported program '
|
| 155 |
+
+ f'name "{m.__name__}". Please rename the module to "{normalized_name}".'
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
case = _make_export_case(m, module_name, configs)
|
| 159 |
+
register_db_case(case)
|
| 160 |
+
return case
|
| 161 |
+
|
| 162 |
+
return wrapper
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def export_rewrite_case(**kwargs):
|
| 166 |
+
def wrapper(m):
|
| 167 |
+
configs = kwargs
|
| 168 |
+
|
| 169 |
+
parent = configs.pop("parent")
|
| 170 |
+
assert isinstance(parent, ExportCase)
|
| 171 |
+
key = parent.name
|
| 172 |
+
if key not in _EXAMPLE_REWRITE_CASES:
|
| 173 |
+
_EXAMPLE_REWRITE_CASES[key] = []
|
| 174 |
+
|
| 175 |
+
configs["example_inputs"] = parent.example_inputs
|
| 176 |
+
case = _make_export_case(m, to_snake_case(m.__name__), configs)
|
| 177 |
+
_EXAMPLE_REWRITE_CASES[key].append(case)
|
| 178 |
+
return case
|
| 179 |
+
|
| 180 |
+
return wrapper
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def normalize_inputs(x: InputsType) -> ExportArgs:
|
| 184 |
+
if isinstance(x, tuple):
|
| 185 |
+
return ExportArgs(*x)
|
| 186 |
+
|
| 187 |
+
assert isinstance(x, ExportArgs)
|
| 188 |
+
return x
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-311.pyc
ADDED
|
Binary file (1.85 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-311.pyc
ADDED
|
Binary file (3.03 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-311.pyc
ADDED
|
Binary file (3.23 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-311.pyc
ADDED
|
Binary file (1.67 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-311.pyc
ADDED
|
Binary file (2.35 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-311.pyc
ADDED
|
Binary file (2.19 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-311.pyc
ADDED
|
Binary file (1.54 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-311.pyc
ADDED
|
Binary file (1.41 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-311.pyc
ADDED
|
Binary file (1.69 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-311.pyc
ADDED
|
Binary file (1.72 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-311.pyc
ADDED
|
Binary file (1.54 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-311.pyc
ADDED
|
Binary file (1.63 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-311.pyc
ADDED
|
Binary file (2.04 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-311.pyc
ADDED
|
Binary file (2.05 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-311.pyc
ADDED
|
Binary file (1.21 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-311.pyc
ADDED
|
Binary file (1.62 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-311.pyc
ADDED
|
Binary file (1.62 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-311.pyc
ADDED
|
Binary file (1.81 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-311.pyc
ADDED
|
Binary file (1.52 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/autograd_function.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class MyAutogradFunction(torch.autograd.Function):
|
| 7 |
+
@staticmethod
|
| 8 |
+
def forward(ctx, x):
|
| 9 |
+
return x.clone()
|
| 10 |
+
|
| 11 |
+
@staticmethod
|
| 12 |
+
def backward(ctx, grad_output):
|
| 13 |
+
return grad_output + 1
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@export_case(
|
| 17 |
+
example_inputs=(torch.randn(3, 2),),
|
| 18 |
+
)
|
| 19 |
+
class AutogradFunction(torch.nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
TorchDynamo does not keep track of backward() on autograd functions. We recommend to
|
| 22 |
+
use `allow_in_graph` to mitigate this problem.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
return MyAutogradFunction.apply(x)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_class_method.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
from functorch.experimental.control_flow import cond
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class MySubModule(torch.nn.Module):
|
| 8 |
+
def foo(self, x):
|
| 9 |
+
return x.cos()
|
| 10 |
+
|
| 11 |
+
def forward(self, x):
|
| 12 |
+
return self.foo(x)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@export_case(
|
| 16 |
+
example_inputs=(torch.ones(3),),
|
| 17 |
+
tags={
|
| 18 |
+
"torch.cond",
|
| 19 |
+
"torch.dynamic-shape",
|
| 20 |
+
},
|
| 21 |
+
)
|
| 22 |
+
class CondBranchClassMethod(torch.nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
|
| 25 |
+
- both branches must take the same args, which must also match the branch args passed to cond.
|
| 26 |
+
- both branches must return a single tensor
|
| 27 |
+
- returned tensor must have the same tensor metadata, e.g. shape and dtype
|
| 28 |
+
- branch function can be free function, nested function, lambda, class methods
|
| 29 |
+
- branch function can not have closure variables
|
| 30 |
+
- no inplace mutations on inputs or global variables
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
This example demonstrates using class method in cond().
|
| 34 |
+
|
| 35 |
+
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.subm = MySubModule()
|
| 41 |
+
|
| 42 |
+
def bar(self, x):
|
| 43 |
+
return x.sin()
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
from functorch.experimental.control_flow import cond
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@export_case(
|
| 8 |
+
example_inputs=(torch.ones(6),),
|
| 9 |
+
tags={
|
| 10 |
+
"torch.cond",
|
| 11 |
+
"torch.dynamic-shape",
|
| 12 |
+
},
|
| 13 |
+
)
|
| 14 |
+
class CondBranchNonlocalVariables(torch.nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
|
| 17 |
+
- both branches must take the same args, which must also match the branch args passed to cond.
|
| 18 |
+
- both branches must return a single tensor
|
| 19 |
+
- returned tensor must have the same tensor metadata, e.g. shape and dtype
|
| 20 |
+
- branch function can be free function, nested function, lambda, class methods
|
| 21 |
+
- branch function can not have closure variables
|
| 22 |
+
- no inplace mutations on inputs or global variables
|
| 23 |
+
|
| 24 |
+
This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions.
|
| 25 |
+
|
| 26 |
+
The code below will not work because capturing closure variables is not supported.
|
| 27 |
+
```
|
| 28 |
+
my_tensor_var = x + 100
|
| 29 |
+
my_primitive_var = 3.14
|
| 30 |
+
|
| 31 |
+
def true_fn(y):
|
| 32 |
+
nonlocal my_tensor_var, my_primitive_var
|
| 33 |
+
return y + my_tensor_var + my_primitive_var
|
| 34 |
+
|
| 35 |
+
def false_fn(y):
|
| 36 |
+
nonlocal my_tensor_var, my_primitive_var
|
| 37 |
+
return y - my_tensor_var - my_primitive_var
|
| 38 |
+
|
| 39 |
+
return cond(x.shape[0] > 5, true_fn, false_fn, [x])
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self):
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
my_tensor_var = x + 100
|
| 50 |
+
my_primitive_var = 3.14
|
| 51 |
+
|
| 52 |
+
def true_fn(x, y, z):
|
| 53 |
+
return x + y + z
|
| 54 |
+
|
| 55 |
+
def false_fn(x, y, z):
|
| 56 |
+
return x - y - z
|
| 57 |
+
|
| 58 |
+
return cond(
|
| 59 |
+
x.shape[0] > 5,
|
| 60 |
+
true_fn,
|
| 61 |
+
false_fn,
|
| 62 |
+
[x, my_tensor_var, torch.tensor(my_primitive_var)],
|
| 63 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/decorator.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from torch._export.db.case import export_case
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def test_decorator(func):
|
| 9 |
+
@functools.wraps(func)
|
| 10 |
+
def wrapper(*args, **kwargs):
|
| 11 |
+
return func(*args, **kwargs) + 1
|
| 12 |
+
|
| 13 |
+
return wrapper
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@export_case(
|
| 17 |
+
example_inputs=(torch.ones(3, 2), torch.ones(3, 2)),
|
| 18 |
+
)
|
| 19 |
+
class Decorator(torch.nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
Decorators calls are inlined into the exported function during tracing.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
@test_decorator
|
| 25 |
+
def forward(self, x, y):
|
| 26 |
+
return x + y
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_assert.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@export_case(
|
| 7 |
+
example_inputs=(torch.ones(3, 2),),
|
| 8 |
+
tags={"python.assert"},
|
| 9 |
+
)
|
| 10 |
+
class DynamicShapeAssert(torch.nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
A basic usage of python assertion.
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
# assertion with error message
|
| 19 |
+
assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2"
|
| 20 |
+
# assertion without error message
|
| 21 |
+
assert x.shape[0] > 1
|
| 22 |
+
return x
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@export_case(
|
| 7 |
+
example_inputs=(torch.ones(3, 2),),
|
| 8 |
+
tags={"torch.dynamic-shape"},
|
| 9 |
+
)
|
| 10 |
+
class DynamicShapeConstructor(torch.nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Tensor constructors should be captured with dynamic shape inputs rather
|
| 13 |
+
than being baked in with static shape.
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
return torch.ones(x.shape[0] * 2)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_map.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
from functorch.experimental.control_flow import map
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@export_case(
|
| 8 |
+
example_inputs=(torch.ones(3, 2), torch.ones(2)),
|
| 9 |
+
tags={"torch.dynamic-shape", "torch.map"},
|
| 10 |
+
)
|
| 11 |
+
class DynamicShapeMap(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
functorch map() maps a function over the first tensor dimension.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
super().__init__()
|
| 18 |
+
|
| 19 |
+
def forward(self, xs, y):
|
| 20 |
+
def body(x, y):
|
| 21 |
+
return x + y
|
| 22 |
+
|
| 23 |
+
return map(body, xs, y)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_round.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case, SupportLevel
|
| 4 |
+
from torch.export import Dim
|
| 5 |
+
|
| 6 |
+
x = torch.ones(3, 2)
|
| 7 |
+
dim0_x = Dim("dim0_x")
|
| 8 |
+
|
| 9 |
+
@export_case(
|
| 10 |
+
example_inputs=(x,),
|
| 11 |
+
tags={"torch.dynamic-shape", "python.builtin"},
|
| 12 |
+
support_level=SupportLevel.NOT_SUPPORTED_YET,
|
| 13 |
+
dynamic_shapes={"x": {0: dim0_x}},
|
| 14 |
+
)
|
| 15 |
+
class DynamicShapeRound(torch.nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
Calling round on dynamic shapes is not supported.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
return x[: round(x.shape[0] / 2)]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/fn_with_kwargs.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case, ExportArgs, SupportLevel
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@export_case(
|
| 7 |
+
example_inputs=ExportArgs(
|
| 8 |
+
torch.randn(4),
|
| 9 |
+
(torch.randn(4), torch.randn(4)),
|
| 10 |
+
*[torch.randn(4), torch.randn(4)],
|
| 11 |
+
mykw0=torch.randn(4),
|
| 12 |
+
input0=torch.randn(4), input1=torch.randn(4)
|
| 13 |
+
),
|
| 14 |
+
tags={"python.data-structure"},
|
| 15 |
+
support_level=SupportLevel.SUPPORTED,
|
| 16 |
+
)
|
| 17 |
+
class FnWithKwargs(torch.nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
Keyword arguments are not supported at the moment.
|
| 20 |
+
"""
|
| 21 |
+
def __init__(self):
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
def forward(self, pos0, tuple0, *myargs, mykw0, **mykwargs):
|
| 25 |
+
out = pos0
|
| 26 |
+
for arg in tuple0:
|
| 27 |
+
out = out * arg
|
| 28 |
+
for arg in myargs:
|
| 29 |
+
out = out * arg
|
| 30 |
+
out = out * mykw0
|
| 31 |
+
out = out * mykwargs["input0"] * mykwargs["input1"]
|
| 32 |
+
return out
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/nested_function.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@export_case(
|
| 7 |
+
example_inputs=(torch.ones(3, 2), torch.ones(2)),
|
| 8 |
+
tags={"python.closure"},
|
| 9 |
+
)
|
| 10 |
+
class NestedFunction(torch.nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Nested functions are traced through. Side effects on global captures
|
| 13 |
+
are not supported though.
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
def forward(self, a, b):
|
| 19 |
+
x = a + b
|
| 20 |
+
z = a - b
|
| 21 |
+
|
| 22 |
+
def closure(y):
|
| 23 |
+
nonlocal x
|
| 24 |
+
x += 1
|
| 25 |
+
return x * y + z
|
| 26 |
+
|
| 27 |
+
return closure(x)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/null_context_manager.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from torch._export.db.case import export_case
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@export_case(
|
| 9 |
+
example_inputs=(torch.ones(3, 2),),
|
| 10 |
+
tags={"python.context-manager"},
|
| 11 |
+
)
|
| 12 |
+
class NullContextManager(torch.nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Null context manager in Python will be traced out.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
"""
|
| 22 |
+
Null context manager in Python will be traced out.
|
| 23 |
+
"""
|
| 24 |
+
ctx = contextlib.nullcontext()
|
| 25 |
+
with ctx:
|
| 26 |
+
return x.sin() + x.cos()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/optional_input.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case, SupportLevel
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@export_case(
|
| 7 |
+
example_inputs=(torch.randn(2, 3),),
|
| 8 |
+
tags={"python.object-model"},
|
| 9 |
+
support_level=SupportLevel.NOT_SUPPORTED_YET,
|
| 10 |
+
)
|
| 11 |
+
class OptionalInput(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Tracing through optional input is not supported yet
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def forward(self, x, y=torch.ones(2, 3)):
|
| 17 |
+
if y is not None:
|
| 18 |
+
return x + y
|
| 19 |
+
return x
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/torch_sym_min.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case, SupportLevel
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@export_case(
|
| 7 |
+
example_inputs=(torch.ones(3, 2),),
|
| 8 |
+
tags={"torch.operator"},
|
| 9 |
+
support_level=SupportLevel.NOT_SUPPORTED_YET,
|
| 10 |
+
)
|
| 11 |
+
class TorchSymMin(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
torch.sym_min operator is not supported in export.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
return x.sum() + torch.sym_min(x.size(0), 100)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/user_input_mutation.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch._export.db.case import export_case, SupportLevel
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@export_case(
|
| 7 |
+
example_inputs=(torch.ones(3, 2),),
|
| 8 |
+
tags={"torch.mutation"},
|
| 9 |
+
support_level=SupportLevel.SUPPORTED,
|
| 10 |
+
)
|
| 11 |
+
class UserInputMutation(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Directly mutate user input in forward
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
x.mul_(2)
|
| 18 |
+
return x.cos()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/logging.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def exportdb_error_message(case_name: str):
|
| 2 |
+
return ""
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (225 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-311.pyc
ADDED
|
Binary file (2.12 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/node_metadata.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Set
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
NodeMetadataValue = Any
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
PROTECTED_KEYS: Set[str] = {
|
| 8 |
+
"val",
|
| 9 |
+
"stack_trace",
|
| 10 |
+
"nn_module_stack",
|
| 11 |
+
"debug_handle",
|
| 12 |
+
"tensor_meta",
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class NodeMetadata:
|
| 17 |
+
def __init__(self, data: Dict[str, Any]) -> None:
|
| 18 |
+
self.data: Dict[str, Any] = data.copy()
|
| 19 |
+
|
| 20 |
+
def __getitem__(self, key: str) -> NodeMetadataValue:
|
| 21 |
+
return self.data[key]
|
| 22 |
+
|
| 23 |
+
def __setitem__(self, key: str, value: NodeMetadataValue) -> NodeMetadataValue:
|
| 24 |
+
if key in PROTECTED_KEYS:
|
| 25 |
+
raise RuntimeError(f"Could not override node key: {key}")
|
| 26 |
+
self.data[key] = value
|
| 27 |
+
|
| 28 |
+
def __contains__(self, key: str) -> bool:
|
| 29 |
+
return key in self.data
|
| 30 |
+
|
| 31 |
+
def copy(self) -> "NodeMetadata":
|
| 32 |
+
return NodeMetadata(self.data.copy())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .replace_view_ops_with_view_copy_ops_pass import ReplaceViewOpsWithViewCopyOpsPass
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-311.pyc
ADDED
|
Binary file (4.39 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-311.pyc
ADDED
|
Binary file (5.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_sym_size_ops_pass.cpython-311.pyc
ADDED
|
Binary file (1.45 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-311.pyc
ADDED
|
Binary file (4.31 kB). View file
|
|
|