diff --git a/.gitattributes b/.gitattributes index 1636cdfcaf1b223a53daf8e2a83182392142f306..0b69d72b4b0e045b50df7dfef2b0165c49aeafd9 100644 --- a/.gitattributes +++ b/.gitattributes @@ -125,3 +125,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_ .venv/lib/python3.11/site-packages/opencv_python_headless.libs/libopenblas-r0-f650aae0.3.3.so filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 b/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 new file mode 100644 index 0000000000000000000000000000000000000000..d12a8ac357cbd98bfe109ec0338cc1ca55207262 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:242b9dba953ae2e4878d66032624135a9118a1616ca24588ed586d4bcc475c69 +size 108421928 diff --git a/.venv/lib/python3.11/site-packages/torch/_export/__init__.py b/.venv/lib/python3.11/site-packages/torch/_export/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91d893e05cb89c65a9bf2a07b5ed973d20f9a2c6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/__init__.py @@ -0,0 +1,317 @@ +# mypy: allow-untyped-defs +import copy +import dataclasses +import functools +import io +import json +import logging +import os +import re +import sys +import types +import warnings +import weakref +import zipfile +from collections import OrderedDict +from contextlib import contextmanager +from functools import lru_cache + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from unittest.mock import patch + +import torch +import torch.fx +import torch.utils._pytree as pytree + +from torch._dispatch.python import enable_python_dispatcher +from torch._utils_internal import log_export_usage +from torch.export._tree_utils import reorder_kwargs +from torch.export.graph_signature import ( + ArgumentSpec, + ConstantArgument, + ExportGraphSignature, + InputKind, + InputSpec, + OutputKind, + OutputSpec, + SymIntArgument, + TensorArgument, +) +from torch.fx import traceback as fx_traceback +from torch.fx._compatibility import compatibility +from torch.fx.experimental.proxy_tensor import make_fx +from torch._subclasses.fake_tensor import unset_fake_temporarily +from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo + +from .wrappers import _wrap_submodules + +log = logging.getLogger(__name__) + +@dataclasses.dataclass +class ExportDynamoConfig: + """ + Manage Export-specific configurations of Dynamo. + """ + allow_rnn: bool = True + + +# We only want to print this once to avoid flooding logs in workflows where capture_pre_autograd_graph +# is called multiple times. +@lru_cache +def capture_pre_autograd_graph_warning(): + from torch._inductor import config + + log.warning("+============================+") + log.warning("| !!! WARNING !!! |") + log.warning("+============================+") + log.warning("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.") + log.warning("Please switch to use torch.export.export_for_training instead.") + if config.is_fbcode(): + log.warning("Unless the unittest is in the blocklist, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.") # noqa: B950 + + +@compatibility(is_backward_compatible=False) +def capture_pre_autograd_graph( + f: torch.nn.Module, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, +) -> torch.nn.Module: + """ + A helper function that is intended to trace a module before any pre-autograd + decomposition is run. The produced module will be "non-functional" and + composed of aten operators. Later this API will be deleted in favor of more general + torch.export API. + + Args: + f: nn.Module to be traced + + args: example positional inputs. + + kwargs: optional example keyword inputs. + + dynamic_shapes: Should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + Returns: + An nn.Module containing the traced method. + + """ + from torch.export._trace import _extract_fake_inputs, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps + from torch._utils_internal import capture_pre_autograd_graph_using_training_ir + from torch._export.non_strict_utils import make_constraints + from torch._subclasses.functional_tensor import FunctionalTensor + from torch.export._unlift import _create_stateful_graph_module + from torch.export.dynamic_shapes import _combine_args + + capture_pre_autograd_graph_warning() + + if sys.platform == "win32": + raise RuntimeError("capture_pre_autograd_graph not yet supported on Windows") + + assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance." + + if kwargs is None: + kwargs = {} + + if capture_pre_autograd_graph_using_training_ir(): + @lru_cache + def print_export_warning(): + log.warning("Using torch.export.export_for_training(...,strict=True)") + print_export_warning() + module = torch.export.export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module() + else: + log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"}) + + # Do not decompose dropout for exported models, because in eval mode the dropout + # op disappears from the graph, which makes it difficult to switch to train mode. + # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832. + decomp_table = { + op: op.decompose + for op in FunctionalTensor.maybe_aliasing_or_mutating_ops + if op != torch.ops.aten.dropout.default + } + with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps(): + m = torch._dynamo.export( + f, + dynamic_shapes=dynamic_shapes, + assume_static_by_default=True, + tracing_mode="symbolic", + decomposition_table=decomp_table, + pre_dispatch=True, + aten_graph=True, + _log_export_usage=False, + )( + *args, + **kwargs, + )[0] + + _, _, fake_mode = _extract_fake_inputs(m, args, kwargs) + + m.meta["inline_constraints"] = { + k: v + for k, v in fake_mode.shape_env.var_to_range.items() + if re.match(r"^[if]\d+$", str(k)) + } + + if isinstance(f, torch.nn.Module): + from torch.export._trace import _restore_state_dict + _restore_state_dict(f, m) + + flat_args, _ = pytree.tree_flatten((args, kwargs or {})) + combined_args = _combine_args(f, args, kwargs) + range_constraints = make_constraints( + fake_mode, + m, + combined_args, + dynamic_shapes, + 0, + ) + + module = _create_stateful_graph_module( + m, + range_constraints=range_constraints, + ) + + error_message = \ + """ + Calling train() or eval() is not supported for exported models. + Alternatively, you may override these methods to do custom user behavior as follows: + + def _my_train(self, mode: bool = True): + ... + + def _my_eval(self): + ... + + model.train = types.MethodType(_my_train, model) + model.eval = types.MethodType(_my_eval, model) + """ + + def _train(self, mode: bool = True): + raise NotImplementedError(error_message) + + def _eval(self, mode: bool = True): + raise NotImplementedError(error_message) + + module.train = types.MethodType(_train, module) # type: ignore[method-assign] + module.eval = types.MethodType(_eval, module) # type: ignore[method-assign] + + # Remove Proxy because they cannot be deepcopied or pickled. + if hasattr(module, "_buffers"): + torch._export.utils.remove_proxy_from_state_dict( + module._buffers, in_place=True + ) + return module + + +def aot_compile( + f: Callable, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + *, + dynamic_shapes: Optional[Dict[str, Any]] = None, + options: Optional[Dict[str, Any]] = None, + remove_runtime_assertions: bool = False, + disable_constraint_solver: bool = False, + same_signature: bool = True, +) -> str: + """ + Note: this function is not stable yet + + Traces either an nn.Module's forward function or just a callable with PyTorch + operations inside, generates executable cpp code from the program, and returns + the path to the generated shared library + + Args: + f: the `nn.Module` or callable to trace. + + args: example positional inputs. + + kwargs: optional example keyword inputs. + + dynamic_shapes: Should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + options: A dictionary of options to control inductor + + disable_constraint_solver: Whether the dim constraint solver must be disabled. + + Returns: + Path to the generated shared library + """ + from torch.export._trace import _export_to_torch_ir + from torch._inductor.decomposition import select_decomp_table + from torch._inductor import config + + if config.is_predispatch: + gm = torch.export._trace._export(f, args, kwargs, dynamic_shapes, pre_dispatch=True).module() + else: + # We want to export to Torch IR here to utilize the pre_grad passes in + # inductor, which run on Torch IR. + gm = _export_to_torch_ir( + f, + args, + kwargs, + dynamic_shapes, + disable_constraint_solver=disable_constraint_solver, + same_signature=same_signature, + # Disabling this flag, because instead we can rely on the mapping + # dynamo_flat_name_to_original_fqn which is coming from Dynamo. + restore_fqn=False, + ) + + with torch.no_grad(): + so_path = torch._inductor.aot_compile(gm, args, kwargs, options=options) # type: ignore[arg-type] + + return so_path + +def aot_load(so_path: str, device: str) -> Callable: + """ + Loads a shared library generated by aot_compile and returns a callable + + Args: + so_path: Path to the shared library + + Returns: + A callable + """ + if device == "cpu": + runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg] + elif device == "cuda" or device.startswith("cuda:"): + runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg] + else: + raise RuntimeError("Unsupported device " + device) + + def optimized(*args, **kwargs): + call_spec = runner.get_call_spec() # type: ignore[attr-defined] + in_spec = pytree.treespec_loads(call_spec[0]) + out_spec = pytree.treespec_loads(call_spec[1]) + flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] + flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)] + flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined] + return pytree.tree_unflatten(flat_outputs, out_spec) + + return optimized diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/__init__.py b/.venv/lib/python3.11/site-packages/torch/_export/db/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10a55772ab58b21573a6eba0356ddd3080164ac7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/case.py b/.venv/lib/python3.11/site-packages/torch/_export/db/case.py new file mode 100644 index 0000000000000000000000000000000000000000..b228f6c2c33773b07e5cb7e82abc1adcad6422d0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/case.py @@ -0,0 +1,174 @@ +# mypy: allow-untyped-defs +import inspect +import re +import string +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Set, Tuple +from types import ModuleType + +import torch + +_TAGS: Dict[str, Dict[str, Any]] = { + "torch": { + "cond": {}, + "dynamic-shape": {}, + "escape-hatch": {}, + "map": {}, + "dynamic-value": {}, + "operator": {}, + "mutation": {}, + }, + "python": { + "assert": {}, + "builtin": {}, + "closure": {}, + "context-manager": {}, + "control-flow": {}, + "data-structure": {}, + "standard-library": {}, + "object-model": {}, + }, +} + + +class SupportLevel(Enum): + """ + Indicates at what stage the feature + used in the example is handled in export. + """ + + SUPPORTED = 1 + NOT_SUPPORTED_YET = 0 + + +ArgsType = Tuple[Any, ...] + + +def check_inputs_type(args, kwargs): + if not isinstance(args, tuple): + raise ValueError( + f"Expecting args type to be a tuple, got: {type(args)}" + ) + if not isinstance(kwargs, dict): + raise ValueError( + f"Expecting kwargs type to be a dict, got: {type(kwargs)}" + ) + for key in kwargs: + if not isinstance(key, str): + raise ValueError( + f"Expecting kwargs keys to be a string, got: {type(key)}" + ) + +def _validate_tag(tag: str): + parts = tag.split(".") + t = _TAGS + for part in parts: + assert set(part) <= set( + string.ascii_lowercase + "-" + ), f"Tag contains invalid characters: {part}" + if part in t: + t = t[part] + else: + raise ValueError(f"Tag {tag} is not found in registered tags.") + + +@dataclass(frozen=True) +class ExportCase: + example_args: ArgsType + description: str # A description of the use case. + model: torch.nn.Module + name: str + example_kwargs: Dict[str, Any] = field(default_factory=dict) + extra_args: Optional[ArgsType] = None # For testing graph generalization. + # Tags associated with the use case. (e.g dynamic-shape, escape-hatch) + tags: Set[str] = field(default_factory=set) + support_level: SupportLevel = SupportLevel.SUPPORTED + dynamic_shapes: Optional[Dict[str, Any]] = None + + def __post_init__(self): + check_inputs_type(self.example_args, self.example_kwargs) + if self.extra_args is not None: + check_inputs_type(self.extra_args, {}) + + for tag in self.tags: + _validate_tag(tag) + + if not isinstance(self.description, str) or len(self.description) == 0: + raise ValueError(f'Invalid description: "{self.description}"') + + +_EXAMPLE_CASES: Dict[str, ExportCase] = {} +_MODULES: Set[ModuleType] = set() +_EXAMPLE_CONFLICT_CASES: Dict[str, List[ExportCase]] = {} +_EXAMPLE_REWRITE_CASES: Dict[str, List[ExportCase]] = {} + + +def register_db_case(case: ExportCase) -> None: + """ + Registers a user provided ExportCase into example bank. + """ + if case.name in _EXAMPLE_CASES: + if case.name not in _EXAMPLE_CONFLICT_CASES: + _EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]] + _EXAMPLE_CONFLICT_CASES[case.name].append(case) + return + + _EXAMPLE_CASES[case.name] = case + + +def to_snake_case(name): + name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() + + +def _make_export_case(m, name, configs): + if not isinstance(m, torch.nn.Module): + raise TypeError("Export case class should be a torch.nn.Module.") + + if "description" not in configs: + # Fallback to docstring if description is missing. + assert ( + m.__doc__ is not None + ), f"Could not find description or docstring for export case: {m}" + configs = {**configs, "description": m.__doc__} + return ExportCase(**{**configs, "model": m, "name": name}) + + +def export_case(**kwargs): + """ + Decorator for registering a user provided case into example bank. + """ + + def wrapper(m): + configs = kwargs + module = inspect.getmodule(m) + if module in _MODULES: + raise RuntimeError("export_case should only be used once per example file.") + + assert module is not None + _MODULES.add(module) + module_name = module.__name__.split(".")[-1] + case = _make_export_case(m, module_name, configs) + register_db_case(case) + return case + + return wrapper + + +def export_rewrite_case(**kwargs): + def wrapper(m): + configs = kwargs + + parent = configs.pop("parent") + assert isinstance(parent, ExportCase) + key = parent.name + if key not in _EXAMPLE_REWRITE_CASES: + _EXAMPLE_REWRITE_CASES[key] = [] + + configs["example_args"] = parent.example_args + case = _make_export_case(m, to_snake_case(m.__name__), configs) + _EXAMPLE_REWRITE_CASES[key].append(case) + return case + + return wrapper diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__init__.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..834dbce32f10bfb339fd2182a2455b43914441c9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__init__.py @@ -0,0 +1,61 @@ +# mypy: allow-untyped-defs +import dataclasses +import glob +import inspect +from os.path import basename, dirname, isfile, join + +import torch +from torch._export.db.case import ( + _EXAMPLE_CASES, + _EXAMPLE_CONFLICT_CASES, + _EXAMPLE_REWRITE_CASES, + SupportLevel, + export_case, + ExportCase, +) + + +def _collect_examples(): + case_names = glob.glob(join(dirname(__file__), "*.py")) + case_names = [ + basename(f)[:-3] for f in case_names if isfile(f) and not f.endswith("__init__.py") + ] + + case_fields = {f.name for f in dataclasses.fields(ExportCase)} + for case_name in case_names: + case = __import__(case_name, globals(), locals(), [], 1) + variables = [name for name in dir(case) if name in case_fields] + export_case(**{v: getattr(case, v) for v in variables})(case.model) + +_collect_examples() + +def all_examples(): + return _EXAMPLE_CASES + + +if len(_EXAMPLE_CONFLICT_CASES) > 0: + + def get_name(case): + model = case.model + if isinstance(model, torch.nn.Module): + model = type(model) + return model.__name__ + + msg = "Error on conflict export case name.\n" + for case_name, cases in _EXAMPLE_CONFLICT_CASES.items(): + msg += f"Case name {case_name} is associated with multiple cases:\n " + msg += f"[{','.join(map(get_name, cases))}]\n" + + raise RuntimeError(msg) + + +def filter_examples_by_support_level(support_level: SupportLevel): + return { + key: val + for key, val in all_examples().items() + if val.support_level == support_level + } + + +def get_rewrite_cases(case): + return _EXAMPLE_REWRITE_CASES.get(case.name, []) diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8468f372bbff90fa5c835fc0b1038f57d2295f37 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2dbca6f2310a30a6c1db24c690b45c4e215336dc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2263dc7dda8d58403f24f68f765b92d0ef7155bd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2519e359ff4c762e0a5b8135404e2921c699b92e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5c57c2f9d9825fd403a0da3f2d62da96c585513 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd14165f97f84d8bc9cb04316e7dbdad062fe9bf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40da55f47aa02f4d92606776a2f3fe2f14e9a53d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42661a0c69b90ede9760fa6f8a1f17e94fd3361c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..928a3971d47a415a44deb5e359c01a75f47481c9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3ba21a43b0eb54ba7fededf02aa6aacc5609bd4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be29be826ef8c95517b3e5f71b0e3804b44a31a1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5dacff2cc821403fa57a356c574bd92e1037f53 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f46eff133efe7c47bd017c863b3a6f7dc7033a45 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f94d041952bb7d391ea5d737c3688a93c804c05 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ed576538bfba1b3148f685c3e4fa9ea43ca6374 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..554d51018d2cb74bd655708e2cd8e567f55db8df Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e1a38e708e49494c7acf26b7597cd89aa8704ce Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a56994ad273c5eb232a35d11e2ff23a76385a35 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26a8ab16cd7c8765e3320f183d6bda979979ca77 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c93092757fd6500b1fe393396476e84e6d91ba5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..189ca3febf172dda5a68670bac0b3fc4309fef80 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03f09bf55190cb877d9de4c84e6b4cebe0ed5f04 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..311a83f5fafc7868320b6588d461ecc74faf0d7c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71c2e6c654f5202782effbfc64cb8b87c0075854 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/unsupported_operator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/unsupported_operator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e71115880f7f056984c762c3953ca9d0cf58947 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/unsupported_operator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb0ebfec8af5ce8aab0a5057fbd9f991879e24fc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/assume_constant_result.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/assume_constant_result.py new file mode 100644 index 0000000000000000000000000000000000000000..931ce7f7a50fc5a175101ac57c424c88cf31b54c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/assume_constant_result.py @@ -0,0 +1,20 @@ +# mypy: allow-untyped-defs +import torch +import torch._dynamo as torchdynamo + + +class AssumeConstantResult(torch.nn.Module): + """ + Applying `assume_constant_result` decorator to burn make non-tracable code as constant. + """ + + @torchdynamo.assume_constant_result + def get_item(self, y): + return y.int().item() + + def forward(self, x, y): + return x[: self.get_item(y)] + +example_args = (torch.randn(3, 2), torch.tensor(4)) +tags = {"torch.escape-hatch"} +model = AssumeConstantResult() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/autograd_function.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/autograd_function.py new file mode 100644 index 0000000000000000000000000000000000000000..0458291e2176932b32b0fc2e44d51d5812dcead3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/autograd_function.py @@ -0,0 +1,23 @@ +# mypy: allow-untyped-defs +import torch + +class MyAutogradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x.clone() + + @staticmethod + def backward(ctx, grad_output): + return grad_output + 1 + +class AutogradFunction(torch.nn.Module): + """ + TorchDynamo does not keep track of backward() on autograd functions. We recommend to + use `allow_in_graph` to mitigate this problem. + """ + + def forward(self, x): + return MyAutogradFunction.apply(x) + +example_args = (torch.randn(3, 2),) +model = AutogradFunction() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/class_method.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/class_method.py new file mode 100644 index 0000000000000000000000000000000000000000..f701f54d4f4ea1cb5816292cd60bb4df3d03c5e8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/class_method.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import torch + +class ClassMethod(torch.nn.Module): + """ + Class methods are inlined during tracing. + """ + + @classmethod + def method(cls, x): + return x + 1 + + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(4, 2) + + def forward(self, x): + x = self.linear(x) + return self.method(x) * self.__class__.method(x) * type(self).method(x) + +example_args = (torch.randn(3, 4),) +model = ClassMethod() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_class_method.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_class_method.py new file mode 100644 index 0000000000000000000000000000000000000000..22600cc504348d1d261b0ea2b9ed2d57af76b0a3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_class_method.py @@ -0,0 +1,44 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class MySubModule(torch.nn.Module): + def foo(self, x): + return x.cos() + + def forward(self, x): + return self.foo(x) + +class CondBranchClassMethod(torch.nn.Module): + """ + The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: + - both branches must take the same args, which must also match the branch args passed to cond. + - both branches must return a single tensor + - returned tensor must have the same tensor metadata, e.g. shape and dtype + - branch function can be free function, nested function, lambda, class methods + - branch function can not have closure variables + - no inplace mutations on inputs or global variables + + + This example demonstrates using class method in cond(). + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def __init__(self) -> None: + super().__init__() + self.subm = MySubModule() + + def bar(self, x): + return x.sin() + + def forward(self, x): + return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x]) + +example_args = (torch.randn(3),) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +model = CondBranchClassMethod() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nested_function.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nested_function.py new file mode 100644 index 0000000000000000000000000000000000000000..b28ceeddc7956d136a8cf786c283344731d3e7ac --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nested_function.py @@ -0,0 +1,41 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class CondBranchNestedFunction(torch.nn.Module): + """ + The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: + - both branches must take the same args, which must also match the branch args passed to cond. + - both branches must return a single tensor + - returned tensor must have the same tensor metadata, e.g. shape and dtype + - branch function can be free function, nested function, lambda, class methods + - branch function can not have closure variables + - no inplace mutations on inputs or global variables + + This example demonstrates using nested function in cond(). + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def forward(self, x): + def true_fn(x): + def inner_true_fn(y): + return x + y + + return inner_true_fn(x) + + def false_fn(x): + def inner_false_fn(y): + return x - y + + return inner_false_fn(x) + + return cond(x.shape[0] < 10, true_fn, false_fn, [x]) + +example_args = (torch.randn(3),) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +model = CondBranchNestedFunction() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py new file mode 100644 index 0000000000000000000000000000000000000000..50d0ec87a690d063cb0e841fc057a6ae95c369fb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py @@ -0,0 +1,59 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class CondBranchNonlocalVariables(torch.nn.Module): + """ + The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: + - both branches must take the same args, which must also match the branch args passed to cond. + - both branches must return a single tensor + - returned tensor must have the same tensor metadata, e.g. shape and dtype + - branch function can be free function, nested function, lambda, class methods + - branch function can not have closure variables + - no inplace mutations on inputs or global variables + + This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions. + + The code below will not work because capturing closure variables is not supported. + ``` + my_tensor_var = x + 100 + my_primitive_var = 3.14 + + def true_fn(y): + nonlocal my_tensor_var, my_primitive_var + return y + my_tensor_var + my_primitive_var + + def false_fn(y): + nonlocal my_tensor_var, my_primitive_var + return y - my_tensor_var - my_primitive_var + + return cond(x.shape[0] > 5, true_fn, false_fn, [x]) + ``` + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def forward(self, x): + my_tensor_var = x + 100 + my_primitive_var = 3.14 + + def true_fn(x, y, z): + return x + y + z + + def false_fn(x, y, z): + return x - y - z + + return cond( + x.shape[0] > 5, + true_fn, + false_fn, + [x, my_tensor_var, torch.tensor(my_primitive_var)], + ) + +example_args = (torch.randn(6),) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +model = CondBranchNonlocalVariables() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_closed_over_variable.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_closed_over_variable.py new file mode 100644 index 0000000000000000000000000000000000000000..183180ab4fc825385170fea2bec6af184374a67e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_closed_over_variable.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class CondClosedOverVariable(torch.nn.Module): + """ + torch.cond() supports branches closed over arbitrary variables. + """ + + def forward(self, pred, x): + def true_fn(val): + return x * 2 + + def false_fn(val): + return x - 2 + + return cond(pred, true_fn, false_fn, [x + 1]) + +example_args = (torch.tensor(True), torch.randn(3, 2)) +tags = {"torch.cond", "python.closure"} +model = CondClosedOverVariable() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_operands.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_operands.py new file mode 100644 index 0000000000000000000000000000000000000000..29941d828ae63e45ad81b0f2c7428706df849cf0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_operands.py @@ -0,0 +1,36 @@ +# mypy: allow-untyped-defs +import torch + +from torch.export import Dim +from functorch.experimental.control_flow import cond + +x = torch.randn(3, 2) +y = torch.randn(2) +dim0_x = Dim("dim0_x") + +class CondOperands(torch.nn.Module): + """ + The operands passed to cond() must be: + - a list of tensors + - match arguments of `true_fn` and `false_fn` + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def forward(self, x, y): + def true_fn(x, y): + return x + y + + def false_fn(x, y): + return x - y + + return cond(x.shape[0] > 2, true_fn, false_fn, [x, y]) + +example_args = (x, y) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +extra_inputs = (torch.randn(2, 2), torch.randn(2)) +dynamic_shapes = {"x": {0: dim0_x}, "y": None} +model = CondOperands() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_predicate.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_predicate.py new file mode 100644 index 0000000000000000000000000000000000000000..68bb8850bba909a0c6546c8f12a1a3fa1bdc70d1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_predicate.py @@ -0,0 +1,25 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class CondPredicate(torch.nn.Module): + """ + The conditional statement (aka predicate) passed to cond() must be one of the following: + - torch.Tensor with a single element + - boolean expression + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def forward(self, x): + pred = x.dim() > 2 and x.shape[2] > 10 + + return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x]) + +example_args = (torch.randn(6, 4, 3),) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +model = CondPredicate() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_size_example.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_size_example.py new file mode 100644 index 0000000000000000000000000000000000000000..fbed4984ac4d14c0b175874c44e9c7009bfe8ddc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_size_example.py @@ -0,0 +1,25 @@ +# mypy: allow-untyped-defs +import torch + + +class ConstrainAsSizeExample(torch.nn.Module): + """ + If the value is not known at tracing time, you can provide hint so that we + can trace further. Please look at torch._check and torch._check_is_size APIs. + torch._check_is_size is used for values that NEED to be used for constructing + tensor. + """ + + def forward(self, x): + a = x.item() + torch._check_is_size(a) + torch._check(a <= 5) + return torch.zeros((a, 5)) + + +example_args = (torch.tensor(4),) +tags = { + "torch.dynamic-value", + "torch.escape-hatch", +} +model = ConstrainAsSizeExample() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_value_example.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_value_example.py new file mode 100644 index 0000000000000000000000000000000000000000..c8bfc3d6e36578ee50f1ca2eee80fed9aae14cfd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_value_example.py @@ -0,0 +1,28 @@ +# mypy: allow-untyped-defs +import torch + + +class ConstrainAsValueExample(torch.nn.Module): + """ + If the value is not known at tracing time, you can provide hint so that we + can trace further. Please look at torch._check and torch._check_is_size APIs. + torch._check is used for values that don't need to be used for constructing + tensor. + """ + + def forward(self, x, y): + a = x.item() + torch._check(a >= 0) + torch._check(a <= 5) + + if a < 6: + return y.sin() + return y.cos() + + +example_args = (torch.tensor(4), torch.randn(5, 5)) +tags = { + "torch.dynamic-value", + "torch.escape-hatch", +} +model = ConstrainAsValueExample() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/decorator.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..7d24cc681a6b62adf40bfd9a2e5283afb3515131 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/decorator.py @@ -0,0 +1,23 @@ +# mypy: allow-untyped-defs +import functools + +import torch + +def test_decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + 1 + + return wrapper + +class Decorator(torch.nn.Module): + """ + Decorators calls are inlined into the exported function during tracing. + """ + + @test_decorator + def forward(self, x, y): + return x + y + +example_args = (torch.randn(3, 2), torch.randn(3, 2)) +model = Decorator() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dictionary.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dictionary.py new file mode 100644 index 0000000000000000000000000000000000000000..49e688bc0ac1f09567e3b877aaca29a1d02b4121 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dictionary.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs +import torch + +class Dictionary(torch.nn.Module): + """ + Dictionary structures are inlined and flattened along tracing. + """ + + def forward(self, x, y): + elements = {} + elements["x2"] = x * x + y = y * elements["x2"] + return {"y": y} + +example_args = (torch.randn(3, 2), torch.tensor(4)) +tags = {"python.data-structure"} +model = Dictionary() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_assert.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_assert.py new file mode 100644 index 0000000000000000000000000000000000000000..cc822e5553e1ab8bd350a26966c22f1a9a1698cf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_assert.py @@ -0,0 +1,18 @@ +# mypy: allow-untyped-defs +import torch + +class DynamicShapeAssert(torch.nn.Module): + """ + A basic usage of python assertion. + """ + + def forward(self, x): + # assertion with error message + assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2" + # assertion without error message + assert x.shape[0] > 1 + return x + +example_args = (torch.randn(3, 2),) +tags = {"python.assert"} +model = DynamicShapeAssert() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..157e460274ad58ba71c886b35364ddc0cd4d886a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py @@ -0,0 +1,15 @@ +# mypy: allow-untyped-defs +import torch + +class DynamicShapeConstructor(torch.nn.Module): + """ + Tensor constructors should be captured with dynamic shape inputs rather + than being baked in with static shape. + """ + + def forward(self, x): + return torch.zeros(x.shape[0] * 2) + +example_args = (torch.randn(3, 2),) +tags = {"torch.dynamic-shape"} +model = DynamicShapeConstructor() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py new file mode 100644 index 0000000000000000000000000000000000000000..21824ef3a0f66eb25f4d8e8c1ba92f53fdd4c275 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py @@ -0,0 +1,19 @@ +# mypy: allow-untyped-defs +import torch + +class DynamicShapeIfGuard(torch.nn.Module): + """ + `if` statement with backed dynamic shape predicate will be specialized into + one particular branch and generate a guard. However, export will fail if the + the dimension is marked as dynamic shape from higher level API. + """ + + def forward(self, x): + if x.shape[0] == 3: + return x.cos() + + return x.sin() + +example_args = (torch.randn(3, 2, 2),) +tags = {"torch.dynamic-shape", "python.control-flow"} +model = DynamicShapeIfGuard() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_map.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_map.py new file mode 100644 index 0000000000000000000000000000000000000000..f8066aed556b9ee588b9744d17ba16c35d8fed6c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_map.py @@ -0,0 +1,19 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import map + +class DynamicShapeMap(torch.nn.Module): + """ + functorch map() maps a function over the first tensor dimension. + """ + + def forward(self, xs, y): + def body(x, y): + return x + y + + return map(body, xs, y) + +example_args = (torch.randn(3, 2), torch.randn(2)) +tags = {"torch.dynamic-shape", "torch.map"} +model = DynamicShapeMap() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_round.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_round.py new file mode 100644 index 0000000000000000000000000000000000000000..decbf036553cb76544a19e531e5aee98d792ae0b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_round.py @@ -0,0 +1,21 @@ +# mypy: allow-untyped-defs +import torch + +from torch._export.db.case import SupportLevel +from torch.export import Dim + +class DynamicShapeRound(torch.nn.Module): + """ + Calling round on dynamic shapes is not supported. + """ + + def forward(self, x): + return x[: round(x.shape[0] / 2)] + +x = torch.randn(3, 2) +dim0_x = Dim("dim0_x") +example_args = (x,) +tags = {"torch.dynamic-shape", "python.builtin"} +support_level = SupportLevel.NOT_SUPPORTED_YET +dynamic_shapes = {"x": {0: dim0_x}} +model = DynamicShapeRound() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py new file mode 100644 index 0000000000000000000000000000000000000000..360fe15f6f98d9d735366bfa53371d79e0b00209 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py @@ -0,0 +1,15 @@ +# mypy: allow-untyped-defs +import torch + +class DynamicShapeSlicing(torch.nn.Module): + """ + Slices with dynamic shape arguments should be captured into the graph + rather than being baked in. + """ + + def forward(self, x): + return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2] + +example_args = (torch.randn(3, 2),) +tags = {"torch.dynamic-shape"} +model = DynamicShapeSlicing() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_view.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_view.py new file mode 100644 index 0000000000000000000000000000000000000000..c45d4aeebb0282a0f56c58a587b4bfe1655f50e3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_view.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs +import torch + +class DynamicShapeView(torch.nn.Module): + """ + Dynamic shapes should be propagated to view arguments instead of being + baked into the exported graph. + """ + + def forward(self, x): + new_x_shape = x.size()[:-1] + (2, 5) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1) + +example_args = (torch.randn(10, 10),) +tags = {"torch.dynamic-shape"} +model = DynamicShapeView() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/fn_with_kwargs.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/fn_with_kwargs.py new file mode 100644 index 0000000000000000000000000000000000000000..46b2637b398c21bf9399d0a3fa2a964354beea3e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/fn_with_kwargs.py @@ -0,0 +1,30 @@ +# mypy: allow-untyped-defs +import torch + +class FnWithKwargs(torch.nn.Module): + """ + Keyword arguments are not supported at the moment. + """ + + def forward(self, pos0, tuple0, *myargs, mykw0, **mykwargs): + out = pos0 + for arg in tuple0: + out = out * arg + for arg in myargs: + out = out * arg + out = out * mykw0 + out = out * mykwargs["input0"] * mykwargs["input1"] + return out + +example_args = ( + torch.randn(4), + (torch.randn(4), torch.randn(4)), + *[torch.randn(4), torch.randn(4)] +) +example_kwargs = { + "mykw0": torch.randn(4), + "input0": torch.randn(4), + "input1": torch.randn(4), +} +tags = {"python.data-structure"} +model = FnWithKwargs() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_contains.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_contains.py new file mode 100644 index 0000000000000000000000000000000000000000..35a140f4ee2e5d6f42c3509984333db896f1c081 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_contains.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs +import torch + +class ListContains(torch.nn.Module): + """ + List containment relation can be checked on a dynamic shape or constants. + """ + + def forward(self, x): + assert x.size(-1) in [6, 2] + assert x.size(0) not in [4, 5, 6] + assert "monkey" not in ["cow", "pig"] + return x + x + +example_args = (torch.randn(3, 2),) +tags = {"torch.dynamic-shape", "python.data-structure", "python.assert"} +model = ListContains() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_unpack.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_unpack.py new file mode 100644 index 0000000000000000000000000000000000000000..3e2f8e2469a04b6b9d3358ed3602e9c1a8d6f35b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_unpack.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +from typing import List + +import torch + +class ListUnpack(torch.nn.Module): + """ + Lists are treated as static construct, therefore unpacking should be + erased after tracing. + """ + + def forward(self, args: List[torch.Tensor]): + """ + Lists are treated as static construct, therefore unpacking should be + erased after tracing. + """ + x, *y = args + return x + y[0] + +example_args = ([torch.randn(3, 2), torch.tensor(4), torch.tensor(5)],) +tags = {"python.control-flow", "python.data-structure"} +model = ListUnpack() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/model_attr_mutation.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/model_attr_mutation.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa623c7dc39efd94fecb8eb32caac3f7420f05d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/model_attr_mutation.py @@ -0,0 +1,26 @@ +# mypy: allow-untyped-defs +import torch +from torch._export.db.case import SupportLevel + + +class ModelAttrMutation(torch.nn.Module): + """ + Attribute mutation is not supported. + """ + + def __init__(self) -> None: + super().__init__() + self.attr_list = [torch.randn(3, 2), torch.randn(3, 2)] + + def recreate_list(self): + return [torch.zeros(3, 2), torch.zeros(3, 2)] + + def forward(self, x): + self.attr_list = self.recreate_list() + return x.sum() + self.attr_list[0].sum() + + +example_args = (torch.randn(3, 2),) +tags = {"python.object-model"} +support_level = SupportLevel.NOT_SUPPORTED_YET +model = ModelAttrMutation() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/nested_function.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/nested_function.py new file mode 100644 index 0000000000000000000000000000000000000000..e4076ac14dada40b4d78812666a9ec6b5e67045b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/nested_function.py @@ -0,0 +1,23 @@ +# mypy: allow-untyped-defs +import torch + +class NestedFunction(torch.nn.Module): + """ + Nested functions are traced through. Side effects on global captures + are not supported though. + """ + + def forward(self, a, b): + x = a + b + z = a - b + + def closure(y): + nonlocal x + x += 1 + return x * y + z + + return closure(x) + +example_args = (torch.randn(3, 2), torch.randn(2)) +tags = {"python.closure"} +model = NestedFunction() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/null_context_manager.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/null_context_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..80d09f68097edbe676077be183711dabe5cbc664 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/null_context_manager.py @@ -0,0 +1,21 @@ +# mypy: allow-untyped-defs +import contextlib + +import torch + +class NullContextManager(torch.nn.Module): + """ + Null context manager in Python will be traced out. + """ + + def forward(self, x): + """ + Null context manager in Python will be traced out. + """ + ctx = contextlib.nullcontext() + with ctx: + return x.sin() + x.cos() + +example_args = (torch.randn(3, 2),) +tags = {"python.context-manager"} +model = NullContextManager() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/optional_input.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/optional_input.py new file mode 100644 index 0000000000000000000000000000000000000000..9693aa476f0eb31ec912c840ed3a4426ac832de8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/optional_input.py @@ -0,0 +1,20 @@ +# mypy: allow-untyped-defs +import torch +from torch._export.db.case import SupportLevel + + +class OptionalInput(torch.nn.Module): + """ + Tracing through optional input is not supported yet + """ + + def forward(self, x, y=torch.randn(2, 3)): + if y is not None: + return x + y + return x + + +example_args = (torch.randn(2, 3),) +tags = {"python.object-model"} +support_level = SupportLevel.NOT_SUPPORTED_YET +model = OptionalInput() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/pytree_flatten.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/pytree_flatten.py new file mode 100644 index 0000000000000000000000000000000000000000..fe401b75e8b94be80247a806839be9c1c89bce25 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/pytree_flatten.py @@ -0,0 +1,16 @@ +# mypy: allow-untyped-defs +import torch + +from torch.utils import _pytree as pytree + +class PytreeFlatten(torch.nn.Module): + """ + Pytree from PyTorch can be captured by TorchDynamo. + """ + + def forward(self, x): + y, spec = pytree.tree_flatten(x) + return y[0] + 1 + +example_args = ({1: torch.randn(3, 2), 2: torch.randn(3, 2)},), +model = PytreeFlatten() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/scalar_output.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/scalar_output.py new file mode 100644 index 0000000000000000000000000000000000000000..86d3b4645330c47c3625736b695d635f4ab58c70 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/scalar_output.py @@ -0,0 +1,23 @@ +# mypy: allow-untyped-defs +import torch + +from torch.export import Dim + +x = torch.randn(3, 2) +dim1_x = Dim("dim1_x") + +class ScalarOutput(torch.nn.Module): + """ + Returning scalar values from the graph is supported, in addition to Tensor + outputs. Symbolic shapes are captured and rank is specialized. + """ + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + return x.shape[1] + 1 + +example_args = (x,) +tags = {"torch.dynamic-shape"} +dynamic_shapes = {"x": {1: dim1_x}} +model = ScalarOutput() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/specialized_attribute.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/specialized_attribute.py new file mode 100644 index 0000000000000000000000000000000000000000..f17092f9afc681b91a982a8a2479ac1dde4f455d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/specialized_attribute.py @@ -0,0 +1,26 @@ +# mypy: allow-untyped-defs +from enum import Enum + +import torch + +class Animal(Enum): + COW = "moo" + +class SpecializedAttribute(torch.nn.Module): + """ + Model attributes are specialized. + """ + + def __init__(self) -> None: + super().__init__() + self.a = "moo" + self.b = 4 + + def forward(self, x): + if self.a == Animal.COW.value: + return x * x + self.b + else: + raise ValueError("bad") + +example_args = (torch.randn(3, 2),) +model = SpecializedAttribute() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_for_loop.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_for_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..3924643bd94c299168f8f69155e846ed20dae1ee --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_for_loop.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs +import torch + +class StaticForLoop(torch.nn.Module): + """ + A for loop with constant number of iterations should be unrolled in the exported graph. + """ + + def forward(self, x): + ret = [] + for i in range(10): # constant + ret.append(i + x) + return ret + +example_args = (torch.randn(3, 2),) +tags = {"python.control-flow"} +model = StaticForLoop() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_if.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_if.py new file mode 100644 index 0000000000000000000000000000000000000000..f169380159a45489142ce5ae3523b2e4504c6721 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_if.py @@ -0,0 +1,18 @@ +# mypy: allow-untyped-defs +import torch + +class StaticIf(torch.nn.Module): + """ + `if` statement with static predicate value should be traced through with the + taken branch. + """ + + def forward(self, x): + if len(x.shape) == 3: + return x + torch.ones(1, 1, 1) + + return x + +example_args = (torch.randn(3, 2, 2),) +tags = {"python.control-flow"} +model = StaticIf() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/tensor_setattr.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/tensor_setattr.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbc263e7ff2240a3cf8618c56f152e744aa40e8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/tensor_setattr.py @@ -0,0 +1,15 @@ +# mypy: allow-untyped-defs +import torch + + +class TensorSetattr(torch.nn.Module): + """ + setattr() call onto tensors is not supported. + """ + def forward(self, x, attr): + setattr(x, attr, torch.randn(3, 2)) + return x + 4 + +example_args = (torch.randn(3, 2), "attr") +tags = {"python.builtin"} +model = TensorSetattr() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/type_reflection_method.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/type_reflection_method.py new file mode 100644 index 0000000000000000000000000000000000000000..99ad42a153c512d65aaae1bcac2377ee1e124f25 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/type_reflection_method.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import torch + +class A: + @classmethod + def func(cls, x): + return 1 + x + +class TypeReflectionMethod(torch.nn.Module): + """ + type() calls on custom objects followed by attribute accesses are not allowed + due to its overly dynamic nature. + """ + + def forward(self, x): + a = A() + return type(a).func(x) + + +example_args = (torch.randn(3, 4),) +tags = {"python.builtin"} +model = TypeReflectionMethod() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/unsupported_operator.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/unsupported_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a52d80b895b3b2c2d85b878ca4efea511e73ea --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/unsupported_operator.py @@ -0,0 +1,18 @@ +# mypy: allow-untyped-defs +import torch +from torch._export.db.case import SupportLevel + + +class TorchSymMin(torch.nn.Module): + """ + torch.sym_min operator is not supported in export. + """ + + def forward(self, x): + return x.sum() + torch.sym_min(x.size(0), 100) + + +example_args = (torch.randn(3, 2),) +tags = {"torch.operator"} +support_level = SupportLevel.NOT_SUPPORTED_YET +model = TorchSymMin() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/user_input_mutation.py b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/user_input_mutation.py new file mode 100644 index 0000000000000000000000000000000000000000..3156b3a1bf2ec6f6361395de3dacb098ecf20c3a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/user_input_mutation.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs +import torch + + +class UserInputMutation(torch.nn.Module): + """ + Directly mutate user input in forward + """ + + def forward(self, x): + x.mul_(2) + return x.cos() + + +example_args = (torch.randn(3, 2),) +tags = {"torch.mutation"} +model = UserInputMutation() diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/gen_example.py b/.venv/lib/python3.11/site-packages/torch/_export/db/gen_example.py new file mode 100644 index 0000000000000000000000000000000000000000..8e44cade322bdde858c5dd05ac116cef47202a33 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/gen_example.py @@ -0,0 +1,21 @@ +import os +import sys + +import torch._export.db.examples as examples + +TEMPLATE = '''import torch + +def {case_name}(x): + """ + """ + + return +''' + +if __name__ == "__main__": + assert len(sys.argv) == 2 + root_dir = examples.__name__.replace(".", "/") + assert os.path.exists(root_dir) + with open(os.path.join(root_dir, sys.argv[1] + ".py"), "w") as f: + print("Writing to", f.name, "...") + f.write(TEMPLATE.format(case_name=sys.argv[1])) diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/logging.py b/.venv/lib/python3.11/site-packages/torch/_export/db/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..2078113fef157e38a465c78156c3b22ff4c235c7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/db/logging.py @@ -0,0 +1,47 @@ +# mypy: allow-untyped-defs + + +def exportdb_error_message(case_name: str): + from .examples import all_examples + from torch._utils_internal import log_export_usage + + ALL_EXAMPLES = all_examples() + # Detect whether case_name is really registered in exportdb. + if case_name in ALL_EXAMPLES: + url_case_name = case_name.replace("_", "-") + return f"See {case_name} in exportdb for unsupported case. \ + https://pytorch.org/docs/main/generated/exportdb/index.html#{url_case_name}" + else: + log_export_usage( + event="export.error.casenotregistered", + message=case_name, + ) + return f"{case_name} is unsupported." + + +def get_class_if_classified_error(e): + """ + Returns a string case name if the export error e is classified. + Returns None otherwise. + """ + + from torch._dynamo.exc import TorchRuntimeError, Unsupported, UserError + + ALWAYS_CLASSIFIED = "always_classified" + DEFAULT_CLASS_SIGIL = "case_name" + + # add error types that should be classified, along with any attribute name + # whose presence acts like a sigil to further distinguish which errors of + # that type should be classified. If the attribute name is None, then the + # error type is always classified. + _ALLOW_LIST = { + Unsupported: DEFAULT_CLASS_SIGIL, + UserError: DEFAULT_CLASS_SIGIL, + TorchRuntimeError: None, + } + if type(e) in _ALLOW_LIST: + attr_name = _ALLOW_LIST[type(e)] + if attr_name is None: + return ALWAYS_CLASSIFIED + return getattr(e, attr_name, None) + return None diff --git a/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__init__.py b/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62e6f169cfc611e5e89f3fd219e37e33c039edba Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c7da42aeb7ab478e3dfdc46cfd03c96469936b3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..439b5eb622d6a72f4d6bed91adc4b78c0923d89d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/node_metadata.py b/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/node_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa9b8093c370dd565dfb7fb44e4b22474446af0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/node_metadata.py @@ -0,0 +1,32 @@ +from typing import Any, Dict, Set + + +NodeMetadataValue = Any + + +PROTECTED_KEYS: Set[str] = { + "val", + "stack_trace", + "nn_module_stack", + "debug_handle", + "tensor_meta", +} + + +class NodeMetadata: + def __init__(self, data: Dict[str, Any]) -> None: + self.data: Dict[str, Any] = data.copy() + + def __getitem__(self, key: str) -> NodeMetadataValue: + return self.data[key] + + def __setitem__(self, key: str, value: NodeMetadataValue) -> NodeMetadataValue: + if key in PROTECTED_KEYS: + raise RuntimeError(f"Could not override node key: {key}") + self.data[key] = value + + def __contains__(self, key: str) -> bool: + return key in self.data + + def copy(self) -> "NodeMetadata": + return NodeMetadata(self.data.copy()) diff --git a/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/proxy_value.py b/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/proxy_value.py new file mode 100644 index 0000000000000000000000000000000000000000..07d888b306560e8e9a29a4d49748bd28c32720eb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/proxy_value.py @@ -0,0 +1,42 @@ +# mypy: allow-untyped-defs +# pyre-strict +from typing import Union + +import torch + + +class ProxyValue: + # pyre-ignore + def __init__(self, data, proxy: Union[torch.fx.Proxy, torch.fx.Node]): + # pyre-ignore + self.data = data + self.proxy_or_node = proxy + + @property + def node(self) -> torch.fx.Node: + if isinstance(self.proxy_or_node, torch.fx.Node): + return self.proxy_or_node + assert isinstance(self.proxy_or_node, torch.fx.Proxy) + return self.proxy_or_node.node + + @property + def proxy(self) -> torch.fx.Proxy: + if not isinstance(self.proxy_or_node, torch.fx.Proxy): + raise RuntimeError( + f"ProxyValue doesn't have attached Proxy object. Node: {self.proxy_or_node.format_node()}" + ) + return self.proxy_or_node + + def to_tensor(self) -> torch.Tensor: + assert isinstance(self.data, torch.Tensor) + return self.data + + def is_tensor(self) -> bool: + return isinstance(self.data, torch.Tensor) + + # pyre-ignore + def __iter__(self): + yield from self.data + + def __bool__(self) -> bool: + return bool(self.data) diff --git a/.venv/lib/python3.11/site-packages/torch/_export/utils.py b/.venv/lib/python3.11/site-packages/torch/_export/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e085e18b68a20753efa38d6f85e6eadf78f8895a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/utils.py @@ -0,0 +1,893 @@ +# mypy: allow-untyped-defs +import ast +import dataclasses +import inspect +import math +import operator +import re +from inspect import Parameter +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING + +import torch +from torch._guards import detect_fake_mode +from torch._subclasses.fake_tensor import FakeTensor + + +if TYPE_CHECKING: + from torch._export.passes.lift_constants_pass import ConstantAttrMap + from torch.export import ExportedProgram + from torch.export.graph_signature import ExportGraphSignature + +from torch.export.graph_signature import InputKind, OutputKind +from torch.utils._pytree import ( + _register_pytree_node, + Context, + FlattenFunc, + FromDumpableContextFn, + GetAttrKey, + KeyPath, + keystr, + MappingKey, + SequenceKey, + ToDumpableContextFn, + tree_flatten_with_path, + UnflattenFunc, +) + + +placeholder_prefixes = { + InputKind.USER_INPUT: "", + InputKind.PARAMETER: "p_", + InputKind.BUFFER: "b_", + InputKind.CONSTANT_TENSOR: "c_", + InputKind.CUSTOM_OBJ: "obj_", + InputKind.TOKEN: "token", +} + + +def _collect_and_set_constant_attrs( + graph_signature, constants, mod +) -> "ConstantAttrMap": + # the exported module will store constants & non-persistent buffers such that + # retracing treats them as persistent buffers, so we inform the constants lifting pass + # and overwrite the new graph signature using the previous program. This is intended to only be used + # in run_decompositions where we still have access to original EP. + from torch._export.passes.lift_constants_pass import ConstantAttrMap + + constant_attrs = ConstantAttrMap() + non_persistent_buffers = { + spec.target + for spec in graph_signature.input_specs + if spec.kind == InputKind.BUFFER and not spec.persistent + } + for name, value in constants.items(): + if name in non_persistent_buffers: + continue + # recursive getattr + _mod = mod + *atoms, attr = name.split(".") + for atom in atoms: + _mod = getattr(_mod, atom) + # remove as buffer, reassign as constant/non-persistent buffer + _mod._buffers.pop(attr, None) + setattr(_mod, attr, value) + constant_attrs.add(value, name) + return constant_attrs + + +def _overwrite_signature_for_non_persistent_buffers( + old_sig: "ExportGraphSignature", new_sig: "ExportGraphSignature" +): + # overwrite signature for non-persistent buffers + non_persistent_buffers = { + spec.target + for spec in old_sig.input_specs + if spec.kind == InputKind.BUFFER and not spec.persistent + } + + for spec in new_sig.input_specs: + if spec.kind == InputKind.BUFFER and spec.target in non_persistent_buffers: + spec.persistent = False + return new_sig + + +def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> Dict[str, Any]: + """ + Param/buffer metadata needs to be saved before lowering to aten IR + because aten IR lifts them, as a result, automatic preservation doesn't work. + This is intended to be called on the strict mode tracing right before lowering to + aten IR OR run_decomposition pass. + """ + params_buffers_to_node_meta = {} + + def _getattr(model: torch.fx.GraphModule, attr_name: str): + *prefix, field = attr_name.split(".") + t = model + for item in prefix: + t = getattr(t, item, None) # type: ignore[assignment] + assert t is not None + + return getattr(t, field) + + for node in mod.graph.nodes: + target = node.target + meta = node.meta + if node.op == "call_module": + submodule = _getattr(mod, target) + if isinstance(submodule, torch.nn.Module): + for name, _ in submodule.named_parameters( + recurse=True, remove_duplicate=False + ): + params_buffers_to_node_meta[target + "." + name] = meta + + for name, _ in submodule.named_buffers( + recurse=True, remove_duplicate=False + ): + params_buffers_to_node_meta[target + "." + name] = meta + + if node.op == "get_attr": + submodule = _getattr(mod, target) + if not isinstance(submodule, torch.fx.GraphModule): + params_buffers_to_node_meta[target] = meta + + # If the call_function uses param as input, we also need to update params' meta + # with this call_function node's meta. + # This is basically the same flow as torch.fx.traceback.preserve_meta() + if node.op == "call_function" and not isinstance( + node.target, torch._ops.HigherOrderOperator + ): + for arg in node._input_nodes: + if arg.op == "get_attr": + for entry in torch.fx.proxy._COPY_META_FIELDS: + if entry in meta: + params_buffers_to_node_meta[arg.target][entry] = meta[entry] + + return params_buffers_to_node_meta + + +def _populate_param_buffer_metadata_to_new_gm( + params_buffers_to_node_meta: Dict[str, Any], + gm: torch.fx.GraphModule, + new_sig: "ExportGraphSignature", +) -> None: + """ + Given that we collected param'buffer metadata before, we put them back in + newly traced graph module + """ + # Don't copy over nn_module_stack, stack_trace metadata for params/buffers nodes + for metadata in params_buffers_to_node_meta.values(): + metadata.pop("nn_module_stack", None) + metadata.pop("stack_trace", None) + + for node in gm.graph.nodes: + if node.op == "placeholder": + if node.target in new_sig.inputs_to_parameters: + param_name = new_sig.inputs_to_parameters[node.target] + if param_name in params_buffers_to_node_meta: + for k, v in params_buffers_to_node_meta[param_name].items(): + node.meta[k] = v + if node.target in new_sig.inputs_to_buffers: + buffer_name = new_sig.inputs_to_buffers[node.target] + if buffer_name in params_buffers_to_node_meta: + for k, v in params_buffers_to_node_meta[buffer_name].items(): + node.meta[k] = v + + +def _get_shape_env_from_gm(gm: torch.fx.GraphModule): + vals = [ + node.meta["val"] + for node in gm.graph.nodes + if node.meta.get("val", None) is not None + ] + + fake_mode = _detect_fake_mode_from_gm(gm) + if fake_mode is not None: + return fake_mode.shape_env + for v in vals: + if isinstance(v, torch.SymInt): + return v.node.shape_env + + +def _rename_without_collisions( + name_map: Dict[str, str], + orig_name: str, + name: str, + is_placeholder: bool = False, +): + """ + Renames nodes to avoid name collisions, with suffixing. + name_map: map from original name to new name + orig_name: mapping key + name: candidate name (potentially suffixed, e.g. mul_2) + is_placeholder: if the node is a placeholder, avoid detecting suffix + """ + if name in name_map.values(): + # non-placeholder nodes may be suffixed with the count + # instead of adding another suffix, we will try to increment it + match = re.match(r"(.*)_(\d+)", name) + if match and not is_placeholder: + name, n = match.group(1), int(match.group(2)) + else: + n = 0 + while (dup_name := f"{name}_{n + 1}") in name_map.values(): + n += 1 + name_map[orig_name] = dup_name + else: + name_map[orig_name] = name + return name_map[orig_name] + + +def _check_input_constraints_for_graph( + input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints +): + def get_keystr(key_path: KeyPath) -> str: + """For a given index into the flat_args, return a human readable string + describing how to access it, e.g. "*args["foo"][0].bar" + """ + # Prefix the keypath with "*args" or "**kwargs" to make it clearer where + # the arguments come from. Ultimately we ought to serialize the + # original arg names for the best error message here. + args_kwargs_key_path = key_path[0] + assert isinstance(args_kwargs_key_path, SequenceKey) + if args_kwargs_key_path.idx == 0: + return f"*args{keystr(key_path[1:])}" + else: + kwarg_key = key_path[1] + assert isinstance(kwarg_key, MappingKey) + name = str(kwarg_key)[1:-1] # get rid of the enclosed [] + return f"{name}{keystr(key_path[2:])}" + + import sympy + + from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( + _convert_range_to_int, + ) + from torch.utils._sympy.solve import try_solve + + if len(flat_args_with_path) != len(input_placeholders): + raise RuntimeError( + "Unexpected number of inputs " + f"(expected {len(input_placeholders)}, got {len(flat_args_with_path)})" + ) + # NOTE: export already guarantees that the same symbol is used in metadata + # for all InputDims related by equality constraints, so we can just unify + # symbols with given input dimension values to check equality constraints. + unification_map: Dict[sympy.Symbol, Any] = {} + for (key_path, arg), node in zip(flat_args_with_path, input_placeholders): + node_val = node.meta.get("val") + if isinstance(node_val, FakeTensor): + if not isinstance(arg, torch.Tensor): + raise RuntimeError( + f"Expected input at {get_keystr(key_path)} to be a tensor, but got {type(arg)}", + ) + + if len(node_val.shape) != len(arg.shape): + raise RuntimeError( + f"Unexpected number of dimensions in input at {get_keystr(key_path)}.shape " + f"(expected {node_val.shape}, got {arg.shape})" + ) + + for j, (arg_dim, node_dim) in enumerate(zip(arg.shape, node_val.shape)): + # TODO(avik): Assert the following property in the IR verifier: + # node_dim is either an int or a SymInt containing an int or a unary sympy.Expr + if ( + isinstance(node_dim, torch.SymInt) + and len(node_dim.node.expr.free_symbols) == 1 + ): + symbol = next(iter(node_dim.node.expr.free_symbols)) + if symbol in unification_map: + existing_dim = node_dim.node.expr.subs(unification_map) + if arg_dim != existing_dim: + raise RuntimeError( + f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to " + f"{existing_dim}, but got {arg_dim}", + ) + else: + if ( + isinstance(arg_dim, torch.SymInt) + and not arg_dim.node.expr.is_number + ): + # This can happen when, say, arg is a fake tensor. + # We do not run checks on symbolic shapes of fake inputs as + # such checks can affect the shape env. + pass + else: + if isinstance(node_dim.node.expr, sympy.Symbol): + # Short cut for try_solve below. Also useful in cases where + # sympy.Eq(node_dim.node.expr, arg_dim) would evaluate to False + # purely because symbol is constrained to be size-like, + # e.g., when node_dim.node.expr = symbol and arg_dim = 0. + unification_map[symbol] = int(arg_dim) + else: + solution = try_solve( + sympy.Eq(node_dim.node.expr, arg_dim), symbol + ) + if solution is None: + raise RuntimeError( # noqa: B904 + f"Expected input {node.name}.shape[{j}] = {arg_dim} to be " + f"of the form {node_dim.node.expr}, where {symbol} is an integer" + ) + else: + unification_map[symbol] = int(solution[1]) + + if node_dim.node.expr in range_constraints: + min_val, max_val = _convert_range_to_int( + range_constraints[node_dim.node.expr] + ) + # NOTE: we allow dimensions to be 0/1 at runtime + if min_val > 2: + if arg_dim < min_val: + raise RuntimeError( + f"Expected input at {get_keystr(key_path)}.shape[{j}] to be >= " + f"{min_val}, but got {arg_dim}", + ) + if max_val < math.inf: + if arg_dim > max_val: + raise RuntimeError( + f"Expected input at {get_keystr(key_path)}.shape[{j}] to be <= " + f"{max_val}, but got {arg_dim}", + ) + else: + if arg_dim != node_dim: + if ( + isinstance(node_dim, torch.SymInt) + and not node_dim.node.expr.is_number + ): + # this means we deferred a guard from export analysis to runtime, let this pass + # we'll add a runtime assert checking equality to this replacement expression + continue + raise RuntimeError( + f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to " + f"{node_dim}, but got {arg_dim}", + ) + elif isinstance(node_val, (int, float, str)): + if type(arg) != type(node_val) or arg != node_val: + raise RuntimeError( + f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}", + ) + + +def register_dataclass_as_pytree_node( + cls: Type[Any], + flatten_fn: Optional[FlattenFunc] = None, + unflatten_fn: Optional[UnflattenFunc] = None, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + return_none_fields: bool = False, +) -> None: + assert dataclasses.is_dataclass( + cls + ), f"Only dataclasses can be registered with this function: {cls}" + + def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]: + flattened = [] + flat_names = [] + none_names = [] + for f in dataclasses.fields(obj): + name, val = f.name, getattr(obj, f.name) + if val is not None or return_none_fields: + flattened.append(val) + flat_names.append(name) + else: + none_names.append(name) + return flattened, [flat_names, none_names] + + def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any: + flat_names, none_names = context + return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names)) + + def default_flatten_fn_with_keys(obj: Any) -> Tuple[List[Any], Context]: + flattened, (flat_names, none_names) = flatten_fn(obj) # type: ignore[misc] + return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names + + flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn + unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn + + if (to_dumpable_context is None) ^ (from_dumpable_context is None): + raise ValueError( + f"Both to_dumpable_context and from_dumpable_context for {cls} must " + "be None or registered." + ) + + _register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + flatten_with_keys_fn=default_flatten_fn_with_keys, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + +def is_param(program: "ExportedProgram", node: torch.fx.Node) -> bool: + """ + Checks if the given node is a parameter within the exported program + """ + + return node.name in program.graph_signature.inputs_to_parameters + + +def get_param( + program: "ExportedProgram", + node: torch.fx.Node, +) -> Optional[torch.nn.Parameter]: + """ + Returns the parameter associated with the given node in the exported program. + Returns None if the node is not a parameter within the exported program + """ + + if is_param(program, node): + parameter_name = program.graph_signature.inputs_to_parameters[node.name] + return program.state_dict[parameter_name] + + return None + + +def is_buffer(program: "ExportedProgram", node: torch.fx.Node) -> bool: + """ + Checks if the given node is a buffer within the exported program + """ + + return node.name in program.graph_signature.inputs_to_buffers + + +def get_buffer( + program: "ExportedProgram", + node: torch.fx.Node, +) -> Optional[torch.Tensor]: + """ + Returns the buffer associated with the given node in the exported program. + Returns None if the node is not a buffer within the exported program + """ + + if is_buffer(program, node): + buffer_name = program.graph_signature.inputs_to_buffers[node.name] + if buffer_name in program.graph_signature.non_persistent_buffers: + return program.constants[buffer_name] + else: + return program.state_dict[buffer_name] + + return None + + +def is_lifted_tensor_constant( + program: "ExportedProgram", + node: torch.fx.Node, +) -> bool: + """ + Checks if the given node is a lifted tensor constant within the exported program + """ + + return node.name in program.graph_signature.inputs_to_lifted_tensor_constants + + +def get_lifted_tensor_constant( + program: "ExportedProgram", + node: torch.fx.Node, +) -> Optional[torch.Tensor]: + """ + Returns the lifted tensor constant associated with the given node in the exported program. + Returns None if the node is not a lifted tensor constant within the exported program + """ + + if is_lifted_tensor_constant(program, node): + lifted_tensor_name = program.graph_signature.inputs_to_lifted_tensor_constants[ + node.name + ] + return program.constants[lifted_tensor_name] + + return None + + +def sequential_split(gm: torch.fx.GraphModule, node_call_back) -> torch.fx.GraphModule: + """ + sequential_split creates a new graph module that splits the input graph module into multiple submodules + based on the node_call_back. It doesn't mutate the input graph module. The node_call_back should return + True if the node is a delimiter. Delimiter will be the first node in the next submodule. + """ + from torch.fx.passes.split_module import split_module + + split_map = {} + split_id = 0 + for node in gm.graph.nodes: + if node_call_back(node): + split_id += 1 + split_map[node] = split_id + + new_gm = split_module( + gm, + gm, + lambda node: split_map[node], + keep_original_order=True, + keep_original_node_name=True, + ) + # Keep the codegen from original graph module to preserve e.g. pytree info. + new_gm.graph._codegen = gm.graph._codegen + new_gm.recompile() + return new_gm + + +def nodes_filter(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]: + """Returns the nodes that match the node_call_back as a list.""" + return [node for node in nodes if node_call_back(node)] + + +def nodes_first( + nodes: List[torch.fx.Node], node_call_back=None +) -> Optional[torch.fx.Node]: + """ + Returns the first node that matches the node_call_back. If no node matches, returns None. + When node_call_back is None, returns the first node in the node list. + """ + ret = nodes_filter(nodes, node_call_back if node_call_back else lambda node: True) + if len(ret) > 0: + return ret[0] + return None + + +def nodes_count(nodes: List[torch.fx.Node], node_call_back) -> int: + """Returns the number of nodes that match the node_call_back.""" + return len(nodes_filter(nodes, node_call_back)) + + +def nodes_map(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]: + """ + Sequentially visit the nodes list and invoke node_call_back on each element. + Returns the nodes list after the node_call_back is invoked on each element. + """ + for node in nodes: + node_call_back(node) + return nodes + + +def node_replace_(old_node: torch.fx.Node, new_node: torch.fx.Node) -> None: + """ + Replace all uses of old_node with new_node. + """ + old_node.replace_all_uses_with(new_node) + old_node.users.clear() + old_node.graph.erase_node(old_node) + + +def node_inline_(call_mod_node: torch.fx.Node) -> None: + """ + Inline the submodule of the given node into the parent module. + Note: we only support the case where submodule takes tensors inputs. + """ + assert call_mod_node.op == "call_module" + gm = call_mod_node.graph.owning_module + + assert isinstance(call_mod_node.target, str) + sub_gm = getattr(gm, call_mod_node.target) + + phs = (node for node in sub_gm.graph.nodes if node.op == "placeholder") + body = ( + node for node in sub_gm.graph.nodes if node.op not in ("placeholder", "output") + ) + output = [node for node in sub_gm.graph.nodes if node.op == "output"] + + for ph, arg in zip(phs, call_mod_node.args): + assert isinstance(arg, torch.fx.Node) + node_replace_(ph, arg) + + with gm.graph.inserting_before(call_mod_node): + for node in body: + new_node = gm.graph.node_copy(node) + node_replace_(node, new_node) + + if len(output) > 0: + assert len(output) == 1 and len(output[0].args) == 1 + new_output = output[0].args[0] + + if isinstance(new_output, torch.fx.Node): + # Clear the users of the output node and set + # the users to be the users of original call_module node. + new_output.users.clear() + node_replace_(call_mod_node, new_output) + elif isinstance(new_output, (list, tuple)): + # Pop subgraph output node from users. + for node in new_output: + node.users.pop(output[0]) + + # Inline the get_item calls for the output node. + get_item_users = nodes_filter( + list(call_mod_node.users.keys()), + lambda node: node.op == "call_function" + and node.target == operator.getitem, + ) + # get_item_node.args[1] is the idx referring to new_output[idx] + nodes_map( + get_item_users, + lambda get_item_node: node_replace_( + get_item_node, + new_output[get_item_node.args[1]], + ), + ) + call_mod_node.graph.erase_node(call_mod_node) + else: + raise NotImplementedError( + f"Unsupported output type {type(new_output)}. Expect it to be a Node or a list/tuple of Nodes." + ) + else: + call_mod_node.graph.erase_node(call_mod_node) + + gm.delete_all_unused_submodules() + gm.recompile() + return gm + + +def _get_torch_jit_trace_forward_signature(mod: torch.nn.Module): + """ + Get source code and parse argument names using AST. The function returns + a signature of the forward() function. + + # TODO: Directly provide inspect.signature compatible TS-d module. + """ + ast_mod = ast.parse(mod.code) + ast_func_def: ast.FunctionDef = ast_mod.body[0] # type: ignore[assignment] + + # FIXME(jiashenc): TorchScript should only allow positional or keywords arguments. + arg_type_map = {"args": Parameter.POSITIONAL_OR_KEYWORD} + + # Traverse all argument types in AST tree and create associated parameters. + param_list = [] + for arg_type, param_type in arg_type_map.items(): + arg_name_list = [a.arg for a in getattr(ast_func_def.args, arg_type)] + for arg_name in arg_name_list: + if arg_name == "self": + continue # Skip self argument. + param_list.append(inspect.Parameter(arg_name, param_type)) + + return inspect.Signature(parameters=param_list) + + +def _bind_signature_to_inputs(mod, fake_args, fake_kwargs): + if isinstance(mod, (torch.jit.ScriptModule, torch.jit.TracedModule)): + sig = _get_torch_jit_trace_forward_signature(mod) + + # Sanity check for placeholder names coming from TorchScript. + assert len(sig.parameters) == len(fake_args) + len(fake_kwargs), ( + "Arguments other than POSITIONAL_OR_KEYWORD kinds in forward() " + "are not supported in _get_torch_jit_trace_forward_signature" + ) + else: + sig = inspect.signature(mod.forward) + + return sig.bind(*fake_args, **fake_kwargs).arguments + + +def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: + """ + Propagate placeholder names from the top-level graph into HigherOrderOp subgraphs, + and handle collisions with non-placeholders by count suffixing. + Different HOO subgraph types have different input schemas, so we first enumerate them + and gather the top-level named placeholder nodes. + """ + # gather all HOO subgraphs and their top-level named placeholder nodes + subgraph_ph_tuples: List[Tuple[torch.fx.GraphModule, List[torch.fx.Node]]] = [] + for node in gm.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.HigherOrderOperator + ): + # HOO subgraphs have varying input schemas, so we enumerate them there + if node.target._name == "cond": + _, true_graph, false_graph, cond_args = node._args + subgraph_ph_tuples.append((getattr(gm, true_graph.target), cond_args)) + subgraph_ph_tuples.append((getattr(gm, false_graph.target), cond_args)) + elif node.target._name == "wrap_with_set_grad_enabled": + subgraph, phs = node._args[1], node._args[2:] + subgraph_ph_tuples.append((getattr(gm, subgraph.target), phs)) + elif node.target._name == "map_impl": + body_graph, array, args = node._args + subgraph_ph_tuples.append( + (getattr(gm, body_graph.target), array + args) + ) + + # propagate names + for subgraph, hoo_phs in subgraph_ph_tuples: + name_map: Dict[str, str] = {} + for i, node in enumerate(subgraph.graph.nodes): + if i < len(hoo_phs): # placeholder, retain name + name_map[node.name] = hoo_phs[i].name + node.name = node.target = hoo_phs[i].name + else: # non-placeholder, check for collisions + node.name = _rename_without_collisions(name_map, node.name, node.name) + + # recurse and recompile + _name_hoo_subgraph_placeholders(subgraph) + subgraph.recompile() + + +def placeholder_naming_pass( + gm: torch.fx.GraphModule, + export_graph_signature: "ExportGraphSignature", + mod: torch.nn.Module, + fake_args, + fake_kwargs, + fake_params_buffers, + constants: Dict[str, Any], +) -> None: + """ + This pass is run at the end of _export_non_strict() to assign better placeholder node names: + - User inputs: + These follow the signature of mod.forward(), e.g. forward(x, y) produces nodes x, y. + For nested inputs from dictionaries, lists, tuples, or dataclasses, + the names are a concatenation of the path to the tensor. + e.g. x = { + 'a': torch.randn(), + 'b': [torch.randn(), torch.randn()] + } + produces nodes x_a, x_b_0, x_b_1. + - Parameters/buffers/constants/custom objects: + These follow the FQN of the object, prefixed by "p", "b", "c", "obj" respectively. + e.g. self.bar.l0.weight produces "p_bar_l0_weight". + - Effect tokens: + These are named token, token_1, ... + """ + + def _strip_name(x): + if x.startswith("L__self___"): + x = x[len("L__self___") :] + elif x.startswith("self_"): + x = x[len("self_") :] + x = re.sub(r"[^a-zA-Z0-9]", "_", x) + return x + + def _extract_pytree_key(x): + if isinstance(x, MappingKey): + x = re.sub(r"[^a-zA-Z0-9]", "_", str(x.key)) + return x + elif isinstance(x, SequenceKey): + return str(x.idx) + elif isinstance(x, GetAttrKey): + return x.name + else: + raise RuntimeError(f"Pytree key of type {type(x)} not handled for {x}") + + name_map: Dict[str, str] = {} + + # map user input names with mod.forward() signature + combined_args = _bind_signature_to_inputs(mod, fake_args, fake_kwargs) + + flat_args_with_path, _ = tree_flatten_with_path(combined_args) + user_input_names = [ + spec.arg.name + for spec in export_graph_signature.input_specs + if spec.kind == InputKind.USER_INPUT + ] + + # use pytree path to name nested user inputs + for (arg_path, arg), user_input_name in zip(flat_args_with_path, user_input_names): + if user_input_name: + _rename_without_collisions( + name_map, + user_input_name, + placeholder_prefixes[InputKind.USER_INPUT] + + "_".join(_extract_pytree_key(x).lower() for x in arg_path), + is_placeholder=True, + ) + + # use graph signature input specs to map param/buffer/constant names + # name effect tokens as token, token_1, ... (these aren't visible to user) + for spec in export_graph_signature.input_specs: + if spec.kind == InputKind.USER_INPUT: + continue + if spec.kind == InputKind.TOKEN: + base_name = "" + else: + base_name = _strip_name(spec.target).lower() + base_name = re.sub(r"[^a-zA-Z0-9]", "_", base_name) + + _rename_without_collisions( + name_map, + spec.arg.name, + placeholder_prefixes[spec.kind] + base_name, + is_placeholder=True, + ) + + # handle naming collisions with call_function/get_attr inputs. + # here, we want to prioritize user input names over call_function names + # e.g. not have forward(self, mul): lead to a placeholder node called mul_13, + # so we increment the suffix of call_function nodes as needed + for node in gm.graph.nodes: + if node.op == "placeholder": + continue + _rename_without_collisions(name_map, node.name, node.name) + + # assign new node names + for node in gm.graph.nodes: + if node.op == "placeholder": + assert node.name in name_map + node.name = node.target = name_map[node.name] + elif node.name in name_map: + node.name = name_map[node.name] + + # propagate names to higher order op subgraphs + _name_hoo_subgraph_placeholders(gm) + + # re-generate graph module code + gm.recompile() + + # modify graph signature (input specs, output specs, user input mutations) + for spec in export_graph_signature.input_specs: + assert spec.arg.name in name_map + spec.arg.name = name_map[spec.arg.name] + if ( # handle targets for custom objects + spec.kind == InputKind.CUSTOM_OBJ and spec.target in name_map + ): + spec.target = name_map[spec.target][4:] # strip obj_ prefix + + for spec in export_graph_signature.output_specs: + if spec.arg.name in name_map: + spec.arg.name = name_map[spec.arg.name] + if spec.kind == OutputKind.USER_INPUT_MUTATION and spec.target in name_map: + spec.target = name_map[spec.target] + + # rename keys in constants dict for custom objects + for name in list(constants.keys()): + constant = constants[name] + if name in name_map and not isinstance( + constant, torch.Tensor + ): # rename custom objects with generic names + new_name = name_map[name] + if ( + new_name != name + and re.match(r"arg(\d+)_1", name) + and new_name != placeholder_prefixes[InputKind.CUSTOM_OBJ] + name + ): + constants[new_name] = constant + del constants[name] + + +def remove_proxy_from_state_dict(state_dict: Dict, in_place: bool) -> Dict: + """ + If `in_place` is false, return a new copy of `state_dict` with "proxy" removed from `v.__dict__`. + `v` is the values in the dictionary. + If `in_place` is true, modify `state_dict` in place. + """ + if in_place: + for k, v in state_dict.items(): + if hasattr(v, "proxy"): + delattr(state_dict[k], "proxy") + return state_dict + else: + new_state_dict = {} + for k, v in state_dict.items(): + if hasattr(v, "proxy"): + new_state_dict[k] = v.clone().detach() + else: + new_state_dict[k] = v + return new_state_dict + + +def _detect_fake_mode_from_gm( + gm: torch.fx.GraphModule, +) -> torch._subclasses.fake_tensor.FakeTensorMode: + """ + For a given graph module, we look at the "val" of placeholder nodes to find the fake inputs. + Additionally, if gm doesn't have placeholders, we further look at the "example_value" or "val" of other nodes. + If no fake mode is found, we return None for fake_mode. + """ + + fake_inps: List[torch.Tensor] = [] + fake_vals: List[torch.Tensor] = [] + for node in gm.graph.nodes: + if node.op == "placeholder" and "val" in node.meta: + fake_val = node.meta["val"] + if fake_val is not None and isinstance(fake_val, torch.Tensor): + fake_inps.append(fake_val) + elif len(fake_inps) == 0 and ( + "example_value" in node.meta or "val" in node.meta + ): + fake_val = None + if "example_value" in node.meta: + fake_val = node.meta["example_value"] + elif "val" in node.meta: + fake_val = node.meta["val"] + if fake_val is not None and isinstance(fake_val, torch.Tensor): + fake_vals.append(fake_val) + + return detect_fake_mode(fake_inps + fake_vals)