diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..758fcd7c447ebe84e3c169fc5e398acbe4908ca7 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__init__.py @@ -0,0 +1,406 @@ +import copy +import dataclasses +import functools +import io +import json +import os +import re +import sys +import types +import warnings +import weakref +import zipfile +from collections import OrderedDict +from contextlib import contextmanager + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from unittest.mock import patch + +import sympy + +import torch +import torch._dynamo +import torch.fx +import torch.utils._pytree as pytree + +from torch._decomp import core_aten_decompositions, get_decompositions +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.exc import UserError, UserErrorType +from torch._dynamo.source import ConstantSource +from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass +from torch._functorch.aot_autograd import aot_export_module, GraphSignature +from torch._functorch.eager_transforms import functionalize +from torch._guards import detect_fake_mode +from torch._inductor import config +from torch._ops import OpOverload +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch._subclasses.functional_tensor import FunctionalTensor +from torch._utils_internal import log_export_usage +from torch.export._tree_utils import reorder_kwargs +from torch.export._unlift import _create_stateful_graph_module +from torch.export.dynamic_shapes import ( + _process_constraints, + _process_dynamic_shapes, + Constraint, + dims, + dynamic_dim, +) +from torch.export.exported_program import ( + _disable_prexisiting_fake_mode, + ExportedProgram, + ModuleCallEntry, + ModuleCallSignature, +) +from torch.export.graph_signature import ( + _sig_to_specs, + ArgumentSpec, + ConstantArgument, + ExportGraphSignature, + InputKind, + InputSpec, + OutputKind, + OutputSpec, + SymIntArgument, + TensorArgument, +) +from torch.fx import traceback as fx_traceback +from torch.fx._compatibility import compatibility +from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode +from torch.fx.experimental.symbolic_shapes import ( + ConstraintViolationError, + GuardOnDataDependentSymNode, + ShapeEnv, + StrictMinMaxConstraint, +) +from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo +from torch.utils._sympy.value_ranges import ValueRangeError, ValueRanges + +from .passes.add_runtime_assertions_for_constraints_pass import ( + _AddRuntimeAssertionsForInlineConstraintsPass, +) +from .wrappers import _wrap_submodules + + +@dataclasses.dataclass +class ExportDynamoConfig: + """ + Manage Export-specific configurations of Dynamo. + """ + allow_rnn: bool = True + + +@compatibility(is_backward_compatible=False) +def capture_pre_autograd_graph( + f: torch.nn.Module, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, +) -> torch.nn.Module: + """ + A helper function that is intended to trace a module before any pre-autograd + decomposition is run. The produced module will be "non-functional" and + composed of aten operators. Later this API will be deleted in favor of more general + torch.export API. + + Args: + f: nn.Module to be traced + + args: example positional inputs. + + kwargs: optional example keyword inputs. + + dynamic_shapes: Should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + Returns: + An nn.Module containing the traced method. + + """ + from torch.export._trace import _convert_input_to_fake, DEFAULT_EXPORT_DYNAMO_CONFIG + from torch.export.dynamic_shapes import _process_dynamic_shapes + + log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"}) + + assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance." + + if kwargs is None: + kwargs = {} + + constraints = _process_dynamic_shapes(f, args, kwargs, dynamic_shapes) + + # Do not decompose dropout for exported models, because in eval mode the dropout + # op disappears from the graph, which makes it difficult to switch to train mode. + # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832. + decomp_table = { + op: op.decompose + for op in FunctionalTensor.maybe_aliasing_or_mutating_ops + if op != torch.ops.aten.dropout.default + } + with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)): + m = torch._dynamo.export( + f, + constraints=constraints, + assume_static_by_default=True, + tracing_mode="symbolic", + decomposition_table=decomp_table, + pre_dispatch=True, + aten_graph=True, + _log_export_usage=False, + )( + *args, + **kwargs, + )[0] + + _, _, _, fake_mode = _convert_input_to_fake(m, args, kwargs) + + m.meta["inline_constraints"] = { + k: v + for k, v in fake_mode.shape_env.var_to_range.items() + if re.match(r"^[if]\d+$", str(k)) + } + + if isinstance(f, torch.nn.Module): + from torch.export._trace import _restore_state_dict + _restore_state_dict(f, m) + + flat_args, _ = pytree.tree_flatten((args, kwargs or {})) + range_constraints = _process_constraints(fake_mode, m, 0, flat_args) + + module = _create_stateful_graph_module( + m, + range_constraints=range_constraints, + ) + + error_message = \ + """ + Calling train() or eval() is not supported for exported models. + Alternatively, you may override these methods to do custom user behavior as follows: + + def _my_train(self, mode: bool = True): + ... + + def _my_eval(self): + ... + + model.train = types.MethodType(_my_train, model) + model.eval = types.MethodType(_my_eval, model) + """ + + def _train(self, mode: bool = True): + raise NotImplementedError(error_message) + + def _eval(self, mode: bool = True): + raise NotImplementedError(error_message) + + module.train = types.MethodType(_train, module) # type: ignore[method-assign] + module.eval = types.MethodType(_eval, module) # type: ignore[method-assign] + return module + + +def save( + ep: ExportedProgram, + f: Union[str, os.PathLike, io.BytesIO], + *, + extra_files: Optional[Dict[str, Any]] = None, + opset_version: Optional[Dict[str, int]] = None, +) -> None: + if not isinstance(ep, ExportedProgram): + raise TypeError(f"save() expects an ExportedProgram but got {type(ep)}") + + from .serde.serialize import serialize, SerializedArtifact + from .serde.schema import SCHEMA_VERSION + artifact: SerializedArtifact = serialize(ep, opset_version) + + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + with zipfile.ZipFile(f, 'w') as zipf: + # Save every field the SerializedArtifact to a file + assert isinstance(artifact.exported_program, bytes) + zipf.writestr("serialized_exported_program.json", artifact.exported_program) + zipf.writestr("serialized_state_dict.pt", artifact.state_dict) + zipf.writestr("serialized_constants.pt", artifact.constants) + + zipf.writestr('version', ".".join(map(str, SCHEMA_VERSION))) + + # Add extra files if provided + if extra_files: + for extra_file_name, content in extra_files.items(): + encoded_content = content.encode('utf-8') + zipf.writestr(f"extra_files/{extra_file_name}", encoded_content) + + +def load( + f: Union[str, os.PathLike, io.BytesIO], + *, + extra_files: Optional[Dict[str, Any]] = None, + expected_opset_version: Optional[Dict[str, int]] = None, +) -> ExportedProgram: + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + extra_files = extra_files or {} + + with zipfile.ZipFile(f, 'r') as zipf: + # Check the version + version = zipf.read('version').decode().split('.') + from .serde.schema import SCHEMA_VERSION + + assert len(version) == len(SCHEMA_VERSION) + if version[0] != str(SCHEMA_VERSION[0]): + raise RuntimeError( + f"Serialized version {version} does not match our current " + f"schema version {SCHEMA_VERSION}." + ) + + from .serde.serialize import deserialize, SerializedArtifact + + # Load serialized_ep and serialized_state_dict from the zip file + + serialized_exported_program: Optional[bytes] = None + serialized_state_dict: Optional[bytes] = None + serialized_constants: Optional[bytes] = None + + for file_info in zipf.infolist(): + file_content = zipf.read(file_info.filename) + + if file_info.filename == "serialized_exported_program.json": + serialized_exported_program = file_content + elif file_info.filename == "serialized_state_dict.json": + warnings.warn("This version of file is deprecated") + serialized_state_dict = file_content + elif file_info.filename == "serialized_constants.json": + warnings.warn("This version of file is deprecated") + serialized_constants = file_content + elif file_info.filename == "serialized_state_dict.pt": + serialized_state_dict = file_content + elif file_info.filename == "serialized_constants.pt": + serialized_constants = file_content + elif file_info.filename.startswith("extra_files"): + filename = file_info.filename.split("/", 1)[1] + extra_files[filename] = file_content.decode('utf-8') + + assert serialized_exported_program is not None + assert serialized_state_dict is not None + assert serialized_constants is not None + artifact: SerializedArtifact = SerializedArtifact( + serialized_exported_program, + serialized_state_dict, + serialized_constants, + ) + + # Deserialize ExportedProgram + ep = deserialize(artifact, expected_opset_version) + + return ep + + +def aot_compile( + f: Callable, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + *, + dynamic_shapes: Optional[Dict[str, Any]] = None, + options: Optional[Dict[str, Any]] = None, + remove_runtime_assertions: bool = False, + disable_constraint_solver: bool = False, +) -> str: + """ + Note: this function is not stable yet + + Traces either an nn.Module's forward function or just a callable with PyTorch + operations inside, generates executable cpp code from the program, and returns + the path to the generated shared library + + Args: + f: the `nn.Module` or callable to trace. + + args: example positional inputs. + + kwargs: optional example keyword inputs. + + dynamic_shapes: Should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + options: A dictionary of options to control inductor + + disable_constraint_solver: Whether the dim constraint solver must be disabled. + + Returns: + Path to the generated shared library + """ + from torch.export._trace import _export_to_torch_ir + from torch._inductor.decomposition import select_decomp_table + + constraints = _process_dynamic_shapes(f, args, kwargs, dynamic_shapes) + + if config.is_predispatch: + gm = torch.export._trace._export(f, args, kwargs, constraints, pre_dispatch=True).module() + else: + # We want to export to Torch IR here to utilize the pre_grad passes in + # inductor, which run on Torch IR. + gm = _export_to_torch_ir( + f, + args, + kwargs, + constraints, + disable_constraint_solver=disable_constraint_solver, + # Disabling this flag, because instead we can rely on the mapping + # dynamo_flat_name_to_original_fqn which is coming from Dynamo. + restore_fqn=False, + ) + flat_example_inputs = pytree.arg_tree_leaves(*args, **(kwargs or {})) + + with torch.no_grad(): + so_path = torch._inductor.aot_compile(gm, flat_example_inputs, options) # type: ignore[arg-type] + + return so_path + +def aot_load(so_path: str, device: str) -> Callable: + """ + Loads a shared library generated by aot_compile and returns a callable + + Args: + so_path: Path to the shared library + + Returns: + A callable + """ + if device == "cpu": + runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg] + elif device == "cuda" or device.startswith("cuda:"): + runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg] + else: + raise RuntimeError("Unsupported device " + device) + + def optimized(*args, **kwargs): + call_spec = runner.get_call_spec() # type: ignore[attr-defined] + in_spec = pytree.treespec_loads(call_spec[0]) + out_spec = pytree.treespec_loads(call_spec[1]) + flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] + flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined] + return pytree.tree_unflatten(flat_outputs, out_spec) + + return optimized diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d737548c3d480d11e722ad5ae076cebe9f2523c4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__init__.py @@ -0,0 +1,52 @@ +import glob +import importlib +from os.path import basename, dirname, isfile, join + +import torch +from torch._export.db.case import ( + _EXAMPLE_CASES, + _EXAMPLE_CONFLICT_CASES, + _EXAMPLE_REWRITE_CASES, + SupportLevel, +) + + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") +] + +# Import all module in the current directory. +from . import * # noqa: F403 + + +def all_examples(): + return _EXAMPLE_CASES + + +if len(_EXAMPLE_CONFLICT_CASES) > 0: + + def get_name(case): + model = case.model + if isinstance(model, torch.nn.Module): + model = type(model) + return model.__name__ + + msg = "Error on conflict export case name.\n" + for case_name, cases in _EXAMPLE_CONFLICT_CASES.items(): + msg += f"Case name {case_name} is associated with multiple cases:\n " + msg += f"[{','.join(map(get_name, cases))}]\n" + + raise RuntimeError(msg) + + +def filter_examples_by_support_level(support_level: SupportLevel): + return { + key: val + for key, val in all_examples().items() + if val.support_level == support_level + } + + +def get_rewrite_cases(case): + return _EXAMPLE_REWRITE_CASES.get(case.name, []) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b13ff35e4e61b74e35ca992a0510257d632cb56 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9aeb9f455f448d5e856e66319d0d1461e0c920a Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72149e1957fabd77fe1daa1d510390d597346c44 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5742e6fe37f9f34cadc5f63ed1895b4f9917dc2 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfb175b58c9ab70fc5299063a9266e3a089d4885 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51d078a6e159367f00173ef4691db97f81f1001a Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02cb1b93d2567c3116070866d2e2e0c4ecb1c20b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ccac4eb1228cbce4d7d67919d91ff487639d1d9 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..385903bef06a03156d19d22d00ff52a138ba7cf5 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20706b191dff4ab41ab711c189ab00b3a7f23951 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8ebc97a0461b27a207dfe4f34e3346f94d7a88a Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0306c97799f4a5459daa339ef3062adc8802251 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38591c96a0819c5faf16d6a67f6418e57a184e63 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..967961fa65013395e838d251cbccc813baaf136b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/torch_sym_min.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/torch_sym_min.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c174b4f584eb17c3526e103b92fe8c4341dbb8a Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/torch_sym_min.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d752ca173915966fd6f7f5b9b806467bdbdda80 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b595cd8c651ee76166f4e1a2195827df62cf2cf0 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/assume_constant_result.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/assume_constant_result.py new file mode 100644 index 0000000000000000000000000000000000000000..664aab8b64da2b239daaa2d78c068a1d7397c4a4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/assume_constant_result.py @@ -0,0 +1,24 @@ +import torch +import torch._dynamo as torchdynamo + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2), torch.tensor(4)), + tags={"torch.escape-hatch"}, +) +class AssumeConstantResult(torch.nn.Module): + """ + Applying `assume_constant_result` decorator to burn make non-tracable code as constant. + """ + + def __init__(self): + super().__init__() + + @torchdynamo.assume_constant_result + def get_item(self, y): + return y.int().item() + + def forward(self, x, y): + return x[: self.get_item(y)] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/class_method.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/class_method.py new file mode 100644 index 0000000000000000000000000000000000000000..77c629559d21eb6390c00ce8143d773d16f5710f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/class_method.py @@ -0,0 +1,24 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 4),), +) +class ClassMethod(torch.nn.Module): + """ + Class methods are inlined during tracing. + """ + + @classmethod + def method(cls, x): + return x + 1 + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 2) + + def forward(self, x): + x = self.linear(x) + return self.method(x) * self.__class__.method(x) * type(self).method(x) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nested_function.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nested_function.py new file mode 100644 index 0000000000000000000000000000000000000000..bd8a1db034256fd305ae8924254070ac212e9248 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nested_function.py @@ -0,0 +1,44 @@ +import torch + +from torch._export.db.case import export_case +from functorch.experimental.control_flow import cond + + +@export_case( + example_inputs=(torch.ones(3),), + tags={ + "torch.cond", + "torch.dynamic-shape", + }, +) +class CondBranchNestedFunction(torch.nn.Module): + """ + The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: + - both branches must take the same args, which must also match the branch args passed to cond. + - both branches must return a single tensor + - returned tensor must have the same tensor metadata, e.g. shape and dtype + - branch function can be free function, nested function, lambda, class methods + - branch function can not have closure variables + - no inplace mutations on inputs or global variables + + This example demonstrates using nested function in cond(). + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + def __init__(self): + super().__init__() + + def forward(self, x): + def true_fn(x): + def inner_true_fn(y): + return x + y + + return inner_true_fn(x) + + def false_fn(x): + def inner_false_fn(y): + return x - y + + return inner_false_fn(x) + + return cond(x.shape[0] < 10, true_fn, false_fn, [x]) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_closed_over_variable.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_closed_over_variable.py new file mode 100644 index 0000000000000000000000000000000000000000..b201c5d679b8eab6e9a3a74705772acf3a9a5af8 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_closed_over_variable.py @@ -0,0 +1,23 @@ +import torch + +from torch._export.db.case import export_case +from functorch.experimental.control_flow import cond + + +@export_case( + example_inputs=(torch.tensor(True), torch.ones(3, 2)), + tags={"torch.cond", "python.closure"}, +) +class CondClosedOverVariable(torch.nn.Module): + """ + torch.cond() supports branches closed over arbitrary variables. + """ + + def forward(self, pred, x): + def true_fn(val): + return x * 2 + + def false_fn(val): + return x - 2 + + return cond(pred, true_fn, false_fn, [x + 1]) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_operands.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_operands.py new file mode 100644 index 0000000000000000000000000000000000000000..a05e584100c958a124f9cfc59c489b417f5d3214 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_operands.py @@ -0,0 +1,39 @@ +import torch + +from torch._export.db.case import export_case +from torch.export import Dim +from functorch.experimental.control_flow import cond + +x = torch.randn(3, 2) +y = torch.ones(2) +dim0_x = Dim("dim0_x") + +@export_case( + example_inputs=(x, y), + tags={ + "torch.cond", + "torch.dynamic-shape", + }, + extra_inputs=(torch.randn(2, 2), torch.ones(2)), + dynamic_shapes={"x": {0: dim0_x}, "y": None}, +) +class CondOperands(torch.nn.Module): + """ + The operands passed to cond() must be: + - a list of tensors + - match arguments of `true_fn` and `false_fn` + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def __init__(self): + super().__init__() + + def forward(self, x, y): + def true_fn(x, y): + return x + y + + def false_fn(x, y): + return x - y + + return cond(x.shape[0] > 2, true_fn, false_fn, [x, y]) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_predicate.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_predicate.py new file mode 100644 index 0000000000000000000000000000000000000000..fd02e2484c54678712593f7c9fa28344e5574375 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_predicate.py @@ -0,0 +1,29 @@ +import torch + +from torch._export.db.case import export_case +from functorch.experimental.control_flow import cond + + +@export_case( + example_inputs=(torch.ones(6, 4, 3),), + tags={ + "torch.cond", + "torch.dynamic-shape", + }, +) +class CondPredicate(torch.nn.Module): + """ + The conditional statement (aka predicate) passed to cond() must be one of the following: + - torch.Tensor with a single element + - boolean expression + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + pred = x.dim() > 2 and x.shape[2] > 10 + + return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x]) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_size_example.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_size_example.py new file mode 100644 index 0000000000000000000000000000000000000000..1af4b22dc988816c011aa2eb085f97c9850d257a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_size_example.py @@ -0,0 +1,27 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.tensor(4),), + tags={ + "torch.dynamic-value", + "torch.escape-hatch", + }, +) +class ConstrainAsSizeExample(torch.nn.Module): + """ + If the value is not known at tracing time, you can provide hint so that we + can trace further. Please look at constrain_as_value and constrain_as_size APIs + constrain_as_size is used for values that NEED to be used for constructing + tensor. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + a = x.item() + torch._constrain_as_size(a, min=0, max=5) + return torch.ones((a, 5)) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_value_example.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_value_example.py new file mode 100644 index 0000000000000000000000000000000000000000..3844c7227a365ceb157222c91d179296e73a3522 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/constrain_as_value_example.py @@ -0,0 +1,30 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.tensor(4), torch.randn(5, 5)), + tags={ + "torch.dynamic-value", + "torch.escape-hatch", + }, +) +class ConstrainAsValueExample(torch.nn.Module): + """ + If the value is not known at tracing time, you can provide hint so that we + can trace further. Please look at constrain_as_value and constrain_as_size APIs. + constrain_as_value is used for values that don't need to be used for constructing + tensor. + """ + + def __init__(self): + super().__init__() + + def forward(self, x, y): + a = x.item() + torch._constrain_as_value(a, min=0, max=5) + + if a < 6: + return y.sin() + return y.cos() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dictionary.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dictionary.py new file mode 100644 index 0000000000000000000000000000000000000000..382b444d7f8a285e85c4f5530f01972918a6d96f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dictionary.py @@ -0,0 +1,21 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2), torch.tensor(4)), + tags={"python.data-structure"}, +) +class Dictionary(torch.nn.Module): + """ + Dictionary structures are inlined and flattened along tracing. + """ + def __init__(self): + super().__init__() + + def forward(self, x, y): + elements = {} + elements["x2"] = x * x + y = y * elements["x2"] + return {"y": y} diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py new file mode 100644 index 0000000000000000000000000000000000000000..45c8d36bee1fa7ed0102809a6871fbfa76628696 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py @@ -0,0 +1,21 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2, 2),), + tags={"torch.dynamic-shape", "python.control-flow"}, +) +class DynamicShapeIfGuard(torch.nn.Module): + """ + `if` statement with backed dynamic shape predicate will be specialized into + one particular branch and generate a guard. However, export will fail if the + the dimension is marked as dynamic shape from higher level API. + """ + + def forward(self, x): + if x.shape[0] == 3: + return x.cos() + + return x.sin() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py new file mode 100644 index 0000000000000000000000000000000000000000..ee45ffb288368dc35ffb76e385bcee20ced22235 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py @@ -0,0 +1,20 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2),), + tags={"torch.dynamic-shape"}, +) +class DynamicShapeSlicing(torch.nn.Module): + """ + Slices with dynamic shape arguments should be captured into the graph + rather than being baked in. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_view.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_view.py new file mode 100644 index 0000000000000000000000000000000000000000..b763a4ec0ae3480a322dbd9b73664944c5e2d8bb --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_view.py @@ -0,0 +1,22 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(10, 10),), + tags={"torch.dynamic-shape"}, +) +class DynamicShapeView(torch.nn.Module): + """ + Dynamic shapes should be propagated to view arguments instead of being + baked into the exported graph. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + new_x_shape = x.size()[:-1] + (2, 5) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_contains.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_contains.py new file mode 100644 index 0000000000000000000000000000000000000000..16b7e54613d7fe405d28878ff45bd3a6ce5a3f4a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_contains.py @@ -0,0 +1,21 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2),), + tags={"torch.dynamic-shape", "python.data-structure", "python.assert"}, +) +class ListContains(torch.nn.Module): + """ + List containment relation can be checked on a dynamic shape or constants. + """ + def __init__(self): + super().__init__() + + def forward(self, x): + assert x.size(-1) in [6, 2] + assert x.size(0) not in [4, 5, 6] + assert "monkey" not in ["cow", "pig"] + return x + x diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_unpack.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_unpack.py new file mode 100644 index 0000000000000000000000000000000000000000..a5bd7fbd8edf523d4d6d11250bc9f8c8986653fd --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/list_unpack.py @@ -0,0 +1,27 @@ +from typing import List + +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=([torch.ones(3, 2), torch.tensor(4), torch.tensor(5)],), + tags={"python.control-flow", "python.data-structure"}, +) +class ListUnpack(torch.nn.Module): + """ + Lists are treated as static construct, therefore unpacking should be + erased after tracing. + """ + + def __init__(self): + super().__init__() + + def forward(self, args: List[torch.Tensor]): + """ + Lists are treated as static construct, therefore unpacking should be + erased after tracing. + """ + x, *y = args + return x + y[0] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/model_attr_mutation.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/model_attr_mutation.py new file mode 100644 index 0000000000000000000000000000000000000000..b4d76cc67eda8cbb3306f27b2315ae35c7517aa2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/model_attr_mutation.py @@ -0,0 +1,25 @@ +import torch + +from torch._export.db.case import export_case, SupportLevel + + +@export_case( + example_inputs=(torch.ones(3, 2),), + tags={"python.object-model"}, + support_level=SupportLevel.NOT_SUPPORTED_YET, +) +class ModelAttrMutation(torch.nn.Module): + """ + Attribute mutation is not supported. + """ + + def __init__(self): + super().__init__() + self.attr_list = [torch.ones(3, 2), torch.ones(3, 2)] + + def recreate_list(self): + return [torch.zeros(3, 2), torch.zeros(3, 2)] + + def forward(self, x): + self.attr_list = self.recreate_list() + return x.sum() + self.attr_list[0].sum() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/pytree_flatten.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/pytree_flatten.py new file mode 100644 index 0000000000000000000000000000000000000000..0d799b2a609acc2b626e70f5c9beb131784f4e6b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/pytree_flatten.py @@ -0,0 +1,20 @@ +import torch + +from torch._export.db.case import export_case, SupportLevel +from torch.utils import _pytree as pytree + + +@export_case( + example_inputs=({1: torch.randn(3, 2), 2: torch.randn(3, 2)},), + support_level=SupportLevel.SUPPORTED, +) +class PytreeFlatten(torch.nn.Module): + """ + Pytree from PyTorch can be captured by TorchDynamo. + """ + def __init__(self): + super().__init__() + + def forward(self, x): + y, spec = pytree.tree_flatten(x) + return y[0] + 1 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/scalar_output.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/scalar_output.py new file mode 100644 index 0000000000000000000000000000000000000000..d3fc2b0ec36a5f9296aceb3146be74f07d5e5ac2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/scalar_output.py @@ -0,0 +1,23 @@ +import torch + +from torch._export.db.case import export_case +from torch.export import Dim + +x = torch.ones(3, 2) +dim1_x = Dim("dim1_x") + +@export_case( + example_inputs=(x,), + tags={"torch.dynamic-shape"}, + dynamic_shapes={"x": {1: dim1_x}}, +) +class ScalarOutput(torch.nn.Module): + """ + Returning scalar values from the graph is supported, in addition to Tensor + outputs. Symbolic shapes are captured and rank is specialized. + """ + def __init__(self): + super().__init__() + + def forward(self, x): + return x.shape[1] + 1 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/specialized_attribute.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/specialized_attribute.py new file mode 100644 index 0000000000000000000000000000000000000000..743a357fc13ca984369cdddadf31bb4ee27e9109 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/specialized_attribute.py @@ -0,0 +1,29 @@ +from enum import Enum + +import torch + +from torch._export.db.case import export_case + + +class Animal(Enum): + COW = "moo" + + +@export_case( + example_inputs=(torch.ones(3, 2),), +) +class SpecializedAttribute(torch.nn.Module): + """ + Model attributes are specialized. + """ + + def __init__(self): + super().__init__() + self.a = "moo" + self.b = 4 + + def forward(self, x): + if self.a == Animal.COW.value: + return x * x + self.b + else: + raise ValueError("bad") diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_for_loop.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_for_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..9d030b6e82aa5f4c89c3e1c37622e53c1c4675f7 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_for_loop.py @@ -0,0 +1,22 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2),), + tags={"python.control-flow"}, +) +class StaticForLoop(torch.nn.Module): + """ + A for loop with constant number of iterations should be unrolled in the exported graph. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + ret = [] + for i in range(10): # constant + ret.append(i + x) + return ret diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_if.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_if.py new file mode 100644 index 0000000000000000000000000000000000000000..c258e430f7ea0fa4a5b58ef9d6988e936fbb0f3f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/static_if.py @@ -0,0 +1,23 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2, 2),), + tags={"python.control-flow"}, +) +class StaticIf(torch.nn.Module): + """ + `if` statement with static predicate value should be traced through with the + taken branch. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + if len(x.shape) == 3: + return x + torch.ones(1, 1, 1) + + return x diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/tensor_setattr.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/tensor_setattr.py new file mode 100644 index 0000000000000000000000000000000000000000..fae18fb1cf934bf1a9437b70578d58cf10130a4e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/tensor_setattr.py @@ -0,0 +1,17 @@ +import torch + +from torch._export.db.case import export_case, SupportLevel + + +@export_case( + example_inputs=(torch.randn(3, 2), "attr"), + tags={"python.builtin"}, + support_level=SupportLevel.SUPPORTED, +) +class TensorSetattr(torch.nn.Module): + """ + setattr() call onto tensors is not supported. + """ + def forward(self, x, attr): + setattr(x, attr, torch.randn(3, 2)) + return x + 4 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/type_reflection_method.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/type_reflection_method.py new file mode 100644 index 0000000000000000000000000000000000000000..a0d78703e2d5ff96c15bd5b772fea10e044ffbfc --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/type_reflection_method.py @@ -0,0 +1,41 @@ +import torch + +from torch._export.db.case import export_case, SupportLevel, export_rewrite_case + + +class A: + @classmethod + def func(cls, x): + return 1 + x + + +@export_case( + example_inputs=(torch.ones(3, 4),), + tags={"python.builtin"}, + support_level=SupportLevel.SUPPORTED, +) +class TypeReflectionMethod(torch.nn.Module): + """ + type() calls on custom objects followed by attribute accesses are not allowed + due to its overly dynamic nature. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + a = A() + return type(a).func(x) + + +@export_rewrite_case(parent=TypeReflectionMethod) +class TypeReflectionMethodRewrite(torch.nn.Module): + """ + Custom object class methods will be inlined. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + return A.func(x) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/remove_runtime_assertions.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/remove_runtime_assertions.py new file mode 100644 index 0000000000000000000000000000000000000000..adcc708e554830b430db0d4374f4494482ce0b39 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/remove_runtime_assertions.py @@ -0,0 +1,26 @@ +import torch +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +class _RemoveRuntimeAssertionsPass(PassBase): + """ + Remove runtime assertions inserted by the + _AddRuntimeAssertionsForInlineConstraintsPass. + """ + + def call(self, graph_module) -> PassResult: + modified = False + for module in graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.target == torch.ops.aten._assert_async.msg: + assert_async_node = node + if len(assert_async_node.users) > 0: + continue + module.graph.erase_node(assert_async_node) + # the upstream scalar_tensor <- {le, ge} <- sym_size + # linear chain of nodes of nodes is removed by the + # downstream dead code elimination + modified = True + return PassResult(graph_module, modified) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..97af59b700a792694a83b923b8c27b692356907f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py @@ -0,0 +1,141 @@ +import torch +from torch._higher_order_ops.wrap import wrap_with_set_grad_enabled + +from ..utils import ( + node_inline_, + node_replace_, + nodes_filter, + nodes_first, + nodes_map, + sequential_split, +) + + +def _is_set_grad_enabled_node(node: torch.fx.Node): + return ( + node + and node.op == "call_function" + and node.target == torch._C._set_grad_enabled + ) + + +def _is_set_grad_enabled_sub_mod(node: torch.fx.Node, omit_if_same_with_ambient=False): + if node.op == "call_module": + assert isinstance(node.target, str) + subgm = getattr(node.graph.owning_module, node.target) + first_non_ph = nodes_first( + subgm.graph.nodes, lambda node: node.op != "placeholder" + ) + if ( + first_non_ph + and first_non_ph.op == "call_function" + and first_non_ph.target == torch._C._set_grad_enabled + ): + return ( + first_non_ph.args[0] != torch.is_grad_enabled() + if omit_if_same_with_ambient + else True + ) + return False + + +def _replace_with_hop(node: torch.fx.Node): + assert node.op == "call_module" + graph: torch.fx.Graph = node.graph + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + sub_graph = sub_gm.graph + set_grad_nodes = nodes_filter(sub_graph.nodes, _is_set_grad_enabled_node) + if len(set_grad_nodes) > 0: + assert len(set_grad_nodes) == 1 + set_grad_node = set_grad_nodes[0] + enable_grad_val = set_grad_node.args[0] + with graph.inserting_before(node): + get_attr_node = graph.get_attr(node.target) + output_node = next(iter(reversed(sub_gm.graph.nodes)), None) + if output_node is not None: + assert len(output_node.args) == 1 + output_args = output_node.args[0] + if isinstance(output_args, (tuple, list)): + call_func_node = graph.call_function( + wrap_with_set_grad_enabled, + (enable_grad_val, get_attr_node, *node.args), + {}, + ) + # Create the metadata + call_func_node.meta["val"] = tuple( + arg.meta["val"] for arg in output_args + ) + node_replace_(node, call_func_node, delete_old=True) + + # Rename the name of getitem nodes to the actual name of its contents + # for passing verifier and better readability, also propagate metadata + for get_item_node in call_func_node.users.keys(): + idx: int = get_item_node.args[1] + output_node = output_args[idx] + get_item_node._rename(output_node.name) + get_item_node.meta = output_node.meta + pass + + elif isinstance(output_args, torch.fx.Node): + call_func_node = graph.create_node( + "call_function", + wrap_with_set_grad_enabled, + (enable_grad_val, get_attr_node, *node.args), + {}, + output_args.name, + ) + call_func_node.meta = output_args.meta + node_replace_(node, call_func_node, delete_old=True) + else: + raise NotImplementedError( + f"repalce_set_grad_with_hop_pass doesnt' support output type {type(output_args)}" + ) + else: + raise NotImplementedError( + "Cannot replace a call_module with a hop if it has no output. This module will gets DCEed." + ) + sub_graph.erase_node(set_grad_node) + + +def _remove_set_grad_and_inline(node: torch.fx.Node): + assert node.op == "call_module" + graph: torch.fx.Graph = node.graph + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + sub_graph = sub_gm.graph + nodes_map( + sub_graph.nodes, + lambda n: sub_graph.erase_node(n) if _is_set_grad_enabled_node(n) else n, + ) + node_inline_(node) + + +def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule): + # If there is no set_grad_enabled node, return the original graph module + need_replacing = False + for node in gm.graph.nodes: + if _is_set_grad_enabled_node(node): + need_replacing = True + + if not need_replacing: + return gm + + new_gm = sequential_split(gm, _is_set_grad_enabled_node) + + def _maybe_inline_or_replace_with_hop(node: torch.fx.Node): + if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True): + _replace_with_hop(node) + else: + _remove_set_grad_and_inline(node) + + nodes_map( + list(new_gm.graph.nodes), + lambda node: _maybe_inline_or_replace_with_hop(node) + if node.op == "call_module" + else node, + ) + new_gm.graph.lint() + return new_gm diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..58e5ed30d86debc6063a5d150c89d06faf95a342 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/utils.py @@ -0,0 +1,401 @@ +import dataclasses +import math +import operator +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type + +import torch +from torch._subclasses.fake_tensor import FakeTensor + +from torch.export import ExportedProgram +from torch.utils._pytree import ( + _register_pytree_node, + Context, + FlattenFunc, + FromDumpableContextFn, + KeyPath, + keystr, + MappingKey, + SequenceKey, + ToDumpableContextFn, + UnflattenFunc, +) + + +def _check_input_constraints_for_graph( + input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints +): + def get_keystr(key_path: KeyPath) -> str: + """For a given index into the flat_args, return a human readable string + describing how to access it, e.g. "*args["foo"][0].bar" + """ + # Prefix the keypath with "*args" or "**kwargs" to make it clearer where + # the arguments come from. Ultimately we ought to serialize the + # original arg names for the best error message here. + args_kwargs_key_path = key_path[0] + assert isinstance(args_kwargs_key_path, SequenceKey) + if args_kwargs_key_path.idx == 0: + return f"*args{keystr(key_path[1:])}" + else: + kwarg_key = key_path[1] + assert isinstance(kwarg_key, MappingKey) + name = str(kwarg_key)[1:-1] # get rid of the enclosed [] + return f"{name}{keystr(key_path[2:])}" + + import sympy + + from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( + _convert_range_to_int, + ) + from torch.utils._sympy.solve import try_solve + + if len(flat_args_with_path) != len(input_placeholders): + raise RuntimeError( + "Unexpected number of inputs " + f"(expected {len(input_placeholders)}, got {len(flat_args_with_path)})" + ) + # NOTE: export already guarantees that the same symbol is used in metadata + # for all InputDims related by equality constraints, so we can just unify + # symbols with given input dimension values to check equality constraints. + unification_map: "Dict[sympy.Symbol, Any]" = {} + for (key_path, arg), node in zip(flat_args_with_path, input_placeholders): + node_val = node.meta.get("val") + if isinstance(node_val, FakeTensor): + if not isinstance(arg, torch.Tensor): + raise RuntimeError( + f"Expected input at {get_keystr(key_path)} to be a tensor, but got {type(arg)}", + ) + + if len(node_val.shape) != len(arg.shape): + raise RuntimeError( + f"Unexpected number of dimensions in input at {get_keystr(key_path)}.shape " + f"(expected {node_val.shape}, got {arg.shape})" + ) + + for j, (arg_dim, node_dim) in enumerate(zip(arg.shape, node_val.shape)): + # TODO(avik): Assert the following property in the IR verifier: + # node_dim is either an int or a SymInt containing an int or a unary sympy.Expr + if ( + isinstance(node_dim, torch.SymInt) + and len(node_dim.node.expr.free_symbols) == 1 + ): + symbol = next(iter(node_dim.node.expr.free_symbols)) + if symbol in unification_map: + existing_dim = node_dim.node.expr.subs(unification_map) + if arg_dim != existing_dim: + raise RuntimeError( + f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to " + f"{existing_dim}, but got {arg_dim}", + ) + else: + if ( + isinstance(arg_dim, torch.SymInt) + and not arg_dim.node.expr.is_number + ): + # This can happen when, say, arg is a fake tensor. + # We do not run checks on symbolic shapes of fake inputs as + # such checks can affect the shape env. + pass + else: + solution = try_solve( + sympy.Eq(node_dim.node.expr, arg_dim), symbol + ) + if solution is None: + raise RuntimeError( # noqa: TRY200 + f"Expected input {node.name}.shape[{j}] = {arg_dim} to be " + f"of the form {node_dim.node.expr}, where {symbol} is an integer" + ) + else: + unification_map[symbol] = int(solution[1]) + + if node_dim.node.expr in range_constraints: + min_val, max_val = _convert_range_to_int( + range_constraints[node_dim.node.expr] + ) + # NOTE: we allow dimensions to be 0/1 at runtime + if min_val > 2: + if arg_dim < min_val: + raise RuntimeError( + f"Expected input at {get_keystr(key_path)}.shape[{j}] to be >= " + f"{min_val}, but got {arg_dim}", + ) + if max_val < math.inf: + if arg_dim > max_val: + raise RuntimeError( + f"Expected input at {get_keystr(key_path)}.shape[{j}] to be <= " + f"{max_val}, but got {arg_dim}", + ) + else: + if arg_dim != node_dim: + raise RuntimeError( + f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to " + f"{node_dim}, but got {arg_dim}", + ) + elif isinstance(node_val, (int, float, str)): + if type(arg) != type(node_val) or arg != node_val: + raise RuntimeError( + f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}", + ) + + +def register_dataclass_as_pytree_node( + cls: Type[Any], + flatten_fn: Optional[FlattenFunc] = None, + unflatten_fn: Optional[UnflattenFunc] = None, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + return_none_fields: bool = False, +) -> None: + assert dataclasses.is_dataclass( + cls + ), f"Only dataclasses can be registered with this function: {cls}" + + def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]: + flattened = [] + flat_names = [] + none_names = [] + for f in dataclasses.fields(obj): + name, val = f.name, getattr(obj, f.name) + if val is not None or return_none_fields: + flattened.append(val) + flat_names.append(name) + else: + none_names.append(name) + return flattened, [flat_names, none_names] + + def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any: + flat_names, none_names = context + return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names)) + + flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn + unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn + + if (to_dumpable_context is None) ^ (from_dumpable_context is None): + raise ValueError( + f"Both to_dumpable_context and from_dumpable_context for {cls} must " + "be None or registered." + ) + + _register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + +def is_param(program: ExportedProgram, node: torch.fx.Node) -> bool: + """ + Checks if the given node is a parameter within the exported program + """ + + return node.name in program.graph_signature.inputs_to_parameters + + +def get_param( + program: ExportedProgram, + node: torch.fx.Node, +) -> Optional[torch.nn.Parameter]: + """ + Returns the parameter associated with the given node in the exported program. + Returns None if the node is not a parameter within the exported program + """ + + if is_param(program, node): + parameter_name = program.graph_signature.inputs_to_parameters[node.name] + return program.state_dict[parameter_name] + + return None + + +def is_buffer(program: ExportedProgram, node: torch.fx.Node) -> bool: + """ + Checks if the given node is a buffer within the exported program + """ + + return node.name in program.graph_signature.inputs_to_buffers + + +def get_buffer( + program: ExportedProgram, + node: torch.fx.Node, +) -> Optional[torch.Tensor]: + """ + Returns the buffer associated with the given node in the exported program. + Returns None if the node is not a buffer within the exported program + """ + + if is_buffer(program, node): + buffer_name = program.graph_signature.inputs_to_buffers[node.name] + if buffer_name in program.graph_signature.non_persistent_buffers: + return program.constants[buffer_name] + else: + return program.state_dict[buffer_name] + + return None + + +def is_lifted_tensor_constant( + program: ExportedProgram, + node: torch.fx.Node, +) -> bool: + """ + Checks if the given node is a lifted tensor constant within the exported program + """ + + return node.name in program.graph_signature.inputs_to_lifted_tensor_constants + + +def get_lifted_tensor_constant( + program: ExportedProgram, + node: torch.fx.Node, +) -> Optional[torch.Tensor]: + """ + Returns the lifted tensor constant associated with the given node in the exported program. + Returns None if the node is not a lifted tensor constant within the exported program + """ + + if is_lifted_tensor_constant(program, node): + lifted_tensor_name = program.graph_signature.inputs_to_lifted_tensor_constants[ + node.name + ] + return program.constants[lifted_tensor_name] + + return None + + +def sequential_split(gm: torch.fx.GraphModule, node_call_back) -> torch.fx.GraphModule: + """ + Splits the graph module into multiple submodules based on the node_call_back. + The node_call_back should return True if the node is a delimiter. Delimiter will be + the first node in the next submodule. + """ + from torch.fx.passes.split_module import split_module + + split_map = {} + split_id = 0 + for node in gm.graph.nodes: + if node_call_back(node): + split_id += 1 + split_map[node] = split_id + + new_gm = split_module( + gm, + gm, + lambda node: split_map[node], + keep_original_order=True, + keep_original_node_name=True, + ) + # Keep the codegen from original graph module to preserve e.g. pytree info. + new_gm.graph._codegen = gm.graph._codegen + new_gm.recompile() + return new_gm + + +def nodes_filter(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]: + """Returns the nodes that match the node_call_back as a list.""" + return [node for node in nodes if node_call_back(node)] + + +def nodes_first( + nodes: List[torch.fx.Node], node_call_back=None +) -> Optional[torch.fx.Node]: + """ + Returns the first node that matches the node_call_back. If no node matches, returns None. + When node_call_back is None, returns the first node in the node list. + """ + ret = nodes_filter(nodes, node_call_back if node_call_back else lambda node: True) + if len(ret) > 0: + return ret[0] + return None + + +def nodes_count(nodes: List[torch.fx.Node], node_call_back) -> int: + """Returns the number of nodes that match the node_call_back.""" + return len(nodes_filter(nodes, node_call_back)) + + +def nodes_map(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]: + """ + Sequentially visit the nodes list and invoke node_call_back on each element. + Returns the nodes list after the node_call_back is invoked on each element. + """ + for node in nodes: + node_call_back(node) + return nodes + + +def node_replace_( + old_node: torch.fx.Node, new_node: torch.fx.Node, delete_old: bool = False +) -> None: + """ + Replace all uses of old_node with new_node. + """ + old_node.replace_all_uses_with(new_node) + if delete_old: + old_node.users.clear() + old_node.graph.erase_node(old_node) + + +def node_inline_(call_mod_node: torch.fx.Node) -> None: + """ + Inline the submodule of the given node into the parent module. + Note: we only support the case where submodule takes tensors inputs. + """ + assert call_mod_node.op == "call_module" + gm = call_mod_node.graph.owning_module + + assert isinstance(call_mod_node.target, str) + sub_gm = getattr(gm, call_mod_node.target) + + phs = (node for node in sub_gm.graph.nodes if node.op == "placeholder") + body = ( + node for node in sub_gm.graph.nodes if node.op not in ("placeholder", "output") + ) + output = [node for node in sub_gm.graph.nodes if node.op == "output"] + + for ph, arg in zip(phs, call_mod_node.args): + assert isinstance(arg, torch.fx.Node) + node_replace_(ph, arg, delete_old=True) + + with gm.graph.inserting_before(call_mod_node): + for node in body: + new_node = gm.graph.node_copy(node) + node_replace_(node, new_node, delete_old=True) + + if len(output) > 0: + assert len(output) == 1 and len(output[0].args) == 1 + new_output = output[0].args[0] + + if isinstance(new_output, torch.fx.Node): + node_replace_(call_mod_node, new_output, delete_old=True) + elif isinstance(new_output, (list, tuple)): + # Inline the get_item calls for the output node. + get_item_users = nodes_filter( + list(call_mod_node.users.keys()), + lambda node: node.op == "call_function" + and node.target == operator.getitem, + ) + # get_item_node.args[1] is the idx referring to new_output[idx] + nodes_map( + get_item_users, + lambda get_item_node: node_replace_( + get_item_node, + new_output[get_item_node.args[1]], + delete_old=True, + ), + ) + call_mod_node.graph.erase_node(call_mod_node) + else: + raise NotImplementedError( + f"Unsupported output type {type(new_output)}. Expect it to be a Node or a list/tuple of Nodes." + ) + else: + call_mod_node.graph.erase_node(call_mod_node) + + gm.delete_all_unused_submodules() + gm.recompile() + return gm diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/wrappers.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..5ca2375ec124fe89f5713cd11a6a6046bdec8a45 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/wrappers.py @@ -0,0 +1,114 @@ +from contextlib import contextmanager + +import torch +import torch._custom_ops +from torch._C import DispatchKey +from torch._higher_order_ops.strict_mode import strict_mode +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.utils import _pytree as pytree + + +_export_tracepoint = HigherOrderOperator("_export_tracepoint") + + +@_export_tracepoint.py_impl(ProxyTorchDispatchMode) +def export_tracepoint_dispatch_mode(mode, *args, **kwargs): + if not mode.enable_tracing: + return _export_tracepoint(*args, **kwargs) + p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs)) + proxy = mode.tracer.create_proxy( + "call_function", _export_tracepoint, p_args, p_kwargs + ) + return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer) + + +@_export_tracepoint.py_impl(FakeTensorMode) +def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs): + with mode: + return args + + +@_export_tracepoint.py_functionalize_impl +def export_tracepoint_functional(ctx, *args, **kwargs): + unwrapped_args = ctx.unwrap_tensors(args) + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + + with ctx.redispatch_to_next(): + out = _export_tracepoint(*unwrapped_args, **unwrapped_kwargs) + return ctx.wrap_tensors(out) + + +_export_tracepoint.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(_export_tracepoint, deferred_error=True) +) + + +@_export_tracepoint.py_impl(DispatchKey.CPU) +def export_tracepoint_cpu(*args, **kwargs): + return args + + +def _wrap_submodule(mod, path, module_call_specs): + assert isinstance(mod, torch.nn.Module) + assert path != "" + submodule = mod + for name in path.split("."): + if not hasattr(submodule, name): + raise RuntimeError(f"Couldn't find submodule at path {path}") + submodule = getattr(submodule, name) + + def update_module_call_signatures(path, in_spec, out_spec): + if path in module_call_specs: + assert module_call_specs[path]["in_spec"] == in_spec + assert module_call_specs[path]["out_spec"] == out_spec + module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec} + + def check_flattened(flat_args): + for a in flat_args: + if not (isinstance(a, (torch.Tensor, str, int, float, bool)) or a is None): + raise AssertionError( + f"Only Tensors or scalars are supported as pytree flattened inputs, got: {a}" + ) + + def pre_hook(module, args, kwargs): + flat_args, in_spec = pytree.tree_flatten((args, kwargs)) + check_flattened(flat_args) + flat_args = _export_tracepoint(*flat_args, kind="module_call_inputs", path=path) + args, kwargs = pytree.tree_unflatten(flat_args, in_spec) + return args, kwargs + + def post_hook(module, args, kwargs, res): + _, in_spec = pytree.tree_flatten((args, kwargs)) + flat_res, out_spec = pytree.tree_flatten(res) + check_flattened(flat_res) + flat_res = _export_tracepoint(*flat_res, kind="module_call_outputs", path=path) + update_module_call_signatures(path, in_spec, out_spec) + return pytree.tree_unflatten(flat_res, out_spec) + + pre_handle = submodule.register_forward_pre_hook(pre_hook, with_kwargs=True) + post_handle = submodule.register_forward_hook(post_hook, with_kwargs=True) + return pre_handle, post_handle + + +@contextmanager +def _wrap_submodules(f, preserve_signature, module_call_signatures): + handles = [] + + try: + for path in preserve_signature: + handles.extend(_wrap_submodule(f, path, module_call_signatures)) + yield + finally: + for handle in handles: + handle.remove() + + +def _mark_strict_experimental(cls): + def call(self, *args): + return strict_mode(self, args) + + cls.__call__ = call + return cls diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/config.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3f0a302f591864b20430e30737fe8d92d4fda22 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/config.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/device_context.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/device_context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..614ba08445a9113d038042f5e069bffdf44acd6b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/device_context.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbbd2410aa92b8ff80e03ac9f136315e3e4a5c37 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecf79b9a2ab5b5fa158a32d28a8afd9f8ef8c93c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e085ae4020fe44b8868cf7b328c8dcd402d341f Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/convert.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/convert.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8688a7d953ee97d014c0fabb50ad3db54c3b4d5f Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/convert.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/fuse_handler.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/fuse_handler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..328fd4b12c32a027313b8c72f8b433e6ceaf01ce Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/fuse_handler.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_fbgemm.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_fbgemm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31a93fb38d3251c83e2177038022d9ddb44b6486 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_fbgemm.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/lstm_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/lstm_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b05a83b8396ba9817b6cb73abcaa208685497760 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/lstm_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/tracer.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/tracer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a8bc8509c6554039f86b143fb08e7e656e5d21d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/tracer.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97551612554d04ad7d88d6b150bd03f909819d9c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/composable_quantizer.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/composable_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..75f7d1ad5f1e9f15890970222d54a2ee492acc17 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/composable_quantizer.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import Dict, List + +import torch + +from torch.fx import Node + +from .quantizer import QuantizationAnnotation, Quantizer + +__all__ = [ + "ComposableQuantizer", +] + + +class ComposableQuantizer(Quantizer): + """ + ComposableQuantizer allows users to combine more than one quantizer into a single quantizer. + This allows users to quantize a model with multiple quantizers. E.g., embedding quantization + maybe supported by one quantizer while linear layers and other ops might be supported by another + quantizer. + + ComposableQuantizer is initialized with a list of `Quantizer` instances. + The order of the composition matters since that is the order in which the quantizers will be + applies. + Example: + ``` + embedding_quantizer = EmbeddingQuantizer() + linear_quantizer = MyLinearQuantizer() + xnnpack_quantizer = XNNPackQuantizer() # to handle ops not quantized by previous two quantizers + composed_quantizer = ComposableQuantizer([embedding_quantizer, linear_quantizer, xnnpack_quantizer]) + prepared_m = prepare_pt2e(model, composed_quantizer) + ``` + """ + + def __init__(self, quantizers: List[Quantizer]): + super().__init__() + self.quantizers = quantizers + self._graph_annotations: Dict[Node, QuantizationAnnotation] = {} + + def _record_and_validate_annotations( + self, gm: torch.fx.GraphModule, quantizer: Quantizer + ) -> None: + for n in gm.graph.nodes: + if "quantization_annotation" in n.meta: + # check if the annotation has been changed by + # comparing QuantizationAnnotation object id + if n in self._graph_annotations and ( + id(self._graph_annotations[n]) + != id(n.meta["quantization_annotation"]) + ): + raise RuntimeError( + f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}" + ) + else: + self._graph_annotations[n] = n.meta["quantization_annotation"] + else: + if n in self._graph_annotations: + raise RuntimeError( + f"Quantizer {quantizer.__class__.__name__} has removed annotations on node {n}" + ) + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """just handling global spec for now""" + for quantizer in self.quantizers: + quantizer.annotate(model) + self._record_and_validate_annotations(model, quantizer) + return model + + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + for quantizer in self.quantizers: + model = quantizer.transform_for_annotation(model) + return model + + def validate(self, model: torch.fx.GraphModule) -> None: + pass diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/__pycache__/functional.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/__pycache__/functional.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbcd02d13d09d0c980ce1fa5f556437d2574acb4 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/__pycache__/functional.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/conv.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/conv.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3d47ffbc0dba191da15ce1e682e5f8d649b7baf Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/conv.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0ee0cc168aea662232d59da636d03a558d3ebd2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__init__.py @@ -0,0 +1,70 @@ +r"""Quantized Modules. + +Note:: + The `torch.nn.quantized` namespace is in the process of being deprecated. + Please, use `torch.ao.nn.quantized` instead. +""" + +from torch.ao.nn.quantized.modules.activation import ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid, Softmax, MultiheadAttention, PReLU +from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d, BatchNorm3d +from torch.ao.nn.quantized.modules.conv import Conv1d, Conv2d, Conv3d +from torch.ao.nn.quantized.modules.conv import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d +from torch.ao.nn.quantized.modules.dropout import Dropout +from torch.ao.nn.quantized.modules.embedding_ops import Embedding, EmbeddingBag +from torch.ao.nn.quantized.modules.functional_modules import FloatFunctional, FXFloatFunctional, QFunctional +from torch.ao.nn.quantized.modules.linear import Linear +from torch.ao.nn.quantized.modules.normalization import LayerNorm, GroupNorm, InstanceNorm1d, InstanceNorm2d, InstanceNorm3d +from torch.ao.nn.quantized.modules.rnn import LSTM + +from torch.ao.nn.quantized.modules import MaxPool2d +from torch.ao.nn.quantized.modules import Quantize, DeQuantize + +# The following imports are needed in case the user decides +# to import the files directly, +# s.a. `from torch.nn.quantized.modules.conv import ...`. +# No need to add them to the `__all__`. +from torch.ao.nn.quantized.modules import activation +from torch.ao.nn.quantized.modules import batchnorm +from torch.ao.nn.quantized.modules import conv +from torch.ao.nn.quantized.modules import dropout +from torch.ao.nn.quantized.modules import embedding_ops +from torch.ao.nn.quantized.modules import functional_modules +from torch.ao.nn.quantized.modules import linear +from torch.ao.nn.quantized.modules import normalization +from torch.ao.nn.quantized.modules import rnn +from torch.ao.nn.quantized.modules import utils + +__all__ = [ + 'BatchNorm2d', + 'BatchNorm3d', + 'Conv1d', + 'Conv2d', + 'Conv3d', + 'ConvTranspose1d', + 'ConvTranspose2d', + 'ConvTranspose3d', + 'DeQuantize', + 'ELU', + 'Embedding', + 'EmbeddingBag', + 'GroupNorm', + 'Hardswish', + 'InstanceNorm1d', + 'InstanceNorm2d', + 'InstanceNorm3d', + 'LayerNorm', + 'LeakyReLU', + 'Linear', + 'LSTM', + 'MultiheadAttention', + 'Quantize', + 'ReLU6', + 'Sigmoid', + 'Softmax', + 'Dropout', + 'PReLU', + # Wrapper modules + 'FloatFunctional', + 'FXFloatFunctional', + 'QFunctional', +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..857fdc56fa12820fc5c0761a4d1450e7684e4ec1 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/functional_modules.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/functional_modules.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c72d94a7bd146b4d8e223685ae3f6a02babc4d03 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/functional_modules.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/functional_modules.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/functional_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..c600f84e776d67c7381b22e30a936edfbcf17438 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/functional_modules.py @@ -0,0 +1,15 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +__all__ = ['FloatFunctional', 'FXFloatFunctional', 'QFunctional'] + +from torch.ao.nn.quantized.modules.functional_modules import FloatFunctional +from torch.ao.nn.quantized.modules.functional_modules import FXFloatFunctional +from torch.ao.nn.quantized.modules.functional_modules import QFunctional diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/normalization.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..1127bf9acb81ea9a5803bd18181f25a311cefa07 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/normalization.py @@ -0,0 +1,17 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +__all__ = ['LayerNorm', 'GroupNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d'] + +from torch.ao.nn.quantized.modules.normalization import LayerNorm +from torch.ao.nn.quantized.modules.normalization import GroupNorm +from torch.ao.nn.quantized.modules.normalization import InstanceNorm1d +from torch.ao.nn.quantized.modules.normalization import InstanceNorm2d +from torch.ao.nn.quantized.modules.normalization import InstanceNorm3d diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9021bd193928fcbd771c4f84abbeb056b6f5774f Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/init.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/init.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64a90a6a543905090f864b42f1d1c93b68c9df41 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/init.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_named_member_accessor.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_named_member_accessor.py new file mode 100644 index 0000000000000000000000000000000000000000..3a82b2b426aa0a1bdbe64cdc177ba42219b78fdc --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_named_member_accessor.py @@ -0,0 +1,374 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Iterable, List, Tuple + +import torch + + +_MISSING: torch.Tensor = object() # type: ignore[assignment] + + +def set_tensor(module: "torch.nn.Module", name: str, tensor: torch.Tensor) -> None: + if not isinstance(module, torch.nn.Module): + raise TypeError(f"{module} is not an instance of torch.nn.Module") + if not isinstance(tensor, torch.Tensor) and tensor is not None: + raise TypeError(f"{tensor} is not an instance of torch.Tensor") + if "." in name: + raise KeyError('tensor name can\'t contain "."') + if name == "": + raise KeyError('tensor name can\'t be empty string ""') + if name in module._parameters: + module._parameters[name] = tensor # type: ignore[assignment] + elif name in module._buffers: + module._buffers[name] = tensor + else: + setattr(module, name, tensor) + + +def swap_tensor( + module: "torch.nn.Module", + name: str, + tensor: torch.Tensor, + allow_missing: bool = False, +) -> torch.Tensor: + if not isinstance(module, torch.nn.Module): + raise TypeError(f"{module} is not an instance of torch.nn.Module") + if ( + tensor is not _MISSING + and not isinstance(tensor, torch.Tensor) + and tensor is not None + ): + raise TypeError(f"{tensor} is not an instance of torch.Tensor") + if "." in name: + raise KeyError('tensor name can\'t contain "."') + if name == "": + raise KeyError('tensor name can\'t be empty string ""') + + orig_tensor: torch.Tensor + if name in module._parameters: + orig_tensor = module._parameters[name] # type: ignore[assignment] + if tensor is not _MISSING: + module._parameters[name] = tensor # type: ignore[assignment] + else: + del module._parameters[name] + elif name in module._buffers: + orig_tensor = module._buffers[name] # type: ignore[assignment] + if tensor is not _MISSING: + module._buffers[name] = tensor + else: + del module._buffers[name] + else: + try: + orig_tensor = getattr(module, name) + except AttributeError as ex: + if not allow_missing: + raise AttributeError( + f"{module._get_name()} has no attribute `{name}`" + ) from ex + orig_tensor = _MISSING + if ( + orig_tensor is not _MISSING + and not isinstance(orig_tensor, torch.Tensor) + and orig_tensor is not None + ): + raise TypeError( + f"attribute `{name}`: {orig_tensor} is not an instance of torch.Tensor" + ) + if tensor is not _MISSING: + setattr(module, name, tensor) + elif hasattr(module, name): + delattr(module, name) + return orig_tensor + + +def swap_submodule( + module: "torch.nn.Module", + name: str, + submodule: "torch.nn.Module", +) -> "torch.nn.Module": + if not isinstance(module, torch.nn.Module): + raise TypeError(f"{module} is not an instance of torch.nn.Module") + if not isinstance(submodule, torch.nn.Module): + raise TypeError(f"{submodule} is not an instance of torch.nn.Module") + if "." in name: + raise KeyError('submodule name can\'t contain "."') + if name == "": + raise KeyError('submodule name can\'t be empty string ""') + if name not in module._modules: + raise KeyError(f"submodule {name} does not exist") + + orig_submodule = module._modules[name] + if not isinstance(orig_submodule, torch.nn.Module): + raise TypeError(f"{name} attribute is not an instance of torch.nn.Module") + module._modules[name] = submodule + return orig_submodule + + +class NamedMemberAccessor: + """ + A class that provides a way to access the submodules and parameters/buffers of a module. + + It provides caching mechanism to speed up submodule lookups. + This is useful for functional programming to manipulate the module state. + """ + + def __init__(self, module: "torch.nn.Module") -> None: + self.module = module + self.memo: Dict[str, torch.nn.Module] = {} + + # Nested attribute access + + def get_submodule(self, name: str) -> "torch.nn.Module": + """ + Return the submodule specified by the given path. + + For example, to get the submodule mod.layer1.conv1, + use accessor.get_submodule("layer1.conv1") + + Compare to mod.get_submodule("layer1.conv1"), this method will cache the + intermediate submodule access to speed up future lookups. + """ + if not name: + return self.module + + try: + return self.memo[name] + except KeyError: + prefix, dot, attr = name.rpartition(".") + if dot: + module = self.get_submodule(prefix) + else: + module = self.module + try: + submodule = getattr(module, attr) + except AttributeError as ex: + raise AttributeError( + f"{module._get_name()} has no attribute `{attr}`" + ) from ex + if not isinstance(submodule, torch.nn.Module): + raise TypeError( # noqa: TRY200 + f"submodule `{name}`: {submodule} is not an instance of torch.nn.Module" + ) + self.memo[name] = submodule + return submodule + + def swap_submodule(self, path: str, value: "torch.nn.Module") -> "torch.nn.Module": + """ + Swap the submodule specified by the given ``path`` to ``value``. + + For example, to swap the attribute mod.layer1.conv1 use + ``accessor.swap_submodule("layer1.conv1", conv2)``. + """ + prefix, _, attr = path.rpartition(".") + return swap_submodule(self.get_submodule(prefix), attr, value) + + def get_tensor(self, name: str) -> torch.Tensor: + """ + Get the tensor specified by the given path to value. + + For example, to get the attribute mod.layer1.conv1.weight, + use accessor.get_tensor('layer1.conv1.weight') + + Compare to mod.get_parameter("layer1.conv1.weight"), this method will + cache the intermediate submodule access to speed up future lookups. + """ + prefix, _, attr = name.rpartition(".") + submodule = self.get_submodule(prefix) + try: + tensor = getattr(submodule, attr) + except AttributeError as ex: + raise AttributeError( + f"{submodule._get_name()} has no attribute `{name}`" + ) from ex + if not isinstance(tensor, torch.Tensor) and tensor is not None: + raise TypeError(f"{tensor} is not an instance of torch.Tensor") + return tensor # type: ignore[return-value] + + def set_tensor(self, name: str, value: torch.Tensor) -> None: + """ + Set the attribute specified by the given path to value. + + For example, to set the attribute mod.layer1.conv1.weight, + use accessor.set_tensor("layer1.conv1.weight", value) + """ + prefix, _, attr = name.rpartition(".") + set_tensor(self.get_submodule(prefix), attr, value) + + def del_tensor(self, name: str) -> None: + """ + Delete the attribute specified by the given path. + + For example, to delete the attribute mod.layer1.conv1.weight, + use accessor.del_tensor("layer1.conv1.weight") + """ + prefix, _, attr = name.rpartition(".") + submodule = self.get_submodule(prefix) + try: + delattr(submodule, attr) + except AttributeError as ex: + raise AttributeError( + f"{submodule._get_name()} has no attribute `{name}`" + ) from ex + + def swap_tensor( + self, name: str, value: torch.Tensor, allow_missing: bool = False + ) -> torch.Tensor: + """ + Swap the attribute specified by the given path to value. + + For example, to swap the attribute mod.layer1.conv1.weight, + use accessor.swap_tensor("layer1.conv1.weight", value) + """ + prefix, _, attr = name.rpartition(".") + return swap_tensor( + self.get_submodule(prefix), attr, value, allow_missing=allow_missing + ) + + # Batched operations + + def get_tensors(self, names: Iterable[str]) -> List[torch.Tensor]: + """ + Get the tensors specified by the given paths. + + For example, to get the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.get_tensors(["layer1.conv1.weight", + "layer1.conv1.bias"]) + """ + return [self.get_tensor(name) for name in names] + + def set_tensors(self, names: Iterable[str], values: Iterable[torch.Tensor]) -> None: + """ + Set the attributes specified by the given paths to values. + + For example, to set the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.set_tensors(["layer1.conv1.weight", + "layer1.conv1.bias"], [weight, bias]) + """ + if not isinstance(names, (list, tuple)): + names = list(names) + if not isinstance(values, (list, tuple)): + values = list(values) + assert len(names) == len(values), "names and values must have the same length" + + for name, value in zip(names, values): + self.set_tensor(name, value) + + def set_tensors_dict(self, named_tensors: Dict[str, torch.Tensor]) -> None: + """ + Set the attributes specified by the given paths to values. + + For example, to set the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.set_tensors_dict({ + "layer1.conv1.weight": weight, + "layer1.conv1.bias": bias, + }) + """ + for name, value in named_tensors.items(): + self.set_tensor(name, value) + + def del_tensors(self, names: Iterable[str]) -> None: + """ + Delete the attributes specified by the given paths. + + For example, to delete the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.del_tensors(["layer1.conv1.weight", + "layer1.conv1.bias"]) + """ + for name in names: + self.del_tensor(name) + + def swap_tensors( + self, + names: Iterable[str], + values: Iterable[torch.Tensor], + allow_missing: bool = False, + ) -> List[torch.Tensor]: + """ + Swap the attributes specified by the given paths to values. + + For example, to swap the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.swap_tensors(["layer1.conv1.weight", + "layer1.conv1.bias"], [weight, bias]) + """ + if not isinstance(names, (list, tuple)): + names = list(names) + if not isinstance(values, (list, tuple)): + values = list(values) + assert len(names) == len(values), "names and values must have the same length" + + return [ + self.swap_tensor(name, value, allow_missing=allow_missing) + for name, value in zip(names, values) + ] + + def swap_tensors_dict( + self, named_tensors: Dict[str, torch.Tensor], allow_missing: bool = False + ) -> Tuple[Dict[str, torch.Tensor], List[str]]: + """ + Swap the attributes specified by the given paths to values. + + For example, to swap the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.swap_tensors_dict({ + "layer1.conv1.weight": weight, + "layer1.conv1.bias": bias, + }) + """ + orig_named_tensors = {} + missing_keys = [] + try: + for name, tensor in named_tensors.items(): + orig_tensor = self.swap_tensor(name, tensor, allow_missing=True) + if orig_tensor is _MISSING: + missing_keys.append(name) + orig_named_tensors[name] = orig_tensor + except Exception: + # Swap back if any exception occurs + for name, orig_tensor in orig_named_tensors.items(): + self.swap_tensor(name, orig_tensor, allow_missing=True) + raise + if missing_keys and not allow_missing: + # Swap back if any key is missing when allow_missing is False + for name, orig_tensor in orig_named_tensors.items(): + self.swap_tensor(name, orig_tensor, allow_missing=True) + raise RuntimeError(f"Missing key(s): {', '.join(map(repr, missing_keys))}.") + return orig_named_tensors, missing_keys + + def check_keys(self, keys: Iterable[str]) -> Tuple[List[str], List[str]]: + """Check that the given keys are valid.""" + keys = set(keys) + valid_keys = {name for name, _ in self.named_tensors(remove_duplicate=False)} + missing_keys = valid_keys - keys + unexpected_keys = keys - valid_keys + return sorted(missing_keys), sorted(unexpected_keys) + + # Shortcut methods + + def named_parameters( + self, + remove_duplicate: bool = True, + ) -> Iterable[Tuple[str, torch.Tensor]]: + """Iterate over all the parameters in the module.""" + yield from self.module.named_parameters(remove_duplicate=remove_duplicate) + + def named_buffers( + self, + remove_duplicate: bool = True, + ) -> Iterable[Tuple[str, torch.Tensor]]: + """Iterate over all the buffers in the module.""" + yield from self.module.named_buffers(remove_duplicate=remove_duplicate) + + def named_tensors( + self, + remove_duplicate: bool = True, + ) -> Iterable[Tuple[str, torch.Tensor]]: + """Iterate over all the tensors in the module.""" + yield from self.module.named_parameters(remove_duplicate=remove_duplicate) + yield from self.module.named_buffers(remove_duplicate=remove_duplicate) + + def named_modules( + self, + remove_duplicate: bool = True, + ) -> Iterable[Tuple[str, "torch.nn.Module"]]: + """Iterate over all the modules in the module.""" + yield from self.module.named_modules(remove_duplicate=remove_duplicate) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_per_sample_grad.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_per_sample_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..0644ab5d2535e07360c77cebe838ab680c842362 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_per_sample_grad.py @@ -0,0 +1,102 @@ +import functools + +import torch +from torch.nn.utils._expanded_weights.expanded_weights_impl import ExpandedWeight + +from torch.utils import _pytree as pytree + + +# dependency on `functional_call` means that this can't be exposed in utils +# without creating circular dependency +def call_for_per_sample_grads(module, *, batch_size=None, loss_reduction="sum", batch_first=True): + r""" + Return a forward function for a module, populating grad_sample with per sample gradients on backward invocation. + + Args: + module: The ``nn.Module`` to get per sample gradients with respect to. All trainable + parameters will compute per sample gradients, located in a ``grad_sample`` + field when ``backward`` is invoked + batch_size: The batch size of the input. If None is passed, all tensor arguments in args and kwargs must have + the same batch size, which is the size of the first dimension. Otherwise, it must be passed manually. + Default: None + loss_reduction: Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. If + "mean", per sample gradients will be scaled by the batch size to offset the crossbatch interaction from + running mean across a batch. Must be "mean" or "sum". Default: "sum" + batch_first: Indicates if the batch dimension is the first dimension. If True, the batch dimension is the first + dimension. If False, it's the second dimension. Default: True. + + Examples:: + >>> # xdoctest: +SKIP + >>> model = nn.Linear(4, 3) + >>> batched_input = torch.randn(5, 4) # batch size of 5 + >>> res = call_for_per_sample_grads(model)(batched_input).sum() + >>> res.backward() + >>> assert model.weight.shape == (3, 4) + >>> assert model.weight.grad_sample.shape == (5, 3, 4) + >>> assert model.weight.grad is None + >>> assert model.bias.shape == (3,) + >>> assert model.bias.grad_sample.shape == (5, 3) + >>> assert model.bias.grad is None + + An example using "mean" loss reduction. The grad_sample fields will be scaled by batch_size from what they would be + if we ran the same code with loss_reduction="sum". This is because the mean at the end will scale all + grad_outputs by 1 / batch_size from cross batch interaction. + >>> model = nn.Linear(4, 3) + >>> batched_input = torch.randn(5, 4) # batch size of 5 + >>> res = call_for_per_sample_grads(model, 5, loss_reduction="mean")(batched_input).mean() + >>> res.backward() + + Note:: + Does not work with any `nn.RNN`, including `nn.GRU` or `nn.LSTM`. Please use custom + rewrites that wrap an `nn.Linear` module. See Opacus for an example + """ + + def maybe_build_expanded_weight(og_tensor, batch_size): + if og_tensor.requires_grad: + return ExpandedWeight(og_tensor, batch_size, loss_reduction) + else: + return og_tensor + + def compute_batch_size(*args, **kwargs): + args_and_kwargs = pytree.arg_tree_leaves(*args, **kwargs) + batch_size = None + for arg in args_and_kwargs: + if not isinstance(arg, torch.Tensor): + continue + + arg_batch_size = arg.shape[0] if batch_first else arg.shape[1] + if batch_size is not None and batch_size != arg_batch_size: + raise RuntimeError("When computing batch size, found at least one input with batch size " + f"{batch_size} and one with batch size {arg_batch_size}. Please specify it " + "explicitly using the batch size kwarg in call_for_per_sample_grads") + batch_size = arg_batch_size + if batch_size is None: + raise RuntimeError("Unable to find a tensor in the passed args and kwargs. They may not be pytree-able " + "and so ExpandedWeights cannot compute the batch size from the inputs. Please specify " + "it explicitly") + return batch_size + + if loss_reduction not in ["sum", "mean"]: + raise RuntimeError(f"Expected loss_reduction argument to be sum or mean, got {loss_reduction}") + + if not isinstance(module, torch.nn.Module): + raise RuntimeError(f"Module passed must be nn.Module, got {type(module).__name__}") + if not (batch_size is None or isinstance(batch_size, int)): + raise RuntimeError(f"Batch size passed must be None or an integer, got {type(batch_size).__name__}") + if batch_size is not None and batch_size < 1: + raise RuntimeError(f"Batch size must be positive, got {batch_size}") + for weight in module.parameters(): + if hasattr(weight, "grad_sample") and weight.grad_sample is not None: # type: ignore[attr-defined] + raise RuntimeError("Current Expanded Weights accumulates the gradients, which will be incorrect for multiple " + f"calls without clearing gradients. Please clear out the grad_sample parameter of {weight} or " + "post an issue to pytorch/pytorch to prioritize correct behavior") + + @functools.wraps(module.forward) + def wrapper(*args, **kwargs): + wrapper_batch_size = batch_size + if wrapper_batch_size is None: + wrapper_batch_size = compute_batch_size(*args, **kwargs) + + params = {name: maybe_build_expanded_weight(value, wrapper_batch_size) for (name, value) in module.named_parameters()} + return torch.func.functional_call(module, params, args, kwargs) + return wrapper diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/convert_parameters.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/convert_parameters.py new file mode 100644 index 0000000000000000000000000000000000000000..e23352b6b6d9bb2f32df6fb26401e4b8d9281636 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/convert_parameters.py @@ -0,0 +1,83 @@ +import torch +from typing import Iterable, Optional + + +def parameters_to_vector(parameters: Iterable[torch.Tensor]) -> torch.Tensor: + r"""Flatten an iterable of parameters into a single vector. + + Args: + parameters (Iterable[Tensor]): an iterable of Tensors that are the + parameters of a model. + + Returns: + The parameters represented by a single vector + """ + # Flag for the device where the parameter is located + param_device = None + + vec = [] + for param in parameters: + # Ensure the parameters are located in the same device + param_device = _check_param_device(param, param_device) + + vec.append(param.view(-1)) + return torch.cat(vec) + + +def vector_to_parameters(vec: torch.Tensor, parameters: Iterable[torch.Tensor]) -> None: + r"""Copy slices of a vector into an iterable of parameters. + + Args: + vec (Tensor): a single vector representing the parameters of a model. + parameters (Iterable[Tensor]): an iterable of Tensors that are the + parameters of a model. + """ + # Ensure vec of type Tensor + if not isinstance(vec, torch.Tensor): + raise TypeError(f'expected torch.Tensor, but got: {torch.typename(vec)}') + # Flag for the device where the parameter is located + param_device = None + + # Pointer for slicing the vector for each parameter + pointer = 0 + for param in parameters: + # Ensure the parameters are located in the same device + param_device = _check_param_device(param, param_device) + + # The length of the parameter + num_param = param.numel() + # Slice the vector, reshape it, and replace the old data of the parameter + param.data = vec[pointer:pointer + num_param].view_as(param).data + + # Increment the pointer + pointer += num_param + + +def _check_param_device(param: torch.Tensor, old_param_device: Optional[int]) -> int: + r"""Check if the parameters are located on the same device. + + Currently, the conversion between model parameters and single vector form is not supported + for multiple allocations, e.g. parameters in different GPUs/PrivateUse1s, or mixture of CPU/GPU/PrivateUse1. + + Args: + param ([Tensor]): a Tensor of a parameter of a model + old_param_device (int): the device where the first parameter of a + model is allocated. + + Returns: + old_param_device (int): report device for the first time + """ + # Meet the first parameter + support_device_types = ["cuda", torch._C._get_privateuse1_backend_name()] + if old_param_device is None: + old_param_device = param.get_device() if param.device.type in support_device_types else -1 + else: + warn = False + if param.device.type in support_device_types: # Check if in same GPU/PrivateUse1 + warn = (param.get_device() != old_param_device) + else: # Check if in CPU + warn = (old_param_device != -1) + if warn: + raise TypeError('Found two parameters on different devices, ' + 'this is currently not supported.') + return old_param_device diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/fusion.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..9433d9c376df81787e91a4ca4dd3698107f32bc5 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/fusion.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import copy +from typing import Optional, Tuple, TypeVar + +import torch + +__all__ = ['fuse_conv_bn_eval', 'fuse_conv_bn_weights', 'fuse_linear_bn_eval', 'fuse_linear_bn_weights'] + +ConvT = TypeVar("ConvT", bound="torch.nn.modules.conv._ConvNd") +LinearT = TypeVar("LinearT", bound="torch.nn.Linear") + +def fuse_conv_bn_eval(conv: ConvT, bn: torch.nn.modules.batchnorm._BatchNorm, transpose: bool = False) -> ConvT: + r"""Fuse a convolutional module and a BatchNorm module into a single, new convolutional module. + + Args: + conv (torch.nn.modules.conv._ConvNd): A convolutional module. + bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module. + transpose (bool, optional): If True, transpose the convolutional weight. Defaults to False. + + Returns: + torch.nn.modules.conv._ConvNd: The fused convolutional module. + + .. note:: + Both ``conv`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed. + """ + assert not (conv.training or bn.training), "Fusion only for eval!" + fused_conv = copy.deepcopy(conv) + + assert bn.running_mean is not None and bn.running_var is not None + fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights( + fused_conv.weight, fused_conv.bias, + bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias, transpose) + + return fused_conv + +def fuse_conv_bn_weights( + conv_w: torch.Tensor, + conv_b: Optional[torch.Tensor], + bn_rm: torch.Tensor, + bn_rv: torch.Tensor, + bn_eps: float, + bn_w: Optional[torch.Tensor], + bn_b: Optional[torch.Tensor], + transpose: bool = False +) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]: + r"""Fuse convolutional module parameters and BatchNorm module parameters into new convolutional module parameters. + + Args: + conv_w (torch.Tensor): Convolutional weight. + conv_b (Optional[torch.Tensor]): Convolutional bias. + bn_rm (torch.Tensor): BatchNorm running mean. + bn_rv (torch.Tensor): BatchNorm running variance. + bn_eps (float): BatchNorm epsilon. + bn_w (Optional[torch.Tensor]): BatchNorm weight. + bn_b (Optional[torch.Tensor]): BatchNorm bias. + transpose (bool, optional): If True, transpose the conv weight. Defaults to False. + + Returns: + Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused convolutional weight and bias. + """ + conv_weight_dtype = conv_w.dtype + conv_bias_dtype = conv_b.dtype if conv_b is not None else conv_weight_dtype + if conv_b is None: + conv_b = torch.zeros_like(bn_rm) + if bn_w is None: + bn_w = torch.ones_like(bn_rm) + if bn_b is None: + bn_b = torch.zeros_like(bn_rm) + bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) + + if transpose: + shape = [1, -1] + [1] * (len(conv_w.shape) - 2) + else: + shape = [-1, 1] + [1] * (len(conv_w.shape) - 2) + + fused_conv_w = (conv_w * (bn_w * bn_var_rsqrt).reshape(shape)).to(dtype=conv_weight_dtype) + fused_conv_b = ((conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b).to(dtype=conv_bias_dtype) + + return ( + torch.nn.Parameter(fused_conv_w, conv_w.requires_grad), torch.nn.Parameter(fused_conv_b, conv_b.requires_grad) + ) + +def fuse_linear_bn_eval(linear: LinearT, bn: torch.nn.modules.batchnorm._BatchNorm) -> LinearT: + r"""Fuse a linear module and a BatchNorm module into a single, new linear module. + + Args: + linear (torch.nn.Linear): A Linear module. + bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module. + + Returns: + torch.nn.Linear: The fused linear module. + + .. note:: + Both ``linear`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed. + """ + assert not (linear.training or bn.training), "Fusion only for eval!" + fused_linear = copy.deepcopy(linear) + + """ + Linear-BN needs to be fused while preserving the shapes of linear weight/bias. + To preserve the shapes of linear weight/bias, the channel dim of bn needs to be broadcastable with the last dim of linear, + because bn operates over the channel dim, (N, C_in, H, W) while linear operates over the last dim, (*, H_in). + To be broadcastable, the number of features in bn and + the number of output features from linear must satisfy the following condition: + 1. they are equal, or + 2. the number of features in bn is 1 + Otherwise, skip the folding path + """ + assert ( + linear.out_features == bn.num_features or bn.num_features == 1 + ), "To fuse, linear.out_features == bn.num_features or bn.num_features == 1" + + assert bn.running_mean is not None and bn.running_var is not None + fused_linear.weight, fused_linear.bias = fuse_linear_bn_weights( + fused_linear.weight, fused_linear.bias, + bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias) + + return fused_linear + +def fuse_linear_bn_weights( + linear_w: torch.Tensor, + linear_b: Optional[torch.Tensor], + bn_rm: torch.Tensor, + bn_rv: torch.Tensor, + bn_eps: float, + bn_w: torch.Tensor, + bn_b: torch.Tensor, +) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]: + r"""Fuse linear module parameters and BatchNorm module parameters into new linear module parameters. + + Args: + linear_w (torch.Tensor): Linear weight. + linear_b (Optional[torch.Tensor]): Linear bias. + bn_rm (torch.Tensor): BatchNorm running mean. + bn_rv (torch.Tensor): BatchNorm running variance. + bn_eps (float): BatchNorm epsilon. + bn_w (torch.Tensor): BatchNorm weight. + bn_b (torch.Tensor): BatchNorm bias. + transpose (bool, optional): If True, transpose the conv weight. Defaults to False. + + Returns: + Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused linear weight and bias. + """ + if linear_b is None: + linear_b = torch.zeros_like(bn_rm) + bn_scale = bn_w * torch.rsqrt(bn_rv + bn_eps) + + fused_w = linear_w * bn_scale.unsqueeze(-1) + fused_b = (linear_b - bn_rm) * bn_scale + bn_b + + return torch.nn.Parameter(fused_w, linear_w.requires_grad), torch.nn.Parameter(fused_b, linear_b.requires_grad) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/init.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/init.py new file mode 100644 index 0000000000000000000000000000000000000000..416ad0db8ef7ef64301614184f611a52c1a01e31 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/init.py @@ -0,0 +1,53 @@ +import inspect +import torch + + +def skip_init(module_cls, *args, **kwargs): + r""" + Given a module class object and args / kwargs, instantiate the module without initializing parameters / buffers. + + This can be useful if initialization is slow or if custom initialization will + be performed, making the default initialization unnecessary. There are some caveats to this, due to + the way this function is implemented: + + 1. The module must accept a `device` arg in its constructor that is passed to any parameters + or buffers created during construction. + + 2. The module must not perform any computation on parameters in its constructor except + initialization (i.e. functions from :mod:`torch.nn.init`). + + If these conditions are satisfied, the module can be instantiated with parameter / buffer values + uninitialized, as if having been created using :func:`torch.empty`. + + Args: + module_cls: Class object; should be a subclass of :class:`torch.nn.Module` + args: args to pass to the module's constructor + kwargs: kwargs to pass to the module's constructor + + Returns: + Instantiated module with uninitialized parameters / buffers + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> import torch + >>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1) + >>> m.weight + Parameter containing: + tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]], + requires_grad=True) + >>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1) + >>> m2.weight + Parameter containing: + tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24, + 4.5915e-41]], requires_grad=True) + + """ + if not issubclass(module_cls, torch.nn.Module): + raise RuntimeError(f'Expected a Module; got {module_cls}') + if 'device' not in inspect.signature(module_cls).parameters: + raise RuntimeError('Module must support a \'device\' arg to skip initialization') + + final_device = kwargs.pop('device', 'cpu') + kwargs['device'] = 'meta' + return module_cls(*args, **kwargs).to_empty(device=final_device) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/memory_format.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/memory_format.py new file mode 100644 index 0000000000000000000000000000000000000000..c8fc22bea51cfc47006d1918d977afd2c4f3310b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/memory_format.py @@ -0,0 +1,143 @@ +import torch + + +def convert_conv2d_weight_memory_format(module, memory_format): + r"""Convert ``memory_format`` of ``nn.Conv2d.weight`` to ``memory_format``. + + The conversion recursively applies to nested ``nn.Module``, including ``module``. + Note that it only changes the memory_format, but not the semantics of each dimensions. + This function is used to facilitate the computation to adopt NHWC kernels, which + provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0 + + .. note:: + Calling ``model.to(memory_format=torch.channels_last)`` is more aggressive + than the utility function ``convert_conv2d_weight_memory_format``. Any + layer with 4d weight will be affected by ``model.to``, which does not + necessarily benefit from conversion to specified ``memory_format``. + One place we are confident in is that NHWC(channels_last) conversion for + convolution in cuDNN, As it is beneficial to run convolution in NHWC, + even in cases where we have to apply permutation to input tensors. + + Hence our strategy here is to convert only the weight of convolution to + channels_last. This ensures that; + 1. Fast convolution kernels will be used, the benefit of which could + outweigh overhead of permutation (if input is not in the same format) + 2. No unnecessary permutations are applied on layers that do not benefit + from memory_format conversion. + + The optimal case is that, layers between convolution layers are channels + last compatible. Input tensor would be permuted to channels last when it + encounters the first convolution layer and stay in that memory format. + Hence following convolutions will not need to permute its input tensor. + + In case where a channels last incompatible layer is between convolution + layers, we need to permute the input tensor back to contiguous format + for that layer. The input tensor will go through the remaining layers in + contiguous format and be permuted to channels last when it encounters + another convolution layer. There's no point in propagating that + permutation to an earlier layer, as most layers are quite agnostic to + ``memory_format``. + + This claim might change when PyTorch supports fusion of permutation, as + there might have been a better spot to fuse the permutation other than + immediately before a convolution. + + Args: + module (nn.Module): ``nn.Conv2d`` & ``nn.ConvTranspose2d`` or container + ``nn.Module`` + memory_format: user specified ``memory_format``, + e.g. ``torch.channels_last`` or ``torch.contiguous_format`` + + Returns: + The original module with updated ``nn.Conv2d`` + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG) + >>> input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda") + >>> model = nn.Sequential( + >>> nn.Conv2d(8, 4, 3)).cuda().half() + >>> # This is identical to: + >>> # nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last) + >>> model = nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last) + >>> out = model(input) + """ + # TODO: expand this to `_ConvNd` when channels_last support is extended + # beyond only 4d tensors. + if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): + weight_data = module.weight.detach().clone().contiguous(memory_format=memory_format) + module.weight.data = weight_data.resize_(weight_data.size(), memory_format=memory_format) + for child in module.children(): + convert_conv2d_weight_memory_format(child, memory_format) + return module + + +def convert_conv3d_weight_memory_format(module, memory_format): + r"""Convert ``memory_format`` of ``nn.Conv3d.weight`` to ``memory_format`` + The conversion recursively applies to nested ``nn.Module``, including ``module``. + Note that it only changes the memory_format, but not the semantics of each dimensions. + This function is used to facilitate the computation to adopt NHWC kernels, which + provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0 + + .. note:: + Calling ``model.to(memory_format=torch.channels_last)`` is more aggressive + than the utility function ``convert_conv3d_weight_memory_format``. Any + layer with 4d weight will be affected by ``model.to``, which does not + necessarily benefit from conversion to specified ``memory_format``. + One place we are confident in is that NHWC(channels_last) conversion for + convolution in cuDNN, As it is beneficial to run convolution in NHWC, + even in cases where we have to apply permutation to input tensors. + + Hence our strategy here is to convert only the weight of convolution to + channels_last. This ensures that; + 1. Fast convolution kernels will be used, the benefit of which could + outweigh overhead of permutation (if input is not in the same format) + 2. No unnecessary permutations are applied on layers that do not benefit + from memory_format conversion. + + The optimal case is that, layers between convolution layers are channels + last compatible. Input tensor would be permuted to channels last when it + encounters the first convolution layer and stay in that memory format. + Hence following convolutions will not need to permute its input tensor. + + In case where a channels last incompatible layer is between convolution + layers, we need to permute the input tensor back to contiguous format + for that layer. The input tensor will go through the remaining layers in + contiguous format and be permuted to channels last when it encounters + another convolution layer. There's no point in propagating that + permutation to an earlier layer, as most layers are quite agnostic to + ``memory_format``. + + This claim might change when PyTorch supports fusion of permutation, as + there might have been a better spot to fuse the permutation other than + immediately before a convolution. + + Args: + module (nn.Module): ``nn.Conv3d`` & ``nn.ConvTranspose3d`` or container + ``nn.Module`` + memory_format: user specified ``memory_format``, + e.g. ``torch.channels_last`` or ``torch.contiguous_format`` + + Returns: + The original module with updated ``nn.Conv3d`` + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG) + >>> input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda") + >>> model = nn.Sequential( + >>> nn.Conv3d(8, 4, 3)).cuda().half() + >>> # This is identical to: + >>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last) + >>> model = nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last) + >>> out = model(input) + """ + + # TODO: expand this to `_ConvNd` when channels_last support is extended + # beyond only 4d tensors. + if isinstance(module, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)): + weight_data = module.weight.detach().clone().contiguous(memory_format=memory_format) + module.weight.data = weight_data.resize_(weight_data.size(), memory_format=memory_format) + for child in module.children(): + convert_conv3d_weight_memory_format(child, memory_format) + return module diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/parametrizations.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/parametrizations.py new file mode 100644 index 0000000000000000000000000000000000000000..f9b25bcac0cb7bbc67b8f99bfc24960b2e54b8f7 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/parametrizations.py @@ -0,0 +1,571 @@ +from enum import Enum, auto + +import torch +from torch import Tensor +from ..utils import parametrize +from ..modules import Module +from .. import functional as F + +from typing import Optional + +__all__ = ['orthogonal', 'spectral_norm', 'weight_norm'] + + +def _is_orthogonal(Q, eps=None): + n, k = Q.size(-2), Q.size(-1) + Id = torch.eye(k, dtype=Q.dtype, device=Q.device) + # A reasonable eps, but not too large + eps = 10. * n * torch.finfo(Q.dtype).eps + return torch.allclose(Q.mH @ Q, Id, atol=eps) + + +def _make_orthogonal(A): + """Assume that A is a tall matrix. + + Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative. + """ + X, tau = torch.geqrf(A) + Q = torch.linalg.householder_product(X, tau) + # The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs + Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2) + return Q + + +class _OrthMaps(Enum): + matrix_exp = auto() + cayley = auto() + householder = auto() + + +class _Orthogonal(Module): + base: Tensor + + def __init__(self, + weight, + orthogonal_map: _OrthMaps, + *, + use_trivialization=True) -> None: + super().__init__() + + # Note [Householder complex] + # For complex tensors, it is not possible to compute the tensor `tau` necessary for + # linalg.householder_product from the reflectors. + # To see this, note that the reflectors have a shape like: + # 0 0 0 + # * 0 0 + # * * 0 + # which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters + # to parametrize the unitary matrices. Saving tau on its own does not work either, because + # not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise + # them as independent tensors we would not maintain the constraint + # An equivalent reasoning holds for rectangular matrices + if weight.is_complex() and orthogonal_map == _OrthMaps.householder: + raise ValueError("The householder parametrization does not support complex tensors.") + + self.shape = weight.shape + self.orthogonal_map = orthogonal_map + if use_trivialization: + self.register_buffer("base", None) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + n, k = X.size(-2), X.size(-1) + transposed = n < k + if transposed: + X = X.mT + n, k = k, n + # Here n > k and X is a tall matrix + if self.orthogonal_map == _OrthMaps.matrix_exp or self.orthogonal_map == _OrthMaps.cayley: + # We just need n x k - k(k-1)/2 parameters + X = X.tril() + if n != k: + # Embed into a square matrix + X = torch.cat([X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1) + A = X - X.mH + # A is skew-symmetric (or skew-hermitian) + if self.orthogonal_map == _OrthMaps.matrix_exp: + Q = torch.matrix_exp(A) + elif self.orthogonal_map == _OrthMaps.cayley: + # Computes the Cayley retraction (I+A/2)(I-A/2)^{-1} + Id = torch.eye(n, dtype=A.dtype, device=A.device) + Q = torch.linalg.solve(torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5)) + # Q is now orthogonal (or unitary) of size (..., n, n) + if n != k: + Q = Q[..., :k] + # Q is now the size of the X (albeit perhaps transposed) + else: + # X is real here, as we do not support householder with complex numbers + A = X.tril(diagonal=-1) + tau = 2. / (1. + (A * A).sum(dim=-2)) + Q = torch.linalg.householder_product(A, tau) + # The diagonal of X is 1's and -1's + # We do not want to differentiate through this or update the diagonal of X hence the casting + Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2) + + if hasattr(self, "base"): + Q = self.base @ Q + if transposed: + Q = Q.mT + return Q # type: ignore[possibly-undefined] + + @torch.autograd.no_grad() + def right_inverse(self, Q: torch.Tensor) -> torch.Tensor: + if Q.shape != self.shape: + raise ValueError(f"Expected a matrix or batch of matrices of shape {self.shape}. " + f"Got a tensor of shape {Q.shape}.") + + Q_init = Q + n, k = Q.size(-2), Q.size(-1) + transpose = n < k + if transpose: + Q = Q.mT + n, k = k, n + + # We always make sure to always copy Q in every path + if not hasattr(self, "base"): + # Note [right_inverse expm cayley] + # If we do not have use_trivialization=True, we just implement the inverse of the forward + # map for the Householder. To see why, think that for the Cayley map, + # we would need to find the matrix X \in R^{n x k} such that: + # Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1) + # A = Y - Y.mH + # cayley(A)[:, :k] + # gives the original tensor. It is not clear how to do this. + # Perhaps via some algebraic manipulation involving the QR like that of + # Corollary 2.2 in Edelman, Arias and Smith? + if self.orthogonal_map == _OrthMaps.cayley or self.orthogonal_map == _OrthMaps.matrix_exp: + raise NotImplementedError("It is not possible to assign to the matrix exponential " + "or the Cayley parametrizations when use_trivialization=False.") + + # If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition. + # Here Q is always real because we do not support householder and complex matrices. + # See note [Householder complex] + A, tau = torch.geqrf(Q) + # We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could + # decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition + # The diagonal of Q is the diagonal of R from the qr decomposition + A.diagonal(dim1=-2, dim2=-1).sign_() + # Equality with zero is ok because LAPACK returns exactly zero when it does not want + # to use a particular reflection + A.diagonal(dim1=-2, dim2=-1)[tau == 0.] *= -1 + return A.mT if transpose else A + else: + if n == k: + # We check whether Q is orthogonal + if not _is_orthogonal(Q): + Q = _make_orthogonal(Q) + else: # Is orthogonal + Q = Q.clone() + else: + # Complete Q into a full n x n orthogonal matrix + N = torch.randn(*(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype, device=Q.device) + Q = torch.cat([Q, N], dim=-1) + Q = _make_orthogonal(Q) + self.base = Q + + # It is necessary to return the -Id, as we use the diagonal for the + # Householder parametrization. Using -Id makes: + # householder(torch.zeros(m,n)) == torch.eye(m,n) + # Poor man's version of eye_like + neg_Id = torch.zeros_like(Q_init) + neg_Id.diagonal(dim1=-2, dim2=-1).fill_(-1.) + return neg_Id + + +def orthogonal(module: Module, + name: str = 'weight', + orthogonal_map: Optional[str] = None, + *, + use_trivialization: bool = True) -> Module: + r"""Apply an orthogonal or unitary parametrization to a matrix or a batch of matrices. + + Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized + matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as + + .. math:: + + \begin{align*} + Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\ + QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n} + \end{align*} + + where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex + and the transpose when :math:`Q` is real-valued, and + :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. + In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n` + and orthonormal rows otherwise. + + If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`. + + The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor: + + - ``"matrix_exp"``/``"cayley"``: + the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_ + :math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric + :math:`A` to give an orthogonal matrix. + - ``"householder"``: computes a product of Householder reflectors + (:func:`~torch.linalg.householder_product`). + + ``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than + ``"householder"``, but they are slower to compute for very thin or very wide matrices. + + If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework", + where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under + ``module.parametrizations.weight[0].base``. This helps the + convergence of the parametrized layer at the expense of some extra memory use. + See `Trivializations for Gradient-Based Optimization on Manifolds`_ . + + Initial value of :math:`Q`: + If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value + of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case) + and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`). + Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``. + Otherwise, the initial value is the result of the composition of all the registered + parametrizations applied to the original tensor. + + .. note:: + This function is implemented using the parametrization functionality + in :func:`~torch.nn.utils.parametrize.register_parametrization`. + + + .. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map + .. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501 + + Args: + module (nn.Module): module on which to register the parametrization. + name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``. + orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``. + Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise. + use_trivialization (bool, optional): whether to use the dynamic trivialization framework. + Default: ``True``. + + Returns: + The original module with an orthogonal parametrization registered to the specified + weight + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> orth_linear = orthogonal(nn.Linear(20, 40)) + >>> orth_linear + ParametrizedLinear( + in_features=20, out_features=40, bias=True + (parametrizations): ModuleDict( + (weight): ParametrizationList( + (0): _Orthogonal() + ) + ) + ) + >>> # xdoctest: +IGNORE_WANT + >>> Q = orth_linear.weight + >>> torch.dist(Q.T @ Q, torch.eye(20)) + tensor(4.9332e-07) + """ + weight = getattr(module, name, None) + if not isinstance(weight, Tensor): + raise ValueError( + f"Module '{module}' has no parameter or buffer with name '{name}'" + ) + + # We could implement this for 1-dim tensors as the maps on the sphere + # but I believe it'd bite more people than it'd help + if weight.ndim < 2: + raise ValueError("Expected a matrix or batch of matrices. " + f"Got a tensor of {weight.ndim} dimensions.") + + if orthogonal_map is None: + orthogonal_map = "matrix_exp" if weight.size(-2) == weight.size(-1) or weight.is_complex() else "householder" + + orth_enum = getattr(_OrthMaps, orthogonal_map, None) + if orth_enum is None: + raise ValueError('orthogonal_map has to be one of "matrix_exp", "cayley", "householder". ' + f'Got: {orthogonal_map}') + orth = _Orthogonal(weight, + orth_enum, + use_trivialization=use_trivialization) + parametrize.register_parametrization(module, name, orth, unsafe=True) + return module + + +class _WeightNorm(Module): + def __init__( + self, + dim: Optional[int] = 0, + ) -> None: + super().__init__() + if dim is None: + dim = -1 + self.dim = dim + + def forward(self, weight_g, weight_v): + return torch._weight_norm(weight_v, weight_g, self.dim) + + def right_inverse(self, weight): + weight_g = torch.norm_except_dim(weight, 2, self.dim) + weight_v = weight + + return weight_g, weight_v + + +def weight_norm(module: Module, name: str = 'weight', dim: int = 0): + r"""Apply weight normalization to a parameter in the given module. + + .. math:: + \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} + + Weight normalization is a reparameterization that decouples the magnitude + of a weight tensor from its direction. This replaces the parameter specified + by :attr:`name` with two parameters: one specifying the magnitude + and one specifying the direction. + + By default, with ``dim=0``, the norm is computed independently per output + channel/plane. To compute a norm over the entire weight tensor, use + ``dim=None``. + + See https://arxiv.org/abs/1602.07868 + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + dim (int, optional): dimension over which to compute the norm + + Returns: + The original module with the weight norm hook + + Example:: + + >>> m = weight_norm(nn.Linear(20, 40), name='weight') + >>> m + ParametrizedLinear( + in_features=20, out_features=40, bias=True + (parametrizations): ModuleDict( + (weight): ParametrizationList( + (0): _WeightNorm() + ) + ) + ) + >>> m.parametrizations.weight.original0.size() + torch.Size([40, 1]) + >>> m.parametrizations.weight.original1.size() + torch.Size([40, 20]) + + """ + _weight_norm = _WeightNorm(dim) + parametrize.register_parametrization(module, name, _weight_norm, unsafe=True) + + def _weight_norm_compat_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + g_key = f"{prefix}{name}_g" + v_key = f"{prefix}{name}_v" + if g_key in state_dict and v_key in state_dict: + original0 = state_dict.pop(g_key) + original1 = state_dict.pop(v_key) + state_dict[f"{prefix}parametrizations.{name}.original0"] = original0 + state_dict[f"{prefix}parametrizations.{name}.original1"] = original1 + module._register_load_state_dict_pre_hook(_weight_norm_compat_hook) + return module + + +class _SpectralNorm(Module): + def __init__( + self, + weight: torch.Tensor, + n_power_iterations: int = 1, + dim: int = 0, + eps: float = 1e-12 + ) -> None: + super().__init__() + ndim = weight.ndim + if dim >= ndim or dim < -ndim: + raise IndexError("Dimension out of range (expected to be in range of " + f"[-{ndim}, {ndim - 1}] but got {dim})") + + if n_power_iterations <= 0: + raise ValueError('Expected n_power_iterations to be positive, but ' + f'got n_power_iterations={n_power_iterations}') + self.dim = dim if dim >= 0 else dim + ndim + self.eps = eps + if ndim > 1: + # For ndim == 1 we do not need to approximate anything (see _SpectralNorm.forward) + self.n_power_iterations = n_power_iterations + weight_mat = self._reshape_weight_to_matrix(weight) + h, w = weight_mat.size() + + u = weight_mat.new_empty(h).normal_(0, 1) + v = weight_mat.new_empty(w).normal_(0, 1) + self.register_buffer('_u', F.normalize(u, dim=0, eps=self.eps)) + self.register_buffer('_v', F.normalize(v, dim=0, eps=self.eps)) + + # Start with u, v initialized to some reasonable values by performing a number + # of iterations of the power method + self._power_method(weight_mat, 15) + + def _reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor: + # Precondition + assert weight.ndim > 1 + + if self.dim != 0: + # permute dim to front + weight = weight.permute(self.dim, *(d for d in range(weight.dim()) if d != self.dim)) + + return weight.flatten(1) + + @torch.autograd.no_grad() + def _power_method(self, weight_mat: torch.Tensor, n_power_iterations: int) -> None: + # See original note at torch/nn/utils/spectral_norm.py + # NB: If `do_power_iteration` is set, the `u` and `v` vectors are + # updated in power iteration **in-place**. This is very important + # because in `DataParallel` forward, the vectors (being buffers) are + # broadcast from the parallelized module to each module replica, + # which is a new module object created on the fly. And each replica + # runs its own spectral norm power iteration. So simply assigning + # the updated vectors to the module this function runs on will cause + # the update to be lost forever. And the next time the parallelized + # module is replicated, the same randomly initialized vectors are + # broadcast and used! + # + # Therefore, to make the change propagate back, we rely on two + # important behaviors (also enforced via tests): + # 1. `DataParallel` doesn't clone storage if the broadcast tensor + # is already on correct device; and it makes sure that the + # parallelized module is already on `device[0]`. + # 2. If the out tensor in `out=` kwarg has correct shape, it will + # just fill in the values. + # Therefore, since the same power iteration is performed on all + # devices, simply updating the tensors in-place will make sure that + # the module replica on `device[0]` will update the _u vector on the + # parallelized module (by shared storage). + # + # However, after we update `u` and `v` in-place, we need to **clone** + # them before using them to normalize the weight. This is to support + # backproping through two forward passes, e.g., the common pattern in + # GAN training: loss = D(real) - D(fake). Otherwise, engine will + # complain that variables needed to do backward for the first forward + # (i.e., the `u` and `v` vectors) are changed in the second forward. + + # Precondition + assert weight_mat.ndim > 1 + + for _ in range(n_power_iterations): + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` + # are the first left and right singular vectors. + # This power iteration produces approximations of `u` and `v`. + self._u = F.normalize(torch.mv(weight_mat, self._v), # type: ignore[has-type] + dim=0, eps=self.eps, out=self._u) # type: ignore[has-type] + self._v = F.normalize(torch.mv(weight_mat.H, self._u), + dim=0, eps=self.eps, out=self._v) # type: ignore[has-type] + + def forward(self, weight: torch.Tensor) -> torch.Tensor: + if weight.ndim == 1: + # Faster and more exact path, no need to approximate anything + return F.normalize(weight, dim=0, eps=self.eps) + else: + weight_mat = self._reshape_weight_to_matrix(weight) + if self.training: + self._power_method(weight_mat, self.n_power_iterations) + # See above on why we need to clone + u = self._u.clone(memory_format=torch.contiguous_format) + v = self._v.clone(memory_format=torch.contiguous_format) + # The proper way of computing this should be through F.bilinear, but + # it seems to have some efficiency issues: + # https://github.com/pytorch/pytorch/issues/58093 + sigma = torch.vdot(u, torch.mv(weight_mat, v)) + return weight / sigma + + def right_inverse(self, value: torch.Tensor) -> torch.Tensor: + # we may want to assert here that the passed value already + # satisfies constraints + return value + + +def spectral_norm(module: Module, + name: str = 'weight', + n_power_iterations: int = 1, + eps: float = 1e-12, + dim: Optional[int] = None) -> Module: + r"""Apply spectral normalization to a parameter in the given module. + + .. math:: + \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, + \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} + + When applied on a vector, it simplifies to + + .. math:: + \mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2} + + Spectral normalization stabilizes the training of discriminators (critics) + in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant + of the model. :math:`\sigma` is approximated performing one iteration of the + `power method`_ every time the weight is accessed. If the dimension of the + weight tensor is greater than 2, it is reshaped to 2D in power iteration + method to get spectral norm. + + + See `Spectral Normalization for Generative Adversarial Networks`_ . + + .. _`power method`: https://en.wikipedia.org/wiki/Power_iteration + .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 + + .. note:: + This function is implemented using the parametrization functionality + in :func:`~torch.nn.utils.parametrize.register_parametrization`. It is a + reimplementation of :func:`torch.nn.utils.spectral_norm`. + + .. note:: + When this constraint is registered, the singular vectors associated to the largest + singular value are estimated rather than sampled at random. These are then updated + performing :attr:`n_power_iterations` of the `power method`_ whenever the tensor + is accessed with the module on `training` mode. + + .. note:: + If the `_SpectralNorm` module, i.e., `module.parametrization.weight[idx]`, + is in training mode on removal, it will perform another power iteration. + If you'd like to avoid this iteration, set the module to eval mode + before its removal. + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter. Default: ``"weight"``. + n_power_iterations (int, optional): number of power iterations to + calculate spectral norm. Default: ``1``. + eps (float, optional): epsilon for numerical stability in + calculating norms. Default: ``1e-12``. + dim (int, optional): dimension corresponding to number of outputs. + Default: ``0``, except for modules that are instances of + ConvTranspose{1,2,3}d, when it is ``1`` + + Returns: + The original module with a new parametrization registered to the specified + weight + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> snm = spectral_norm(nn.Linear(20, 40)) + >>> snm + ParametrizedLinear( + in_features=20, out_features=40, bias=True + (parametrizations): ModuleDict( + (weight): ParametrizationList( + (0): _SpectralNorm() + ) + ) + ) + >>> torch.linalg.matrix_norm(snm.weight, 2) + tensor(1.0081, grad_fn=) + """ + weight = getattr(module, name, None) + if not isinstance(weight, Tensor): + raise ValueError( + f"Module '{module}' has no parameter or buffer with name '{name}'" + ) + + if dim is None: + if isinstance(module, (torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d)): + dim = 1 + else: + dim = 0 + parametrize.register_parametrization(module, name, _SpectralNorm(weight, n_power_iterations, dim, eps)) + return module diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/stateless.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/stateless.py new file mode 100644 index 0000000000000000000000000000000000000000..ae7ebcdf3df7f00cc9bde5b108b81c65eb0f884b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/stateless.py @@ -0,0 +1,263 @@ +import contextlib +import warnings +from collections import defaultdict +from typing import Any, Dict, Iterator, Optional, Set, Tuple, Union + +import torch +from torch import Tensor +from torch.nn.utils._named_member_accessor import NamedMemberAccessor + +__all__ = ["functional_call"] + + +def _untie_named_tensors_map( + module: "torch.nn.Module", + parameters_and_buffers: Dict[str, Tensor], +) -> Dict[str, Tensor]: + """ + Unties all tied tensors in the module to parameters_and_buffers. + + This function returns a new untied_parameters_and_buffers dictionary and leave the original + untied_parameters_and_buffers dictionary unchanged. It adds new (missing) keys for tied tensors + in the module to untied_parameters_and_buffers. The value of the new key is the user-given value + in the original parameters_and_buffers dictionary. + + If there are more than one user-given values for the same tied tensor, it will raise an error. + + For example, if the module has two tied weights self.foo and self.tied_foo and the user passes + {'foo': foo_value, ...}, this will return {'foo': foo_value, 'tied_foo': foo_value, ...}. If the + user passes {'foo': foo_value, 'tied_foo': tied_foo_value, ...}, it will raise an error. If the + user passes {'foo': foo_value, 'tied_foo': foo_value, ...}, it will not raise an error. + + Args: + module (torch.nn.Module): the module to determine which tensors are tied. + parameters_and_buffers (Dict[str, Tensor]): a map of {name: tensor} for reparamaterizing the module. + + Returns: + A new untied version of the parameters_and_buffers dictionary. + + Raises: + ValueError: if there are more than one user-given values for the same tied tensor. + """ + # A map of {name: tensor} for all tensors (including tied ones) in the module. + all_named_tensors: Dict[str, Tensor] = {} + all_named_tensors.update(module.named_parameters(remove_duplicate=False)) + all_named_tensors.update(module.named_buffers(remove_duplicate=False)) + + # A map of {tensor: set(all_tied_names)} for all tensor names in the module. + tensor_to_tied_names_map: Dict[Tensor, Set[str]] = defaultdict(set) + for name, tensor in all_named_tensors.items(): + tensor_to_tied_names_map[tensor].add(name) + + # A map of {tied_name: set(all_tied_names)} for all tensor names in the module. + # If a name is not tied, it will not be in this map. + tied_names_map: Dict[str, Set[str]] = {} + for tied_names in tensor_to_tied_names_map.values(): + if len(tied_names) > 1: + for tied_name in tied_names: + tied_names_map[tied_name] = tied_names + + # Make sure the user didn't pass multiple values for the same tied tensor. + given_names = set(parameters_and_buffers.keys()) + given_names_for_tied_tensors = given_names.intersection(tied_names_map.keys()) + for given_name in given_names_for_tied_tensors: + tied_names = tied_names_map[given_name] + if ( + # Detect if there are multiple keys present for the same tied tensor. + len(tied_names.intersection(given_names_for_tied_tensors)) > 1 + # Only raise an error if the user passed multiple values for the same tied tensor. + # If all given values are the same, don't raise. + and len({parameters_and_buffers[tied_name] for tied_name in tied_names}) + != 1 + ): + raise ValueError( + f"functional_call got multiple values for keys {sorted(tied_names)}, " + f"which are tied. Consider using tie_weights=False" + ) + + # Untie the given named tensor map + # Make a copy for not modifying the original dict + untied_parameters_and_buffers = parameters_and_buffers.copy() + for given_name in given_names_for_tied_tensors: + for tied_name in tied_names_map[given_name]: + untied_parameters_and_buffers[tied_name] = parameters_and_buffers[ + given_name + ] + return untied_parameters_and_buffers + + +@contextlib.contextmanager +def _reparametrize_module( + module: "torch.nn.Module", + parameters_and_buffers: Dict[str, Tensor], + *, + tie_weights: bool = False, + strict: bool = False, +) -> Iterator[None]: + if tie_weights: + untied_parameters_and_buffers = _untie_named_tensors_map( + module, parameters_and_buffers + ) + else: + untied_parameters_and_buffers = parameters_and_buffers + + accessor = NamedMemberAccessor(module) + if strict: + missing_keys, unexpected_keys = accessor.check_keys( + untied_parameters_and_buffers + ) + error_msgs = [] + if len(unexpected_keys) > 0: + error_msgs.append( + f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}." + ) + if len(missing_keys) > 0: + error_msgs.append(f"Missing key(s): {', '.join(map(repr, missing_keys))}.") + if len(error_msgs) > 0: + raise RuntimeError( + "Error(s) in reparametrizing for {}:\n\t{}".format( + module._get_name(), "\n\t".join(error_msgs) + ) + ) + + orig_parameters_and_buffers: Dict[str, Tensor] = {} + try: + orig_parameters_and_buffers, _ = accessor.swap_tensors_dict( + untied_parameters_and_buffers, allow_missing=True + ) + yield + finally: + new_parameters_and_buffers, _ = accessor.swap_tensors_dict( + orig_parameters_and_buffers, allow_missing=True + ) + # Sometimes the module is not completely stateless and has some in-place modifications on + # the _parameters and _buffers dictionaries. + # Write the changed parameters and buffers back to the original dict. + parameters_and_buffers.update( + { + k: new_parameters_and_buffers[k] + for k in parameters_and_buffers + if k in new_parameters_and_buffers + } + ) + + +def functional_call( + module: "torch.nn.Module", + parameters_and_buffers: Dict[str, Tensor], + args: Union[Any, Tuple], + kwargs: Optional[Dict[str, Any]] = None, + *, + tie_weights: bool = True, + strict: bool = False, +): + r"""Perform a functional call on the module by replacing the module parameters and buffers with the provided ones. + + .. warning:: + + This API is deprecated as of PyTorch 2.0 and will be removed in a future + version of PyTorch. Please use :func:`torch.func.functional_call` instead, + which is a drop-in replacement for this API. + + .. note:: If the module has active parametrizations, passing a value in the + :attr:`parameters_and_buffers` argument with the name set to the regular parameter + name will completely disable the parametrization. + If you want to apply the parametrization function to the value passed + please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``. + + .. note:: If the module performs in-place operations on parameters/buffers, these will be reflected + in the `parameters_and_buffers` input. + + Example:: + + >>> a = {'foo': torch.zeros(())} + >>> # xdoctest: +SKIP + >>> mod = Foo() # does self.foo = self.foo + 1 + >>> print(mod.foo) # tensor(0.) + >>> functional_call(mod, a, torch.ones(())) + >>> print(mod.foo) # tensor(0.) + >>> print(a['foo']) # tensor(1.) + + .. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the + tie_weights flag. + + Example:: + + >>> a = {'foo': torch.zeros(())} + >>> # xdoctest: +SKIP + >>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied + >>> print(mod.foo) # tensor(1.) + >>> mod(torch.zeros(())) # tensor(2.) + >>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too + >>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated + >>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())} + >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.) + + Args: + module (torch.nn.Module): the module to call + parameters_and_buffers (dict of str and Tensor): the parameters that will be used in + the module call. + args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument. + kwargs (dict): keyword arguments to be passed to the module call + tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as + tied in the reparamaterized version. Therefore, if True and different values are passed for the tied + parameters and buffers, it will error. If False, it will not respect the originally tied parameters and + buffers unless the values passed for both weights are the same. Default: True. + strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and + buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will + error. Default: False. + + Returns: + Any: the result of calling ``module``. + """ + warnings.warn( + "This API is deprecated as of PyTorch 2.0 and will be removed in a future " + "version of PyTorch. Please use torch.func.functional_call instead " + "which is a drop-in replacement for this API." + ) + + return _functional_call( + module, + parameters_and_buffers, + args, + kwargs, + tie_weights=tie_weights, + strict=strict, + ) + + +def _functional_call( + module: "torch.nn.Module", + parameters_and_buffers: Dict[str, Tensor], + args: Union[Any, Tuple], + kwargs: Optional[Dict[str, Any]] = None, + *, + tie_weights: bool = True, + strict: bool = False, +): + # TODO allow kwargs such as unsafe and others for parametrization + if ( + torch.jit.is_tracing() + or torch.jit.is_scripting() + or isinstance( + module, + ( + torch.jit.RecursiveScriptModule, + torch.jit.ScriptModule, + torch.jit.ScriptFunction, + ), + ) + ): + raise RuntimeError("The stateless API can't be used with Jitted modules") + if isinstance(module, torch.nn.DataParallel): + raise RuntimeError( + "The stateless API can't be used with nn.DataParallel module" + ) + if kwargs is None: + kwargs = {} + if not isinstance(args, tuple): + args = (args,) + with _reparametrize_module( + module, parameters_and_buffers, tie_weights=tie_weights, strict=strict + ): + return module(*args, **kwargs) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/weight_norm.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/weight_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..942a13a4eb83c4bac35f69f61bddf6ea6ca4645c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/weight_norm.py @@ -0,0 +1,151 @@ +r"""Weight Normalization from https://arxiv.org/abs/1602.07868.""" +from torch.nn.parameter import Parameter, UninitializedParameter +from torch import _weight_norm, norm_except_dim +from typing import Any, TypeVar +import warnings +from ..modules import Module + +__all__ = ['WeightNorm', 'weight_norm', 'remove_weight_norm'] + +class WeightNorm: + name: str + dim: int + + def __init__(self, name: str, dim: int) -> None: + if dim is None: + dim = -1 + self.name = name + self.dim = dim + + # TODO Make return type more specific + def compute_weight(self, module: Module) -> Any: + g = getattr(module, self.name + '_g') + v = getattr(module, self.name + '_v') + return _weight_norm(v, g, self.dim) + + @staticmethod + def apply(module, name: str, dim: int) -> 'WeightNorm': + warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.") + + for hook in module._forward_pre_hooks.values(): + if isinstance(hook, WeightNorm) and hook.name == name: + raise RuntimeError(f"Cannot register two weight_norm hooks on the same parameter {name}") + + if dim is None: + dim = -1 + + fn = WeightNorm(name, dim) + + weight = getattr(module, name) + if isinstance(weight, UninitializedParameter): + raise ValueError( + 'The module passed to `WeightNorm` can\'t have uninitialized parameters. ' + 'Make sure to run the dummy forward before applying weight normalization') + # remove w from parameter list + del module._parameters[name] + + # add g and v as new parameters and express w as g/||v|| * v + module.register_parameter(name + '_g', Parameter(norm_except_dim(weight, 2, dim).data)) + module.register_parameter(name + '_v', Parameter(weight.data)) + setattr(module, name, fn.compute_weight(module)) + + # recompute weight before every forward() + module.register_forward_pre_hook(fn) + + return fn + + def remove(self, module: Module) -> None: + weight = self.compute_weight(module) + delattr(module, self.name) + del module._parameters[self.name + '_g'] + del module._parameters[self.name + '_v'] + setattr(module, self.name, Parameter(weight.data)) + + def __call__(self, module: Module, inputs: Any) -> None: + setattr(module, self.name, self.compute_weight(module)) + + +T_module = TypeVar('T_module', bound=Module) + +def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_module: + r"""Apply weight normalization to a parameter in the given module. + + .. math:: + \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} + + Weight normalization is a reparameterization that decouples the magnitude + of a weight tensor from its direction. This replaces the parameter specified + by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude + (e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``). + Weight normalization is implemented via a hook that recomputes the weight + tensor from the magnitude and direction before every :meth:`~Module.forward` + call. + + By default, with ``dim=0``, the norm is computed independently per output + channel/plane. To compute a norm over the entire weight tensor, use + ``dim=None``. + + See https://arxiv.org/abs/1602.07868 + + .. warning:: + + This function is deprecated. Use :func:`torch.nn.utils.parametrizations.weight_norm` + which uses the modern parametrization API. The new ``weight_norm`` is compatible + with ``state_dict`` generated from old ``weight_norm``. + + Migration guide: + + * The magnitude (``weight_g``) and direction (``weight_v``) are now expressed + as ``parametrizations.weight.original0`` and ``parametrizations.weight.original1`` + respectively. If this is bothering you, please comment on + https://github.com/pytorch/pytorch/issues/102999 + + * To remove the weight normalization reparametrization, use + :func:`torch.nn.utils.parametrize.remove_parametrizations`. + + * The weight is no longer recomputed once at module forward; instead, it will + be recomputed on every access. To restore the old behavior, use + :func:`torch.nn.utils.parametrize.cached` before invoking the module + in question. + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + dim (int, optional): dimension over which to compute the norm + + Returns: + The original module with the weight norm hook + + Example:: + + >>> m = weight_norm(nn.Linear(20, 40), name='weight') + >>> m + Linear(in_features=20, out_features=40, bias=True) + >>> m.weight_g.size() + torch.Size([40, 1]) + >>> m.weight_v.size() + torch.Size([40, 20]) + + """ + WeightNorm.apply(module, name, dim) + return module + + +def remove_weight_norm(module: T_module, name: str = 'weight') -> T_module: + r"""Remove the weight normalization reparameterization from a module. + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + + Example: + >>> m = weight_norm(nn.Linear(20, 40)) + >>> remove_weight_norm(m) + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, WeightNorm) and hook.name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + return module + + raise ValueError(f"weight_norm of '{name}' not found in {module}")