diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/error.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/error.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..100ece580da536c5b0c9d7e4b4b5cd4b101990b5 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/error.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa6d65e774c9e1cadb558fdd1c656f98e5a6e7b1 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10a55772ab58b21573a6eba0356ddd3080164ac7 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d438095a9efe1cec7456cc1b02204a6b7fab314c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/gen_example.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/gen_example.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbe8d83c8cc50222021ced00344dd53628805347 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/gen_example.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/logging.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/logging.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc8f44805d47793dee629aee380371b94d9f0c10 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/logging.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/case.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/case.py new file mode 100644 index 0000000000000000000000000000000000000000..6c4c03572e3ab3c0c7ed9ff9f816ceac3b725051 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/case.py @@ -0,0 +1,188 @@ +import inspect +import re +import string +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Set, Tuple, Union +from types import ModuleType + +import torch + +_TAGS: Dict[str, Dict[str, Any]] = { + "torch": { + "cond": {}, + "dynamic-shape": {}, + "escape-hatch": {}, + "map": {}, + "dynamic-value": {}, + "operator": {}, + "mutation": {}, + }, + "python": { + "assert": {}, + "builtin": {}, + "closure": {}, + "context-manager": {}, + "control-flow": {}, + "data-structure": {}, + "standard-library": {}, + "object-model": {}, + }, +} + + +class SupportLevel(Enum): + """ + Indicates at what stage the feature + used in the example is handled in export. + """ + + SUPPORTED = 1 + NOT_SUPPORTED_YET = 0 + + +class ExportArgs: + __slots__ = ("args", "kwargs") + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + +InputsType = Union[Tuple[Any, ...], ExportArgs] + + +def check_inputs_type(x): + if not isinstance(x, (ExportArgs, tuple)): + raise ValueError( + f"Expecting inputs type to be either a tuple, or ExportArgs, got: {type(x)}" + ) + + +def _validate_tag(tag: str): + parts = tag.split(".") + t = _TAGS + for part in parts: + assert set(part) <= set( + string.ascii_lowercase + "-" + ), f"Tag contains invalid characters: {part}" + if part in t: + t = t[part] + else: + raise ValueError(f"Tag {tag} is not found in registered tags.") + + +@dataclass(frozen=True) +class ExportCase: + example_inputs: InputsType + description: str # A description of the use case. + model: torch.nn.Module + name: str + extra_inputs: Optional[InputsType] = None # For testing graph generalization. + # Tags associated with the use case. (e.g dynamic-shape, escape-hatch) + tags: Set[str] = field(default_factory=set) + support_level: SupportLevel = SupportLevel.SUPPORTED + dynamic_shapes: Optional[Dict[str, Any]] = None + + def __post_init__(self): + check_inputs_type(self.example_inputs) + if self.extra_inputs is not None: + check_inputs_type(self.extra_inputs) + + for tag in self.tags: + _validate_tag(tag) + + if not isinstance(self.description, str) or len(self.description) == 0: + raise ValueError(f'Invalid description: "{self.description}"') + + +_EXAMPLE_CASES: Dict[str, ExportCase] = {} +_MODULES: Set[ModuleType] = set() +_EXAMPLE_CONFLICT_CASES: Dict[str, List[ExportCase]] = {} +_EXAMPLE_REWRITE_CASES: Dict[str, List[ExportCase]] = {} + + +def register_db_case(case: ExportCase) -> None: + """ + Registers a user provided ExportCase into example bank. + """ + if case.name in _EXAMPLE_CASES: + if case.name not in _EXAMPLE_CONFLICT_CASES: + _EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]] + _EXAMPLE_CONFLICT_CASES[case.name].append(case) + return + + _EXAMPLE_CASES[case.name] = case + + +def to_snake_case(name): + name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() + + +def _make_export_case(m, name, configs): + if not issubclass(m, torch.nn.Module): + raise TypeError("Export case class should be a torch.nn.Module.") + m = m() + + if "description" not in configs: + # Fallback to docstring if description is missing. + assert ( + m.__doc__ is not None + ), f"Could not find description or docstring for export case: {m}" + configs = {**configs, "description": m.__doc__} + return ExportCase(**{**configs, "model": m, "name": name}) + + +def export_case(**kwargs): + """ + Decorator for registering a user provided case into example bank. + """ + + def wrapper(m): + configs = kwargs + module = inspect.getmodule(m) + if module in _MODULES: + raise RuntimeError("export_case should only be used once per example file.") + + assert module is not None + _MODULES.add(module) + normalized_name = to_snake_case(m.__name__) + module_name = module.__name__.split(".")[-1] + if module_name != normalized_name: + raise RuntimeError( + f'Module name "{module.__name__}" is inconsistent with exported program ' + + f'name "{m.__name__}". Please rename the module to "{normalized_name}".' + ) + + case = _make_export_case(m, module_name, configs) + register_db_case(case) + return case + + return wrapper + + +def export_rewrite_case(**kwargs): + def wrapper(m): + configs = kwargs + + parent = configs.pop("parent") + assert isinstance(parent, ExportCase) + key = parent.name + if key not in _EXAMPLE_REWRITE_CASES: + _EXAMPLE_REWRITE_CASES[key] = [] + + configs["example_inputs"] = parent.example_inputs + case = _make_export_case(m, to_snake_case(m.__name__), configs) + _EXAMPLE_REWRITE_CASES[key].append(case) + return case + + return wrapper + + +def normalize_inputs(x: InputsType) -> ExportArgs: + if isinstance(x, tuple): + return ExportArgs(*x) + + assert isinstance(x, ExportArgs) + return x diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..423d0a899751eedc255d5760df4b7fc92303b7b4 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a220802c2339c8bf3c6ed3b5a3978a09f260dee2 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c94808f70a100ae39d68f27b5a1d3d9fbfbca600 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adbf8cd2199b881e700a791532fe9acdb7d396ec Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..925f950367e6667651495d1e575a64db30d6b8d3 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..434db9913edb453e3f6baef85459feb9c1163553 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1785cbb7c9d45bc42f45f801cff5020ae42de7ab Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af5ff3c64d325a235cfd368bf13b565c98885000 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cdc9876e704553139deb7e6d81cfb5c1bd39893 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c9586330073ae68f6064caff998b5f1315d38df Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d4c7d1dc170b8dcad107e5534f05c4e809e1f68 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0ca4e1d4e16b4e6fefb5836414be640887de67c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7796714fbfc81985d251a5f01dfca7316075d9c6 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49e2cdc08a9d71090aca5b6d3cd4df96aaa6afa2 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10cf9cda1699facb959f660f401663a1a0e047ee Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..097ceec64c1a430e950cd0abfd4dd0bcde6b6c52 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4af0995a470cdd7155321e7bbca051dc529508ad Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adfa19b10db0dabfc8c01e92ef43d7650b50687c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc8ae78a41c70c9581656c6145a6deaa36fac51a Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/autograd_function.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/autograd_function.py new file mode 100644 index 0000000000000000000000000000000000000000..9c8aeadc45ae291f363bb4850b30bab4fb14214d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/autograd_function.py @@ -0,0 +1,26 @@ +import torch + +from torch._export.db.case import export_case + + +class MyAutogradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x.clone() + + @staticmethod + def backward(ctx, grad_output): + return grad_output + 1 + + +@export_case( + example_inputs=(torch.randn(3, 2),), +) +class AutogradFunction(torch.nn.Module): + """ + TorchDynamo does not keep track of backward() on autograd functions. We recommend to + use `allow_in_graph` to mitigate this problem. + """ + + def forward(self, x): + return MyAutogradFunction.apply(x) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_class_method.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_class_method.py new file mode 100644 index 0000000000000000000000000000000000000000..68dd3772684d1c8ea784a5d74214895dedeeb530 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_class_method.py @@ -0,0 +1,46 @@ +import torch + +from torch._export.db.case import export_case +from functorch.experimental.control_flow import cond + + +class MySubModule(torch.nn.Module): + def foo(self, x): + return x.cos() + + def forward(self, x): + return self.foo(x) + + +@export_case( + example_inputs=(torch.ones(3),), + tags={ + "torch.cond", + "torch.dynamic-shape", + }, +) +class CondBranchClassMethod(torch.nn.Module): + """ + The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: + - both branches must take the same args, which must also match the branch args passed to cond. + - both branches must return a single tensor + - returned tensor must have the same tensor metadata, e.g. shape and dtype + - branch function can be free function, nested function, lambda, class methods + - branch function can not have closure variables + - no inplace mutations on inputs or global variables + + + This example demonstrates using class method in cond(). + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def __init__(self): + super().__init__() + self.subm = MySubModule() + + def bar(self, x): + return x.sin() + + def forward(self, x): + return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x]) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py new file mode 100644 index 0000000000000000000000000000000000000000..38905b57e31243e10e52193ab36a8503ba4991f4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py @@ -0,0 +1,63 @@ +import torch + +from torch._export.db.case import export_case +from functorch.experimental.control_flow import cond + + +@export_case( + example_inputs=(torch.ones(6),), + tags={ + "torch.cond", + "torch.dynamic-shape", + }, +) +class CondBranchNonlocalVariables(torch.nn.Module): + """ + The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: + - both branches must take the same args, which must also match the branch args passed to cond. + - both branches must return a single tensor + - returned tensor must have the same tensor metadata, e.g. shape and dtype + - branch function can be free function, nested function, lambda, class methods + - branch function can not have closure variables + - no inplace mutations on inputs or global variables + + This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions. + + The code below will not work because capturing closure variables is not supported. + ``` + my_tensor_var = x + 100 + my_primitive_var = 3.14 + + def true_fn(y): + nonlocal my_tensor_var, my_primitive_var + return y + my_tensor_var + my_primitive_var + + def false_fn(y): + nonlocal my_tensor_var, my_primitive_var + return y - my_tensor_var - my_primitive_var + + return cond(x.shape[0] > 5, true_fn, false_fn, [x]) + ``` + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + my_tensor_var = x + 100 + my_primitive_var = 3.14 + + def true_fn(x, y, z): + return x + y + z + + def false_fn(x, y, z): + return x - y - z + + return cond( + x.shape[0] > 5, + true_fn, + false_fn, + [x, my_tensor_var, torch.tensor(my_primitive_var)], + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/decorator.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..39eff84af34812e1a31006c698652ec6dc2bbd20 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/decorator.py @@ -0,0 +1,26 @@ +import functools + +import torch + +from torch._export.db.case import export_case + + +def test_decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + 1 + + return wrapper + + +@export_case( + example_inputs=(torch.ones(3, 2), torch.ones(3, 2)), +) +class Decorator(torch.nn.Module): + """ + Decorators calls are inlined into the exported function during tracing. + """ + + @test_decorator + def forward(self, x, y): + return x + y diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_assert.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_assert.py new file mode 100644 index 0000000000000000000000000000000000000000..ec95df0bd97dda4e673e7898a1072db8215f8310 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_assert.py @@ -0,0 +1,22 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2),), + tags={"python.assert"}, +) +class DynamicShapeAssert(torch.nn.Module): + """ + A basic usage of python assertion. + """ + def __init__(self): + super().__init__() + + def forward(self, x): + # assertion with error message + assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2" + # assertion without error message + assert x.shape[0] > 1 + return x diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..51b8dd57252529079411cf2db2b4a14a4b905634 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py @@ -0,0 +1,19 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2),), + tags={"torch.dynamic-shape"}, +) +class DynamicShapeConstructor(torch.nn.Module): + """ + Tensor constructors should be captured with dynamic shape inputs rather + than being baked in with static shape. + """ + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.ones(x.shape[0] * 2) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_map.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_map.py new file mode 100644 index 0000000000000000000000000000000000000000..5be0003fd170abb49afc80544229177d4b8b8de4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_map.py @@ -0,0 +1,23 @@ +import torch + +from torch._export.db.case import export_case +from functorch.experimental.control_flow import map + + +@export_case( + example_inputs=(torch.ones(3, 2), torch.ones(2)), + tags={"torch.dynamic-shape", "torch.map"}, +) +class DynamicShapeMap(torch.nn.Module): + """ + functorch map() maps a function over the first tensor dimension. + """ + + def __init__(self): + super().__init__() + + def forward(self, xs, y): + def body(x, y): + return x + y + + return map(body, xs, y) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_round.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_round.py new file mode 100644 index 0000000000000000000000000000000000000000..7d6a50320f5baba5843e6e4831789c1993b5e6ed --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/dynamic_shape_round.py @@ -0,0 +1,24 @@ +import torch + +from torch._export.db.case import export_case, SupportLevel +from torch.export import Dim + +x = torch.ones(3, 2) +dim0_x = Dim("dim0_x") + +@export_case( + example_inputs=(x,), + tags={"torch.dynamic-shape", "python.builtin"}, + support_level=SupportLevel.NOT_SUPPORTED_YET, + dynamic_shapes={"x": {0: dim0_x}}, +) +class DynamicShapeRound(torch.nn.Module): + """ + Calling round on dynamic shapes is not supported. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + return x[: round(x.shape[0] / 2)] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/fn_with_kwargs.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/fn_with_kwargs.py new file mode 100644 index 0000000000000000000000000000000000000000..6182a747955561fc8bba1a4e3c3e6187e987c135 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/fn_with_kwargs.py @@ -0,0 +1,32 @@ +import torch + +from torch._export.db.case import export_case, ExportArgs, SupportLevel + + +@export_case( + example_inputs=ExportArgs( + torch.randn(4), + (torch.randn(4), torch.randn(4)), + *[torch.randn(4), torch.randn(4)], + mykw0=torch.randn(4), + input0=torch.randn(4), input1=torch.randn(4) + ), + tags={"python.data-structure"}, + support_level=SupportLevel.SUPPORTED, +) +class FnWithKwargs(torch.nn.Module): + """ + Keyword arguments are not supported at the moment. + """ + def __init__(self): + super().__init__() + + def forward(self, pos0, tuple0, *myargs, mykw0, **mykwargs): + out = pos0 + for arg in tuple0: + out = out * arg + for arg in myargs: + out = out * arg + out = out * mykw0 + out = out * mykwargs["input0"] * mykwargs["input1"] + return out diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/nested_function.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/nested_function.py new file mode 100644 index 0000000000000000000000000000000000000000..58b946f94a0c28447501a1d1a1fd4c98405d49d2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/nested_function.py @@ -0,0 +1,27 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2), torch.ones(2)), + tags={"python.closure"}, +) +class NestedFunction(torch.nn.Module): + """ + Nested functions are traced through. Side effects on global captures + are not supported though. + """ + def __init__(self): + super().__init__() + + def forward(self, a, b): + x = a + b + z = a - b + + def closure(y): + nonlocal x + x += 1 + return x * y + z + + return closure(x) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/null_context_manager.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/null_context_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..1689537db833a90bf09122221dde47aad79ebf34 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/null_context_manager.py @@ -0,0 +1,26 @@ +import contextlib + +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2),), + tags={"python.context-manager"}, +) +class NullContextManager(torch.nn.Module): + """ + Null context manager in Python will be traced out. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + """ + Null context manager in Python will be traced out. + """ + ctx = contextlib.nullcontext() + with ctx: + return x.sin() + x.cos() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/optional_input.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/optional_input.py new file mode 100644 index 0000000000000000000000000000000000000000..4a06207b6eaf8f24d673c7ec227c3a5643c2d6a3 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/optional_input.py @@ -0,0 +1,19 @@ +import torch + +from torch._export.db.case import export_case, SupportLevel + + +@export_case( + example_inputs=(torch.randn(2, 3),), + tags={"python.object-model"}, + support_level=SupportLevel.NOT_SUPPORTED_YET, +) +class OptionalInput(torch.nn.Module): + """ + Tracing through optional input is not supported yet + """ + + def forward(self, x, y=torch.ones(2, 3)): + if y is not None: + return x + y + return x diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/torch_sym_min.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/torch_sym_min.py new file mode 100644 index 0000000000000000000000000000000000000000..b9f4dd8f8496ccfd6c81b7007a96d9a05e6ffce5 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/torch_sym_min.py @@ -0,0 +1,17 @@ +import torch + +from torch._export.db.case import export_case, SupportLevel + + +@export_case( + example_inputs=(torch.ones(3, 2),), + tags={"torch.operator"}, + support_level=SupportLevel.NOT_SUPPORTED_YET, +) +class TorchSymMin(torch.nn.Module): + """ + torch.sym_min operator is not supported in export. + """ + + def forward(self, x): + return x.sum() + torch.sym_min(x.size(0), 100) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/user_input_mutation.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/user_input_mutation.py new file mode 100644 index 0000000000000000000000000000000000000000..2bb16cd64a56fce4c4ccfdbb257f32f11514439c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/user_input_mutation.py @@ -0,0 +1,18 @@ +import torch + +from torch._export.db.case import export_case, SupportLevel + + +@export_case( + example_inputs=(torch.ones(3, 2),), + tags={"torch.mutation"}, + support_level=SupportLevel.SUPPORTED, +) +class UserInputMutation(torch.nn.Module): + """ + Directly mutate user input in forward + """ + + def forward(self, x): + x.mul_(2) + return x.cos() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/logging.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..fc412b8c5082dd8c4346711314fc7cc43c1a9ba2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/logging.py @@ -0,0 +1,2 @@ +def exportdb_error_message(case_name: str): + return "" diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..479802d2843b71d33ba3991f2b628360d64ec93c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5742a264b8a0b311a83e4f13a2cf26699f60a8fc Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/node_metadata.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/node_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa9b8093c370dd565dfb7fb44e4b22474446af0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/node_metadata.py @@ -0,0 +1,32 @@ +from typing import Any, Dict, Set + + +NodeMetadataValue = Any + + +PROTECTED_KEYS: Set[str] = { + "val", + "stack_trace", + "nn_module_stack", + "debug_handle", + "tensor_meta", +} + + +class NodeMetadata: + def __init__(self, data: Dict[str, Any]) -> None: + self.data: Dict[str, Any] = data.copy() + + def __getitem__(self, key: str) -> NodeMetadataValue: + return self.data[key] + + def __setitem__(self, key: str, value: NodeMetadataValue) -> NodeMetadataValue: + if key in PROTECTED_KEYS: + raise RuntimeError(f"Could not override node key: {key}") + self.data[key] = value + + def __contains__(self, key: str) -> bool: + return key in self.data + + def copy(self) -> "NodeMetadata": + return NodeMetadata(self.data.copy()) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa9ce2ac03c23600c86ff02e38a2a4bfeefef9e2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__init__.py @@ -0,0 +1 @@ +from .replace_view_ops_with_view_copy_ops_pass import ReplaceViewOpsWithViewCopyOpsPass diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3d18ad5c640b526e102a8ec3152c51481691ddb Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a990b435b3b09e5289141a6646590bfbd80d4b63 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_sym_size_ops_pass.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_sym_size_ops_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0df575a539d8d17a9c4783c9aa90c4041390891 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_sym_size_ops_pass.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07e963a7ac8bac8414ecdb7936f7b70fcc75e9af Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/collect_tracepoints_pass.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/collect_tracepoints_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..6a2b9c674859f4eefd56033cf37536a1b532ae65 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/collect_tracepoints_pass.py @@ -0,0 +1,66 @@ +import operator + +import torch + +from torch.export.exported_program import ConstantArgument, TensorArgument +from torch.fx.passes.infra.pass_base import PassBase, PassResult + +__all__ = ["CollectTracepointsPass"] + + +class CollectTracepointsPass(PassBase): + """ + Performs constant folding and constant propagation. + """ + + def __init__(self, specs, sig) -> None: + super().__init__() + self.specs = specs + self.sig = sig + + def call(self, gm): + def get_arg_spec(arg): + if isinstance(arg, torch.fx.Node): + if isinstance(arg.meta.get("val"), torch.Tensor): + return TensorArgument(name=arg.name) + else: + raise AssertionError( + "Symint input is not implemented yet for submodule call signature." + ) + else: + return ConstantArgument(value=arg) + + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.op != "call_function": + continue + if node.target == torch.ops.higher_order._export_tracepoint: + for i, arg in enumerate(node.args): + kind = node.kwargs["kind"] + if kind == "module_call_inputs": + self.specs[node.kwargs["path"]].inputs.append( + get_arg_spec(arg) + ) + elif kind == "module_call_outputs": + self.specs[node.kwargs["path"]].outputs.append( + get_arg_spec(arg) + ) + else: + raise AssertionError(f"Unknown tracepoint kind: {kind}") + if isinstance(arg, torch.fx.Node): + for user in node.users: + assert user.op == "call_function" + assert user.target == operator.getitem + assert isinstance(user.args[1], int) + if user.args[1] == i: + user.replace_all_uses_with(arg) + self.sig.replace_all_uses(user.name, arg.name) + break + users = list(node.users) + for user in users: + assert len(user.users) == 0 + gm.graph.erase_node(user) + gm.graph.erase_node(node) + return PassResult(gm, True) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/lift_constants_pass.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/lift_constants_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..5f93eabdc2b5d8cea145ebc8399ccc3e2c5a7816 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/lift_constants_pass.py @@ -0,0 +1,248 @@ +import collections +from typing import Any, Dict, Union + +import torch +from torch._export.verifier import SpecViolationError +from torch._guards import detect_fake_mode +from torch.export.exported_program import ( + ArgumentSpec, + CustomObjArgument, + ExportGraphSignature, + InputKind, + InputSpec, + TensorArgument, +) + + +class ConstantAttrMap(collections.abc.MutableMapping): + """A mapping class that understands how to use module constants (tensors and + ScriptObjects) as keys. We store tensors normally, but ScriptObjects are + stored by hash, because different torch.ScriptObjects can point to the same + underlying value (but we guarantee that they will `hash()` to the same value + if that's the case). + """ + + def __init__(self): + # Underlying dict that we use to implement this mapping. + self._constant_attrs: Dict[Union[int, torch.Tensor], Any] = {} + # Map from the hash(ScriptObject) to the ScriptObject itself. Used for + # APIs like `__iter__` that should look like they're returning the + # original ScriptObjects. + self._script_object_map: Dict[int, torch.ScriptObject] = {} + + def __getitem__(self, key: Union[torch.Tensor, torch.ScriptObject]) -> Any: + real_key = hash(key) if isinstance(key, torch.ScriptObject) else key + assert isinstance(real_key, (int, torch.Tensor)) + return self._constant_attrs[real_key] + + def __setitem__( + self, key: Union[torch.Tensor, torch.ScriptObject], value: Any + ) -> None: + if isinstance(key, torch.ScriptObject): + self._constant_attrs[hash(key)] = value + self._script_object_map[hash(key)] = key + elif isinstance(key, torch.Tensor): + self._constant_attrs[key] = value + else: + raise TypeError( + f"Expected key to be a tensor or ScriptObject, got {type(key)}" + ) + + def __delitem__(self, key): + real_key = hash(key) if isinstance(key, torch.ScriptObject) else key + + del self._constant_attrs[real_key] + + def __iter__(self): + for key in self._constant_attrs: + if isinstance(key, int): + yield self._script_object_map[key] + else: + yield key + + def __len__(self): + return len(self._constant_attrs) + + def __contains__(self, key: object) -> bool: + real_key = hash(key) if isinstance(key, torch.ScriptObject) else key + return real_key in self._constant_attrs + + +def get_constant_fqn(node: torch.fx.Node, constant_name: str) -> str: + # The FQN of the constant tensor in the state dict should + # correspond to the module where the constant tensor was + # originally used. + parent_fqn = list(node.meta["nn_module_stack"].values())[-1][0] + if len(parent_fqn) > 0: + return f"{parent_fqn}.{constant_name}" + else: + return constant_name + + +def lift_constants_pass( + gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + constant_attrs: ConstantAttrMap, +) -> Dict[str, Union[torch.Tensor, torch._C.ScriptObject]]: + """ + Takes a graph module, graph signature, and modifies them implace to lift any + constants (tensors or custom classes) as inputs to the graph. Returns a + dictionary of names to constants. + + Arguments: + gm (torch.fx.GraphModule): The graph module containing the graph and constants to lift. + graph_signature (ExportGraphSignature): This graph signature will be + mutated to add additional CONSTANT_TENSOR and CUSTOM_OBJ inputs. + constant_attrs (ConstantAttr): A mapping from a constant value to its + fully-qualified path in `gm`. This is used to maintain consistent + location of constants between the original module and the exported + version. + + Returns: + A dictionary of fqn => constant value. + """ + all_constants: Dict[str, Union[torch.Tensor, torch._C.ScriptObject]] = {} + + inputs = graph_signature.input_specs + num_custom_obj = sum( + input_specs.kind == InputKind.CUSTOM_OBJ for input_specs in inputs + ) + num_tensor_constants = sum( + input_specs.kind == InputKind.CONSTANT_TENSOR for input_specs in inputs + ) + + fake_mode = detect_fake_mode( + tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder") + ) + + first_user_input_loc, first_user_input = 0, None + for node in gm.graph.nodes: + if node.op == "placeholder" and node.name in graph_signature.user_inputs: + first_user_input = node + break + first_user_input_loc += 1 + + lifted_objs = ConstantAttrMap() + for node in gm.graph.nodes: + if node.op == "get_attr": + constant_val = getattr(gm, node.target) + if constant_val in lifted_objs: + # We already lifted this constant elsewhere. Just rewrite uses + # of this get_attr to point to the already-existing placeholder + # node. + const_placeholder_node = lifted_objs[constant_val] + node.replace_all_uses_with(const_placeholder_node) + gm.graph.erase_node(node) + continue + + # For ScriptObject and Tensor constants: + # First check if the constant was an attribute on some module by + # consulting `constant_attrs` map. If it is, use the fqn that keeps + # its location consistent with the eager module. + # + # If it's not in the `constant_attrs` map, that means it's an inline + # constant (e.g. x + torch.tensor(0)), and thus did not have a + # specific location in the eager module. In that case, just generate + # some name and attach it to the module in which it was used. + if isinstance(constant_val, torch.ScriptObject): + constant_kind = InputKind.CUSTOM_OBJ + constant_fqn = constant_attrs.get(constant_val) + if constant_fqn is not None: + _, _, constant_name = constant_fqn.rpartition(".") + else: + constant_name = f"_lifted_custom_obj{num_custom_obj}" + constant_fqn = get_constant_fqn(node, constant_name) + num_custom_obj += 1 + elif isinstance(constant_val, torch.Tensor): + constant_kind = InputKind.CONSTANT_TENSOR + constant_fqn = constant_attrs.get(constant_val) + if constant_fqn is not None: + _, _, constant_name = constant_fqn.rpartition(".") + else: + constant_name = f"_lifted_tensor_constant{num_tensor_constants}" + constant_fqn = get_constant_fqn(node, constant_name) + num_tensor_constants += 1 + elif isinstance(constant_val, torch.fx.GraphModule): + continue + elif "LoweredBackendModule" in type(constant_val).__name__: + continue + else: + raise SpecViolationError( + f"getattr node {node} referencing unsupported type {type(constant_val)}" + ) + + with gm.graph.inserting_before(first_user_input): + # Insert the constant node before the first user input + const_placeholder_node = gm.graph.placeholder(constant_name) + # match target name with its node name in case there is name collision + # and suffix is added to node name in fx + const_placeholder_node.target = const_placeholder_node.name + + for k, v in node.meta.items(): + const_placeholder_node.meta[k] = v + + input_spec_arg: ArgumentSpec + if isinstance(constant_val, torch.Tensor): + if fake_mode is not None: + const_placeholder_node.meta["val"] = fake_mode.from_tensor( + constant_val, static_shapes=True + ) + const_placeholder_node.meta["val"].constant = constant_val + else: + const_placeholder_node.meta["val"] = constant_val + input_spec_arg = TensorArgument(name=const_placeholder_node.name) + elif isinstance(constant_val, torch._C.ScriptObject): + class_fqn = constant_val._type().qualified_name() # type: ignore[attr-defined] + const_placeholder_node.meta["val"] = CustomObjArgument( + constant_fqn, class_fqn + ) + input_spec_arg = CustomObjArgument( + name=const_placeholder_node.name, class_fqn=class_fqn + ) + else: + raise SpecViolationError( + f"tried to lift unsupported type {type(constant_val)} from node {node.format_node()}" + ) + + lifted_objs[constant_val] = const_placeholder_node + node.replace_all_uses_with(const_placeholder_node) + gm.graph.erase_node(node) + + # Add the constant as a buffer to the graph signature + graph_signature.input_specs.insert( + first_user_input_loc, + InputSpec( + kind=constant_kind, + arg=input_spec_arg, + target=constant_fqn, + ), + ) + all_constants[constant_fqn] = constant_val + first_user_input_loc += 1 + + return all_constants + + +def rewrite_script_object_meta( + gm: torch.fx.GraphModule, +) -> Dict[str, Union[torch.Tensor, torch.ScriptObject]]: + """When tracing, we produce a graph with an actual ScriptObject in the + meta["val"]. Eventually we want to change this behavior, when FakeMode infra + for ScriptObjects lands. + + For now, we rewrie meta["val"] to be a placeholder CustomObjArgument + """ + constants: Dict[str, Union[torch.Tensor, torch._C.ScriptObject]] = {} + for node in gm.graph.nodes: + if "val" not in node.meta or not isinstance( + node.meta["val"], torch.ScriptObject + ): + continue + + old_meta = node.meta["val"] + class_fqn = old_meta._type().qualified_name() # type: ignore[attr-defined] + new_meta = CustomObjArgument(node.name, class_fqn) + constants[node.name] = old_meta + node.meta["val"] = new_meta + + return constants diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa6b700d2df798d7584568995b6c884d7219f058 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/schema.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..0d6bebb71f3f854c9571e6f7507b5a715400f3c6 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/schema.py @@ -0,0 +1,346 @@ +# NOTE: This is a placeholder for iterating on export serialization schema design. +# Anything is subject to change and no guarantee is provided at this point. + +from dataclasses import dataclass, field +from enum import IntEnum +from typing import Dict, List, Optional, Tuple + +from torch._export.serde.union import _Union + +# NOTE: Please update this value if any modifications are made to the schema +SCHEMA_VERSION = (5, 1) +TREESPEC_VERSION = 1 + + +class ScalarType(IntEnum): + UNKNOWN = 0 + BYTE = 1 + CHAR = 2 + SHORT = 3 + INT = 4 + LONG = 5 + HALF = 6 + FLOAT = 7 + DOUBLE = 8 + COMPLEXHALF = 9 + COMPLEXFLOAT = 10 + COMPLEXDOUBLE = 11 + BOOL = 12 + BFLOAT16 = 13 + + +class Layout(IntEnum): + Unknown = 0 + SparseCoo = 1 + SparseCsr = 2 + SparseCsc = 3 + SparseBsr = 4 + SparseBsc = 5 + _mkldnn = 6 + Strided = 7 + + +class MemoryFormat(IntEnum): + Unknown = 0 + ContiguousFormat = 1 + ChannelsLast = 2 + ChannelsLast3d = 3 + PreserveFormat = 4 + + +@dataclass +class Device: + type: str + index: Optional[int] = None + + +@dataclass(repr=False) +class SymExprHint(_Union): + as_int: int + as_float: float + as_bool: bool + + +# This is for storing the symbolic expressions behind symints/symfloats/symbools +# For example, we can get something like +# SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4) +# if we also have the hint that s0 and s1 are both 2. +@dataclass +class SymExpr: + expr_str: str + hint: Optional[SymExprHint] = None + + +@dataclass(repr=False) +class SymInt(_Union): + as_expr: SymExpr + as_int: int + + +@dataclass(repr=False) +class SymBool(_Union): + as_expr: SymExpr + as_bool: bool + + +@dataclass +class TensorMeta: + dtype: ScalarType + sizes: List[SymInt] + requires_grad: bool + device: Device + strides: List[SymInt] + storage_offset: SymInt + layout: Layout + + +# In most cases we will use the "as_name" field to store arguments which are +# SymInts. +# The "as_int" field is used in the case where we have a list containing a mix +# of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to +# be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints +# to the "as_int" field. +@dataclass(repr=False) +class SymIntArgument(_Union): + as_name: str + as_int: int + + +# In most cases we will use the "as_name" field to store arguments which are +# SymBools. +# The "as_bool" field is used in the case where we have a list containing a mix +# of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to +# be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools +# to the "as_bool" field. +@dataclass(repr=False) +class SymBoolArgument(_Union): + as_name: str + as_bool: bool + + +@dataclass +class TensorArgument: + name: str + + +# This is use for storing the contents of a list which contain optional tensors +# (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the +# type List[OptionalTensorArgument], with tensor values seiralized to the +# "as_tensor" field, and None values serialized to the "as_none" field. +@dataclass(repr=False) +class OptionalTensorArgument(_Union): + as_tensor: str + as_none: Tuple[()] + + +@dataclass +class GraphArgument: + name: str + graph: 'Graph' + + +@dataclass +class CustomObjArgument: + name: str + class_fqn: str + + +# This is actually a union type +@dataclass(repr=False) +class Argument(_Union): + as_none: Tuple[()] + as_tensor: TensorArgument + as_tensors: List[TensorArgument] + as_int: int + as_ints: List[int] + as_float: float + as_floats: List[float] + as_string: str + as_strings: List[str] + as_sym_int: SymIntArgument + as_sym_ints: List[SymIntArgument] + as_scalar_type: ScalarType + as_memory_format: MemoryFormat + as_layout: Layout + as_device: Device + as_bool: bool + as_bools: List[bool] + as_sym_bool: SymBoolArgument + as_sym_bools: List[SymBoolArgument] + as_graph: GraphArgument + as_optional_tensors: List[OptionalTensorArgument] + as_custom_obj: CustomObjArgument + as_operator: str + + +@dataclass +class NamedArgument: + # Argument name from the operator schema + name: str + arg: Argument + + +@dataclass +class Node: + target: str + inputs: List[NamedArgument] + outputs: List[Argument] + metadata: Dict[str, str] + + +@dataclass +class Graph: + inputs: List[Argument] + outputs: List[Argument] + nodes: List[Node] + tensor_values: Dict[str, TensorMeta] + sym_int_values: Dict[str, SymInt] + sym_bool_values: Dict[str, SymBool] + # This is for deserializing the submodule graphs from higher order ops + # (ex. cond, map) where single tensor returns will just return a single + # tensor, rather than following export schema and returning a singleton + # list. + is_single_tensor_return: bool = False + custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict) + + +@dataclass +class UserInputSpec: + # Actually, only tensors and SymInts are allowed here + arg: Argument + + +@dataclass +class InputToParameterSpec: + arg: TensorArgument + parameter_name: str + + +@dataclass +class InputToBufferSpec: + arg: TensorArgument + buffer_name: str + persistent: bool + + + +@dataclass +class InputToTensorConstantSpec: + arg: TensorArgument + tensor_constant_name: str + + +@dataclass +class InputToCustomObjSpec: + arg: CustomObjArgument + custom_obj_name: str + + +@dataclass(repr=False) +class InputSpec(_Union): + user_input: UserInputSpec + parameter: InputToParameterSpec + buffer: InputToBufferSpec + tensor_constant: InputToTensorConstantSpec + custom_obj: InputToCustomObjSpec + + +@dataclass +class UserOutputSpec: + arg: Argument + + +@dataclass +class LossOutputSpec: + arg: TensorArgument + + +@dataclass +class BufferMutationSpec: + arg: TensorArgument + buffer_name: str + + +@dataclass +class GradientToParameterSpec: + arg: TensorArgument + parameter_name: str + + +@dataclass +class GradientToUserInputSpec: + arg: TensorArgument + user_input_name: str + + +@dataclass +class UserInputMutationSpec: + arg: TensorArgument + user_input_name: str + + +@dataclass(repr=False) +class OutputSpec(_Union): + user_output: UserOutputSpec + loss_output: LossOutputSpec + buffer_mutation: BufferMutationSpec + gradient_to_parameter: GradientToParameterSpec + gradient_to_user_input: GradientToUserInputSpec + user_input_mutation: UserInputMutationSpec + + +@dataclass +class GraphSignature: + input_specs: List[InputSpec] + output_specs: List[OutputSpec] + + +@dataclass +class RangeConstraint: + min_val: int + max_val: int + + +@dataclass +class ModuleCallSignature: + inputs: List[Argument] + outputs: List[Argument] + + # These are serialized by calling pytree.treespec_loads + # And deserialized by calling pytree.treespec_dumps + in_spec: str + out_spec: str + + +@dataclass +class ModuleCallEntry: + fqn: str + signature: Optional[ModuleCallSignature] = None + + +@dataclass +class GraphModule: + graph: Graph + signature: GraphSignature + # This is used for unflattening, by tracking the calling structure of all of + # the modules in order to unflatten the modules back to the eager calling + # conventions. + module_call_graph: List[ModuleCallEntry] + + +# Invariant: Every time a change is made to the schema, one of the versions +# should be upadted. +@dataclass +class SchemaVersion: + major: int # Major version number is bumped every time a breaking change is made. + minor: int # Minor version number is bumped when a compatible change is made. + + +@dataclass +class ExportedProgram: + graph_module: GraphModule + # Key is the opset namespace (ex. aten), and value is the version number + opset_version: Dict[str, int] + range_constraints: Dict[str, RangeConstraint] + schema_version: SchemaVersion + dialect: str diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/schema_check.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/schema_check.py new file mode 100644 index 0000000000000000000000000000000000000000..cde4cf1ada271ca19800f2480a9f8c203286a340 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/schema_check.py @@ -0,0 +1,285 @@ +import dataclasses +import hashlib +import re +import typing +from enum import IntEnum +from typing import Any, Dict, Optional, Union + +from torch._export.serde import schema +from torch._export.serde.union import _Union + + +class SchemaUpdateError(Exception): + pass + + +def _check(x, msg): + if not x: + raise SchemaUpdateError(msg) + + +def _staged_schema(): + ret: Dict[str, Any] = {} + defs = {} + + def _handle_aggregate(ty): + def dump_type(t): + if isinstance(t, type): + return t.__name__ + elif isinstance(t, str): + assert t in defs + return t + elif o := typing.get_origin(t): + # Lemme know if there's a better way to do this. + if o == list: + head = "List" + elif o == dict: + head = "Dict" + elif o == tuple: + if typing.get_args(t) == (): + return "Tuple[()]" + head = "Tuple" + elif o == Union: + args = typing.get_args(t) + assert len(args) == 2 and args[1] == type(None) + return f"Optional[{dump_type(args[0])}]" + else: + raise AssertionError(f"Type {t} is not supported in export schema.") + return ( + f"{head}[{', '.join([dump_type(x) for x in typing.get_args(t)])}]" + ) + elif t == (): + return "()" + else: + raise AssertionError(f"Type {t} is not supported in export schema.") + + def dump_field(f): + t = dump_type(f.type) + ret = {"type": t} + + value = dataclasses.MISSING + if f.default is not dataclasses.MISSING: + value = f.default + elif f.default_factory is not dataclasses.MISSING: + value = f.default_factory() + + if t.startswith("Optional[") and value is not None: + raise AssertionError( + f"Optional field {ty.__name__}.{f.name} must have default value to be None." + ) + + if value is not dataclasses.MISSING: + default = str(value) + ret["default"] = default + return ret + + return {f.name: dump_field(f) for f in dataclasses.fields(ty)} + + def _handle_int_enum(name, ty): + ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}} + + def _handle_struct(name, ty): + ret[name] = {"kind": "struct", "fields": _handle_aggregate(ty)} + + def _handle_union(name, ty): + ret[name] = {"kind": "union", "fields": _handle_aggregate(ty)} + + for name in dir(schema): + if name.startswith("_"): + continue + + value = getattr(schema, name) + + if hasattr(value, "__module__") and value.__module__ != schema.__name__: + continue + + defs[name] = value + + for name, value in defs.items(): + if isinstance(value, type): + if issubclass(value, IntEnum): + _handle_int_enum(name, value) + elif dataclasses.is_dataclass(value): + if issubclass(value, _Union): + _handle_union(name, value) + else: + _handle_struct(name, value) + else: + raise AssertionError(f"Unknown schema type {name}: {value}") + elif isinstance(value, (int, tuple)): + assert name in ("SCHEMA_VERSION", "TREESPEC_VERSION") + else: + raise AssertionError(f"Unknown variable {name}: {value}") + + ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"]) + assert all(x > 0 for x in ret["SCHEMA_VERSION"]) + ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"] + assert ret["TREESPEC_VERSION"] > 0 + return ret + + +def _diff_schema(dst, src): + additions = {key: src[key] for key in src.keys() - dst.keys()} + subtractions = {key: dst[key] for key in dst.keys() - src.keys()} + + common_keys = src.keys() & dst.keys() + + versions = {"SCHEMA_VERSION", "TREESPEC_VERSION"} + common_keys -= versions + + for key in common_keys: + src_kind = src[key]["kind"] + src_fields = src[key]["fields"] + dst_kind = dst[key]["kind"] + dst_fields = dst[key]["fields"] + _check( + src_kind == dst_kind, + f"Type {key} changed kind from {dst_kind} to {src_kind}", + ) + assert isinstance(src_fields, dict) and isinstance(dst_fields, dict) + added_fields = { + key: src_fields[key] for key in src_fields.keys() - dst_fields.keys() + } + subtracted_fields = { + key: dst_fields[key] for key in dst_fields.keys() - src_fields.keys() + } + common_fields = src_fields.keys() & dst_fields.keys() + + for field in common_fields: + src_field = src_fields[field] + dst_field = dst_fields[field] + if src_kind == "struct": + _check( + src_field["type"] == dst_field["type"], + f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}", + ) + if "default" in src_field and "default" not in dst_field: + added_fields[field] = {} + added_fields[field]["default"] = src_field["default"] + if "default" not in src_field and "default" in dst_field: + subtracted_fields[field] = {} + subtracted_fields[field]["default"] = dst_field["default"] + elif src_kind == "enum": + _check( + src_field == dst_field, + f"Value of the enum field {key}.{field} changed from {dst_field} to {src_field}", + ) + elif src_kind == "union": + _check( + src_field["type"] == dst_field["type"], + f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}", + ) + else: + raise AssertionError(f"Unknown kind {src_kind}: {key}") + if len(added_fields) > 0: + assert key not in additions + additions[key] = {} + additions[key]["fields"] = added_fields + if len(subtracted_fields) > 0: + assert key not in subtractions + subtractions[key] = {} + subtractions[key]["fields"] = subtracted_fields + + return additions, subtractions + + +def _hash_schema(s): + return hashlib.sha256(repr(s).encode("utf-8")).hexdigest() + + +@dataclasses.dataclass +class _Commit: + result: Dict[str, Any] + checksum_result: str + path: str + additions: Dict[str, Any] + subtractions: Dict[str, Any] + base: Dict[str, Any] + checksum_base: Optional[str] + + +def update_schema(): + import importlib.resources + + if importlib.resources.is_resource(__package__, "schema.yaml"): + content = importlib.resources.read_text(__package__, "schema.yaml") + match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content) + _check(match is not None, "checksum not found in schema.yaml") + assert match is not None + checksum_base = match.group(1) + from yaml import load, Loader + + dst = load(content, Loader=Loader) + assert isinstance(dst, dict) + else: + checksum_base = None + dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None} + + src = _staged_schema() + additions, subtractions = _diff_schema(dst, src) + return _Commit( + result=src, + checksum_result=_hash_schema(src), + path=__package__.replace(".", "/") + "/schema.yaml", + additions=additions, + subtractions=subtractions, + base=dst, + checksum_base=checksum_base, + ) + + +def check(commit: _Commit, force_unsafe: bool = False): + next_version = None + reason = "" + # Step 1: Detect major schema updates. + if len(commit.additions) > 0: + for k, v in commit.additions.items(): + if k not in commit.base: + continue + kind = commit.result[k]["kind"] + fields = v["fields"] + for f, d in fields.items(): + if "default" not in d and kind == "struct": + reason += ( + f"Field {k}.{f} is added to schema.py without a default value as an incomparible change " + + "which requires major version bump.\n" + ) + next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1] + + if len(commit.subtractions) > 0: + for k, v in commit.subtractions.items(): + if k not in commit.result: + continue + for f in v["fields"]: + reason = f"Field {k}.{f} is removed from schema.py as an incompatible change which requires major version bump.\n" + next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1] + + if force_unsafe: + reason += "--force-unsafe is used." + next_version = commit.result["SCHEMA_VERSION"] + else: + # Step 2: Detect minor schema updates. + if next_version is None and len(commit.additions) > 0: + for k, v in commit.additions.items(): + for f in v["fields"]: + reason += ( + f"Field {k}.{f} is added to schema.py as an compatible change " + + "which still requires minor version bump.\n" + ) + next_version = [ + commit.base["SCHEMA_VERSION"][0], + commit.base["SCHEMA_VERSION"][1] + 1, + ] + if next_version is None and len(commit.subtractions) > 0: + for k, v in commit.subtractions.items(): + for f in v["fields"]: + reason += ( + f"Field {k}.{f} is removed from schema.py as an compatible change " + + "which still requires minor version bump.\n" + ) + next_version = [ + commit.base["SCHEMA_VERSION"][0], + commit.base["SCHEMA_VERSION"][1] + 1, + ] + + return next_version, reason diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py new file mode 100644 index 0000000000000000000000000000000000000000..01625ec63c327df1f0986680d2d5fe349f211b0d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py @@ -0,0 +1,2434 @@ +import base64 +import copy +import dataclasses +import heapq +import inspect +import io +import json +import logging +import math +import operator +import typing +import copyreg + +from contextlib import contextmanager +from dataclasses import dataclass, field +from enum import Enum +from typing import ( + Any, + Callable, + cast, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + Union, +) + +import sympy + +import torch +import torch.export.exported_program as ep +from torch._export.serde.schema import SchemaVersion +from torch._export.verifier import load_verifier +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx.experimental import symbolic_shapes +from torch.utils import _pytree as pytree +from torch.utils._pytree import treespec_dumps, treespec_loads +from torch.utils._sympy.value_ranges import ValueRanges + +from .schema import ( # type: ignore[attr-defined] + Argument, + BufferMutationSpec, + CustomObjArgument, + Device, + ExportedProgram, + GradientToParameterSpec, + GradientToUserInputSpec, + Graph, + GraphArgument, + GraphModule, + GraphSignature, + InputSpec, + InputToBufferSpec, + InputToCustomObjSpec, + InputToParameterSpec, + InputToTensorConstantSpec, + Layout, + LossOutputSpec, + MemoryFormat, + ModuleCallEntry, + ModuleCallSignature, + NamedArgument, + Node, + OptionalTensorArgument, + OutputSpec, + RangeConstraint, + ScalarType, + SCHEMA_VERSION, + SymBool, + SymBoolArgument, + SymExpr, + SymExprHint, + SymInt, + SymIntArgument, + TensorArgument, + TensorMeta, + TREESPEC_VERSION, + UserInputMutationSpec, + UserInputSpec, + UserOutputSpec, +) +from .union import _Union + + +__all__ = [ + "serialize", + "GraphModuleSerializer", + "ExportedProgramSerializer", + "GraphModuleDeserializer", + "ExportedProgramDeserializer", +] + +from .upgrade import GraphModuleOpUpgrader + +log = logging.getLogger(__name__) + + +class SerializeError(RuntimeError): + pass + + +def _reverse_map(d: Dict[Any, Enum]): + return {v.value: k for k, v in d.items()} + + +MetaType = Union[FakeTensor, int, torch.SymInt, bool, torch.SymBool, ep.CustomObjArgument] + + +ST_DELIMITER = ";" + +_TORCH_TO_SERIALIZE_DTYPE = { + torch.uint8: ScalarType.BYTE, + torch.int8: ScalarType.CHAR, + torch.int16: ScalarType.SHORT, + torch.int32: ScalarType.INT, + torch.int64: ScalarType.LONG, + torch.float16: ScalarType.HALF, + torch.float32: ScalarType.FLOAT, + torch.float64: ScalarType.DOUBLE, + torch.complex32: ScalarType.COMPLEXHALF, + torch.complex64: ScalarType.COMPLEXFLOAT, + torch.complex128: ScalarType.COMPLEXDOUBLE, + torch.bool: ScalarType.BOOL, + torch.bfloat16: ScalarType.BFLOAT16 +} + + +_SERIALIZE_TO_TORCH_DTYPE = _reverse_map(_TORCH_TO_SERIALIZE_DTYPE) # type: ignore[arg-type] + + +_TORCH_TO_SERIALIZE_LAYOUT = { + torch.sparse_coo: Layout.SparseCoo, + torch.sparse_csr: Layout.SparseCsr, + torch.sparse_csc: Layout.SparseCsc, + torch.sparse_bsr: Layout.SparseBsr, + torch.sparse_bsc: Layout.SparseBsc, + torch._mkldnn: Layout._mkldnn, # type: ignore[attr-defined] + torch.strided: Layout.Strided, +} + + +_SERIALIZE_TO_TORCH_LAYOUT = _reverse_map(_TORCH_TO_SERIALIZE_LAYOUT) # type: ignore[arg-type] + + +_TORCH_TO_SERIALIZE_MEMORY_FORMAT = { + torch.contiguous_format: MemoryFormat.ContiguousFormat, + torch.channels_last: MemoryFormat.ChannelsLast, + torch.channels_last_3d: MemoryFormat.ChannelsLast3d, + torch.preserve_format: MemoryFormat.PreserveFormat, +} + + +_SERIALIZE_TO_TORCH_MEMORY_FORMAT = _reverse_map(_TORCH_TO_SERIALIZE_MEMORY_FORMAT) # type: ignore[arg-type] + + +_SYM_INT_OPS = { + operator.mul, + operator.add, + operator.sub, + operator.floordiv, + operator.mod, + torch.sym_int, + torch.sym_ite, + torch.sym_max, + torch.sym_min, + torch.sym_sqrt, +} + + +_SYM_BOOL_OPS = { + operator.eq, + operator.ne, + operator.le, + operator.ge, + operator.lt, + operator.gt, + torch.sym_not, +} + + +@dataclass +class SerializedArtifact: + exported_program: Union[ExportedProgram, bytes] + state_dict: bytes + constants: bytes + + +def deserialize_device(d: Device) -> torch.device: + if d.index is None: + return torch.device(type=d.type) # type: ignore[call-overload] + return torch.device(type=d.type, index=d.index) + + +def serialize_sym_int(s: Union[int, torch.SymInt]) -> SymInt: + if isinstance(s, (torch.SymInt, int)): + if symbolic_shapes.is_concrete_int(s): + return SymInt.create(as_int=int(s)) + else: + assert isinstance(s, torch.SymInt) + if s.node.hint is None: + return SymInt.create(as_expr=SymExpr(str(s))) + else: + return SymInt.create(as_expr=SymExpr(str(s), hint=SymExprHint.create(as_int=s.node.hint))) + else: + raise SerializeError( + f"SymInt should be either symbol or int, got `{s}` of type `{type(s)}`" + ) + + +def serialize_sym_bool(s: Union[bool, torch.SymBool]) -> SymBool: + if isinstance(s, (torch.SymBool, bool)): + if symbolic_shapes.is_concrete_bool(s): + return SymBool.create(as_bool=bool(s)) + else: + return SymBool.create(as_expr=SymExpr(expr_str=str(s))) + else: + raise SerializeError( + f"SymBool should be either symbol or bool, got `{s}` of type `{type(s)}`" + ) + + +def serialize_tensor_meta(t: torch.Tensor) -> TensorMeta: + """ + Extract a TensorMeta describing `t`. + """ + return TensorMeta( + dtype=_TORCH_TO_SERIALIZE_DTYPE[t.dtype], + sizes=[serialize_sym_int(s) for s in t.shape], + requires_grad=t.requires_grad, + device=Device(type=t.device.type, index=t.device.index), + strides=[serialize_sym_int(s) for s in t.stride()], + storage_offset=serialize_sym_int(0), # TODO needs to be fixed. + layout=_TORCH_TO_SERIALIZE_LAYOUT[t.layout], + ) + + +_CURRENT_DESERIALIZER: Optional["GraphModuleDeserializer"] = None + + +def _reduce_fake_tensor(fake_tensor: FakeTensor): + is_parameter = isinstance(fake_tensor, torch.nn.Parameter) + tensor_meta = serialize_tensor_meta(fake_tensor) + tensor_meta_bytes = json.dumps(_dataclass_to_dict(tensor_meta), cls=EnumEncoder).encode("utf-8") + return _reconstruct_fake_tensor, (tensor_meta_bytes, is_parameter) + + +def _reconstruct_fake_tensor(serialized_tensor_meta: bytes, is_parameter: bool) -> FakeTensor: + # Deserialize the bytes into a TensorMeta + json_tensor_meta = json.loads(serialized_tensor_meta.decode("utf-8")) + tensor_meta = _dict_to_dataclass(TensorMeta, json_tensor_meta) + # Find the current fake mode + assert _CURRENT_DESERIALIZER is not None, "Need access to current deserializer state" + fake_tensor = _CURRENT_DESERIALIZER.deserialize_tensor_meta(tensor_meta) + if is_parameter: + fake_tensor = torch.nn.Parameter(fake_tensor) # type: ignore[assignment] + return fake_tensor + + +def serialize_torch_artifact(artifact: Dict[str, Any]) -> bytes: + assert FakeTensor not in copyreg.dispatch_table, "Refusing to stomp on existing FakeTensor reducer" + try: + copyreg.pickle(FakeTensor, _reduce_fake_tensor) + buffer = io.BytesIO() + # This is a workaround for backend's tensor deserialization problem: + # unpickleTensor() always create a tensor on the device where it was originally saved + # This behavior is bad for multi-gpu training, as we wish to directly load the tensor + # on the designated device. + # For now, we simply move the tensor to cpu before saving. + # TODO: this should be fixed by deserialization instead. + torch.save(artifact, buffer) + return buffer.getvalue() + finally: + del copyreg.dispatch_table[FakeTensor] + + +def deserialize_torch_artifact(serialized: bytes): + if len(serialized) == 0: + return {} + buffer = io.BytesIO(serialized) + buffer.seek(0) + artifact = torch.load(buffer) + assert isinstance(artifact, dict) + return artifact + + +def _sympy_int_to_int(val: sympy.Expr): + # Convert simple sympy Integers into concrete int + if val == sympy.oo: + return math.inf + if val == -sympy.oo: + return -math.inf + if isinstance(val, sympy.Integer): + return int(val) + raise RuntimeError( + "Export constraints cannot be non-integer expressions" + ) + + +def _int_to_sympy_int(val) -> sympy.Expr: + # Convert concrete int into simple sympy Integers + if val == math.inf: + return sympy.oo + if val == -math.inf: + return -sympy.oo + return sympy.Integer(val) + + +def serialize_range_constraints( + range_constraints: Dict[sympy.Symbol, ValueRanges] +) -> Dict[str, RangeConstraint]: + return { + str(k): RangeConstraint( + _sympy_int_to_int(v.lower), # type: ignore[arg-type] + _sympy_int_to_int(v.upper), # type: ignore[arg-type] + ) + for k, v in range_constraints.items() + } + + +def _is_single_tensor_return(target: torch._ops.OpOverload) -> bool: + returns = target._schema.returns + return len(returns) == 1 and isinstance(returns[0].real_type, torch.TensorType) + + +def _is_single_tensor_list_return(target: torch._ops.OpOverload) -> bool: + returns = target._schema.returns + if len(returns) != 1: + return False + return_type = returns[0].real_type + return isinstance(return_type, torch.ListType) and isinstance( + return_type.getElementType(), torch.TensorType + ) + + +@dataclass +class GraphState: + inputs: List[Argument] = field(default_factory=list) + outputs: List[Argument] = field(default_factory=list) + nodes: List[Node] = field(default_factory=list) + tensor_values: Dict[str, TensorMeta] = field(default_factory=dict) + sym_int_values: Dict[str, SymInt] = field(default_factory=dict) + sym_bool_values: Dict[str, SymBool] = field(default_factory=dict) + is_single_tensor_return: bool = False + custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict) + + +class GraphModuleSerializer: + def __init__( + self, + graph_signature: ep.ExportGraphSignature, + module_call_graph: List[ep.ModuleCallEntry] + ): + self.graph_state = GraphState() + self.graph_signature = graph_signature + self.module_call_graph = module_call_graph + self.custom_objs: Dict[str, torch._C.ScriptObject] = {} + + @contextmanager + def save_graph_state(self): + saved = self.graph_state + self.graph_state = GraphState() + try: + yield + finally: + self.graph_state = saved + + def handle_placeholder(self, node: torch.fx.Node): + assert node.op == "placeholder" + if isinstance(node.meta['val'], torch.Tensor): + graph_input = Argument.create(as_tensor=TensorArgument(name=node.name)) + self.graph_state.tensor_values[node.name] = serialize_tensor_meta(node.meta["val"]) + elif isinstance(node.meta['val'], torch.SymInt): + raise AssertionError("SymInt graph input is not implemented yet.") + elif isinstance(node.meta['val'], (int, bool, str, float, type(None))): + graph_input = self.serialize_input(node.meta['val']) + elif isinstance(node.meta['val'], ep.CustomObjArgument): + class_fqn = node.meta["val"].class_fqn + graph_input = Argument.create(as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn)) + self.graph_state.custom_obj_values[node.name] = self.serialize_script_obj_meta(node.meta["val"]) + else: + raise AssertionError(f"Unimplemented graph input type: {node.meta['val']}") + self.graph_state.inputs.append(graph_input) + + def handle_output(self, node: torch.fx.Node): + assert node.op == "output" + assert len(node.args) == 1, "FX.Node's args should have one arg" + node_args = node.args[0] + if isinstance(node_args, torch.fx.Node): + # For singleton tensor returns + self.graph_state.is_single_tensor_return = True + self.graph_state.outputs = [self.serialize_input(node_args)] + else: + assert isinstance(node_args, (tuple, list)) + self.graph_state.outputs = [self.serialize_input(arg) for arg in node_args] + + def serialize_operator(self, target) -> str: + if isinstance(target, str): + return target + elif target.__module__.startswith("torch._ops"): + # TODO(zhxchen17) Maybe provide a function name helper in FX. + # From torch.fx.node._get_qualified_name + module = target.__module__.replace("torch._ops", "torch.ops") + return f"{module}.{target.__name__}" + else: # TODO(zhxchen17) Don't catch all here. + return f"{target.__module__}.{target.__name__}" + + def handle_call_function(self, node: torch.fx.Node): + assert node.op == "call_function" + + # getitem has been handled in the producer node, skip it here + if node.target is operator.getitem: + return + + if node.target in _SYM_INT_OPS: + assert len(node.kwargs) == 0 + meta_val = node.meta["val"] + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_sym_op_inputs(node.target, node.args), + outputs=[Argument.create(as_sym_int=self.serialize_sym_int_output(node.name, meta_val))], + metadata=self.serialize_metadata(node), + ) + elif node.target in _SYM_BOOL_OPS: + assert len(node.kwargs) == 0 + meta_val = node.meta["val"] + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_sym_op_inputs(node.target, node.args), + outputs=[Argument.create(as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val))], + metadata=self.serialize_metadata(node), + ) + elif isinstance(node.target, torch._ops.OpOverload): + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_inputs(node.target, node.args, node.kwargs), + outputs=self.serialize_outputs(node), + # TODO: create a new tensor_values here, meta might have faketensor info + metadata=self.serialize_metadata(node), + ) + elif isinstance(node.target, torch._ops.HigherOrderOperator): + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_hoo_inputs(node.args, node.kwargs), + outputs=self.serialize_hoo_outputs(node), + metadata=self.serialize_metadata(node), + ) + else: + raise SerializeError(f"Serializing {node.target} is not supported") + + self.graph_state.nodes.append(ex_node) + + def handle_get_attr(self, node): + pass + + def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: + ret = {} + if stack_trace := node.meta.get("stack_trace"): + ret["stack_trace"] = stack_trace + + if nn_module_stack := node.meta.get("nn_module_stack"): + def export_nn_module_stack(val): + assert isinstance(val, tuple) and len(val) == 2 + path, ty = val + + assert isinstance(path, str) + + # node.meta["nn_module_stack"] could have two forms: + # 1. (path: str, module_type: 'type'), e.g. + # ('', ) + # 2. (path: str, module_type: str), e.g. + # ('', 'sigmoid.inference.MySimpleModel') + # ExportedProgram directly produced by torch.export() has form 1 + # ExportedProgram deserialized from disk has form 2 + # TODO: This is not ideal, we should fix this. + if isinstance(ty, str): + normalized_ty = ty + else: + normalized_ty = ty.__module__ + "." + ty.__qualname__ + + return path + "," + normalized_ty + + # Serialize to "key,orig_path,type_str" + nn_module_list = [ + f"{k},{export_nn_module_stack(v)}" + for k, v in nn_module_stack.items() + ] + ret["nn_module_stack"] = ST_DELIMITER.join(nn_module_list) + + if source_fn_st := node.meta.get("source_fn_stack"): + source_fn_list = [f"{source_fn[0]},{self.serialize_operator(source_fn[1])}" for source_fn in source_fn_st] + ret["source_fn_stack"] = ST_DELIMITER.join(source_fn_list) + + return ret + + def serialize_script_obj_meta(self, script_obj_meta: ep.CustomObjArgument) -> CustomObjArgument: + return CustomObjArgument( + name=script_obj_meta.name, + class_fqn=script_obj_meta.class_fqn, + ) + + def serialize_sym_op_inputs(self, op, args) -> List[NamedArgument]: + serialized_args = [] + args_names = inspect.signature(op).parameters.keys() + for args_name, arg in zip(args_names, args): + serialized_args.append( + NamedArgument(name=args_name, arg=self.serialize_input(arg)) + ) + return serialized_args + + def serialize_inputs( + self, target: torch._ops.OpOverload, args, kwargs=None + ) -> List[NamedArgument]: + assert isinstance(target, torch._ops.OpOverload) + kwargs = kwargs or {} + serialized_args = [] + for i, schema_arg in enumerate(target._schema.arguments): + if schema_arg.name in kwargs: + serialized_args.append( + NamedArgument( + name=schema_arg.name, + arg=self.serialize_input(kwargs[schema_arg.name]), + ) + ) + elif not schema_arg.kwarg_only and i < len(args): + serialized_args.append( + NamedArgument( + name=schema_arg.name, + arg=self.serialize_input(args[i]), + ) + ) + else: + # We intentionally don't serialize the missing arguments + # with default values + pass + + + return serialized_args + + def serialize_hoo_inputs(self, args, kwargs) -> List[NamedArgument]: + """ + For serializing HOO inputs since HOOs do not have a schema. + """ + inputs = [ + NamedArgument( + name="", + arg=self.serialize_input(a), + ) for a in args + ] + inputs.extend([ + NamedArgument( + name=name, + arg=self.serialize_input(a) + ) for name, a in kwargs.items() + ]) + return inputs + + def is_sym_int_arg(self, arg) -> bool: + return isinstance(arg, int) or ( + isinstance(arg, torch.fx.Node) and arg.name in self.graph_state.sym_int_values + ) + + def is_sym_bool_arg(self, arg) -> bool: + return isinstance(arg, bool) or ( + isinstance(arg, torch.fx.Node) and arg.name in self.graph_state.sym_bool_values + ) + + def serialize_input(self, arg) -> Argument: + import torch._inductor.ir as inductor_ir + inductor_tensor_buffers = ( + inductor_ir.Buffer, + inductor_ir.ReinterpretView, + ) + + if isinstance(arg, torch.fx.Node): + if arg.op == "get_attr": + assert isinstance(arg.target, str) + attr = getattr(arg.graph.owning_module, arg.target) + + if isinstance(attr, torch.Tensor): + raise SerializeError("getattr nodes containing tensors should not appear in the graph") + elif isinstance(attr, torch.fx.GraphModule): + with self.save_graph_state(): + graph = self.serialize_graph(attr) + return Argument.create(as_graph=GraphArgument(name=arg.target, graph=graph)) + else: + raise SerializeError(f"Unsupported getattr attribute {arg.target} with type: {type(attr)}") + elif self.is_sym_int_arg(arg): + return Argument.create(as_sym_int=SymIntArgument.create(as_name=arg.name)) + elif self.is_sym_bool_arg(arg): + return Argument.create(as_sym_bool=SymBoolArgument.create(as_name=arg.name)) + else: + if isinstance(arg.meta["val"], ep.CustomObjArgument): + return Argument.create(as_custom_obj=CustomObjArgument(name=arg.name, class_fqn=arg.meta["val"].class_fqn)) + return Argument.create(as_tensor=TensorArgument(name=arg.name)) + elif isinstance(arg, inductor_tensor_buffers): + # Other branches are for arguments in fx node. + # This is a special branch for handling buffers (representing tensor arguments) + # for inductor's ExternalFallbackNode + # export_extern_kernel_node() is using this function to serialize arguments + arg_name = arg.get_name() + assert arg_name is not None, "Buffer must have valid name" + return Argument.create(as_tensor=TensorArgument(name=arg_name)) + elif isinstance(arg, torch.SymInt): + # This is a special branch for handling SymInt args in inductor's + # ExternalFallbackNode. + # For regular FX graph, SymInt arg should be a fx.Node with + # self.is_sym_int_arg(arg) being true + return Argument.create(as_sym_int=SymIntArgument.create(as_name=str(arg))) + elif isinstance(arg, bool): + return Argument.create(as_bool=arg) + elif isinstance(arg, str): + return Argument.create(as_string=arg) + elif isinstance(arg, int): + return Argument.create(as_int=arg) + elif isinstance(arg, float): + return Argument.create(as_float=arg) + elif arg is None: + return Argument.create(as_none=()) + elif isinstance(arg, (list, tuple)): + # Must check bool first, as bool is also treated as int + if all(isinstance(a, bool) for a in arg): + return Argument.create(as_bools=list(arg)) + elif all(isinstance(a, int) for a in arg): + return Argument.create(as_ints=list(arg)) + elif all(isinstance(a, float) for a in arg): + return Argument.create(as_floats=list(arg)) + elif all(isinstance(a, str) for a in arg): + return Argument.create(as_strings=list(arg)) + elif all(isinstance(a, torch.SymInt) for a in arg): + # This is a special branch for handling SymInt args in inductor's + # ExternalFallbackNode. + # For regular FX graph, SymInt arg should be a fx.Node with + # self.is_sym_int_arg(arg) being true + return Argument.create( + as_sym_ints=[SymIntArgument.create(as_name=str(a)) for a in arg] + ) + elif all(self.is_sym_int_arg(a) for a in arg): + # list of sym_ints + values = [] + for a in arg: + if isinstance(a, torch.fx.Node): + values.append(SymIntArgument.create(as_name=a.name)) + elif isinstance(a, int): + values.append(SymIntArgument.create(as_int=a)) + return Argument.create(as_sym_ints=values) + elif all(self.is_sym_bool_arg(a) for a in arg): + # list of sym_bools + values = [] + for a in arg: + if isinstance(a, torch.fx.Node): + values.append(SymBoolArgument.create(as_name=a.name)) + elif isinstance(a, bool): + values.append(SymBoolArgument.create(as_bool=a)) + return Argument.create(as_sym_bools=values) + elif all(isinstance(a, torch.fx.Node) for a in arg): + # list of tensors + arguments = [] + for a in arg: + if a.op == "get_attr": + raise SerializeError("getattr nodes containing tensors should not appear in the graph") + arguments.append(TensorArgument(name=a.name)) + return Argument.create(as_tensors=arguments) + elif all(isinstance(a, (torch.fx.Node, type(None))) for a in arg): + # list of optional tensors + def serialize_optional_tensor_args(a): + if a is None: + return OptionalTensorArgument.create(as_none=()) + elif isinstance(a, torch.fx.Node): + return OptionalTensorArgument.create(as_tensor=a.name) + else: + raise SerializeError(f"Unsupported list/tuple argument: {a}") + return Argument.create( + as_optional_tensors=list(map(serialize_optional_tensor_args, arg)) + ) + elif all(isinstance(a, inductor_tensor_buffers) for a in arg): + # list of inductor buffers + return Argument.create( + as_tensors=[TensorArgument(name=a.get_name()) for a in arg], + ) + elif all(isinstance(a, (*inductor_tensor_buffers, type(None))) for a in arg): + # list of inductor buffers as optional tensors + def serialize_optional_tensor_args(a): + if a is None: + return OptionalTensorArgument.create(as_none=()) + elif isinstance(a, inductor_tensor_buffers): + return OptionalTensorArgument.create(as_tensor=a.get_name()) + else: + raise SerializeError(f"Unsupported list/tuple argument: {a}") + return Argument.create( + as_optional_tensors=list(map(serialize_optional_tensor_args, arg)) + ) + else: + raise SerializeError(f"Unsupported list/tuple argument type: {[type(a) for a in arg]}") + elif isinstance(arg, torch.dtype): + return Argument.create(as_scalar_type=_TORCH_TO_SERIALIZE_DTYPE[arg]) + elif isinstance(arg, torch.device): + return Argument.create(as_device=Device(type=arg.type, index=arg.index)) + elif isinstance(arg, torch.memory_format): + return Argument.create(as_memory_format=_TORCH_TO_SERIALIZE_MEMORY_FORMAT[arg]) + elif isinstance(arg, torch.layout): + return Argument.create(as_layout=_TORCH_TO_SERIALIZE_LAYOUT[arg]) + elif isinstance(arg, torch._C.ScriptObject): + if not ( + arg._has_method("__getstate__") and # type: ignore[attr-defined] + arg._has_method("__setstate__") # type: ignore[attr-defined] + ): + raise SerializeError( + f"Unable to serialize custom class {arg}. Please define " + "serialization methods via def_pickle()." + ) + # Custom objects through torchind are serializable with pickle, + # through implementing the .def_pickle function. This should result + # in the object containing a __getstate__ and __setstate__ + # serialize/deserialize function. + custom_obj_name = f"_custom_obj_{len(self.custom_objs)}" + self.custom_objs[custom_obj_name] = arg + class_fqn = arg._type().qualified_name() # type: ignore[attr-defined] + return Argument.create(as_custom_obj=CustomObjArgument(custom_obj_name, class_fqn)) + elif isinstance(arg, torch._ops.OpOverload): + return Argument.create(as_operator=self.serialize_operator(arg)) + else: + raise SerializeError(f"Unsupported argument type: {type(arg)}") + + def serialize_tensor_output(self, name, meta_val) -> TensorArgument: + assert name not in self.graph_state.tensor_values + self.graph_state.tensor_values[name] = serialize_tensor_meta(meta_val) + return TensorArgument(name=name) + + def serialize_sym_int_output(self, name, meta_val) -> SymIntArgument: + assert name not in self.graph_state.sym_int_values + self.graph_state.sym_int_values[name] = serialize_sym_int(meta_val) + return SymIntArgument.create(as_name=name) + + def serialize_sym_bool_output(self, name, meta_val) -> SymIntArgument: + assert name not in self.graph_state.sym_bool_values + self.graph_state.sym_bool_values[name] = serialize_sym_bool(meta_val) + return SymBoolArgument.create(as_name=name) + + def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec: + if spec.kind == ep.InputKind.USER_INPUT: + return InputSpec.create( + user_input=UserInputSpec( + arg=self.serialize_argument_spec(spec.arg) + ) + ) + elif spec.kind == ep.InputKind.PARAMETER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return InputSpec.create( + parameter=InputToParameterSpec( + arg=TensorArgument(name=spec.arg.name), + parameter_name=spec.target, + ) + ) + elif spec.kind == ep.InputKind.BUFFER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + assert spec.persistent is not None + return InputSpec.create( + buffer=InputToBufferSpec( + arg=TensorArgument(name=spec.arg.name), + buffer_name=spec.target, + persistent=spec.persistent, + ) + ) + elif spec.kind == ep.InputKind.CONSTANT_TENSOR: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return InputSpec.create( + tensor_constant=InputToTensorConstantSpec( + arg=TensorArgument(name=spec.arg.name), + tensor_constant_name=spec.target, + ) + ) + elif spec.kind == ep.InputKind.CUSTOM_OBJ: + assert spec.target is not None + assert isinstance(spec.arg, ep.CustomObjArgument) + return InputSpec.create( + custom_obj=InputToCustomObjSpec( + arg=CustomObjArgument(name=spec.arg.name, class_fqn=spec.arg.class_fqn), + custom_obj_name=spec.target, + ) + ) + else: + raise AssertionError(f"Unknown argument kind: {spec}") + + def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec: + if spec.kind == ep.OutputKind.USER_OUTPUT: + return OutputSpec.create( + user_output=UserOutputSpec( + arg=self.serialize_argument_spec(spec.arg) + ) + ) + elif spec.kind == ep.OutputKind.LOSS_OUTPUT: + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + loss_output=LossOutputSpec( + arg=TensorArgument(name=spec.arg.name) + ) + ) + elif spec.kind == ep.OutputKind.BUFFER_MUTATION: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + buffer_mutation=BufferMutationSpec( + arg=TensorArgument(name=spec.arg.name), + buffer_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.GRADIENT_TO_PARAMETER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + gradient_to_parameter=GradientToParameterSpec( + arg=TensorArgument(name=spec.arg.name), + parameter_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.GRADIENT_TO_USER_INPUT: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + gradient_to_user_input=GradientToUserInputSpec( + arg=TensorArgument(name=spec.arg.name), + user_input_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.USER_INPUT_MUTATION: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + user_input_mutation=UserInputMutationSpec( + arg=TensorArgument(name=spec.arg.name), + user_input_name=spec.target, + ) + ) + else: + raise AssertionError(f"Unknown argument kind: {spec}") + + def serialize_signature(self, sig: ep.ExportGraphSignature) -> GraphSignature: + return GraphSignature( + input_specs=[self.serialize_input_spec(s) for s in sig.input_specs], + output_specs=[self.serialize_output_spec(s) for s in sig.output_specs], + ) + + def serialize_argument_spec(self, x: ep.ArgumentSpec) -> Argument: + if isinstance(x, ep.TensorArgument): + return Argument.create(as_tensor=TensorArgument(name=x.name)) + elif isinstance(x, ep.SymIntArgument): + return Argument.create(as_sym_int=SymIntArgument.create(as_name=x.name)) + elif isinstance(x, ep.ConstantArgument): + return self.serialize_input(x.value) + elif isinstance(x, ep.CustomObjArgument): + return Argument.create(as_custom_obj=CustomObjArgument(name=x.name, class_fqn=x.class_fqn)) + else: + raise AssertionError("TODO") + + def serialize_module_call_signature(self, module_call_signature: ep.ModuleCallSignature) -> ModuleCallSignature: + return ModuleCallSignature( + inputs=[self.serialize_argument_spec(x) for x in module_call_signature.inputs], + outputs=[self.serialize_argument_spec(x) for x in module_call_signature.outputs], + in_spec=treespec_dumps(module_call_signature.in_spec, TREESPEC_VERSION), + out_spec=treespec_dumps(module_call_signature.out_spec, TREESPEC_VERSION), + ) + + def serialize_module_call_graph(self, module_call_graph: List[ep.ModuleCallEntry]) -> List[ModuleCallEntry]: + return [ + ModuleCallEntry( + fqn=entry.fqn, + signature=self.serialize_module_call_signature(entry.signature) if entry.signature else None, + ) for entry in module_call_graph + ] + + def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]: + """For a given node, return the dataclass representing its output values. + + [NOTE: Multiple outputs] We handle aggregates differently than FX. For + FX, it looks like: + + x = call_function("multiple_return", ...) + element0 = call_function(getitem, x, 0) + foo = call_function("use_output", element0) + + We do not want the intermediate `getitem` call, so our serialized thing looks like: + + element0, element1, element2 = call_function("multiple_return", ...) + foo = call_function("use_output", element0) + + We want names to be consistent across these two schemes, so that we can + mostly reuse the names coming from FX. This function computes a mapping from + the FX representation to our representation, preserving the names. + """ + assert node.op == "call_function" and isinstance(node.target, torch._ops.OpOverload) + + assert isinstance(node.target, torch._ops.OpOverload) + returns = node.target._schema.returns + + if len(returns) == 0: + return [] + + meta_val = node.meta["val"] + + def output_node_at_index(node, index): + for user in node.users: + assert user.target is operator.getitem, f"{user} is not a getitem node" + if index == user.args[1]: + return user + return None + + # Check single value return + if _is_single_tensor_list_return(node.target): + # e.g "-> Tensor[]" + tensor_args = [] + for idx, meta in enumerate(meta_val): + user_node = output_node_at_index(node, idx) + name = ( + user_node.name + if user_node is not None + else f"{node.name}_unused_{idx}" + ) + tensor_args.append(self.serialize_tensor_output(name, meta)) + return [Argument.create(as_tensors=tensor_args)] + elif len(returns) == 1: + return [self.serialize_output(node.name, meta_val)] + + # There are a two possibilities at this point: + # - This operator returns a tuple of Tensors, e.g. "-> (Tensor, Tensor)" + # - This operator returns a tuple of mixed of Tensor and Tensors, e.g. "-> (Tensor, Tensor[])" + # + # Either way, start by gathering a list of TensorArguments with the correct names. + # For consistent naming with FX, consult the downstream `getitem` node and + # make sure our outputs have the same name. + + output_arguments = [] + for idx, (meta, return_schema) in enumerate(zip(meta_val, returns)): + if meta is None: + assert isinstance(return_schema.real_type, (torch.OptionalType, torch.TensorType)) + # When the return type is annoated as Tensor type, the op can also return an + # undefined Tensor which will be implicitly converted to None in Python. + output_arguments.append(Argument.create(as_none=())) + elif isinstance(meta, FakeTensor): + assert isinstance(return_schema.real_type, torch.TensorType) + user_node = output_node_at_index(node, idx) + name = ( + user_node.name + if user_node is not None + else f"{node.name}_unused_{idx}" + ) + output_arguments.append(self.serialize_output(name, meta)) + elif isinstance(meta, list): + # for List[Tensor] return type + assert isinstance( + return_schema.real_type, torch.ListType + ) and isinstance( + return_schema.real_type.getElementType(), torch.TensorType + ) + user_node = output_node_at_index(node, idx) + assert user_node is not None + + args = [] + for i, m in enumerate(meta): + if m is None: + continue + sub_user_node = output_node_at_index(user_node, i) + assert sub_user_node is not None, f"No user found at index {i}" + + args.append(self.serialize_tensor_output(sub_user_node.name, m)) + output_arguments.append(Argument.create(as_tensors=args)) + elif isinstance(meta, (int, SymInt)): + user_node = output_node_at_index(node, idx) + name = ( + user_node.name + if user_node is not None + else f"{node.name}_unused_{idx}" + ) + output_arguments.append(self.serialize_output(name, meta)) + else: + raise ValueError(f"Unhandled output type {type(meta)} from node {node.format_node()}") + + return output_arguments + + def serialize_hoo_outputs(self, node: torch.fx.Node) -> List[Argument]: + """ + For serializing HOO outputs since HOOs do not have a schema. + """ + meta_val = node.meta["val"] + + if isinstance(meta_val, tuple): + # Note: Since we don't have a schema, we just serialize all tuple + # outputs to be a list of values. Even if the output is supposed to + # be a tensor list (Tensor[]), we will serialize it to be a list of + # tensors (Tensor, Tensor, Tensor). An exception is that if there's + # a singleton tensor, we will serialize this to be a singleton + # tensor list so that the deserializer knows to insert getitem nodes. + + idx_to_name = {} + for user in node.users: + if user.target is not operator.getitem: + continue + idx_to_name[user.args[1]] = user.name + + for idx in range(len(meta_val)): + # FX does not emit a getitem node for any outputs that are unused. + # However, we need a name for them so that the number of outputs will + # correctly match the schema. Just assign a dummy name. + if idx not in idx_to_name: + idx_to_name[idx] = f"{node.name}_unused_{idx}" + + if len(meta_val) == 1: + tensors = [] + for i, v in enumerate(meta_val): + assert isinstance(v, torch.Tensor) + tensors.append(self.serialize_tensor_output(idx_to_name[i], v)) + return [Argument.create(as_tensors=tensors)] + + else: + return [ + self.serialize_output(idx_to_name[i], element_meta_val) + for i, element_meta_val in enumerate(meta_val) + ] + + else: + return [self.serialize_output(node.name, meta_val)] + + def serialize_output(self, name: str, meta_val: Any) -> Argument: + # Check single value return + if meta_val is None: + return Argument.create(as_none=()) + if isinstance(meta_val, torch.Tensor): + # e.g "-> Tensor" + return Argument.create(as_tensor=self.serialize_tensor_output(name, meta_val)) + elif isinstance(meta_val, (int, torch.SymInt)): + # e.g "-> SymInt" + return Argument.create(as_sym_int=self.serialize_sym_int_output(name, meta_val)) + elif isinstance(meta_val, torch.SymBool): + # e.g "-> SymBool" + return Argument.create(as_sym_bool=self.serialize_sym_bool_output(name, meta_val)) + + # list outputs should've been handled earlier + raise SerializeError(f"Unable to serialize output {meta_val}") + + def _handle_getitem_users(self, node: torch.fx.Node) -> List[TensorArgument]: + meta_val = node.meta["val"] + + idx_to_name = {} + for user in node.users: + assert user.target is operator.getitem, f"User node {user} of {node} is incorrect" + idx_to_name[user.args[1]] = user.name + + for idx, _ in enumerate(meta_val): + # FX does not emit a getitem node for any outputs that are unused. + # However, we need a name for them so that the number of outputs will + # correctly match the schema. Just assign a dummy name. + if idx not in idx_to_name: + idx_to_name[idx] = f"{node.name}_unused_{idx}" + + arg_list = [] + for i, element_meta_val in enumerate(meta_val): + arg_list.append( + self.serialize_tensor_output(idx_to_name[i], element_meta_val) + ) + + return arg_list + + def serialize_graph(self, graph_module: torch.fx.GraphModule) -> Graph: + assert isinstance(graph_module, torch.fx.GraphModule) + for node in graph_module.graph.nodes: + try: + getattr(self, f"handle_{node.op}")(node) + except Exception as e: + raise SerializeError(f"Failed serializing node {node} in graph: {node.format_node()}") from e + + return Graph( + inputs=self.graph_state.inputs, + nodes=self.graph_state.nodes, + tensor_values=self.graph_state.tensor_values, + sym_int_values=self.graph_state.sym_int_values, + sym_bool_values=self.graph_state.sym_bool_values, + custom_obj_values=self.graph_state.custom_obj_values, + outputs=self.graph_state.outputs, + is_single_tensor_return=self.graph_state.is_single_tensor_return, + ) + + def serialize(self, graph_module: torch.fx.GraphModule) -> GraphModule: + graph = self.serialize_graph(graph_module) + + return GraphModule( + graph=graph, + signature=self.serialize_signature(self.graph_signature), + module_call_graph=self.serialize_module_call_graph(self.module_call_graph), + ) + + +class ExportedProgramSerializer: + def __init__(self, opset_version: Optional[Dict[str, int]] = None): + self.opset_version: Dict[str, int] = {} + if opset_version: + self.opset_version.update(opset_version) + if "aten" not in self.opset_version: + self.opset_version["aten"] = torch._C._get_max_operator_version() + + def serialize(self, exported_program: ep.ExportedProgram) -> SerializedArtifact: + """ + Args: + exported_program: Exported Program to serialize + """ + if type(self) == ExportedProgramSerializer: + exported_program._validate() + + gm_serializer = GraphModuleSerializer( + exported_program.graph_signature, + exported_program.module_call_graph + ) + serialized_graph_module = gm_serializer.serialize(exported_program.graph_module) + serialized_range_constraints = serialize_range_constraints(exported_program.range_constraints) + + # TODO: Directly serialize exported_program.constants once + # CustomClassHolders get stored in the ExportedProgram rather than in + # the graph + constants = {} + for n, c in gm_serializer.custom_objs.items(): + constants[n] = c + for n, t in exported_program.constants.items(): + assert n not in constants + constants[n] = t + + serialized_ep = ExportedProgram( + graph_module=serialized_graph_module, + opset_version=self.opset_version, + range_constraints=serialized_range_constraints, + schema_version=SchemaVersion( + major=SCHEMA_VERSION[0], + minor=SCHEMA_VERSION[1], + ), + dialect=exported_program.dialect, + ) + + # Test canonical form is well defined. + canonicalize(serialized_ep) + + return SerializedArtifact( + serialized_ep, + serialize_torch_artifact(exported_program.state_dict), + serialize_torch_artifact(constants), + ) + + +class GraphModuleDeserializer: + @dataclasses.dataclass + class Result: + graph_module: torch.fx.GraphModule + signature: ep.ExportGraphSignature + module_call_graph: List[ep.ModuleCallEntry] + names_to_symbols: Dict[str, sympy.Symbol] + state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]] + constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]] + + def __init__(self): + self.serialized_name_to_node: Dict[str, torch.fx.Node] = {} + self.serialized_name_to_meta: Dict[str, MetaType] = {} + self.graph = torch.fx.Graph() + self.module = torch.nn.Module() + + @contextmanager + def save_graph_module(self) -> Iterator[None]: + saved = self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta + self.graph = torch.fx.Graph() + self.module = torch.nn.Module() + self.serialized_name_to_node = {} + self.serialized_name_to_meta = {} + try: + yield + finally: + self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta = saved + + def deserialize_operator(self, serialized_target: str): + if serialized_target.startswith("_operator"): # TODO(zhxchen17) Follow up on this. + module = operator + serialized_target_names = serialized_target.split(".")[1:] + elif serialized_target.startswith("torch"): + module = torch # type: ignore[misc] + serialized_target_names = serialized_target.split(".")[1:] + else: # TODO(zhxchen17) Don't catch all here. + return serialized_target + + target = module + for name in serialized_target_names: + if not hasattr(target, name): + return serialized_target + else: + target = getattr(target, name) + return target + + def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: + val = s.value + if s.type == "as_expr": + if val.expr_str in self.symbol_name_to_symbol: + sym = self.symbol_name_to_symbol[val.expr_str] + else: + sym = sympy.sympify(val.expr_str, locals=self.symbol_name_to_symbol) + # NOTE(avik): Assumptions on symbols are not explicitly serialized. + # This seems dangerous: it might cause unknown differences in shape env behavior + # on deserialization? Probably deserves a follow-up. + + # Here we force symbols corresponding to SymInts to be at least integers. + # Otherwise some expressions that the shape env would otherwise evaluate to False, + # e.g., 2*s = 9, can have rational solutions, e.g., 9/2. + sym = sym.subs({s: sympy.Symbol(s.name, integer=True) for s in sym.free_symbols}) + if isinstance(sym, sympy.Symbol): + self.symbol_name_to_symbol[val.expr_str] = sym + + if vr := self.symbol_name_to_range.get(val.expr_str): + symbolic_shapes._constrain_symbol_range( + self.shape_env, + sym, + compiler_min=vr.lower, # type: ignore[arg-type] + compiler_max=vr.upper, # type: ignore[arg-type] + ) + else: + # Placeholders, in particular, can have shapes as symbolic expressions. + # We need to populate the shape env with the range constraints of their + # free symbols, otherwise evaluating such expressions will error. + self.symbol_name_to_symbol[val.expr_str] = sym + free_symbols = sym.free_symbols + for s in free_symbols: + if s.name not in self.symbol_name_to_symbol: + self.symbol_name_to_symbol[s.name] = s + if vr := self.symbol_name_to_range.get(s.name): + symbolic_shapes._constrain_symbol_range( + self.shape_env, + s, + compiler_min=vr.lower, # type: ignore[arg-type] + compiler_max=vr.upper, # type: ignore[arg-type] + ) + + + if val.hint is None: + hint = None + else: + assert val.hint.type == "as_int" + hint = val.hint.value + + return self.shape_env.create_symintnode(sym, hint=hint) + elif s.type == "as_int": + assert isinstance(val, int) + return val + else: + raise SerializeError( + f"SymInt has invalid field type {s.type} with value {s.value}" + ) + + def deserialize_sym_bool(self, s: SymBool) -> Union[bool, torch.SymBool]: + val = s.value + if s.type == "as_expr": + expr = sympy.sympify(val.expr_str, locals=self.symbol_name_to_symbol) + return self.shape_env.create_symboolnode(expr) + elif s.type == "as_bool": + assert isinstance(val, bool) + return val + else: + raise SerializeError( + f"SymBool has invalid field type {s.type} with value {s.value}" + ) + + def deserialize_tensor_meta( + self, + tensor_meta: TensorMeta, + ) -> FakeTensor: + with self.fake_tensor_mode: + return cast( + FakeTensor, + torch.empty_strided( + tuple(self.deserialize_sym_int(val) for val in tensor_meta.sizes), # type: ignore[misc] + tuple(self.deserialize_sym_int(val) for val in tensor_meta.strides), # type: ignore[misc] + device=deserialize_device(tensor_meta.device), + dtype=_SERIALIZE_TO_TORCH_DTYPE[tensor_meta.dtype], + ), + ) + + def deserialize_script_obj_meta(self, script_obj_meta: CustomObjArgument) -> ep.CustomObjArgument: + return ep.CustomObjArgument( + name=script_obj_meta.name, + class_fqn=script_obj_meta.class_fqn, + ) + + def deserialize_graph_output(self, output) -> torch.fx.Node: + if output.type == "as_tensor": + return self.serialized_name_to_node[output.as_tensor.name] + elif output.type == "as_sym_int": + return self.serialized_name_to_node[output.as_sym_int.as_name] + elif output.type == "as_sym_bool": + return self.serialized_name_to_node[output.as_sym_bool.as_name] + else: + raise SerializeError(f"Unable to deserialize output node {output}") + + def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: + # Handle the tensor metas. + for name, tensor_value in serialized_graph.tensor_values.items(): + meta_val = self.deserialize_tensor_meta(tensor_value) + self.serialized_name_to_meta[name] = meta_val + + for name, sym_int_value in serialized_graph.sym_int_values.items(): + self.serialized_name_to_meta[name] = self.deserialize_sym_int(sym_int_value) + + for name, sym_bool_value in serialized_graph.sym_bool_values.items(): + self.serialized_name_to_meta[name] = self.deserialize_sym_bool(sym_bool_value) + + for name, script_obj_meta in serialized_graph.custom_obj_values.items(): + self.serialized_name_to_meta[name] = self.deserialize_script_obj_meta(script_obj_meta) + + # Inputs: convert to placeholder nodes in FX. + for i, input_ in enumerate(serialized_graph.inputs): + if input_.type in ("as_tensor", "as_sym_int", "as_custom_obj"): + node_name = input_.value.name + placeholder_node = self.graph.placeholder(node_name) + self.sync_fx_node(node_name, placeholder_node) + elif input_.type in ("as_int", "as_float", "as_bool", "as_none", "as_string"): + node_name = f"arg{i}" + placeholder_node = self.graph.placeholder(node_name) + placeholder_node.meta["val"] = self.deserialize_input(input_) + else: + raise SerializeError(f"Invalid input type {input_}") + + # Nodes: convert to call_function nodes. + for serialized_node in serialized_graph.nodes: + try: + target = self.deserialize_operator(serialized_node.target) + self.deserialize_node(serialized_node, target) + + except Exception as e: + raise SerializeError(f"Failed deserializing node {serialized_node}") from e + + # Outputs: convert to a single `output` node. + outputs = [] + for output in serialized_graph.outputs: + outputs.append(self.deserialize_graph_output(output)) + + if serialized_graph.is_single_tensor_return: + assert len(outputs) == 1 + outputs = outputs[0] # type: ignore[assignment] + else: + outputs = tuple(outputs) # type: ignore[assignment] + + output_node = self.graph.output(outputs) + + if serialized_graph.is_single_tensor_return: + output_node.meta["val"] = output_node.args[0].meta["val"] + else: + output_node.meta["val"] = tuple( + arg.meta["val"] for arg in output_node.args[0] + ) + + return self.graph + + def deserialize_node(self, serialized_node: Node, target: Callable) -> None: + if target in _SYM_BOOL_OPS or target in _SYM_INT_OPS: + name = serialized_node.outputs[0].value.as_name + args = self.deserialize_sym_op_inputs(serialized_node.inputs) + + fx_node = self.graph.create_node("call_function", target, args, {}, name) + self.deserialize_sym_op_outputs(serialized_node, fx_node) + + elif isinstance(target, torch._ops.HigherOrderOperator): + args, kwargs = self.deserialize_hoo_inputs(serialized_node.inputs) + # If HOP returns a single tensor, name the + # newly-created node after it. This ensures that these tensor values + # have names that are consistent with serialized. + # + # HOPs don't have schema yet, just check the output lengths and as_tensor attribute + name = ( + serialized_node.outputs[0].as_tensor.name + if len(serialized_node.outputs) == 1 and hasattr(serialized_node.outputs[0], "as_tensor") + else None + ) + fx_node = self.graph.create_node( + "call_function", target, args, kwargs, name + ) + self.deserialize_outputs(serialized_node, fx_node) + fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata)) + + elif isinstance(target, torch._ops.OpOverload): + # For convenience: if this node returns a single tensor, name the + # newly-created node after it. This ensures that these tensor values + # have names that are consistent with serialized. + name = ( + serialized_node.outputs[0].as_tensor.name + if _is_single_tensor_return(target) + else None # FX will generate a name for us. + ) + args, kwargs = self.deserialize_inputs(target, serialized_node) + fx_node = self.graph.create_node("call_function", target, args, kwargs, name) + self.deserialize_outputs(serialized_node, fx_node) + else: + raise SerializeError(f"Unsupported target type for node {serialized_node}: {target}") + + fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata)) + + def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec: + if i.type == "user_input": + return ep.InputSpec( + kind=ep.InputKind.USER_INPUT, + arg=self.deserialize_argument_spec(i.user_input.arg), + target=None + ) + elif i.type == "parameter": + return ep.InputSpec( + kind=ep.InputKind.PARAMETER, + arg=ep.TensorArgument(name=i.parameter.arg.name), + target=i.parameter.parameter_name, + ) + elif i.type == "buffer": + return ep.InputSpec( + kind=ep.InputKind.BUFFER, + arg=ep.TensorArgument(name=i.buffer.arg.name), + target=i.buffer.buffer_name, + persistent=i.buffer.persistent, + ) + elif i.type == "tensor_constant": + return ep.InputSpec( + kind=ep.InputKind.CONSTANT_TENSOR, + arg=ep.TensorArgument(name=i.tensor_constant.arg.name), + target=i.tensor_constant.tensor_constant_name, + ) + elif i.type == "custom_obj": + return ep.InputSpec( + kind=ep.InputKind.CUSTOM_OBJ, + arg=ep.CustomObjArgument(name=i.custom_obj.arg.name, class_fqn=i.custom_obj.arg.class_fqn), + target=i.custom_obj.custom_obj_name, + ) + else: + raise AssertionError(f"Unknown input spec {i}") + + def deserialize_output_spec(self, o: OutputSpec) -> ep.OutputSpec: + if o.type == "user_output": + return ep.OutputSpec( + kind=ep.OutputKind.USER_OUTPUT, + arg=self.deserialize_argument_spec(o.user_output.arg), + target=None, + ) + elif o.type == "loss_output": + return ep.OutputSpec( + kind=ep.OutputKind.LOSS_OUTPUT, + arg=ep.TensorArgument(name=o.loss_output.arg.name), + target=None, + ) + elif o.type == "buffer_mutation": + return ep.OutputSpec( + kind=ep.OutputKind.BUFFER_MUTATION, + arg=ep.TensorArgument(name=o.buffer_mutation.arg.name), + target=o.buffer_mutation.buffer_name + ) + elif o.type == "gradient_to_parameter": + return ep.OutputSpec( + kind=ep.OutputKind.GRADIENT_TO_PARAMETER, + arg=ep.TensorArgument(name=o.gradient_to_parameter.arg.name), + target=o.gradient_to_parameter.parameter_name + ) + elif o.type == "gradient_to_user_input": + return ep.OutputSpec( + kind=ep.OutputKind.GRADIENT_TO_USER_INPUT, + arg=ep.TensorArgument(name=o.gradient_to_user_input.arg.name), + target=o.gradient_to_user_input.user_input_name + ) + elif o.type == "user_input_mutation": + return ep.OutputSpec( + kind=ep.OutputKind.USER_INPUT_MUTATION, + arg=ep.TensorArgument(name=o.user_input_mutation.arg.name), + target=o.user_input_mutation.user_input_name + ) + else: + raise AssertionError(f"Unknown output spec {o}") + + def deserialize_signature(self, sig: GraphSignature) -> ep.ExportGraphSignature: + return ep.ExportGraphSignature( + input_specs=[self.deserialize_input_spec(i) for i in sig.input_specs], + output_specs=[self.deserialize_output_spec(o) for o in sig.output_specs] + ) + + def deserialize( + self, + serialized_graph_module: GraphModule, + serialized_state_dict: bytes, + constants: bytes, + symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None, + ) -> Result: + global _CURRENT_DESERIALIZER + assert _CURRENT_DESERIALIZER is None + _CURRENT_DESERIALIZER = self + try: + self.shape_env = symbolic_shapes.ShapeEnv(assume_static_by_default=True) + self.fake_tensor_mode = FakeTensorMode( + allow_fallback_kernels=False, + allow_non_fake_inputs=True, + shape_env=self.shape_env, + ) + self.symbol_name_to_symbol: Dict[str, sympy.Symbol] = {} + self.symbol_name_to_range = {} if symbol_name_to_range is None else symbol_name_to_range + self.signature = self.deserialize_signature(serialized_graph_module.signature) + self.constants = deserialize_torch_artifact(constants) + self.deserialize_graph(serialized_graph_module.graph) + + module_call_graph = self.deserialize_module_call_graph(serialized_graph_module.module_call_graph) + return GraphModuleDeserializer.Result( + graph_module=ep._create_graph_module_for_export(self.module, self.graph), + signature=self.signature, + module_call_graph=module_call_graph, + names_to_symbols=self.symbol_name_to_symbol, + state_dict=deserialize_torch_artifact(serialized_state_dict), + constants=self.constants, + ) + finally: + _CURRENT_DESERIALIZER = None + + def sync_fx_node(self, name: str, fx_node: torch.fx.Node): + if name in self.serialized_name_to_node: + raise SerializeError(f"Node {name} has already been deserialized before.") + self.serialized_name_to_node[name] = fx_node + assert "val" not in fx_node.meta + fx_node.meta["val"] = self.serialized_name_to_meta[name] + + def deserialize_sym_op_inputs(self, inputs): + return tuple(self.deserialize_input(input.arg) for input in inputs) + + def deserialize_inputs(self, target: torch._ops.OpOverload, serialized_node: Node): + schema_args = target._schema.arguments + actual_args = { + input.name: self.deserialize_input(input.arg) for input in serialized_node.inputs + } + args = [] + kwargs = {} + for schema_arg in schema_args: + is_positional = not schema_arg.has_default_value() and not schema_arg.kwarg_only + if is_positional: + args.append(actual_args[schema_arg.name]) + else: + if schema_arg.name in actual_args: + kwargs[schema_arg.name] = actual_args[schema_arg.name] + return tuple(args), kwargs + + def deserialize_hoo_inputs(self, inputs: List[NamedArgument]): + """ + For deserializing HOO inputs since HOOs do not have a schema. + """ + args = [] + kwargs = {} + for input_ in inputs: + if input_.name != "": + kwargs[input_.name] = self.deserialize_input(input_.arg) + else: + args.append(self.deserialize_input(input_.arg)) + return (tuple(args), kwargs) + + def deserialize_input(self, inp: Argument) -> Any: + value = inp.value + typ_ = inp.type + if typ_ == "as_none": + # None should converted as None, but is encoded as bool in serialized + # Convert serialized object to torch equivalent + return None + elif typ_ == "as_tensor": + return self.serialized_name_to_node[inp.as_tensor.name] + elif typ_ == "as_scalar_type": + return _SERIALIZE_TO_TORCH_DTYPE[inp.as_scalar_type] + elif typ_ == "as_memory_format": + return _SERIALIZE_TO_TORCH_MEMORY_FORMAT[inp.as_memory_format] + elif typ_ == "as_layout": + return _SERIALIZE_TO_TORCH_LAYOUT[inp.as_layout] + elif typ_ == "as_graph": + assert isinstance(value, GraphArgument) + with self.save_graph_module(): + self.deserialize_graph(value.graph) + submodule = ep._create_graph_module_for_export(self.module, self.graph) + self.module.register_module(value.name, submodule) + return self.graph.create_node( + "get_attr", + value.name, + name=value.name, + ) + elif typ_ == "as_device": + return deserialize_device(inp.as_device) + elif typ_ == "as_int": + return inp.as_int + elif typ_ == "as_float": + return inp.as_float + elif typ_ == "as_bool": + return inp.as_bool + elif typ_ == "as_string": + return inp.as_string + elif typ_ == "as_sym_int": + return self.deserialize_sym_argument(inp.as_sym_int) + elif typ_ == "as_sym_bool": + return self.deserialize_sym_argument(inp.as_sym_bool) + elif isinstance(value, list): + if len(value) == 0: + return [] + elif typ_ == "as_tensors": + result = [] + for arg in value: + result.append(self.serialized_name_to_node[arg.name]) + return result + elif typ_ in ("as_ints", "as_floats", "as_bools", "as_strings"): + # convert from serialized.python.types.List to python list + return list(value) + elif typ_ in ("as_sym_ints", "as_sym_bools"): + return [self.deserialize_sym_argument(arg) for arg in value] + elif typ_ == "as_optional_tensors": + def deserialize_optional_tensor_args(a): + if a.type == "as_none": + return None + elif a.type == "as_tensor": + return self.serialized_name_to_node[a.value] + else: + raise SerializeError(f"Unhandled argument {inp}") + return list(map(deserialize_optional_tensor_args, value)) + else: + raise SerializeError(f"Unhandled argument {inp}") + elif typ_ == "as_custom_obj": + if inp.as_custom_obj.name in self.serialized_name_to_node: + # Custom object has been lifted as an input + return self.serialized_name_to_node[inp.as_custom_obj.name] + return self.constants[inp.as_custom_obj.name] + elif typ_ == "as_operator": + return self.deserialize_operator(inp.as_operator) + else: + raise SerializeError(f"Unhandled argument {inp}") + + def deserialize_sym_argument(self, sym_arg): + if isinstance(sym_arg, SymIntArgument): + if sym_arg.type == "as_int": + return sym_arg.as_int + elif sym_arg.type == "as_name": + return self.serialized_name_to_node[sym_arg.as_name] + elif isinstance(sym_arg, SymBoolArgument): + if sym_arg.type == "as_bool": + return sym_arg.as_bool + elif sym_arg.type == "as_name": + return self.serialized_name_to_node[sym_arg.as_name] + raise SerializeError(f"Unknown symbolic argument type: {sym_arg}") + + def deserialize_sym_op_outputs(self, serialized_node: Node, fx_node: torch.fx.Node): + self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node) + + def deserialize_outputs(self, serialized_node: Node, fx_node: torch.fx.Node): + # Check single value return + if len(serialized_node.outputs) == 0: + return + if ( + len(serialized_node.outputs) == 1 + and serialized_node.outputs[0].type == "as_tensor" + ): + self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node) + return + elif ( + len(serialized_node.outputs) == 1 and + isinstance(serialized_node.outputs[0].value, (SymIntArgument, SymBoolArgument)) + ): + self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node) + return + + self.deserialize_multiple_outputs(serialized_node, fx_node) + + def deserialize_multiple_outputs(self, serialized_node: Node, fx_node: torch.fx.Node) -> None: + deserialized_metadata = self.deserialize_metadata(serialized_node.metadata) + + def generate_getitem(meta_val, fx_node: torch.fx.Node, arg: Union[TensorArgument, SymIntArgument], idx: int): + if isinstance(arg, TensorArgument): + name = arg.name + elif isinstance(arg, SymIntArgument): + name = arg.as_name + else: + raise AssertionError(f"generate_getitem got unknown argument type {type(arg)}") + individual_output = self.graph.create_node( + "call_function", + operator.getitem, + (fx_node, idx), + name=name, + ) + self.sync_fx_node(name, individual_output) + meta_val.append(self.serialized_name_to_meta[name]) + # The derived `getitem` nodes should have the same stacktrace as the + # original `fx_node` + individual_output.meta.update(deserialized_metadata) + + def generate_getitems(meta_val, fx_node: torch.fx.Node, args): + for idx, arg in enumerate(args): + if isinstance(arg, Argument): + arg = arg.value + if isinstance(arg, (TensorArgument, SymIntArgument)): + generate_getitem(meta_val, fx_node, arg, idx) + elif isinstance(arg, (list, tuple)): + list_output = self.graph.create_node( + "call_function", + operator.getitem, + (fx_node, idx), + ) + meta_val.append([]) + generate_getitems(meta_val[-1], list_output, arg) + list_output.meta.update(deserialized_metadata) + list_output.meta['val'] = meta_val[-1] + else: + raise NotImplementedError(f"Unimplemented node output type: {arg}") + + # Convert multiple return types to FX format. + # In FX, each node only returns one value. So in order to represent + # multiple return values, we have to emit a `getitem` node for each + # return value. + # This performs the inverse mapping of the `serialize_outputs` call in + # serialization, see [NOTE: Multiple outputs] + meta_val: List[Any] = [] + if len(serialized_node.outputs) == 1: + assert isinstance(serialized_node.outputs[0].value, list) + assert isinstance(serialized_node.outputs[0].value[0], TensorArgument) + generate_getitems(meta_val, fx_node, serialized_node.outputs[0].as_tensors) + else: + generate_getitems(meta_val, fx_node, serialized_node.outputs) + + # also update the metaval for `fx_node` to be a list(meta) + fx_node.meta["val"] = tuple(meta_val) + self.serialized_name_to_node[fx_node.name] = fx_node + + def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]: + ret: Dict[str, Any] = {} + if stack_trace := metadata.get("stack_trace"): + ret["stack_trace"] = stack_trace + + def deserialize_meta_func(serialized_target: str): + module = None + if serialized_target.startswith("torch.nn"): + module = torch.nn + serialized_target_names = serialized_target.split(".")[2:] + elif serialized_target.startswith("torch"): + module = torch + serialized_target_names = serialized_target.split(".")[1:] + else: + return self.deserialize_operator(serialized_target) + + target = module + for name in serialized_target_names: + if not hasattr(target, name): + return serialized_target + else: + target = getattr(target, name) + return target + + if nn_module_stack_str := metadata.get("nn_module_stack"): + # Originally serialized to "key,orig_path,type_str" + def import_nn_module_stack(key, path, ty): + return key, (path, ty) + nn_module_stack = dict( + import_nn_module_stack(*item.split(",")) + for item in nn_module_stack_str.split(ST_DELIMITER) + ) + ret["nn_module_stack"] = nn_module_stack + + if source_fn_st_str := metadata.get("source_fn_stack"): + # Originally serializes to "fx_node_name,op_str" + source_fn_st = [] + for source_fn_str in source_fn_st_str.split(ST_DELIMITER): + name, target_str = source_fn_str.split(",") + source_fn_st.append((name, deserialize_meta_func(target_str))) + ret["source_fn_stack"] = source_fn_st + return ret + + def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec: + if x.type == "as_tensor": + return ep.TensorArgument(name=x.as_tensor.name) + elif x.type == "as_sym_int": + return ep.SymIntArgument(name=x.as_sym_int.as_name) + else: + return ep.ConstantArgument(value=self.deserialize_input(x)) + + def deserialize_module_call_signature(self, module_call_signature: ModuleCallSignature) -> ep.ModuleCallSignature: + return ep.ModuleCallSignature( + inputs=[self.deserialize_argument_spec(x) for x in module_call_signature.inputs], + outputs=[self.deserialize_argument_spec(x) for x in module_call_signature.outputs], + in_spec=treespec_loads(module_call_signature.in_spec), + out_spec=treespec_loads(module_call_signature.out_spec), + ) + + def deserialize_module_call_graph(self, module_call_graph: List[ModuleCallEntry]) -> List[ep.ModuleCallEntry]: + return [ + ep.ModuleCallEntry( + fqn=entry.fqn, + signature=self.deserialize_module_call_signature(entry.signature) if entry.signature else None, + ) for entry in module_call_graph + ] + + +class ExportedProgramDeserializer: + def __init__(self, expected_opset_version: Optional[Dict[str, int]] = None): + self.expected_opset_version: Dict[str, int] = {} + if expected_opset_version: + self.expected_opset_version.update(expected_opset_version) + if "aten" not in self.expected_opset_version: + self.expected_opset_version["aten"] = torch._C._get_max_operator_version() + + def deserialize_range_constraints( + self, + symbol_name_to_range: Dict[str, symbolic_shapes.ValueRanges], + symbol_name_to_symbol: Dict[str, sympy.Symbol], + ) -> Dict[sympy.Symbol, ValueRanges]: + range_constraints = {} + for k, v in symbol_name_to_range.items(): + if symbol := symbol_name_to_symbol.get(k): + range_constraints[symbol] = v # type: ignore[arg-type] + else: + log.warning(f"Symbol {k} did not appear in the graph that was deserialized") # noqa: G004 + return range_constraints + + def deserialize( + self, serialized_artifact: SerializedArtifact + ) -> ep.ExportedProgram: + assert isinstance(serialized_artifact.exported_program, ExportedProgram) + + if serialized_artifact.exported_program.schema_version.major != SCHEMA_VERSION[0]: + raise SerializeError( + f"Serialized schema version {serialized_artifact.exported_program.schema_version} " + f"does not match our current schema version {SCHEMA_VERSION}." + ) + + symbol_name_to_range = { + k: symbolic_shapes.ValueRanges(_int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val)) + for k, v in serialized_artifact.exported_program.range_constraints.items() + } + res = ( + GraphModuleDeserializer() + .deserialize( + serialized_artifact.exported_program.graph_module, + serialized_artifact.state_dict, + serialized_artifact.constants, + symbol_name_to_range, + ) + ) + range_constraints = self.deserialize_range_constraints( + symbol_name_to_range, res.names_to_symbols, + ) + model_opset_version: Optional[Dict[str, int]] = serialized_artifact.exported_program.opset_version + self._validate_model_opset_version(model_opset_version) + + upgrader = GraphModuleOpUpgrader(self.expected_opset_version, model_opset_version) + + exported_program = ep.ExportedProgram( + root=res.graph_module, + graph=res.graph_module.graph, + graph_signature=res.signature, + state_dict=res.state_dict, # type: ignore[arg-type] + range_constraints=range_constraints, + module_call_graph=res.module_call_graph, + example_inputs=None, + verifier=load_verifier(serialized_artifact.exported_program.dialect), + constants=res.constants, + ) + return upgrader.upgrade(exported_program) + + def _validate_model_opset_version(self, model_opset_version: Optional[Dict[str, int]]): + """Compare model_opset_version with expected_opset_version and raise error if we can't resolve the version + difference. + E.g., model_opset_version = {"aten": 3, "custom": 4} + expected_opset_version = {"aten": 4, "custom": 4} + This means we can use an upgrader for ATen to reconcile the deserialized model. + + The logic of this method: + + For common op namespaces: + 1. if model version < expected version, this case can be handled by upgraders. + 2. if model version > expected version, we need downgraders but not implemented yet. + 3. if model version == expected version, we don't need extra handling. + + For op namespace only in model_opset_version, we should give a warning because it is missing from + expected_opset_version. + """ + if not model_opset_version: + raise RuntimeError("Serialized model should have opset version.") + common_namespaces = {key for key in model_opset_version if key in self.expected_opset_version} + for namespace in common_namespaces: + assert ( + isinstance(model_version := model_opset_version[namespace], int) + ), f"model_opset_version value should be int, got {model_opset_version[namespace]}" + + assert ( + isinstance(compiler_version := self.expected_opset_version[namespace], int) + ), f"expected_opset_version value should be int, got {self.expected_opset_version[namespace]}" + + # TODO(larryliu0820): Add support for upgrader & downgrader + if model_version != compiler_version: + raise NotImplementedError( + f"Model opset version {model_opset_version} doesn't match to compiler opset version " + f"{self.expected_opset_version}! Upgrader/downgrader is not implemented yet." + ) + for namespace in model_opset_version: + if namespace in common_namespaces: + continue + log.warning("Compiler doesn't have a version table for op namespace: {ns}. ", extra={"ns": namespace}) + + +class EnumEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, Enum): + return obj.value + if isinstance(obj, bytes): + return base64.b64encode(obj).decode('utf-8') + return super().default(obj) + + +def _dataclass_to_dict(obj): + if isinstance(obj, _Union): + return {obj.type: _dataclass_to_dict(obj.value)} + elif dataclasses.is_dataclass(obj): + return { + f.name: _dataclass_to_dict(getattr(obj, f.name)) + for f in dataclasses.fields(obj) + if not (f.default is None and getattr(obj, f.name) is None) + } + elif isinstance(obj, list): + return [_dataclass_to_dict(x) for x in obj] + elif isinstance(obj, tuple): + return tuple(_dataclass_to_dict(x) for x in obj) + elif isinstance(obj, dict): + return {k: _dataclass_to_dict(v) for k, v in obj.items()} + else: + return obj + + +def serialize( + exported_program: ep.ExportedProgram, + opset_version: Optional[Dict[str, int]] = None, +) -> SerializedArtifact: + serialized_artifact = ( + ExportedProgramSerializer(opset_version).serialize(exported_program) + ) + assert isinstance(serialized_artifact.exported_program, ExportedProgram) + + + json_program = json.dumps( + _dataclass_to_dict(serialized_artifact.exported_program), cls=EnumEncoder + ) + json_bytes = json_program.encode('utf-8') + artifact = SerializedArtifact( + json_bytes, + serialized_artifact.state_dict, + serialized_artifact.constants + ) + return artifact + + +def _dict_to_dataclass(cls, data): + assert not isinstance(cls, str), f"Unresolved class type: '{cls}'." + if typing.get_origin(cls) == typing.Union and type(None) in typing.get_args(cls): + if data is None: + return None + ty_args = typing.get_args(cls) + assert len(ty_args) == 2 + return _dict_to_dataclass(ty_args[0], data) + elif isinstance(cls, type) and issubclass(cls, _Union): + assert isinstance(data, dict) + assert len(data) == 1 + _type = next(iter(data.keys())) + _value = next(iter(data.values())) + assert isinstance(_type, str) + field_type = cls.__annotations__[_type] + return cls.create(**{_type: _dict_to_dataclass(field_type, _value)}) + elif dataclasses.is_dataclass(cls): + obj = cls(**data) # type: ignore[assignment] + type_hints = typing.get_type_hints(cls) + for f in dataclasses.fields(cls): + name = f.name + new_field_obj = _dict_to_dataclass(type_hints[name], getattr(obj, name)) + setattr(obj, name, new_field_obj) + return obj + elif isinstance(data, list): + if len(data) == 0: + return data + d_type = typing.get_args(cls)[0] + return [ + _dict_to_dataclass(d_type, d) + for d in data + ] + elif isinstance(data, dict): + v_type = typing.get_args(cls)[1] + return { + k: _dict_to_dataclass(v_type, v) + for k, v in data.items() + } + return data + + +def deserialize( + artifact: SerializedArtifact, + expected_opset_version: Optional[Dict[str, int]] = None, +) -> ep.ExportedProgram: + assert isinstance(artifact.exported_program, bytes) + exported_program_str = artifact.exported_program.decode('utf-8') + exported_program_dict = json.loads(exported_program_str) + serialized_exported_program = _dict_to_dataclass(ExportedProgram, exported_program_dict) + return ( + ExportedProgramDeserializer(expected_opset_version) + .deserialize( + SerializedArtifact( + serialized_exported_program, + artifact.state_dict, + artifact.constants + ) + ) + ) + + +def _canonicalize_graph(sorted_inputs, sorted_outputs, graph) -> Tuple[Graph, Dict[str, str]]: + def _get_argument(a: Argument): + if a.type == "as_none": + return None + elif a.type == "as_tensor": + return a.as_tensor + elif a.type == "as_tensors": + return a.as_tensors + elif a.type == "as_int": + return None + elif a.type == "as_ints": + return None + elif a.type == "as_float": + return None + elif a.type == "as_floats": + return None + elif a.type == "as_string": + return None + elif a.type == "as_strings": + return None + elif a.type == "as_sym_int": + return a.as_sym_int + elif a.type == "as_sym_ints": + return a.as_sym_ints + elif a.type == "as_scalar_type": + return None + elif a.type == "as_memory_format": + return None + elif a.type == "as_layout": + return None + elif a.type == "as_device": + return None + elif a.type == "as_bool": + return None + elif a.type == "as_bools": + return None + elif a.type == "as_sym_bool": + return a.as_sym_bool + elif a.type == "as_sym_bools": + return a.as_sym_bools + elif a.type == "as_graph": + return None + elif a.type == "as_optional_tensors": + return a.as_optional_tensors + elif a.type == "as_custom_obj": + return None + elif a.type == "as_operator": + return None + else: + raise AssertionError(f"Unknown input type to the ExportedProgram: {a}") + + # Stage 1: Reorder named items. + def for_args(f, a): + assert isinstance(a, Argument) + pytree.tree_map(f, _get_argument(a)) + + def sort_nodes(nodes): + @dataclass + class Edges: + outs: List[int] + ins: int + + graph_inputs: Set[str] = set() + def_table: Dict[str, int] = {} + edges: Dict[int, Edges] = {} + candidates: List[Tuple[str, List[Tuple[str, List[int]]], int]] = [] + rank: Dict[str, int] = {} + ret: List[Node] = [] + + def get_name(a) -> Optional[str]: + if a is None: + return None + if isinstance(a, TensorArgument): + return a.name + elif isinstance(a, (SymIntArgument, SymBoolArgument)): + if a.type == "as_name": + return a.as_name + elif a.type in ("as_int", "as_bool"): + return None + else: + raise AssertionError(f"Unknown argument type: {a}") + elif isinstance(a, OptionalTensorArgument): + if a.type == "as_tensor": + assert isinstance(a.as_tensor, str) + return a.as_tensor + elif a.type == "as_none": + return None + else: + raise AssertionError(f"Unknown optional tensor type: {a}") + else: + raise AssertionError(f"Unknown argument type: {a}") + + for i in sorted_inputs: + def add_input(a): + if s := get_name(a): + graph_inputs.add(s) + + for_args(add_input , i) + + for idx, node in enumerate(nodes): + def add_def(a): + if s := get_name(a): + assert s not in def_table + def_table[s] = idx + + for o in node.outputs: + for_args(add_def, o) + + edges[idx] = Edges([], 0) + + for idx, user in enumerate(nodes): + def add_edge(a): + if s := get_name(a): + if s not in def_table: + assert s in graph_inputs + return + src = def_table[s] + edges[src].outs.append(idx) + edges[idx].ins += 1 + + for i in user.inputs: + for_args(add_edge, i.arg) + + def add_rank(a): + if s := get_name(a): + assert s not in rank + rank[s] = len(rank) + + def get_rank(a): + if s := get_name(a): + return rank[s] + else: + return -1 + + for i in sorted_inputs: + for_args(add_rank, i) + + def add_candidate(idx: int): + def get_ranks(i): + ranks = [] + for_args(lambda x: ranks.append(get_rank(x)), i) + return ranks + node = nodes[idx] + args_rank = [(a.name, get_ranks(a.arg)) for a in node.inputs] + heapq.heappush(candidates, (node.target, args_rank, idx)) + + for idx, e in edges.items(): + if e.ins == 0: + add_candidate(idx) + + while len(candidates) > 0: + _, _, idx = heapq.heappop(candidates) + node = nodes[idx] + for o in node.outputs: + for_args(add_rank, o) + ret.append(node) + assert idx in edges + for user in edges[idx].outs: + e = edges[user] + assert e.ins > 0 + e.ins -= 1 + if e.ins == 0: + add_candidate(user) + edges[idx].outs.clear() + + return ret + + sorted_nodes = sort_nodes(graph.nodes) + assert len(sorted_nodes) == len(graph.nodes) + + # Stage 2: Rename nodes. + name_table: Dict[str, str] = {} + + def rename_def(a): + def _rename(arg_name, values): + new_name = f"_{len(name_table)}" + assert arg_name not in name_table + name_table[arg_name] = new_name + assert arg_name in values + values[new_name] = values.pop(arg_name) + return new_name + + if a is None: + return + if isinstance(a, TensorArgument): + a.name = _rename(a.name, graph.tensor_values) + elif isinstance(a, SymIntArgument): + if a.type == "as_name": + a.as_name = _rename(a.as_name, graph.sym_int_values) + elif isinstance(a, SymBoolArgument): + if a.type == "as_name": + a.as_name = _rename(a.as_name, graph.sym_bool_values) + else: + raise AssertionError(f"Unknown argument type: {a}") + + def replace_use(a): + if a is None: + return + if isinstance(a, TensorArgument): + a.name = name_table.get(a.name, a.name) + elif isinstance(a, SymIntArgument): + if a.type == "as_name": + a.as_name = name_table.get(a.as_name, a.as_name) + elif isinstance(a, SymBoolArgument): + if a.type == "as_name": + a.as_name = name_table.get(a.as_name, a.as_name) + elif isinstance(a, OptionalTensorArgument): + if a.type == "as_tensor": + assert isinstance(a.as_tensor, str) + a.as_tensor = name_table.get(a.as_tensor, a.as_tensor) + else: + raise AssertionError(f"Unknown argument type: {a}") + + for i in sorted_inputs: + for_args(rename_def, i) + + for n in sorted_nodes: + for o in n.outputs: + for_args(rename_def, o) + + for n in sorted_nodes: + for i in n.inputs: + for_args(replace_use, i.arg) + + for o in sorted_outputs: + for_args(replace_use, o) + + # Stage 3: Remove unstable fields. + for n in sorted_nodes: + n.metadata.clear() + + # Stage 4: Aggregate values. + sorted_tensor_values = dict(sorted(graph.tensor_values.items(), key=lambda x: x[0])) + sorted_sym_int_values = dict(sorted(graph.sym_int_values.items(), key=lambda x: x[0])) + sorted_sym_bool_values = dict(sorted(graph.sym_bool_values.items(), key=lambda x: x[0])) + + # Stage 5: Recurse in subgraphs. + counter = 0 + for node in sorted_nodes: + for i in node.inputs: + a = i.arg + if a.type == "as_graph": + a.as_graph.graph = _canonicalize_graph( + a.as_graph.graph.inputs, + a.as_graph.graph.outputs, + a.as_graph.graph + ) + a.as_graph.name = f"_g{counter}" + counter += 1 + + graph = Graph( + inputs=sorted_inputs, + outputs=sorted_outputs, + nodes=sorted_nodes, + tensor_values=sorted_tensor_values, + sym_int_values=sorted_sym_int_values, + sym_bool_values=sorted_sym_bool_values, + is_single_tensor_return=graph.is_single_tensor_return, + ) + return graph, name_table + + +def canonicalize(ep: ExportedProgram) -> ExportedProgram: + """ + Normalize a serialized ExportedProgram, so that different eager program which + shares the same semantics can get a single representation on disk. + + This function canonicalizes an ExportedProgram by: + + 1. Sorting nodes in topological order. + 2. Rename nodes to have unique names. + 3. Remove unstable fields. + 4. Aggregate the above program fields. + 5. Recurse in subgraphs. + + Args: + ep (ExportedProgram): The ExportedProgram to canonicalize. + + Returns: + ExportedProgram: The canonicalized exported program. + """ + ep = copy.deepcopy(ep) + + opset_version = dict(sorted(ep.opset_version.items(), key=lambda x: x[0])) + range_constraints = dict(sorted(ep.range_constraints.items(), key=lambda x: x[0])) + module_call_graph = sorted(ep.graph_module.module_call_graph, key=lambda x: x.fqn) + signature = ep.graph_module.signature + graph = ep.graph_module.graph + + assert len(graph.inputs) == len(signature.input_specs) + assert len(graph.outputs) == len(signature.output_specs) + + def rank_input(inp) -> Tuple[int, Optional[str], int]: + idx, (arg, spec) = inp + assert isinstance(spec, InputSpec) + if spec.type == "user_input": + return 5, None, idx + elif spec.type == "parameter": + return 1, spec.parameter.parameter_name, idx + elif spec.type == "buffer": + return 2, spec.buffer.buffer_name, idx + elif spec.type == "tensor_constant": + return 3, spec.tensor_constant.tensor_constant_name, idx + elif spec.type == "custom_obj": + return 4, spec.custom_obj.custom_obj_name, idx + else: + raise AssertionError(f"Unknown input type: {spec}") + + def rank_output(out) -> Tuple[int, Optional[str], int]: + idx, (arg, spec) = out + assert isinstance(spec, OutputSpec) + if spec.type == "user_output": + return 3, None, idx + elif spec.type == "loss_output": + return 3, None, idx + elif spec.type == "buffer_mutation": + return 1, spec.buffer_mutation.buffer_name, idx + elif spec.type == "gradient_to_parameter": + return 4, spec.gradient_to_parameter.parameter_name, idx + elif spec.type == "gradient_to_user_input": + return 5, None, idx + elif spec.type == "user_input_mutation": + return 2, None, idx + else: + raise AssertionError(f"Unknown output type: {spec}") + + sorted_ins = sorted(enumerate(zip(graph.inputs, signature.input_specs)), key=rank_input) + sorted_inputs, input_specs = zip(*(i for idx, i in sorted_ins)) # type: ignore[assignment] + + sorted_outs = sorted(enumerate(zip(graph.outputs, signature.output_specs)), key=rank_output) + sorted_outputs, output_specs = zip(*(i for idx, i in sorted_outs)) # type: ignore[assignment] + + sorted_graph, replace_table = _canonicalize_graph(sorted_inputs, sorted_outputs, graph) + + def replace_input(inp): + assert isinstance(spec, InputSpec) + if spec.type == "user_input": + arg = spec.user_input.arg + if arg.type == "as_tensor": + t = arg.as_tensor + t.name = replace_table[t.name] + elif arg.type == "as_sym_int": + s = arg.as_sym_int + if s.type == "as_name": + s.as_name = replace_table[s.as_name] + elif s.type == "as_int": + pass + else: + raise AssertionError(f"Unknown sym_int type: {s}") + elif arg.type in ("as_none", "as_int", "as_float", "as_string", "as_custom_obj"): + return + else: + raise AssertionError(f"Unknown input type: {arg}") + elif spec.type == "parameter": + t = spec.parameter.arg + t.name = replace_table[t.name] + elif spec.type == "buffer": + t = spec.buffer.arg + t.name = replace_table[t.name] + elif spec.type == "tensor_constant": + t = spec.tensor_constant.arg + t.name = replace_table[t.name] + elif spec.type == "custom_obj": + return + else: + raise AssertionError(f"Unknown input type: {spec}") + + def replace_output(out): + assert isinstance(spec, OutputSpec) + if spec.type == "user_output": + arg = spec.user_output.arg + if arg.type == "as_tensor": + t = arg.as_tensor + t.name = replace_table[t.name] + elif arg.type == "as_sym_int": + s = arg.as_sym_int + if s.type == "as_name": + s.as_name = replace_table[s.as_name] + elif s.type == "as_int": + pass + else: + raise AssertionError(f"Unknown sym_int type: {s}") + elif arg.type in ("as_none", "as_int", "as_float", "as_string"): + return + else: + raise AssertionError(f"Unknown input type: {arg}") + elif spec.type == "loss_output": + t = spec.loss_output.arg + t.name = replace_table[t.name] + elif spec.type == "buffer_mutation": + t = spec.buffer_mutation.arg + t.name = replace_table[t.name] + elif spec.type == "gradient_to_parameter": + t = spec.gradient_to_parameter.arg + t.name = replace_table[t.name] + elif spec.type == "gradient_to_user_input": + g = spec.gradient_to_user_input + g.arg.name = replace_table[g.arg.name] + g.user_input_name = replace_table[g.user_input_name] + elif spec.type == "user_input_mutation": + u = spec.user_input_mutation + u.arg.name = replace_table[u.arg.name] + u.user_input_name = replace_table[u.user_input_name] + else: + raise AssertionError(f"Unknown output type: {spec}") + + for spec in input_specs: + replace_input(spec) + + for spec in output_specs: + replace_output(spec) + + return ExportedProgram( + graph_module=GraphModule( + graph=sorted_graph, + signature=GraphSignature( + input_specs=list(input_specs), + output_specs=list(output_specs), + ), + module_call_graph=module_call_graph, + ), + opset_version=opset_version, + range_constraints=range_constraints, + schema_version=ep.schema_version, + dialect=ep.dialect, + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/closure.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/closure.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4530b0b61459f610348910214bc19aaf9110f7f2 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/closure.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/extract_compiled_graph.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/extract_compiled_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..033d000c69d858aa1b8264d90c7d3e984229eb23 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/extract_compiled_graph.py @@ -0,0 +1,223 @@ +import copy +import dataclasses +import itertools +import os +from typing import Any, Callable, Dict, List + +import torch +import torch._lazy as lazy +import torch._lazy.metrics as metrics +from torch import fx +from torch._lazy import computation, debug as lazy_debug +from torch._lazy.tensor_factory_functions import tensor_factory_functions + +debug = os.environ.get("debug_extract_compiled_graph") is not None + + +@dataclasses.dataclass +class GraphInputMatcher: + """ + The GraphInputMatcher class setup the graph inputs for future calls after lazy tracing. + Specifically, those graph inputs corresponding to method parameters should be replaced with the + arguments for the current call. + + tensor_id_to_arg_idx maps the tensor id to the parameter index. + graph_input_tensor_ids, graph_input_ivalues list the tensor_id and ivalue for each of the + TS/XLA graph inputs. + """ + + tensor_id_to_arg_idx: Dict[int, int] + graph_input_tensor_ids: List[int] + # there are 2 categories of graph_input_tensors. + # Category 1: those whose id are not found in tensor_id_to_arg_idx. These are + # most likely const tensors and we can get its content from graph_input_tensors + # Category 2: those whose id are found in tensor_id_to_arg_idx. We should get + # the tensor from method arguments + graph_input_ivalues: List[Any] + + # get the real graph input tensors + def __call__(self, args): + real_input = [] + for tensor_id, traced_ivalue in zip( + self.graph_input_tensor_ids, self.graph_input_ivalues + ): + arg_idx = self.tensor_id_to_arg_idx.get(tensor_id, None) + if arg_idx is None: + inp = traced_ivalue + else: + inp = args[arg_idx] + real_input.append(inp) + return real_input + + +class ReturnValueHandler: + r""" + When ltc_sync_multi is called on multi tensors, the compiled graph + will contain output only for unique tensors - if a tensor appears multiple + times in the input to _ltc_sync_multi, only the first occurance matters. + + However from python level, we still expect multi tensors returned with duplciation + even if the TS graph dedup the output. e.g. for method: + + def forward(self, a): + return a, a + + the TS graph captured by LTC will return a single tensor, but Python method expects 2. + + This class dedup the lazy tensors first to get the index that will be used + to duplicate the eager tensors later. + """ + + def __init__(self, lazy_out_list): + self.index: List[List[int]] = [] + self.total_count = len(lazy_out_list) + + tensor_id_to_idx: Dict[int, int] = {} + for dup_idx, lazy_tensor in enumerate(lazy_out_list): + uniq_idx = tensor_id_to_idx.get(id(lazy_tensor), None) + if uniq_idx is not None: + self.index[uniq_idx].append(dup_idx) + else: + uniq_idx = len(self.index) + self.index.append([dup_idx]) + tensor_id_to_idx[id(lazy_tensor)] = uniq_idx + + def duplicate_eager_tensors(self, eager_tensor_list): + duplicated_list = [None] * self.total_count + assert len(eager_tensor_list) == len(self.index) + + for uniq_idx, eager_tensor in enumerate(eager_tensor_list): + for dup_idx in self.index[uniq_idx]: + duplicated_list[dup_idx] = eager_tensor + return duplicated_list + + +def force_lazy_device(model: fx.GraphModule): + """ + Factory methods in a Fx graph may create tensors for a specific eager devices. + If we take no actions, those eager tensors will be mixed with lazy tensors and + cause crash. This method overwrite those eager device to lazy device. + """ + + def tolazydevice(dev): + if isinstance(dev, torch.device): + return torch.device("lazy", index=dev.index) + return dev + + def hasDeviceArg(args, kwargs): + return any( + isinstance(arg, torch.device) + for arg in itertools.chain(args, kwargs.values()) + ) + + for nd in model.graph.nodes: + nd.args = tuple(tolazydevice(arg) for arg in nd.args) + nd.kwargs = {k: tolazydevice(v) for k, v in nd.kwargs.items()} + + # For torchbench like yolov3, hf_Bart, dynamo generates Fx graph that return + # eager tensors on the default device + # (check https://gist.github.com/shunting314/eabdf6c769c59bc384469717b8f9bb7f for yolove, + # and https://gist.github.com/shunting314/8d5e2d9348a3258959d3954186c48814 for hf_Bart). + # To force those tensors on the lazy device, we can not simply override + # the device argument since there is no explicit device argument. + # What we are doing here is, for the list of covered tensor factory methods + # we add a lazy device argument explicity. + # + # TODO: This solution is no ideal since we may miss some factory methods. In future + # when we support lazy mode, this method can be replaced by that. + if nd.target in tensor_factory_functions and not hasDeviceArg( + nd.args, nd.kwargs + ): + kwargs = dict(nd.kwargs) # nd.kwargs is immutable. make a mutable copy. + kwargs["device"] = torch.device("lazy") + nd.kwargs = kwargs + + model.recompile() + + +def get_fallback_ops(): + fallback_ops = [] + for opname in metrics.counter_names(): + if "aten::" not in opname: + continue + val = int(metrics.counter_value(opname)) + if val > 0: + fallback_ops.append(f"{opname}={val}") + + return fallback_ops + + +def extract_compiled_graph(model: fx.GraphModule, example_inputs) -> Callable: + """ + Optimize an eager model with LTC and returns a wrapper to execute the + compiled graph directly without retracing. It depends on other mechanisms + like TorchDynamo guards to guarantee the returned wrapper is only called + when it's safe. + """ + lazy_args = [arg.to(device="lazy") for arg in example_inputs] + args_tensor_ids = [lazy.get_tensor_id(lazy_arg) for lazy_arg in lazy_args] + tensor_id_to_arg_idx = {tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)} + lazy_model = copy.deepcopy(model).to(device=torch.device("lazy")) + force_lazy_device(lazy_model) + + # This line executes lazy tracing and enable us extracting compiled graph later + metrics.reset() + lazy_out = lazy_model(*lazy_args) + fallback_ops = get_fallback_ops() + metrics.reset() + + if len(fallback_ops) > 0: + raise RuntimeError( + f"Fail to extact the compiled graph because of fallback: {','.join(fallback_ops)}" + ) + + if not isinstance(lazy_out, (tuple, list)): + lazy_out = (lazy_out,) + + args_and_out = tuple(lazy_args) + tuple(lazy_out) + return_value_handler = ReturnValueHandler(args_and_out) + if debug: + print("Fx code:\n", model.code) + print("LTC IR:", lazy_debug.dump_ir(args_and_out, "text")) + + # TODO: this part is TS backend specific for now and will be generalized to + # support XLA + ( + graph_input_tensor_ids, + graph_input_ivalues, + ) = computation.get_tensors_ts_device_data_node(args_and_out) + assert len(graph_input_tensor_ids) == len(graph_input_ivalues) + graph_input_matcher = GraphInputMatcher( + tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_ivalues + ) + + graph_hash = computation.get_graph_hash(args_and_out) + + if debug: + print("graph_hash", graph_hash) + print(f"args_tensor_ids {args_tensor_ids}") + print("tensor ids from device data:", graph_input_tensor_ids) + + # sync the list of output tensors so the computation graph for these + # tensors will be cached. Those computation graphs can be retrieved + # by graph hash later. + lazy.sync_multi(args_and_out, []) + + def optimized_mod(*args): + if len(args_and_out) == 0: + return () + graph_input = graph_input_matcher(args) + res = return_value_handler.duplicate_eager_tensors( + computation.run_cached_graph(graph_hash, graph_input) + ) + + assert len(res) == len(args_and_out) + for i, arg in enumerate(args): + # only copy those tensors that get inplace updated + if arg is not res[i]: + arg.copy_(res[i]) + + # skip the args + return res[len(args) :] + + return optimized_mod diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/__pycache__/wrappers.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/__pycache__/wrappers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3692b85c8d3ce945b38918345a86e7fb5766287c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/__pycache__/wrappers.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/generate_numeric_debug_handle.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/generate_numeric_debug_handle.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e1d0f34a7a13c6f7f27968c24e74d0e8cc521e8 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/generate_numeric_debug_handle.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/port_metadata_pass.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/port_metadata_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..624e7f99a3cd1349fe65619cd008c622240009e6 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/port_metadata_pass.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/x86_inductor_quantizer.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/x86_inductor_quantizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf8a1552d3823116a5b2873ecdef60e6f0e96664 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/x86_inductor_quantizer.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/contrib/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/contrib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/contrib/_tensorboard_vis.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/contrib/_tensorboard_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..87c325948a8b111d42409140a5d1f8150342794c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/contrib/_tensorboard_vis.py @@ -0,0 +1,142 @@ +import time +from collections import defaultdict +from functools import partial +from typing import DefaultDict + +import torch + + +# Unfortunately it doesn't seem as if there was any way to get TensorBoard to do +# anything without having TF installed, and so this file has a hard dependency on it +# as well. It really is a debugging tool, so it doesn't matter. +try: + from tensorflow.core.util import event_pb2 + from tensorflow.core.framework import graph_pb2 + from tensorflow.python.summary.writer.writer import FileWriter +except ImportError: + raise ImportError("TensorBoard visualization of GraphExecutors requires having " + "TensorFlow installed") from None + + +def dump_tensorboard_summary(graph_executor, logdir): + with FileWriter(logdir) as w: + pb_graph = visualize(graph_executor) + evt = event_pb2.Event(wall_time=time.time(), graph_def=pb_graph.SerializeToString()) + w.add_event(evt) + + +def visualize(graph, name_prefix='', pb_graph=None, executors_it=None): + """Visualizes an independent graph, or a graph executor.""" + value_map = {} + pb_graph = pb_graph or graph_pb2.GraphDef() + + if isinstance(graph, torch._C.GraphExecutorState): + visualize_graph_executor(graph, name_prefix, pb_graph, + partial(visualize, pb_graph=pb_graph)) + return pb_graph + + # Set up an input node + input_node = pb_graph.node.add(op='input', name=name_prefix + 'input') + for i, value in enumerate(graph.param_node().outputs()): + value_map[value.unique()] = name_prefix + 'input:' + str(i) + + visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it) + + # Gather all outputs + return_node = pb_graph.node.add(op='output', name=name_prefix + 'output') + for value in graph.return_node().inputs(): + return_node.input.append(value_map[value.unique()]) + + return pb_graph + + +def visualize_graph_executor(state, name_prefix, pb_graph, inline_graph): + """Append the state of a given GraphExecutor to the graph protobuf. + + Args: + state (GraphExecutor or GraphExecutorState): GraphExecutor to display. + name_prefix (str): Name prefix of the containing subgraph. + pb_graph (GraphDef): graph to append to. + inline_graph (Callable): a function that handles setting up a value_map, + so that some graphs in here can be inlined. This is necessary, because + this will simply be `visualize` for the top-level GraphExecutor, + or `inline_graph` for all nested ones. + + The signature should look like (Graph, name_prefix) -> (). + It will be called exactly once. + + The strategy is to embed all different configurations as independent subgraphs, + while inlining the original graph as the one that actually produces the values. + """ + if state.autograd_fallback_graph is not None: + visualize(graph=state.autograd_fallback_graph, + name_prefix=name_prefix + 'autograd_fallback/', + pb_graph=pb_graph, + executors_it=iter(state.autograd_fallback.executors())) + + for i, (arg_spec, plan) in enumerate(state.execution_plans.items()): + subgraph_name = name_prefix + f'plan{i}/' + + # Create a disconnected node that will keep information regarding the input + # types of this trace. This is unfortunately a bit too verbose to be included + # in the subgraph name. + input_kinds = pb_graph.node.add(op='INPUT_KIND', name=subgraph_name) + input_kinds.attr['inputs'].s = repr(arg_spec).encode('ascii') + + visualize(plan.graph, subgraph_name, pb_graph, iter(plan.code.executors())) + + # Show gradient as an independent subgraph of this plan + if plan.grad_executor is not None: + grad_subgraph_name = subgraph_name + 'grad/' + visualize(plan.grad_executor, grad_subgraph_name, pb_graph) + + return inline_graph(state.graph, name_prefix + 'original/') + + +def visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it=None): + """Recursive part of visualize (basically skips setting up the input and output nodes).""" + def inline_graph(subgraph, name, node): + rec_value_map = {inp.unique(): value_map[val.unique()] + for inp, val in zip(subgraph.inputs(), node.inputs())} + visualize_rec(graph=subgraph, + value_map=rec_value_map, + name_prefix=name, + pb_graph=pb_graph) + for out, val in zip(subgraph.outputs(), node.outputs()): + value_map[val.unique()] = rec_value_map[out.unique()] + + op_id_counter: DefaultDict[str, int] = defaultdict(int) + + def name_for(node): + kind = node.kind()[node.kind().index('::') + 2:] + op_id_counter[kind] += 1 + return kind, name_prefix + kind + '_' + str(op_id_counter[kind]) + + def add_fusion_group(node): + op, name = name_for(node) + inline_graph(node.g('Subgraph'), name + '/', node) + + def add_graph_executor(node): + op, name = name_for(node) + if executors_it is None: + add_node(node) + else: + ge = next(executors_it) + visualize_graph_executor(ge, name + '/', pb_graph, + partial(inline_graph, node=node)) + + def add_node(node): + if node.kind() == 'prim::FusionGroup': + return add_fusion_group(node) + elif node.kind() == 'prim::GraphExecutor': + return add_graph_executor(node) + op, name = name_for(node) + pb_node = pb_graph.node.add(op=op, name=name) + for value in node.inputs(): + pb_node.input.append(value_map[value.unique()]) + # TODO: handle attrs + for i, value in enumerate(node.outputs()): + value_map[value.unique()] = name + ':' + str(i) + + for node in graph.nodes(): + add_node(node) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..02c15ec395d15dbd1012ca3373069e629072bc64 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py @@ -0,0 +1,1202 @@ +import builtins +import copy +import functools +import inspect +import math +import os +import warnings +import collections +from itertools import chain +from types import CodeType, FunctionType, ModuleType +from typing import ( + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Set, + Tuple, + Type, + Union, +) + +import torch +import torch.utils._pytree as pytree +from torch._C import ScriptObject # type: ignore[attr-defined] + +from ._compatibility import compatibility +from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph +from .graph_module import GraphModule +from ._lazy_graph_module import _make_graph_module +from .node import Argument, base_types, map_aggregate +from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager + +HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS + +# These need to run in global scope to handle nested calls correctly +_orig_module_call: Callable = torch.nn.Module.__call__ +_orig_module_getattr: Callable = torch.nn.Module.__getattr__ + +_proxyable_classes: Dict[Type, None] = {} + +_is_fx_tracing_flag = False + + +def is_fx_tracing(): + return _is_fx_tracing_flag + +@compatibility(is_backward_compatible=True) +class ProxyableClassMeta(type): + """ + ProxyableClassMeta allows you to make construction of a given Python class + symbolically traceable. For example:: + + import torch + import torch.fx + + class TensorPair(metaclass=torch.fx.ProxyableClassMeta): + def __init__(self, left, right): + self.left, self.right = left, right + + def add(self, other): + l = self.left + other.left + r = self.right + other.right + return TensorPair(l, r) + + def mul(self, other): + l = self.left * other.left + r = self.right * other.right + return TensorPair(l, r) + + def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor): + s = x.add(TensorPair(y, y)) + return s.mul(x) + + x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) + y = torch.randn(5, 3) + ref_out = use_tensor_pair_ctor(x, y) + + traced = torch.fx.symbolic_trace(use_tensor_pair_ctor) + print(traced.code) + ''' + def forward(self, x : __main___TensorPair, y : torch.Tensor): + tensor_pair = __main___TensorPair(y, y); y = None + add = x.add(tensor_pair); tensor_pair = None + mul = add.mul(x); add = x = None + return mul + ''' + + From this example, we can see that construction of a class (``TensorPair``) + defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic + tracing. + """ + + def __init__(cls, name, bases, attrs): + _proxyable_classes.setdefault(cls) + super().__init__(name, bases, attrs) + + def __call__(cls, *args, **kwargs): + instance = cls.__new__(cls) # type: ignore[call-overload] + + if not is_fx_tracing(): + cls.__init__(instance, *args, **kwargs) # type: ignore[misc] + return instance + + found_proxies = [] + + def check_proxy(a): + if isinstance(a, Proxy): + found_proxies.append(a) + + map_aggregate(args, check_proxy) + map_aggregate(kwargs, check_proxy) + + if len(found_proxies) != 0: + tracer = found_proxies[0].tracer + return tracer.create_proxy("call_function", cls, args, kwargs) + else: + cls.__init__(instance, *args, **kwargs) # type: ignore[misc] + return instance + + +def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: + co = fn.__code__ + co_flags = co.co_flags & ~HAS_VARSTUFF + co_args: tuple + if hasattr(co, "co_qualname"): + # Python-3.11+ code signature + co_args = ( + nargs, + 0, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_qualname, # type: ignore[attr-defined] + co.co_firstlineno, + co.co_lnotab, + co.co_exceptiontable, # type: ignore[attr-defined] + co.co_freevars, + co.co_cellvars, + ) + elif hasattr(co, "co_posonlyargcount"): + co_args = ( + nargs, + 0, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_freevars, + co.co_cellvars, + ) + else: + co_args = ( + nargs, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_freevars, + co.co_cellvars, + ) + new_code = CodeType(*co_args) # type: ignore[arg-type] + return FunctionType( + new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__ + ) + + # we need to insert placeholder nodes for *args and **kwargs + # we can't call this function normally, otherwise it would try to unpack them + # instead, let's make python think that args and kwargs are normal variables + + +@compatibility(is_backward_compatible=False) +class PHBase: + """ + Object representing an input placeholder to `concrete_args` + """ + + def __repr__(self): + return "PH" + + +PH = PHBase() + + +@compatibility(is_backward_compatible=False) +class PHWithMeta(PHBase): + """ + Object representing an input placeholder to `concrete_args` + """ + def __init__(self, ph_key: Optional[str] = None): + super().__init__() + + # Provide a hey for user to identify placeholder node during analysis + self.ph_key = ph_key + + +def _transfer_attrs(fr, to): + for attr_name in dir(fr): + attr_val = getattr(fr, attr_name) + if ( + not callable(attr_val) + and not attr_name.startswith("__") + and not hasattr(to, attr_name) + ): + setattr(to, attr_name, attr_val) + + +@compatibility(is_backward_compatible=True) +class Tracer(TracerBase): + # Reference: https://github.com/pytorch/pytorch/issues/54354 + # The first line of this docstring overrides the one Sphinx generates for the + # documentation. We need it so that Sphinx doesn't leak `math`s path from the + # build environment (e.g. ` None: + # This method's signature is overridden by the first line of this class' + # docstring. If this method's signature is modified, the signature that + # overrides it also should be modified accordingly. + + """ + Construct a Tracer object. + + Args: + + autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`, + Python modules whose functions should be wrapped automatically + without needing to use fx.wrap(). Backward-compatibility for + this parameter is guaranteed. + + autowrap_functions (Tuple[Callable, ...]): defaults to `()`, + Python functions that should be wrapped automatically without + needing to use fx.wrap(). Backward compatibility for this + parameter is guaranteed. + + param_shapes_constant (bool): When this flag is set, calls to shape, + size and a few other shape like attributes of a module's parameter + will be evaluated directly, rather than returning a new Proxy value + for an attribute access. Backward compatibility for this parameter + is guaranteed. + """ + + super().__init__() + + # Functions we will eagerly wrap when we see them while tracing + # this captures both `math.sqrt()` and `from math import sqrt` automatically + self._autowrap_function_ids: Set[int] = { + id(value) + for name, value in chain(*[m.__dict__.items() for m in autowrap_modules]) + if not name.startswith("_") and callable(value) + } + self._autowrap_function_ids.update({id(f) for f in autowrap_functions}) + + # Python modules to apply autowrap to at the start, in addition to + # modules we see while tracing + self._autowrap_search: List[ModuleType] = list(autowrap_modules) + self.param_shapes_constant = param_shapes_constant + + self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None + self.root_module_name: str = "" + # Maps the containing module's name to the operator name + self.scope = Scope("", None) + # Records the module call stack + self.module_stack = collections.OrderedDict() + # Mapping of node name to module scope + self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} + + @compatibility(is_backward_compatible=True) + def create_arg(self, a: Any) -> "Argument": + """ + A method to specify the behavior of tracing when preparing values to + be used as arguments to nodes in the ``Graph``. + + By default, the behavior includes: + + #. Iterate through collection types (e.g. tuple, list, dict) and recursively + call ``create_args`` on the elements. + #. Given a Proxy object, return a reference to the underlying IR ``Node`` + #. Given a non-Proxy Tensor object, emit IR for various cases: + + * For a Parameter, emit a ``get_attr`` node referring to that Parameter + * For a non-Parameter Tensor, store the Tensor away in a special + attribute referring to that attribute. + + This method can be overridden to support more types. + + Args: + + a (Any): The value to be emitted as an ``Argument`` in the ``Graph``. + + + Returns: + + The value ``a`` converted into the appropriate ``Argument`` + """ + # The base tracer is used to construct Graphs when there is no associated + # module hierarchy, so it can never create parameter references. + # The default tracer adds the ability to refer to parameters when + # tracing modules. + if isinstance(a, torch.nn.Parameter): + for n, p in self.root.named_parameters(): + if a is p: + return self.create_node("get_attr", n, (), {}) + raise NameError("parameter is not a member of this module") + elif isinstance(a, torch.Tensor): + for n_, p_ in self.root.named_buffers(): + if a is p_: + return self.create_node("get_attr", n_, (), {}) + elif isinstance(a, torch.nn.Module): + for n_, p_ in self.root.named_modules(): + if a is p_: + return self.create_node("get_attr", n_, (), {}) + # For NamedTuple instances that appear literally as args, we emit + # a node to construct the NamedTuple and use that Node as the argument. + if isinstance(a, tuple) and hasattr(a, "_fields"): + args = tuple(self.create_arg(elem) for elem in a) + return self.create_node("call_function", a.__class__, args, {}) + + # Tensors do not have a reliable string repr() from which they can be + # constructed (and we probably don't want to rely on that, either), so + # for any constant Tensor values we encounter, first search for if they + # are an attribute of some module in the module hierarchy. If so, emit + # a get_attr to retrieve that tensor. Otherwise, we'll store away the + # tensor value into a special attribute on the Module s.t. we can + # retrieve it with a get_attr. + if isinstance(a, (torch.Tensor, ScriptObject)): + qualname: Optional[str] = self.tensor_attrs.get(a) + + # Tensor was not found in the Module hierarchy, stow it away in a + # special attribute and set the qualname to refer to that + if not qualname: + i = 0 + while True: + qualname = f"_tensor_constant{i}" + if not hasattr(self.root, qualname): + break + i += 1 + self.tensor_attrs[a] = qualname + setattr(self.root, qualname, a) + + return self.create_node("get_attr", qualname, (), {}) + + if type(a) in _proxyable_classes: + # This is an instance of a proxyable class for which we did not + # witness its construction. Intern this as a constant attribute + + # TODO: binary search + i = 0 + while True: + qualname = f"_{a.__class__.__name__}_constant_{i}" + if not hasattr(self.root, qualname): + break + i += 1 + setattr(self.root, qualname, a) + + return self.create_node("get_attr", qualname, (), {}) + + return super().create_arg(a) + + @compatibility(is_backward_compatible=True) + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + """ + A method to specify whether a given ``nn.Module`` is a "leaf" module. + + Leaf modules are the atomic units that appear in + the IR, referenced by ``call_module`` calls. By default, + Modules in the PyTorch standard library namespace (torch.nn) + are leaf modules. All other modules are traced through and + their constituent ops are recorded, unless specified otherwise + via this parameter. + + Args: + + m (Module): The module being queried about + module_qualified_name (str): The path to root of this module. For example, + if you have a module hierarchy where submodule ``foo`` contains + submodule ``bar``, which contains submodule ``baz``, that module will + appear with the qualified name ``foo.bar.baz`` here. + """ + return ( + (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn")) + and not isinstance(m, torch.nn.Sequential) + ) + + @compatibility(is_backward_compatible=True) + def path_of_module(self, mod: torch.nn.Module) -> str: + """ + Helper method to find the qualified name of ``mod`` in the Module hierarchy + of ``root``. For example, if ``root`` has a submodule named ``foo``, which has + a submodule named ``bar``, passing ``bar`` into this function will return + the string "foo.bar". + + Args: + + mod (str): The ``Module`` to retrieve the qualified name for. + """ + # Prefer the O(1) algorithm + if self.submodule_paths: + path = self.submodule_paths.get(mod) + if path is None: + raise NameError("module is not installed as a submodule") + assert isinstance(path, str) + return path + # O(N^2) fallback in the case that we didn't store the submodule + # paths. + else: + for n, p in self.root.named_modules(): + if mod is p: + return n + raise NameError("module is not installed as a submodule") + + @compatibility(is_backward_compatible=True) + def call_module( + self, + m: torch.nn.Module, + forward: Callable[..., Any], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ) -> Any: + """ + Method that specifies the behavior of this ``Tracer`` when it encounters + a call to an ``nn.Module`` instance. + + By default, the behavior is to check if the called module is a leaf module + via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to + ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through + the operations in its ``forward`` function. + + This method can be overridden to--for example--create nested traced + GraphModules, or any other behavior you would want while tracing across + ``Module`` boundaries. + + Args: + + m (Module): The module for which a call is being emitted + forward (Callable): The forward() method of the ``Module`` to be invoked + args (Tuple): args of the module callsite + kwargs (Dict): kwargs of the module callsite + + Return: + + The return value from the Module call. In the case that a ``call_module`` + node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever + value was returned from the ``Module`` invocation. + """ + module_qualified_name = self.path_of_module(m) + with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope: + # module_stack is an ordered dict so writing then deleting the + # entry is equivalent to push/pop on a list + self.module_stack[_scope.module_path] = (module_qualified_name, _scope.module_type) + if not self.is_leaf_module(m, module_qualified_name): + ret_val = forward(*args, **kwargs) + else: + ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs) + key, _ = self.module_stack.popitem(last=True) + assert key == _scope.module_path, f" Unexpected key {key}" + + return ret_val + + @compatibility(is_backward_compatible=False) + def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]): + """ + Method that specifies the behavior of this ``Tracer`` when we call getattr + on a call to an ``nn.Module`` instance. + + By default, the behavior is to return a proxy value for the attribute. It + also stores the proxy value in the ``parameter_proxy_cache``, so that future + calls will reuse the proxy rather than creating a new one. + + This method can be overridden to --for example-- not return proxies when + querying parameters. + + Args: + + attr (str): The name of the attribute being queried + attr_val (Any): The value of the attribute + parameter_proxy_cache (Dict[str, Any]): A cache of attr names to proxies + + Return: + + The return value from the getattr call. + """ + def maybe_get_proxy_for_attr( + attr_val, collection_to_search, parameter_proxy_cache + ): + for n, p in collection_to_search: + if attr_val is p: + if n not in parameter_proxy_cache: + kwargs = {} + if ( + "proxy_factory_fn" + in inspect.signature(self.create_proxy).parameters + ): + kwargs["proxy_factory_fn"] = ( + None + if not self.param_shapes_constant + else lambda node: ParameterProxy( + self, node, n, attr_val + ) + ) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] + parameter_proxy_cache[n] = val_proxy + return parameter_proxy_cache[n] + return None + + if isinstance(attr_val, torch.nn.Parameter): + maybe_parameter_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_parameters(), parameter_proxy_cache + ) + if maybe_parameter_proxy is not None: + return maybe_parameter_proxy + + if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): + maybe_buffer_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_buffers(), parameter_proxy_cache + ) + if maybe_buffer_proxy is not None: + return maybe_buffer_proxy + + return attr_val + + # This method will be refactored + @compatibility(is_backward_compatible=False) + def create_args_for_root(self, root_fn, is_module, concrete_args=None): + """ + Create ``placeholder`` nodes corresponding to the signature of the ``root`` + Module. This method introspects root's signature and emits those + nodes accordingly, also supporting ``*args`` and ``**kwargs``. + """ + # In some cases, a function or method has been decorated with a wrapper + # defined via ``functools.wraps``. In this case, the outer code object + # will likely not contain the actual parameters we care about, so unwrap + # the function to get to the innermost callable. + fn_for_analysis = inspect.unwrap(root_fn) + co = fn_for_analysis.__code__ + total_args = co.co_argcount + co.co_kwonlyargcount + orig_args = list(co.co_varnames) + names_iter = iter(co.co_varnames) + args: List[Any] = [] + skip_arg_idx = 0 + if is_module: + if total_args == 0: + raise RuntimeError( + "``self`` argument cannot be part of *args expansion!" + ) + skip_arg_idx = 1 + next(names_iter) # skip self + args.append(self.root) + + sig = inspect.signature(fn_for_analysis) + + + # This covers the very specific case where we are passing in flat + # concrete_args as a tuple, but our traced fn takes (*args, **kwargs). + # In this case, just take the concrete_args and pass them through. + name_idx = 0 + if isinstance(concrete_args, tuple) and \ + len(concrete_args) > 0 and \ + (co.co_flags & HAS_VARSTUFF) and \ + total_args == 1: + for concrete_arg in concrete_args: + out = self.create_proxy("placeholder", f"input_{name_idx}", (), {}) + if isinstance(concrete_arg, PHBase): + if concrete_arg != PH: + # Transfer attrs in the case where you're using a placeholder other + # than the singleton PH (PH has no attributes to transfer). + # Proxies were created out of the placeholders. + # Transfer any metadata (put on the placeholders in the form of + # attributes set by the user) from the placeholder to the + # underlying nodes (the proxy is unwrapped by the user, but + # the metadata should hold). + _transfer_attrs(fr=concrete_arg, to=out.node) + args.append(out) + name_idx += 1 + return root_fn, args + + arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)] + if isinstance(concrete_args, tuple): + if len(arg_names) != len(concrete_args): + raise RuntimeError( + f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments" + ) + concrete_args = dict(zip(arg_names, concrete_args)) + + def proxy_placeholder(name): + return self._proxy_placeholder(name, concrete_args, sig, fn_for_analysis) + + args.extend(proxy_placeholder(names) for names in arg_names) + + if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF: + # TODO: type annotations for *args and **kwargs + if co.co_flags & inspect.CO_VARARGS: + args.append(proxy_placeholder("*" + next(names_iter))) + if co.co_flags & inspect.CO_VARKEYWORDS: + args.append(proxy_placeholder("**" + next(names_iter))) + root_fn = _patch_function(root_fn, len(args)) + + flat_args, in_spec = pytree.tree_flatten(tuple(args)) + if not all(child.is_leaf() for child in in_spec.children_specs): + # In the case that we have pytree-flattened inputs in + # `concrete_args`, generate a flattening wrapper around the + # original root function and return that. + self.graph._codegen = _PyTreeCodeGen( + _PyTreeInfo(orig_args[:total_args], in_spec, None) + ) + + def flatten_fn(*args): + tree_args = pytree.tree_unflatten(list(args), in_spec) + tree_out = root_fn(*tree_args) + out_args, out_spec = pytree.tree_flatten(tree_out) + assert isinstance(self.graph._codegen, _PyTreeCodeGen) + self.graph._codegen.pytree_info = ( + self.graph._codegen.pytree_info._replace(out_spec=out_spec) + ) + return out_args + + return flatten_fn, flat_args + return root_fn, args + + @compatibility(is_backward_compatible=True) + def trace( + self, + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, + ) -> Graph: + """ + Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` + can either be an ``nn.Module`` instance or a Python callable. + + Note that after this call, ``self.root`` may be different from the ``root`` passed + in here. For example, when a free function is passed to ``trace()``, we will + create an ``nn.Module`` instance to use as the root and add embedded constants + to. + + + Args: + + root (Union[Module, Callable]): Either a ``Module`` or a function to be + traced through. Backwards-compatibility for this parameter is + guaranteed. + concrete_args (Optional[Dict[str, any]]): Concrete arguments that should + not be treated as Proxies. This parameter is experimental and + its backwards-compatibility is *NOT* guaranteed. + + Returns: + + A ``Graph`` representing the semantics of the passed-in ``root``. + """ + global _is_fx_tracing_flag + old_is_fx_tracing_flag = _is_fx_tracing_flag + _is_fx_tracing_flag = True + try: + if isinstance(root, torch.nn.Module): + + # do real recompilation for _LazyGraphModule before retracing since the trace + # method can not trace the _lazy_forward method. Got error: + # https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259 + # without this. + from torch.fx._lazy_graph_module import _LazyGraphModule + _LazyGraphModule.force_recompile(root) + + self.root = root + + assert hasattr( + type(root), self.traced_func_name + ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}" + + fn = getattr(type(root), self.traced_func_name) + self.root_module_name = root._get_name() + self.submodule_paths = {mod: name for name, mod in root.named_modules()} + else: + self.root = torch.nn.Module() + fn = root + + tracer_cls: Optional[Type[Tracer]] = getattr(self, "__class__", None) + self.graph = Graph(tracer_cls=tracer_cls) + if hasattr(fn, '__code__'): + code = fn.__code__ + self.graph._co_fields = { + 'co_name': code.co_name, + 'co_filename': code.co_filename, + 'co_firstlineno': code.co_firstlineno, + } + + # When we encounter a Tensor value that's not a parameter, we look if it + # is some other attribute on the model. Construct a dict mapping Tensor + # values to the qualified name here for efficiency. This is used downstream + # in create_arg + self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {} + + def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): + for k, v in m.__dict__.items(): + if isinstance(v, (torch.Tensor, ScriptObject)): + self.tensor_attrs[v] = ".".join(prefix_atoms + [k]) + for k, v in m.named_children(): + collect_tensor_attrs(v, prefix_atoms + [k]) + + collect_tensor_attrs(self.root, []) + + assert isinstance(fn, FunctionType) + + fn_globals = fn.__globals__ # run before it gets patched + fn, args = self.create_args_for_root( + fn, isinstance(root, torch.nn.Module), concrete_args + ) + + parameter_proxy_cache: Dict[ + str, Proxy + ] = {} # Reduce number of get_attr calls + + # Method dispatch on parameters is not recorded unless it's directly used. + # Thus, we need to insert a proxy when __getattr__ requests a parameter. + @functools.wraps(_orig_module_getattr) + def module_getattr_wrapper(mod, attr): + attr_val = _orig_module_getattr(mod, attr) + return self.getattr(attr, attr_val, parameter_proxy_cache) + + @functools.wraps(_orig_module_call) + def module_call_wrapper(mod, *args, **kwargs): + def forward(*args, **kwargs): + return _orig_module_call(mod, *args, **kwargs) + + _autowrap_check( + patcher, + getattr(getattr(mod, "forward", mod), "__globals__", {}), + self._autowrap_function_ids, + ) + return self.call_module(mod, forward, args, kwargs) + + with _Patcher() as patcher: + # allow duplicate patches to support the case of nested calls + patcher.patch_method( + torch.nn.Module, + "__getattr__", + module_getattr_wrapper, + deduplicate=False, + ) + patcher.patch_method( + torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False + ) + _patch_wrapped_functions(patcher) + _autowrap_check(patcher, fn_globals, self._autowrap_function_ids) + for module in self._autowrap_search: + _autowrap_check( + patcher, module.__dict__, self._autowrap_function_ids + ) + self.create_node( + "output", + "output", + (self.create_arg(fn(*args)),), + {}, + type_expr=fn.__annotations__.get("return", None), + ) + + self.submodule_paths = None + finally: + _is_fx_tracing_flag = old_is_fx_tracing_flag + return self.graph + + def __deepcopy__(self, memo): + # _autowrap_search contains modules, which cannot be deepcopied. + new_tracer = Tracer.__new__(Tracer) + + for k, v in self.__dict__.items(): + if k in {'_autowrap_search'}: + new_obj = copy.copy(v) + else: + new_obj = copy.deepcopy(v, memo) + + new_tracer.__dict__[k] = new_obj + + return new_tracer + + def _proxy_placeholder(self, name, concrete_args, sig, fn_for_analysis): + if concrete_args is not None and name in concrete_args: + cnt = 0 + + def replace_ph(x): + nonlocal cnt + cnt += 1 + param = sig.parameters[name] + default = ( + () + if param.default is inspect.Parameter.empty + else (param.default,) + ) + out = self.create_proxy( + "placeholder", f"{name}_{str(cnt)}", default, {} + ) + if isinstance(x, PHBase): + if x != PH: + # Transfer attrs in the case where you're using a placeholder other + # than the singleton PH (PH has no attributes to transfer). + # Proxies were created out of the placeholders. + # Transfer any metadata (put on the placeholders in the form of + # attributes set by the user) from the placeholder to the + # underlying nodes (the proxy is unwrapped by the user, but + # the metadata should hold). + _transfer_attrs(fr=x, to=out.node) + + return out + # Union[int, bool] == bool in Python <= 3.6 + if ( + type(x) == bool + or type(x) in base_types + and type(x) != torch.Tensor + ): + torch._assert( + out == x, + f"{name} has been specialized to have value {x} but got another value", + ) + elif x is None: + args = ( + out, + f"{name} has been specialized to have value None but got another value", + ) + self.create_proxy("call_function", _assert_is_none, args, {}) + else: + warnings.warn( + f"Was not able to add assertion to guarantee correct input {name} to " + f"specialized function. It is up to the user to make sure that your inputs match the " + f"inputs you specialized the function with." + ) + + return x + + return pytree.tree_map(replace_ph, concrete_args[name]) + if name[0] == "*": + default = () + else: + param = sig.parameters[name] + default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore[assignment] + return self.create_proxy( + "placeholder", + name, + default, + {}, + type_expr=fn_for_analysis.__annotations__.get(name, None) + ) + + +# Dictionary of (id(globals dict), function name) => globals_dict to patch for +# the purposes of the wrap() API. +# We key by the globals dict id and function name to ensure we're wrapping a given +# function only once. +_wrapped_fns_to_patch: Dict[Tuple[int, str], dict] = {} + +# List of methods on classes to wrap (class type, function name) +# this currently only works for Tensor.* methods that aren't traced properly +_wrapped_methods_to_patch: List[Tuple[type, str]] = [] + +if os.environ.get("FX_PATCH_GETITEM") == "1": + # This change is needed to trace models like PositionalEmbedding from BERT: + # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py + # but causes issues in quantization documented here: + # https://github.com/pytorch/pytorch/issues/50710 + # once that is fixed we can make this the default behavior. + _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) + + +def _find_proxy(*objects_to_search): + """ + Recursively search a data structure for a Proxy() and return it, + return None if not found. + """ + proxy = None + + def find_proxy(x): + nonlocal proxy + if isinstance(x, Proxy): + proxy = x + + map_aggregate(objects_to_search, find_proxy) + return proxy + + +def _create_wrapped_func(orig_fn): + @functools.wraps(orig_fn) + def wrapped(*args, **kwargs): + """ + Given an closed-over ``orig_function`` to invoke, search the args and kwargs for + a Proxy object. If there is one, emit a ``call_function`` node to preserve the + call to this leaf function directly. Otherwise, just return the results of + this function call, as this function is not being traced. + """ + proxy = _find_proxy(args, kwargs) + if proxy is not None: + return_proxy = proxy.tracer.create_proxy( + "call_function", orig_fn, args, kwargs + ) + return_proxy.node.meta["is_wrapped"] = True + return return_proxy + return orig_fn(*args, **kwargs) + + return wrapped + + +def _create_wrapped_method(cls, name): + orig_fn = getattr(cls, name) + + @functools.wraps(orig_fn) + def wrapped(*args, **kwargs): + """ + Search the args and kwargs for a Proxy object. If there is one, + emit a ``call_method`` node to preserve the call to this method + directly. Otherwise, just return the results of this function + call, as this function is not being traced. + """ + proxy = _find_proxy(args, kwargs) + if proxy is not None: + return proxy.tracer.create_proxy("call_method", name, args, kwargs) + return orig_fn(*args, **kwargs) + + return wrapped + + +class _PatchedFn(NamedTuple): + frame_dict: Any + fn_name: str + orig_fn: Any + + def revert(self): + raise NotImplementedError() + + +class _PatchedFnSetItem(_PatchedFn): + def revert(self): + self.frame_dict[self.fn_name] = self.orig_fn + + +class _PatchedFnDel(_PatchedFn): + def revert(self): + del self.frame_dict[self.fn_name] + + +class _PatchedFnSetAttr(_PatchedFn): + def revert(self): + setattr(self.frame_dict, self.fn_name, self.orig_fn) + + +class _Patcher: + def __init__(self): + super().__init__() + self.patches_made: List[_PatchedFn] = [] + self.visited: Set[int] = set() + + def patch( + self, + frame_dict: Dict[str, Any], + name: str, + new_fn: Callable, + deduplicate: bool = True, + ): + """ + Replace frame_dict[name] with new_fn until we exit the context manager. + """ + new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] + if name not in frame_dict and hasattr(builtins, name): + self.patches_made.append(_PatchedFnDel(frame_dict, name, None)) + elif getattr(frame_dict[name], "__fx_already_patched", False): + return # already patched, no need to do it again + else: + self.patches_made.append( + _PatchedFnSetItem(frame_dict, name, frame_dict[name]) + ) + frame_dict[name] = new_fn + + def patch_method( + self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True + ): + """ + Replace object_or_dict.name with new_fn until we exit the context manager. + """ + new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] + orig_fn = getattr(cls, name) + if getattr(orig_fn, "__fx_already_patched", False): + return # already patched, no need to do it again + self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn)) + setattr(cls, name, new_fn) + + def visit_once(self, thing: Any): + """Return True on the first call to with thing, otherwise false""" + idx = id(thing) + if idx in self.visited: + return False + self.visited.add(idx) + return True + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Undo all the changes made via self.patch() and self.patch_method() + """ + while self.patches_made: + # unpatch in reverse order to handle duplicates correctly + self.patches_made.pop().revert() + self.visited.clear() + + +def _patch_wrapped_functions(patcher: _Patcher): + """ + Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap + the listed global functions in the `_create_wrapped_func` wrapper. + """ + for (_, name), frame_dict in _wrapped_fns_to_patch.copy().items(): + if name not in frame_dict and hasattr(builtins, name): + orig_fn = getattr(builtins, name) + else: + orig_fn = frame_dict[name] + patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn)) + + for cls, name in _wrapped_methods_to_patch: + patcher.patch_method(cls, name, _create_wrapped_method(cls, name)) + + +def _autowrap_check( + patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int] +): + """ + Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them. + This method searches a scope for them and patches them if found. + """ + if patcher.visit_once(frame_dict): + for name, value in frame_dict.items(): + if ( + not name.startswith("_") + and callable(value) + and id(value) in function_ids + ): + patcher.patch(frame_dict, name, _create_wrapped_func(value)) + + +@compatibility(is_backward_compatible=True) +def wrap(fn_or_name: Union[str, Callable]): + """ + This function can be called at module-level scope to register fn_or_name as a "leaf function". + A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being + traced through:: + + # foo/bar/baz.py + def my_custom_function(x, y): + return x * x + y * y + + torch.fx.wrap('my_custom_function') + + def fn_to_be_traced(x, y): + # When symbolic tracing, the below call to my_custom_function will be inserted into + # the graph rather than tracing it. + return my_custom_function(x, y) + + This function can also equivalently be used as a decorator:: + + # foo/bar/baz.py + @torch.fx.wrap + def my_custom_function(x, y): + return x * x + y * y + + A wrapped function can be thought of a "leaf function", analogous to the concept of + "leaf modules", that is, they are functions that are left as calls in the FX trace + rather than traced through. + + Args: + + fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the + graph when it's called + """ + if not callable(fn_or_name) and not isinstance(fn_or_name, str): + raise RuntimeError( + "Unsupported type for global function! Must be either a callable or " + "string name" + ) + + if callable(fn_or_name): + assert not isinstance(fn_or_name, str) # to make mypy happy + fn_name = fn_or_name.__name__ + else: + assert isinstance( + fn_or_name, str + ), "fn_or_name must be a global function or string name" + fn_name = fn_or_name + + currentframe = inspect.currentframe() + assert currentframe is not None + f = currentframe.f_back + assert f is not None + if f.f_code.co_name != "": + raise NotImplementedError("wrap must be called at the top level of a module") + + # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search + # semantics would be slightly different, but would add support `from x import wrapped_function` + _wrapped_fns_to_patch[(id(f.f_globals), fn_name)] = f.f_globals + return fn_or_name + + +@compatibility(is_backward_compatible=True) +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, +) -> GraphModule: + """ + Symbolic tracing API + + Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule`` + constructed by recording operations seen while tracing through ``root``. + + ``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures. + + For example:: + + def f(a, b): + if b == True: + return a + else: + return a*2 + + FX can typically not trace through this due to the presence of control + flow. However, we can use `concrete_args` to specialize on the value of + `b` to trace through this:: + + f = fx.symbolic_trace(f, concrete_args={'b': False}) + assert f(3, False) == 6 + + Note that although you can still pass in different values of `b`, they will be ignored. + + We can also use `concrete_args` to eliminate data-structure handling from + our function. This will use pytrees to flatten your input. To avoid + overspecializing, pass in `fx.PH` for values that shouldn't be + specialized. For example:: + + def f(x): + out = 0 + for v in x.values(): + out += v + return out + f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}}) + assert f({'a': 1, 'b': 2, 'c': 4}) == 7 + + + Args: + root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted + into a Graph representation. + concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized + + Returns: + GraphModule: a Module created from the recorded operations from ``root``. + """ + tracer = Tracer() + graph = tracer.trace(root, concrete_args) + name = ( + root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + ) + return _make_graph_module(tracer.root, graph, name) + + +@wrap +def _assert_is_none(value, msg): + assert value is None, msg diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_sym_dispatch_mode.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_sym_dispatch_mode.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9b7eafa88aa351c78387bc6d82fc784444482ed Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_sym_dispatch_mode.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe6f1dbfadd7b155b57d5e477f799360d10f2293 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/_config.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ead0037c9e1855d9b5fb17a31acbf54eae1253cb --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/_config.py @@ -0,0 +1,76 @@ +import os +import sys + +from typing import Optional + +# [@compile_ignored: debug] Uses z3 for validating the guard optimizations transformations. +translation_validation = ( + os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION", "0") == "1" +) +# Timeout (in milliseconds) for z3 finding a solution. +# [@compile_ignored: debug] +translation_validation_timeout = int( + os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION_TIMEOUT", "600000") +) +# Disables bisection for translation validation. +# +# Translation validation bisection is enabled by default, if translation validation +# is also enabled. This should help finding guard simplification issues. However, +# since validation uses Z3 for bisecting, it might take a lot of time. +# +# Set this configuration option so as to avoid bisecting. +# [@compile_ignored: debug] +translation_validation_no_bisect = ( + os.environ.get("TORCHDYNAMO_TRANSLATION_NO_BISECT", "0") == "1" +) +# Checks whether replaying ShapeEnv events on a freshly constructed one yields +# the a ShapeEnv with the same state. This should be used only in testing. +check_shape_env_recorded_events = False + +# TODO: Perhaps consider allowing unions for the configs below (so you can hit +# multiple reps at the same time) + +# Give extended debug information if the string representation of a guard +# matches this. For example, set this to "Ne(s0, 10)" and whenever we issue +# this guard, we will generate full Python and C++ backtrace +# [@compile_ignored: debug] +extended_debug_guard_added = os.environ.get( + "TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED", None +) + +# Give extended debug information when a particular symbol is allocated. For +# example, set this to "u2" and whenever we create this symbol, we will +# generate full Python and C++ backtrace +# [@compile_ignored: debug] +extended_debug_create_symbol = os.environ.get( + "TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL", None +) + +# Give extended debug information (C++ backtrace) for all extended debug +# settings as well as errors. The C++ backtrace is slow and very spammy so we +# don't include it by default even when you're requesting extended debug. +# [@compile_ignored: debug] +extended_debug_cpp = os.environ.get("TORCHDYNAMO_EXTENDED_DEBUG_CPP", "") != "" + +# [@compile_ignored: debug] Show a warning for every specialization +print_specializations = False + +# wraps (un)equalities with 'Not' class after recording the correct expression +# in the FX graph. This should incorrectly construct the divisible and replacement +# lists, and incorrectly issue guards. +inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY = False + +# [@compile_ignored: debug] Validate that ShapeEnv's version key is updated correctly +validate_shape_env_version_key = False + +# If we produce more than this many guards on a symbol, force the symbol to +# get specialized and bail out if this many guards mention this particular +# symbol. This may be slightly more aggressive than the true number of guards +# issued (as we test if we've hit the limit on-the-fly, whereas we may +# do further simplifications at final guard issuance time that make guards +# irrelevant.) +symbol_guard_limit_before_specialize: Optional[int] = None + +from torch.utils._config_module import install_config_module + +install_config_module(sys.modules[__name__]) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/_sym_dispatch_mode.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/_sym_dispatch_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6160ea41c941835a0e1d30d0dc4d1ae4b168ef --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/_sym_dispatch_mode.py @@ -0,0 +1,58 @@ +from typing import List, Optional, Type + +__all__ = ["SymDispatchMode", "handle_sym_dispatch", "sym_function_mode"] + +SYM_FUNCTION_MODE: Optional["SymDispatchMode"] = None + + +# SymDispatchMode gets invoked whenever an operation is processed on +# a PySymInt. When this occurs, you get called at __sym_dispatch__ +# with the operation in question. This is symmetric to TorchDispatchMode +# but with some caveats: +# +# - In TorchDispatchMode, you get the same arguments as what a user +# invoked your API with; e.g., if you call torch.ops.aten.foo(a, b), +# you get (a, b) as args to your call. In SymDispatchMode, if +# you call a + b (where a and b are SymInts), you will get +# (a.node, b.node) as your args (these are PySymInts) +# +# - SymInt/PySymInt don't have FX proxy support (unlike, e.g., Tensor). +# So you have to manually call Tracer/create_node to write into +# the graph. See ProxySymDispatchMode for an example +# +class SymDispatchMode: + def __sym_dispatch__(self, func, types, args, kwargs): + raise NotImplementedError() + + def __enter__(self): + global SYM_FUNCTION_MODE + old = SYM_FUNCTION_MODE + if hasattr(self, "inner"): + raise RuntimeError( + f"{self} has already been used as a mode. Please use a fresh version" + ) + else: + self.inner = old + SYM_FUNCTION_MODE = self + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + global SYM_FUNCTION_MODE + SYM_FUNCTION_MODE = self.inner + + +def handle_sym_dispatch(func, args, kwargs): + global SYM_FUNCTION_MODE + mode = sym_function_mode() + assert mode + SYM_FUNCTION_MODE = mode.inner + try: + # TODO: properly compute types + types: List[Type] = [] + return mode.__sym_dispatch__(func, types, args, kwargs) + finally: + SYM_FUNCTION_MODE = mode + + +def sym_function_mode(): + return SYM_FUNCTION_MODE diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/const_fold.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/const_fold.py new file mode 100644 index 0000000000000000000000000000000000000000..548d1d3852b022d5c589dae53aa3556517ab112b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/const_fold.py @@ -0,0 +1,289 @@ +import re +from typing import Callable, Dict, Optional, Set, Union + +import torch.fx +from torch.fx.node import map_arg +from torch.fx.passes.split_module import split_module + + +__all__ = ['FoldedGraphModule', 'get_unique_attr_name_in_module', 'split_const_subgraphs'] + +class FoldedGraphModule(torch.fx.GraphModule): + """ + FoldedGraphModule is a GraphModule which also contains another + `const_subgraph_module` representing a subgraph which has all const attr + inputs and which can be run once before running the main standard + `graph`. The `const_output_names` are the ordered list names of attrs which + represent what each respective output from the const_subgraph should be set + on which attrs. + """ + + def __init__( + self, + root: torch.nn.Module, + graph: torch.fx.Graph, + const_subgraph: Optional[torch.fx.Graph] = None, + fx_const_folded_attrs_name: Optional[str] = None, + device_for_folded_attrs: str = "cuda", + ): + super().__init__(root, graph) + self.const_subgraph_module = ( + None + if const_subgraph is None + else torch.fx.GraphModule(root, const_subgraph) + ) + self.has_folding_been_run = False + self.fx_const_folded_attrs_name = fx_const_folded_attrs_name + self.device_for_folded_attrs = device_for_folded_attrs + + def __call__(self, *args, **kwargs): + if not self.has_folding_been_run: + self.run_folding() + return super().__call__(*args) + + def run_folding(self): + # If there's no const subgraph module or attr output names to use, return + # early as there is no const folding to perform. + if ( + self.const_subgraph_module is None + or self.fx_const_folded_attrs_name is None + ): + return + + assert not self.has_folding_been_run + self.has_folding_been_run = True + + # Actually run const folding subgraph. Note that single attr const fold + # subgraphs output a single Tensor while multiple outputs are returned as + # Tuple[Tensor,]. + folded_attrs = self.const_subgraph_module() + + def _create_param(i): + return torch.nn.Parameter( + i + if not isinstance(i, int) + else torch.Tensor([i]).to(device=self.device_for_folded_attrs), + requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False, + ) + + params = ( + torch.nn.ParameterList([_create_param(i) for i in folded_attrs]) + if isinstance(folded_attrs, tuple) + else _create_param(folded_attrs) + ) + setattr(self, self.fx_const_folded_attrs_name, params) + + +def _inline_module(gm: torch.fx.GraphModule, inline_mod_name: str): + """ + Given `gm` and some graph module which is called with target name `inline_mod_name`, + this helper will inline all of the nodes from that called graph module into `gm`. + """ + # Fetch the inner graph module that we want to inline inside `gm`. + inline_mod = dict(gm.named_modules())[inline_mod_name] + assert isinstance(inline_mod, torch.fx.GraphModule) + call_mod_node_to_replace = None + for node in gm.graph.nodes: + if node.op == "call_module" and node.target == inline_mod_name: + call_mod_node_to_replace = node + break + assert call_mod_node_to_replace is not None + + # Now actually do the swap. Note that we have to keep track of new nodes that are + # copied into `gm` -- we do this via replacement_mapping. + call_mod_args = call_mod_node_to_replace.args + replacement_mapping: Dict[torch.fx.Node, torch.fx.Node] = {} + ph_count = 0 + + def replacement_fn(node): + new_node = replacement_mapping[node] + new_node.meta = node.meta.copy() + return new_node + + for inline_node in inline_mod.graph.nodes: + if inline_node.op == "placeholder": + replacement_mapping[inline_node] = call_mod_args[ph_count] + ph_count += 1 + continue + + if inline_node.op == "output": + outputs = inline_node.args[0] + output_replacements = map_arg(outputs, replacement_fn) + call_mod_node_to_replace.replace_all_uses_with(output_replacements) + continue + + with gm.graph.inserting_before(call_mod_node_to_replace): + new_node = gm.graph.node_copy(inline_node, replacement_fn) + replacement_mapping[inline_node] = new_node + + gm.graph.eliminate_dead_code() + + +def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) -> str: + """ + Make sure the name is unique (in a module) and can represents an attr. + """ + # Delete all characters that are illegal in a Python identifier. + name = re.sub("[^0-9a-zA-Z_]+", "_", name) + if name[0].isdigit(): + name = f"_{name}" + # Now make sure it is in fact unique to the module by incrementing suffix value. + while hasattr(mod_traced, name): + match = re.match(r"(.*)_(\d+)$", name) + if match is None: + name = name + "_1" + else: + base, num = match.group(1, 2) + name = f"{base}_{int(num) + 1}" + + return name + + +def split_const_subgraphs( + module: Union[torch.nn.Module, torch.fx.GraphModule], + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, + device_for_folded_attrs: str = "cpu", +) -> FoldedGraphModule: + """ + Looks through `module` for any nodes that have all constant attribute inputs + and separates them out into their own constant subgraph, and returns a + FoldedGraphModule which runs that constant subgraph on the first run to set + attributes on the module prior to running the non-constant portion of the + graph. + """ + if not isinstance(module, torch.fx.GraphModule): + mod_traced = torch.fx.symbolic_trace(module) + else: + mod_traced = module + + # Build up a list of const_nodes, defined as nodes that are themselves + # get_attrs, or have all get_attr or other constant node inputs. + const_nodes: Set[torch.fx.Node] = set() + found_const_folding = False + for node in mod_traced.graph.nodes: + # Skip over placeholders/outputs because they can't be const folded and + # we don't want to add tags to them. + if node.op in {"placeholder", "output"}: + continue + + # If the node itself is constant, or all of its inputs are constant, + # then tag it as constant. + if node.op != "get_attr" and not set(node.all_input_nodes).issubset( + const_nodes + ): + continue + + # If provided skip folding function says to skip, then skip. + if skip_folding_node_fn and skip_folding_node_fn(node): + continue + + # Skip folding side-effectful functions + if node.is_impure(): + continue + + # Must be a constant foldable node at this point. + const_nodes.add(node) + if node.op != "get_attr": + found_const_folding = True + + # If we did not find any const folding then return early without a const fold subgraph. + if not found_const_folding: + return FoldedGraphModule(mod_traced, mod_traced.graph) + + # Partition the module into two: submod_0 for constant folding subgraph, and + # submod_1 for the rest. + def mod_partition(node: torch.fx.Node): + return 0 if node in const_nodes else 1 + + split = split_module(mod_traced, module, mod_partition) + + const_gm, non_const_gm = split.submod_0, split.submod_1 + const_mod_name, non_const_mod_name = "submod_0", "submod_1" + + # The module that a call_module node refers to gets copied to submodules during split. + # The path to the module also gets inlined, i.e. mod.a.b -> mod_a_b. Here we need to + # attach inlined modules to `split` as it's the owning module now. + for node in non_const_gm.graph.nodes: + if node.op == "call_module": + setattr(split, node.target, getattr(non_const_gm, node.target)) + for node in const_gm.graph.nodes: + if node.op == "call_module": + setattr(split, node.target, getattr(const_gm, node.target)) + + # split_module currently does not use get_attrs for attrs. Instead it passes + # them in as args from the parent module, which used get_attrs. Here we set + # them as get_attrs inside const_gm, allowing for running folding without + # somehow a priori knowing the attrs that should be passed as args. We can + # unconditionally do this for all placeholders because we know all + # placeholders to const_gm must be constants accessible via get_attr. + call_const_gm_args = None + for node in split.graph.nodes: + if node.op == "call_module": + if node.target == const_mod_name: + call_const_gm_args = node.args + break + assert call_const_gm_args is not None + + # Here we do the actual replacement of placeholders to get_attrs. Note that here we + # set the const_gm.graph into a new root_const_gm with split as the root module, + # because we are fetching attributes directly from the root module, instead of + # fetching them from const_gm. Example: The const_gm must have some format like: + # graph(): + # %inp : [num_users=1] = placeholder[target=const_inp] + # %add : [num_users=1] = call_function[target=operator.add](args = (%inp, %inp), kwargs = {}) + # return add + # We replace that with the following, which does not have any placeholders: + # graph(): + # %inp_1 : [num_users=1] = get_attr[target=const_inp] + # %add : [num_users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {}) + # return add + root_const_gm = torch.fx.GraphModule(split, const_gm.graph) + for node in root_const_gm.graph.nodes: + if node.op == "output": + multiple_outputs = isinstance(node.args[0], tuple) + continue + if node.op != "placeholder": + continue + in_node = next(n for n in call_const_gm_args if n.name == node.target) + assert in_node.op == "get_attr" + with root_const_gm.graph.inserting_before(node): + new_node = root_const_gm.graph.get_attr(in_node.target) + new_node.meta = node.meta.copy() + node.replace_all_uses_with(new_node) + root_const_gm.graph.erase_node(node) + assert "multiple_outputs" in locals() + + # Now find the call to const_gm inside split, and replace it with a getattr to the + # folded tensor(s) that result from constant folding. Note that we don't need to + # worry about whether this is one or more tensors because the original graph + # correctly uses getitem to extract individual tensors if there are multiple folded. + fx_const_folded_attrs_name = get_unique_attr_name_in_module( + split, "_FX_CONST_FOLDED_ATTRS" + ) + setattr( + split, + fx_const_folded_attrs_name, + torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(), # type: ignore[possibly-undefined] + ) + for node in split.graph.nodes: + if node.op == "call_module" and node.target == const_mod_name: + with node.graph.inserting_before(node): + folded_attrs = node.graph.get_attr(fx_const_folded_attrs_name) + folded_attrs.meta = node.meta.copy() + node.replace_all_uses_with(folded_attrs) + break + + split.graph.eliminate_dead_code() + + # Finally, inline the non-constant submod into the split submod. This is so that the + # original caller who may have passed in a graph module will get back out a graph + # module whose graph is traced to the same granularity. + _inline_module(split, non_const_mod_name) + + return FoldedGraphModule( + split, + split.graph, + root_const_gm.graph, + fx_const_folded_attrs_name, + device_for_folded_attrs, + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/graph_gradual_typechecker.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/graph_gradual_typechecker.py new file mode 100644 index 0000000000000000000000000000000000000000..e44a75ddad085a5c00d01b65e4a182d5025bd683 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/graph_gradual_typechecker.py @@ -0,0 +1,914 @@ +from functools import reduce +import torch +import operator +from torch.fx.tensor_type import Dyn, is_consistent, TensorType, is_more_precise +from typing import Callable, Dict +from torch.fx.node import Target, Node +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.conv import Conv2d +from torch.fx.experimental.refinement_types import Equality +import itertools + +from torch.fx.experimental.unification import Var # type: ignore[attr-defined] + +import sympy + +_INFERENCE_RULES: Dict[Target, Callable] = {} +_REFINEMENT_RULES: Dict[Target, Callable] = {} +_RULES: Dict[Target, Callable] = {} + + +def expand_to_tensor_dim(t, n): + """ + Expand a type to the desired tensor dimension if possible + Raise an error otherwise. + - t is the given type + - n is a number of dimensions to expand to + """ + if t == Dyn: + dims = [Dyn] * n + return TensorType(tuple(dims)) + elif isinstance(t, TensorType): + if len(t.__args__) != n: + raise TypeError(f'Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}') + return t + else: + raise TypeError(f'Cannot match the type {t}') + + +def broadcast_types(t1, t2): + """ + Applies broadcasting to both given types such that they + become consistent with eachother and returns two new + resulting types + """ + + # if either type is Dyn, do nothing since the types are already consistent + if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): + return t1, t2 + + if isinstance(t1, TensorType) and isinstance(t2, TensorType): + s1 = len(t1.__args__) + s2 = len(t2.__args__) + + new_t1 = list(t1.__args__) + new_t2 = list(t2.__args__) + + # We make the types the same length which is the first requirement + # for consistency + if s1 > s2: + for i in range(s1 - s2): + new_t2.insert(0, 1) + + elif s2 > s1: + for i in range(s2 - s1): + new_t1.insert(0, 1) + + # we replace occurrences of "1" with each tensor with + # the corresponding type from the other tensor + for i, (x, y) in enumerate(zip(new_t1, new_t2)): + if x == 1: + new_t1[i] = y + elif y == 1: + new_t2[i] = x + + # at this point our tensors should be consistent + # and we can apply the element-wise operation and find the right dimension + # for the output of the operation + (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2)) + return (t1, t2) + else: + raise TypeError(f'Cannot broadcast types {t1} and {t2}') + +def register_inference_rule(call_target): + def register(fn): + if call_target in _INFERENCE_RULES: + raise RuntimeError(f'Inference rule already registered for {call_target}!') + _INFERENCE_RULES[call_target] = fn + return fn + return register + +def register_refinement_rule(call_target): + def register(fn): + if call_target in _REFINEMENT_RULES: + raise RuntimeError(f'Refinement rule already registered for {call_target}!') + _REFINEMENT_RULES[call_target] = fn + return fn + return register + +def register_algebraic_expressions_inference_rule(call_target): + def register(fn): + if call_target in _RULES: + raise RuntimeError(f'Rule already registered for {call_target}!') + _RULES[call_target] = fn + return fn + return register + +@register_inference_rule(torch.add) +@register_inference_rule(operator.add) +def add_inference_rule(n: Node): + """ + Apply the addition inference rule. This includes: + - scalar addition + - broadcasting semantics + + Note that we always return the least precise type between + the operands (after applying broadcasting) to be the final type of the operation + + Note that we do not modify the operand types themselves after applying broadcasting + to them. We only use them to calculate the final type + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + t1 = n.args[0].type + t2 = n.args[1].type + + # handle scalar addition + if t1 == int and isinstance(t2, TensorType): + n.type = t2 + return n.type + + # handle scalar addition + elif t2 == int and isinstance(t1, TensorType): + n.type = t1 + return n.type + + # we bring the new types to the point where + # we can check for consistency + # any inconsistency would not have been caused + # by broadcasting at this point + (new_t1, new_t2) = broadcast_types(t1, t2) + + if new_t1 != t1 or new_t2 != t2: + n.meta['broadcast'] = True + n.meta[str(n.args[0])] = new_t1 + n.meta[str(n.args[1])] = new_t2 + + else: + n.meta['broadcast'] = False + + new_t1 = t1 if not n.meta['broadcast'] else new_t1 + new_t2 = t2 if not n.meta['broadcast'] else new_t2 + + # we check for consistency between the new types + if is_consistent(new_t1, new_t2): + # we return the less precise type because + # broadcasting may have happened + # for operands with shape [1,2,Dyn] and [1,2,1] + # we have to assign the node [1,2,Dyn] + if is_more_precise(new_t1, new_t2): + n.type = new_t2 + else: + n.type = new_t1 + return n.type + else: + raise TypeError(f'Cannot add arguments {n.args[0]} ({ n.args[0].type}) and {n.args[1]} ({ n.args[1].type}) in node {n}.' + f' Types should match ') + +@register_inference_rule(getattr) +def get_attr_inference_rule(n: Node, traced): + """ + The current getattr rule only handles the shape attribute + Can be extended to other attributes + The most representitive type we have is "Dyn" but the system + can be extended with more types, such as a type to represent shapes + """ + attr_node = n.args[0] + attr_name = n.args[1] + + if attr_name == "shape": + n.type = Dyn + else: + raise TypeError("Not yet implemented") + + # TODO. We leave it like this till we add a type to represent tensor sizes + return n.type + +@register_inference_rule(torch.transpose) +def transpose_inference_rule(n: Node): + """ + We check that dimensions for the transpose operations + are within range of the tensor type of the node + """ + if n.target == torch.transpose: + assert isinstance(n.args[0], Node) + t = n.args[0].type + + assert isinstance(n.args[1], int) + assert isinstance(n.args[2], int) + dim1, dim2 = n.args[1], n.args[2] + + if t == Dyn: + n.type = Dyn + return n.type + + elif isinstance(t, TensorType): + if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__): + new_type = list(t.__args__) + new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1] + final = TensorType(new_type) + n.type = get_greatest_upper_bound(n.type, final) + return n.type + else: + raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}') + else: + raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}') + + +@register_inference_rule(torch.reshape) +def reshape_inference_rule(n: Node): + """ + Without dynamism, the rule checks that the + product of the elements of the argument tensor + type is equal to the product of the elements + of the required shape. We gradualize this rule + by adding a case to handle fully dynamic input + as well as input where some of the tensor dimensions + are unknown. In this case we check for divisibility + """ + assert isinstance(n.args[0], Node) + t1 = n.args[0].type + + assert isinstance(n.args[1], list) + t2 = n.args[1] + t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) + + # if we do not know the original tensor dimension, + # we return the required dimension + if t1 == Dyn: + n.type = t2_type + return t2_type + + # if any of the dimensions are unknown, + # we check for divisibility + elif isinstance(t1, TensorType): + assert isinstance(t1, TensorType) + a = [e if e != Dyn else 1 for e in t1.__args__] + p1 = reduce(operator.mul, a) + p2 = reduce(operator.mul, t2) + if p1 % p2 == 0 or p2 % p1 == 0: + n.type = t2_type + return t2_type + else: + raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') + else: + raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') + +@register_inference_rule(BatchNorm2d) +def bn2d_inference_rule(n: Node, module_instance): + """ + Given a BatchNorm2D instance and a node check the following conditions: + - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, x_3, x_4) + - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') + - t is consistent with t' + - x_2 is consistent with the module's num_features + - x_2' is consistent with the module's num_features + output type: the more precise type of t and t' + """ + assert isinstance(n.args[0], Node) + n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4) + arg_type = n.args[0].type + n.type = expand_to_tensor_dim(n.type, 4) + + # we check the conditions on the incoming argument + # and any existing annotation + # we also check for consistency between both annotations + if is_consistent(arg_type.__args__[1], module_instance.num_features) and \ + is_consistent(n.type.__args__[1], module_instance.num_features) and \ + is_consistent(arg_type, n.type): + + # we choose the more precise type + # to be the node type + # so if an incoming argument has more type information + # we set this node's type to be the argument type + n.type = get_greatest_upper_bound(arg_type, n.type) + return n.type + else: + raise TypeError(f'Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}') + + +def calculate_out_dimension(d_in, module_instance, index): + """ + For calculating h_in and w_out according to the conv2D documentation + """ + padding = (module_instance.padding, module_instance.padding) \ + if isinstance(module_instance.padding, int) else module_instance.padding + kernel_size = (module_instance.kernel_size, module_instance.kernel_size) \ + if isinstance(module_instance.kernel_size, int) else module_instance.kernel_size + stride = (module_instance.stride, module_instance.stride) \ + if isinstance(module_instance.stride, int) else module_instance.stride + dilation = (module_instance.dilation, module_instance.dilation) \ + if isinstance(module_instance.dilation, int) else module_instance.dilation + + DIMENSION_TYPES = (int, sympy.Symbol) + + if d_in == Dyn: + return Dyn + + elif isinstance(d_in, DIMENSION_TYPES): + n = d_in + 2 * padding[index] - \ + dilation[index] * \ + (kernel_size[index] - 1) - 1 + + return (n // stride[0]) + 1 + + else: + raise TypeError(f'{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}') + + +def get_greatest_upper_bound(type1, type2): + """ + Get the most precise type that's consistent with the given types + """ + if type1 == Dyn: + return type2 + elif type2 == Dyn: + return type1 + elif isinstance(type1, TensorType) and isinstance(type2, TensorType): + if not is_consistent(type1, type2): + raise TypeError(f'Inconsistent types {type1}, {type2}') + gub = [t1 if is_more_precise(t1, t2) else t2 for (t1, t2) in zip(type1.__args__, type2.__args__)] + return TensorType(tuple(gub)) + + +@register_inference_rule(Conv2d) +def conv2d_inference_rule(n: Node, module_instance): + """ + Given a Conv2D instance and a node check the following conditions: + - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, H, W) + - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') + - x_2 is consistent with the module's in_channels + - let o = (x_1, out_channels, H_out, W_out) + then the output is the greatest upper bound of o and the existing node type t'. + """ + assert isinstance(n.args[0], Node) + n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4) + arg_type = n.args[0].type + curr_node_type = expand_to_tensor_dim(n.type, 4) + + if is_consistent(arg_type.__args__[1], module_instance.in_channels): + w_in = arg_type.__args__[3] + h_in = arg_type.__args__[2] + h_out = calculate_out_dimension(h_in, module_instance, 0) + w_out = calculate_out_dimension(w_in, module_instance, 1) + new_type = TensorType((arg_type.__args__[0], module_instance.out_channels, h_out, w_out)) + gub = get_greatest_upper_bound(new_type, curr_node_type) + n.type = gub + return n.type + else: + raise TypeError(f'Cannot apply {module_instance} with input type { arg_type} and existing type {n.type} on {n}') + + +@register_inference_rule(torch.nn.ReLU) +def relu_inference_rule(n: Node, module_instance): + """ + Input and output shapes should be equal. + """ + assert isinstance(n.args[0], Node) + + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + + if isinstance(n.args[0].type, TensorType): + n.type = get_greatest_upper_bound(n.args[0].type, n.type) + return n.type + + +def maxpool2d_check(typ, module_instance): + """ + Applies the maxpool2d shape information to the input + this affects the last two dimensions + """ + new_type_list = list(typ.__args__) + if len(new_type_list) == 4 or len(new_type_list) == 3: + w_in = new_type_list[-1] + h_in = new_type_list[-2] + + h_out = calculate_out_dimension(h_in, module_instance, 0) + w_out = calculate_out_dimension(w_in, module_instance, 1) + + new_type_list[-1] = w_out + new_type_list[-2] = h_out + return TensorType(tuple(new_type_list)) + + else: + raise TypeError(f'Wrong size {typ} for {module_instance}') + + +@register_inference_rule(torch.nn.MaxPool2d) +def maxpool2d_inference_rule(n: Node, module_instance): + """ + Given a MaxPool2D instance and a node check the following conditions: + - Input size matches size 3 or 4 + - Current node type is consistent with the output type we will calculate + - Input size matches output size and the last two dimensions of the output + are w_out and h_out. The remaining dimensions are the same as the input + - Our final result is the greatest upper bound of the output we calculate + and the current node type. + """ + assert isinstance(n.args[0], Node) + + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + if isinstance(n.args[0].type, TensorType): + output = maxpool2d_check(n.args[0].type, module_instance) + n.type = get_greatest_upper_bound(output, n.type) + return n.type + + + +def linear_check(tensor_type, module_instance): + """ + Checks that an input tensor type satisfies the conditions for linear operation + and returns the output type based on in and out features given by module_instance + """ + if len(tensor_type.__args__) >= 2: + if is_consistent(module_instance.in_features, tensor_type.__args__[-1]): + new_type_args = list(tensor_type.__args__) + new_type_args[-1] = module_instance.out_features + return TensorType(tuple(new_type_args)) + else: + raise TypeError(f'Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}') + else: + raise TypeError(f'Type {tensor_type} must have rank 2 or more.') + + +@register_inference_rule(torch.nn.Linear) +def linear_inference_rule(n: Node, module_instance): + """ + Applies the shape information to the input then gets the greatest upper bound + of the resulting type and the existing type + """ + assert isinstance(n.args[0], Node) + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + if isinstance(n.args[0].type, TensorType): + output_type = linear_check(n.args[0].type, module_instance) + n.type = get_greatest_upper_bound(output_type, n.type) + return n.type + + +def adaptiveavgpool2d_check(tensor_type, module_instance): + output_size = module_instance.output_size + if isinstance(output_size, int): + output_size = [output_size, output_size] + elif isinstance(output_size, tuple): + output_size = list(output_size) + if output_size[0] is None: + output_size[0] = output_size[1] + if output_size[1] is None: + output_size[1] = output_size[0] + + new_type_list = list(tensor_type.__args__) + + if len(tensor_type.__args__) == 4 or len(tensor_type.__args__) == 3: + new_type_list[-1] = output_size[1] + new_type_list[-2] = output_size[0] + + return TensorType(tuple(new_type_list)) + + else: + raise TypeError(f'Tensor ranks must be 3 or 4. Got {tensor_type}') + +@register_inference_rule(torch.nn.AdaptiveAvgPool2d) +def adaptiveavgpool2d_inference_rule(n: Node, module_instance): + """ + The input and output sizes should be the same except for the last + two dimensions taken from the input, which represent width and height + """ + assert isinstance(n.args[0], Node) + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + if isinstance(n.args[0].type, TensorType): + output_type = adaptiveavgpool2d_check(n.args[0].type, module_instance) + n.type = get_greatest_upper_bound(n.type, output_type) + return n.type + +def flatten_check(tensor_type, start_dim, end_dim): + l = len(tensor_type.__args__) + + start_dim = l if start_dim == -1 else abs(start_dim) + end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1 + + if 0 <= start_dim <= (l - 1) and 0 <= end_dim <= l and start_dim < end_dim: + my_args = list(tensor_type.__args__) + lhs = my_args[0:start_dim] + rhs = my_args[end_dim:] + mid = my_args[start_dim:end_dim] + if Dyn in mid: + mid = [Dyn] + else: + mid = [reduce(operator.mul, my_args[start_dim:end_dim])] + new_type_list = lhs + mid + rhs + return TensorType(tuple(new_type_list)) + else: + raise TypeError(f'Incompatible dimensions {start_dim}, {end_dim - 1} in type {tensor_type}') + +@register_inference_rule(torch.flatten) +def flatten_inference_rule(n: Node): + """ + Applies the flatten shape information to the input then gets the + greatest upper bound of the resulting type and the existing type + """ + assert isinstance(n.args[0], Node) + + # set the default start and end dims + start_dim = 1 + end_dim = -1 + + if len(n.args) > 1: + assert isinstance(n.args[1], int) + start_dim = n.args[1] + + if len(n.args) > 2: + assert isinstance(n.args[2], int) + end_dim = n.args[2] + + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + + if isinstance(n.args[0].type, TensorType): + output_type = flatten_check(n.args[0].type, start_dim, end_dim) + n.type = get_greatest_upper_bound(output_type , n.type) + + return n.type + +class GraphTypeChecker: + def __init__(self, env, traced): + self.env = env + self.traced = traced + + def type_check(self): + """ + A gradual type checker for graphs + Effect: every node's field type will be + populated with a type after type-checking is done + """ + graph = self.traced.graph + + # type check every node with gradual type rules + # if any node does not type check return false + for n in graph.nodes: + self.type_check_node(n) + return True + + def type_check_node(self, n: Node): + """ + Type check a given fx node. + Current operations: + - Reshape + - Transpose + - Add + - Relu + - conv2d + - batchnorm2d + - flatten + - maxpool2d + - adaptiveavgpool2d + - linear + """ + if n.type is None: + n.type = Dyn + + if n.op == 'placeholder': + return n.type + + elif n.op == 'get_attr': + t = get_parameter(self.traced, n.target) # type: ignore[arg-type] + if isinstance(t.data, torch.Tensor): + n.type = TensorType(t.data.shape) + return n.type + + elif n.op == 'call_function': + if n.target == getattr: + assert getattr in _INFERENCE_RULES + return _INFERENCE_RULES[n.target](n, self.traced) + + elif n.target in _INFERENCE_RULES: + return _INFERENCE_RULES[n.target](n) + else: + raise RuntimeError(f'No inference rule registered for target {n.target}!') + + elif n.op == 'call_module': + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _INFERENCE_RULES: + return _INFERENCE_RULES[type(module_instance)](n, module_instance) + else: + raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!') + + elif n.op == 'output': + def get_node_type(a): + return a.type + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) + return n.type + + else: + raise NotImplementedError(f"Method {n.op} not yet implemented") + + +@register_refinement_rule(Conv2d) +def conv_refinement_rule(n: Node): + """ + The equality constraints are between the first dimension of + the input and output + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + res = [Equality(arg_type.__args__[0], n.type.__args__[0])] + return res + + +@register_refinement_rule(torch.nn.Linear) +def linear_refinement_rule(n: Node): + """ + The equality constraints are between the first dimension of + the input and output + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + res = [Equality(arg_type.__args__[0], n.type.__args__[0])] + return res + +@register_refinement_rule(BatchNorm2d) +@register_refinement_rule(torch.nn.ReLU) +def all_eq(n: Node): + """ + For operations where the input shape is equal to the output shape + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + args1 = arg_type.__args__ + args2 = n.type.__args__ + res = [Equality(args1[i], args2[i]) for i in range(len(args1))] + return res + + +@register_refinement_rule(torch.nn.AdaptiveAvgPool2d) +@register_refinement_rule(torch.nn.MaxPool2d) +def first_two_eq(n: Node): + """ + For operations where the first two dimensions of the input and output shape + are equal + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + args1 = arg_type.__args__ + args2 = n.type.__args__ + res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])] + return res + + +@register_refinement_rule(torch.add) +@register_refinement_rule(operator.add) +def element_wise_eq(n: Node): + """ + For element-wise operations and handles broadcasting. + Note that after applying broadcasting to the arguments + we are able to determine if certain dimensions have not been broadcast + if they are symbolicallu equal. + + in this case, we can establish equality between those dimensions and the + corresponding output dimensions. + + Note that it takes two iterations for this result. One iteration to establish + equality between certain dimensions of the operands (requiring the whole solver + including unification) and another iteration to establish equality between the operands + and the resulting type, requiring another round of constraint generation and unificaiton. + """ + res = [] + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + arg_type1 = n.args[0].type + arg_type2 = n.args[1].type + if isinstance(arg_type1, TensorType) and isinstance(arg_type2, TensorType) and isinstance(n.type, TensorType): + args1, args2 = broadcast_types(arg_type1, arg_type2) + # by this point, we know that args1 and args2 are the same size. + a1 = args1.__args__ + a2 = args2.__args__ + a3 = n.type.__args__ + + # we would be here in the second iteration where we establish equality + # between operand type dimensions and the resulting type dimensions + r = [] + for x, y, z in zip(a1, a2, a3): + if x == y: + r.append(Equality(x, z)) + res = r + return res + + +@register_refinement_rule(torch.flatten) +def flatten_refinement_rule(n: Node): + """ + Generates equality constraints between the dimensions of the input and output + that will not be involved in the flatten operation + """ + assert isinstance(n.args[0], Node) + + eq_const = [] + + start_dim = 1 + end_dim = -1 + + if len(n.args) > 1: + assert isinstance(n.args[1], int) + start_dim = n.args[1] + + if len(n.args) > 2: + assert isinstance(n.args[2], int) + end_dim = n.args[2] + + if isinstance(n.type, TensorType) and isinstance(n.args[0].type, TensorType): + l = len(n.type.__args__) + arg_type = n.args[0].type + start_dim = l if start_dim == -1 else start_dim + end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1 + + for t1, t2 in zip(n.type.__args__[0:start_dim], arg_type.__args__[0:start_dim]): + eq_const.append(Equality(t1, t2)) + + for t1, t2 in zip(n.type.__args__[end_dim:], arg_type.__args__[end_dim:]): + eq_const.append(Equality(t1, t2)) + return eq_const + + +@register_algebraic_expressions_inference_rule(Conv2d) +def conv_rule(n: Node, module_instance): + """ + Represents the outout in terms of an algrbraic expression w.r.t + the input when possible + """ + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + w_in = arg_type.__args__[3] + h_in = arg_type.__args__[2] + h_out = calculate_out_dimension(h_in, module_instance, 0) + w_out = calculate_out_dimension(w_in, module_instance, 1) + new_type = TensorType((n.type.__args__[0], n.type.__args__[1], h_out, w_out)) + n.type = new_type + return new_type + +class Refine: + """ + Symbolic shape inference. + Generates constraints over type variables. + Currently all constraints are equality constraints. + """ + def __init__(self, traced): + self.constraints = [] + self.traced = traced + self.symbol_iter = itertools.count(start=0, step=1) + + def refine(self): + """ + Generates constraints for + every node in the graph based on + the operation. + """ + graph = self.traced.graph + for n in graph.nodes: + self.refine_node(n) + return True + + def symbolic_relations(self): + """ + Infers algebraic relations + """ + graph = self.traced.graph + for n in graph.nodes: + self.infer_symbolic_relations(n) + return True + + def replace_dyn_with_fresh_var(self, typ): + """ + Replace all unknown types with fresh type variables. + """ + if typ == Dyn: + new_symbol = Var(next(self.symbol_iter)) + return new_symbol + elif isinstance(typ, TensorType): + new_args = [self.replace_dyn_with_fresh_var(a) for a in typ.__args__] + return TensorType(tuple(new_args)) + elif isinstance(typ, list): + return [self.replace_dyn_with_fresh_var(t) for t in typ] + elif isinstance(typ, tuple): + return (self.replace_dyn_with_fresh_var(t) for t in typ) + else: + return typ + + + def convert_to_sympy_symbols(self, typ): + """ + Replace all unknown types with fresh type variables. + """ + if isinstance(typ, Var): + return sympy.symbols(str(typ)) + elif isinstance(typ, TensorType): + new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__] + return TensorType(tuple(new_args)) + elif isinstance(typ, list): + return [self.convert_to_sympy_symbols(t) for t in typ] + elif isinstance(typ, tuple): + return (self.convert_to_sympy_symbols(t) for t in typ) + else: + return typ + + def refine_node(self, n: Node): + """ + Returns a list of equality constraints for + call_module and call_function nodes. + Models the relation between input and output dimensions + using constraints in case they are both tensors. + All operations used in resnet50 are defined. + """ + if n.type is None: + n.type = Dyn + + n.type = self.replace_dyn_with_fresh_var(n.type) + + if n.op == 'call_function': + if n.target in _REFINEMENT_RULES: + self.constraints += _REFINEMENT_RULES[n.target](n) + else: + pass + + if n.op == 'call_module': + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _REFINEMENT_RULES: + self.constraints += _REFINEMENT_RULES[type(module_instance)](n) + else: + pass + + if n.op == 'output': + def get_node_type(a): + return a.type + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) + return n.type + + else: + pass + + def infer_symbolic_relations(self, n: Node): + n.type = self.convert_to_sympy_symbols(n.type) + if n.op == 'call_function': + if n.target in _RULES: + return _RULES[n.target](n) + else: + pass + + if n.op == 'call_module': + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _RULES: + return _RULES[type(module_instance)](n, module_instance) + else: + pass + + if n.op == 'output': + def get_node_type(a): + return a.type + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) + return n.type + + else: + pass + +def get_parameter(traced, target: str): + """ + Returns the parameter given by ``target`` if it exists, + otherwise throws an error. + + See the docstring for ``get_submodule`` for a more detailed + explanation of this method's functionality as well as how to + correctly specify ``target``. + + Args: + target: The fully-qualified string name of the Parameter + to look for. (See ``get_submodule`` for how to specify a + fully-qualified string.) + + Returns: + torch.nn.Parameter: The Parameter referenced by ``target`` + + Raises: + AttributeError: If the target string references an invalid + path or resolves to something that is not an + ``nn.Parameter`` + """ + module_path, _, param_name = target.rpartition(".") + + mod: torch.nn.Module = traced.get_submodule(module_path) + + if not hasattr(mod, param_name): + raise AttributeError(mod._get_name() + " has no attribute `" + param_name + "`") + + param: torch.nn.Parameter = getattr(mod, param_name) + + return param diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/normalize.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..06bc2309975caf6197bbe6ff0c3c4cffeff7ee51 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/normalize.py @@ -0,0 +1,162 @@ +import operator +from typing import Any, Callable, Dict, Tuple, Optional + +import torch +import torch.fx +import torch.fx as fx +from torch.fx import Transformer, Proxy +from torch.fx.node import Argument, Target, Node, map_aggregate +from torch.fx.operator_schemas import ( + normalize_module, + normalize_function, + create_type_hint, +) + +from .schema_type_annotation import AnnotateTypesWithSchema + + +class NormalizeArgs(Transformer): + """ + Normalize arguments to Python targets. This means that + `args/kwargs` will be matched up to the module/functional's + signature and rewritten to exclusively kwargs in positional order + if `normalize_to_only_use_kwargs` is true. Also populates default + values. Does not support positional-only parameters or varargs + parameters (*args, **kwargs). + + If the nodes have 'type' metadata, it will use it to disambiguate + overloads. Otherwise, it will throw an error. + + Example usage: + m = torchvision.models.resnet18() + traced = torch.fx.symbolic_trace(m) + traced = NormalizeArgs(traced).transform() + """ + + def __init__( + self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True + ): + super().__init__(module) + self.node_map: Dict[Proxy, Node] = {} + self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs + + def run_node(self, n: Node) -> Any: + args, kwargs = self.fetch_args_kwargs_from_env(n) + + def get_type(arg): + if isinstance(arg, fx.Node): + return n.meta["type"] if "type" in n.meta else None + return type(arg) + + arg_types = map_aggregate(n.args, get_type) + assert isinstance(arg_types, tuple) + arg_types = tuple([create_type_hint(i) for i in arg_types]) + kwarg_types = {k: get_type(v) for k, v in kwargs.items()} + if n.op == "call_function": + out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types) + else: + out = super().run_node(n) + if n.op != "output": + self.node_map[out] = n + out.node.meta = n.meta + out.node.type = n.type + return out + + def call_function( + self, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Any], + arg_types: Optional[Tuple[Any, ...]] = None, + kwarg_types: Optional[Dict[str, Any]] = None, + ): + assert callable(target) + new_args_and_kwargs = normalize_function( + target, + args, # type: ignore[arg-type] + kwargs, + arg_types, # type: ignore[arg-type] + kwarg_types, + self.normalize_to_only_use_kwargs, + ) + if new_args_and_kwargs: + new_args, new_kwargs = new_args_and_kwargs + return self.tracer.create_proxy( + "call_function", target, new_args, new_kwargs + ) + else: + return super().call_function(target, args, kwargs) + + def call_module( + self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ): + assert isinstance(target, str) + new_args_and_kwargs = normalize_module( + self.module, + target, + args, # type: ignore[arg-type] + kwargs, + self.normalize_to_only_use_kwargs, + ) + if new_args_and_kwargs: + new_args, new_kwargs = new_args_and_kwargs + return super().call_module(target, new_args, new_kwargs) + else: + return super().call_module(target, args, kwargs) + + +class NormalizeOperators(AnnotateTypesWithSchema): + """ + Normalize callsites that are different ways of "spelling" the same + invocation into a single, canonical call. Currently supports: + + 1. Normalize operators (e.g. operator.add) to the `torch` ops they + ultimately invoke (e.g. torch.add) when it is possible to statically + reason that + + Example usage: + + m = torchvision.models.resnet18() + + traced = torch.fx.symbolic_trace(m) + + traced = NormalizeOperators(traced).transform() + """ + + binary_magic_method_remap: Dict[ + Callable[[Any, Any], Any], Callable[[Any, Any], Any] + ] = { + torch.add: operator.add, + torch.mul: operator.mul, + torch.sub: operator.sub, + torch.div: operator.truediv, + torch.floor_divide: operator.floordiv, + torch.remainder: operator.mod, + torch.eq: operator.eq, + torch.ne: operator.ne, + torch.lt: operator.lt, + torch.le: operator.le, + torch.gt: operator.gt, + torch.ge: operator.ge, + } + + def call_function( + self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ): + # Normalize operators according to the magic methods implemented on tensors here: + # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950 + + assert callable(target) + + if target in self.binary_magic_method_remap: + if len(args) != 2: + return super().call_function(target, args, kwargs) + lhs, rhs = args + + return super().call_function( + target=self.binary_magic_method_remap[target], + args=(lhs, rhs), + kwargs={}, + ) + + return super().call_function(target, args, kwargs) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/partitioner_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/partitioner_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d96c6b40667f334870a07ad4aa09d207f95080f4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/partitioner_utils.py @@ -0,0 +1,317 @@ +from enum import Enum +from typing import NamedTuple, Dict, List, Set + +from torch.fx.node import Node, map_arg + + +class Partition: + """Partition class contains all the information about an individual partition. + It also provides necessary methods for manipulation the partition. + """ + + def __init__(self, partition_id: int) -> None: + self.nodes: Set[Node] = set() + self.partition_id = partition_id + self.parents: Set[Partition] = set() + self.children: Set[Partition] = set() + self.bfs_level: int = -1 + self.used_mem_bytes: int = 0 + self.logical_device_ids: List[int] = [] + + def __str__(self): + return str(self.partition_id) + + def recalculate_mem_size(self): + self.used_mem_bytes = 0 + for node in self.nodes: + self.used_mem_bytes += get_extra_size_of(node, self.nodes) + + def add_node(self, node): + input_nodes: Dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # Add current node's input nodes if they are placeholder or constants + for n in input_nodes: + if n.op in {"placeholder", "get_attr"}: + self.nodes.add(n) + self.nodes.add(node) + self.recalculate_mem_size() + + def remove_node(self, node): + # Remove a node only if the node is in the partition + if node in self.nodes: + self.nodes.remove(node) + # Collect the node's input nodes + input_nodes: Dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # Check if an input node is a placeholder or get_attr, + # and this input node is not used by some other nodes in this partition, + # the remove this input node + for input_node in input_nodes: + if all( + n not in self.nodes for n in input_node.users + ) and input_node.op in {"placeholder", "get_attr"}: + self.nodes.remove(input_node) + self.recalculate_mem_size() + + +class Device(NamedTuple): + name: str + available_mem_bytes: int + logical_id: int + + +class NodeLatency(NamedTuple): + # Latency due to the memory bandwidth + mem_latency_sec: float + # Latency due to the computation + computer_latency_sec: float + + +class PartitionLatency(NamedTuple): + # Sum of all nodes' memory latency on the critical path + mem_latency_sec: float + # Sum of all nodes' compute latency on the critical path + computer_latency_sec: float + # Latency of the critical path + overall_latency_sec: float + + +class PartitionMode(Enum): + size_based = 0 + sparse_nn = 1 + cost_aware = 2 + kl_based = 3 + aot_based = 4 + + +class PartitionerConfig(NamedTuple): + devices: List[Device] + mode: PartitionMode = PartitionMode.size_based + transfer_rate_bytes_per_sec: float = 0.0 + node_to_latency_mapping: Dict[Node, NodeLatency] = {} + node_to_partition_mapping: Dict[Node, int] = {} + partition_to_logical_device_mapping: Dict[int, List[int]] = {} + # Saturate host by replicating partitions to the remaining idle devices. + saturate_host: bool = False + + +def get_extra_size_of(node: Node, nodes: Set[Node]) -> int: + """Given a node and a set of nodes, + this function return the extra size that needed + if this node is included in this set. + """ + # Find all its input nodes + input_nodes: Dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # Calculate total size of related nodes + total_size_of_input_nodes = 0 + for n in input_nodes: + # Make sure this node hasn't been in this set yet + if n not in nodes: + size_bytes = getattr(n, "size_bytes", None) + if size_bytes: + total_size_of_input_nodes += size_bytes.output_size + else: + raise RuntimeError("node has no size_bytes attr") + # Don't forget the op node itself + size_bytes = getattr(node, "size_bytes", None) + if size_bytes: + total_size_of_input_nodes += size_bytes.total_size + else: + raise RuntimeError("node has no size_bytes attr") + return total_size_of_input_nodes + + +def get_latency_of_one_partition( + partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency] +) -> PartitionLatency: + """Given a partition and its nodes' latency, return a PartitionLatency for this partition""" + + def get_top_nodes(partition: Partition) -> List[Node]: + """Given a partition, return a list of nodes on the top bfs level""" + top_nodes: List[Node] = [] + for node in partition.nodes: + # Skip placeholder and get_attr nodes + if node.op in {"placeholder", "get_attr"}: + continue + input_nodes: Dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # If a node has no input nodes in this partition, + # or its input nodes in this partition are placeholders and get_attrs + # this node is on the top bfs level in this partition + if not any( + n in partition.nodes and n.op not in {"placeholder", "get_attr"} + for n in input_nodes + ): + top_nodes.append(node) + return top_nodes + + def dfs_helper(node: Node, partition_latency) -> PartitionLatency: + """Given a top node of a partition, this function returns + the latency of the critical path in the partition + """ + node_latency = node_to_latency_mapping[node] + # Calculate the current overall latency of the partition + overall_latency_sec = partition_latency.overall_latency_sec + max( + node_latency.computer_latency_sec, node_latency.mem_latency_sec + ) + # Update the mem latency of this path + mem_latency_sec = ( + partition_latency.mem_latency_sec + node_latency.mem_latency_sec + ) + # Update the compute latency of this path + computer_latency_sec = ( + partition_latency.computer_latency_sec + node_latency.computer_latency_sec + ) + # Get all users of this node that are in this partition + users = set(node.users).intersection(partition.nodes) + if users: + max_latency = PartitionLatency( + mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 + ) + for n in users: + # Get new partition latency recursively + new_partition_latency = dfs_helper( + n, + PartitionLatency( + mem_latency_sec, computer_latency_sec, overall_latency_sec + ), + ) + if ( + new_partition_latency.overall_latency_sec + > max_latency.overall_latency_sec + ): + max_latency = new_partition_latency + return max_latency + # If there is no user, the node is at bottom of the partition + return PartitionLatency( + mem_latency_sec, computer_latency_sec, overall_latency_sec + ) + + # Main part starts + # Get all top level nodes of this partition + top_nodes = get_top_nodes(partition) + critical_path_latency = PartitionLatency( + mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 + ) + # Go through all top nodes and find the largest latency (critical pass latency) + for node in top_nodes: + partition_latency = dfs_helper( + node, + PartitionLatency( + mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 + ), + ) + if ( + partition_latency.overall_latency_sec + > critical_path_latency.overall_latency_sec + ): + critical_path_latency = partition_latency + return critical_path_latency + + +def get_partition_to_latency_mapping( + partitions: List[Partition], node_to_latency_mapping: Dict[Node, NodeLatency] +) -> Dict[Partition, PartitionLatency]: + """Given all the partitions and node_to_latency_mapping dictionary, + return a mapping dictionary of each partition to its overall latency + """ + partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {} + # Go through each partition and get its latency + for partition in partitions: + partition_latency = get_latency_of_one_partition( + partition, node_to_latency_mapping + ) + partition_to_latency_mapping[partition] = partition_latency + return partition_to_latency_mapping + + +def get_comm_latency_between( + parent_partition: Partition, + child_partition: Partition, + transfer_rate_bytes_per_sec: float, +): + """Given two partitions (parent and child), + calculate the communication latency between the two. + """ + # If two partitions are on the same device, the comm latency is 0. + if ( + parent_partition.logical_device_ids != [] + and child_partition.logical_device_ids != [] + and parent_partition.logical_device_ids == child_partition.logical_device_ids + ): + return 0.0 + # Keep tracking the communication size between parent and child + comm_size = 0 + # Keep tracking all the counted node + visited_nodes = set() + # Go through all nodes in the child partition + # If a node has input nodes from the parent partition, + # the output size of those input nodes will be counted + # and added to comm_size + for node in child_partition.nodes: + input_nodes: Dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + for n in input_nodes: + if n in parent_partition.nodes and n not in visited_nodes: + size_bytes = getattr(n, "size_bytes", None) + if size_bytes is not None: + comm_size += size_bytes.output_size + visited_nodes.add(n) + return comm_size / transfer_rate_bytes_per_sec + + +def get_latency_of_partitioned_graph( + partitions: List[Partition], + partition_to_latency_mapping: Dict[Partition, PartitionLatency], + transfer_rate_bytes_per_sec: float, +): + """Given all partitions in a graph, find the critical path among all partitions + and return its latency as the latency of the whole graph + """ + + def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float: + """This function helps to recursively get the latency of a path of partitions""" + # Update latency by adding current partition's latency + latency_so_far_sec += partition_to_latency_mapping[ + partition + ].overall_latency_sec + children = partition.children + if partition.children: + max_latency_sec = 0.0 + for child in partition.children: + # Calculate latency between + comm_latency_sec = get_comm_latency_between( + partition, child, transfer_rate_bytes_per_sec + ) + new_latency_sec = dfs_helper( + child, latency_so_far_sec + comm_latency_sec + ) + if new_latency_sec > max_latency_sec: + max_latency_sec = new_latency_sec + return max_latency_sec + return latency_so_far_sec + + def get_top_partitions(partitions: List[Partition]) -> List[Partition]: + """This function is to return all the partitions without parents + as the starting points of all the paths + """ + top_partitions = [] + for partition in partitions: + # If a partition has no parents, then it is a top partition + if len(partition.parents) == 0: + top_partitions.append(partition) + return top_partitions + + top_partitions = get_top_partitions(partitions) + critical_path_latency_sec = 0.0 + for partition in top_partitions: + latency_sec = dfs_helper(partition, 0.0) + if latency_sec > critical_path_latency_sec: + critical_path_latency_sec = latency_sec + return critical_path_latency_sec diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/schema_type_annotation.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/schema_type_annotation.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a840408618a1cf4b1be4a2be136935f964ba2a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/schema_type_annotation.py @@ -0,0 +1,111 @@ +import torch +import torch.fx +import inspect +from typing import Any, Dict, Optional, Tuple +from torch.fx.node import Argument, Target +from torch._jit_internal import boolean_dispatched +from torch.fx.operator_schemas import _torchscript_type_to_python_type + +from torch.fx import Transformer + +class AnnotateTypesWithSchema(Transformer): + """ + Use Python function signatures to annotate types for `Nodes` within an FX graph. + This pulls out Python function signatures for: + + 1. Standard `torch.nn` Module calls + 2. `torch.nn.functional` calls + 3. Attribute fetches via `get_attr` + + Example usage: + + m = torchvision.models.resnet18() + + traced = torch.fx.symbolic_trace(m) + + traced = AnnotateTypesWithSchema(traced).transform() + + """ + def __init__(self, module : torch.nn.Module, annotate_functionals : bool = True, + annotate_modules : bool = True, annotate_get_attrs : bool = True): + super().__init__(module) + self.annotate_functionals = annotate_functionals + self.annotate_modules = annotate_modules + self.annotate_get_attrs = annotate_get_attrs + + def call_function(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): + python_ret_type = None + if self.annotate_functionals and target.__module__ == 'torch.nn.functional': + target_for_analysis = target + if target in boolean_dispatched: + # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have + # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false` + # branches of the dispatch have exactly the same signature. If they do, use the `true` + # branch signature for analysis. Otherwise, leave this un-normalized + assert not isinstance(target, str) + dispatched = boolean_dispatched[target] + if_true, if_false = dispatched['if_true'], dispatched['if_false'] + # TODO: can we emit the union of these? What are the implications on TorchScript + # compilation? + if inspect.signature(if_true).return_annotation != inspect.signature(if_false).return_annotation: + return super().call_function(target, args, kwargs) + target_for_analysis = if_true + + python_ret_type = self._extract_python_return_type(target_for_analysis) + + return_proxy = super().call_function(target, args, kwargs) + return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type + return return_proxy + + def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): + python_ret_type = None + assert isinstance(target, str) + submod = self.fetch_attr(target) + if self.annotate_modules and hasattr(submod.__class__, '__name__'): + classname = submod.__class__.__name__ + if getattr(torch.nn, classname, None) == submod.__class__: + python_ret_type = self._extract_python_return_type(submod.forward) + return_proxy = super().call_module(target, args, kwargs) + return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type + return return_proxy + + def get_attr(self, target : torch.fx.node.Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): + attr_proxy = super().get_attr(target, args, kwargs) + + if self.annotate_get_attrs: + module_itr = self.module + assert isinstance(target, str) + atoms = target.split('.') + for i, atom in enumerate(atoms): + if not hasattr(module_itr, atom): + raise RuntimeError(f'Node referenced nonextent target {".".join(atoms[:i])}!') + module_itr = getattr(module_itr, atom) + + maybe_inferred_ts_type = torch._C._jit_try_infer_type(module_itr) + if maybe_inferred_ts_type.success(): + python_type = _torchscript_type_to_python_type(maybe_inferred_ts_type.type()) + attr_proxy.node.type = python_type if not attr_proxy.node.type else attr_proxy.node.type + + return attr_proxy + + def _extract_python_return_type(self, target : Target) -> Optional[Any]: + """ + Given a Python call target, try to extract the Python return annotation + if it is available, otherwise return None + + Args: + + target (Callable): Python callable to get return annotation for + + Returns: + + Optional[Any]: Return annotation from the `target`, or None if it was + not available. + """ + assert callable(target) + try: + sig = inspect.signature(target) + except (ValueError, TypeError): + return None + + return sig.return_annotation if sig.return_annotation is not inspect.Signature.empty else None diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/reductions.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/reductions.py new file mode 100644 index 0000000000000000000000000000000000000000..f5eb0a6abd86f2d2036032aec894298862a322cf --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/reductions.py @@ -0,0 +1,594 @@ +import multiprocessing +import os +import threading +from multiprocessing.reduction import ForkingPickler +from multiprocessing.util import register_after_fork +from typing import Union + +import torch +import torch.utils.hooks +from torch._namedtensor_internals import check_serializing_named_tensor + +try: + # Early load resource_sharer to prevent a partially initialized instance + # from being inherited in a forked child process. The reduce_storage method + # requires this module indirectly through DupFd(). The built-in mp.Queue + # class pickles arguments in a background thread which may overlap with the + # fork. + import multiprocessing.resource_sharer +except ImportError: + pass + + +class StorageWeakRef: + r"""A weak reference to a Storage. + + The cdata member is a Python number containing the integer representation of + the Storage pointer. + """ + + __slots__ = ["cdata", "_free_weak_ref"] + + def __init__(self, storage): + self.cdata = storage._weak_ref() + # Save a direct reference to _free_weak_ref because the `torch` module + # might be cleared during Python shutdown before this module is cleared. + self._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined] + + @classmethod + def from_weakref(cls, cdata): + instance = cls.__new__(cls) + instance.cdata = cdata + instance._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined] + return instance + + def expired(self): + return torch.Storage._expired(self.cdata) # type: ignore[attr-defined] + + def __del__(self): + self._free_weak_ref(self.cdata) + + def __hash__(self): + return self.cdata + + def __eq__(self, other): + if id(self) == id(other): + return True + return self.cdata == other.cdata + + +class SharedCache(dict): + """Dictionary from multiprocessing handles to StorageWeakRef.""" + + def __init__(self): + # free_dead_references() is called if the len exceeds the current + # limit. The limit scales with the number of remaining live objects. + self.limit = 128 + # `fork` inherits lock state, so in case we fork when the lock is held, + # we register a function to reset the lock to a new object to avoid + # possible deadlocks, following python multiprocessing library design. + self._after_fork() + register_after_fork(self, SharedCache._after_fork) + + def _after_fork(self): + self.lock = threading.Lock() + + def get(self, key): + with self.lock: + return dict.get(self, key) + + def __setitem__(self, key, storage_ref): + with self.lock: + dict.__setitem__(self, key, storage_ref) + if len(self) > self.limit: + self.free_dead_references() + + def free_dead_references(self): + live = 0 + for key, storage_ref in list(self.items()): + if storage_ref.expired(): + del self[key] + else: + live += 1 + self.limit = max(128, live * 2) + + +# mapping from handles to StorageWeakRef objects +shared_cache = SharedCache() + + +def rebuild_event(device, handle): + return torch.cuda.Event.from_ipc_handle(device, handle) + + +def reduce_event(event): + handle = event.ipc_handle() + return (rebuild_event, (event.device, handle)) + + +def rebuild_tensor(cls, storage, metadata): + storage_offset, size, stride, requires_grad = metadata + t = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) + if cls == torch.nn.parameter.Parameter: + # we have to pass requires_grad into constructor, rather than set it as an + # attribute later, because it's an important check for Integer Tensors to + # have requires_grad=False (or else they raise an error) + t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) + else: + t.requires_grad = requires_grad + return t + + +def rebuild_cuda_tensor( + tensor_cls, + tensor_size, + tensor_stride, + tensor_offset, + storage_cls, + dtype, + storage_device, + storage_handle, + storage_size_bytes, + storage_offset_bytes, + requires_grad, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, +): + # If storage_handle is None, storage points to nullptr. + if storage_handle is None or storage_size_bytes == 0: + storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True) + else: + storage = storage_from_cache( + storage_cls, (storage_handle, storage_offset_bytes) + ) + if storage is None: + torch.cuda._lazy_init() + storage = storage_cls._new_shared_cuda( + storage_device, + storage_handle, + storage_size_bytes, + storage_offset_bytes, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ) + shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef( + storage + ) + else: + # We already ref counting this Storage, but producer needs new ref-counters to be released. + storage_cls._release_ipc_counter( + ref_counter_handle, ref_counter_offset, device=storage_device + ) + + _storage = ( + storage + if isinstance(storage, torch.UntypedStorage) + else storage._untyped_storage + ) + + t = torch._utils._rebuild_tensor( + torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True), + tensor_offset, + tensor_size, + tensor_stride, + ) + + if tensor_cls == torch.nn.parameter.Parameter: + # It is crucial for integer tensors to receive + # the requires_grad=False as an argument in the constructor + t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) + else: + t.requires_grad = requires_grad + + return t + + +def reduce_tensor(tensor): + if tensor.requires_grad and not tensor.is_leaf: + raise RuntimeError( + "Cowardly refusing to serialize non-leaf tensor which requires_grad, " + "since autograd does not support crossing process boundaries. " + "If you just want to transfer the data, call detach() on the tensor " + "before serializing (e.g., putting it on the queue)." + ) + + check_serializing_named_tensor(tensor) + torch.utils.hooks.warn_if_has_hooks(tensor) + + # Note [CUDA IPC and the caching allocator] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # When you send a CUDA tensor over IPC, you might expect that you will + # get out the same storage from the other end. However, the CUDA caching + # allocator makes it difficult to preserve this invariant. Consider + # the following situation: a tensor of size 0x100 points to offset 0x20 of + # a storage at 0xA100 of size 0x100. (For simplicity, all of these + # sizes are given in bytes). HOWEVER, with the caching allocator, this storage + # might be part of a larger cudaMalloc allocation 0xA000 of size 0x4000. + # + # When we want to send this CUDA tensor over IPC, we must send the + # *entire* cudaMalloc allocation, i.e., the 0xA000 region, not just + # the storage 0xA100 (because that is what CUDA supports). So, on the + # other end, there simply isn't any way to say, "Wait, you gave me + # a bigger region (0xA000) than the one I wanted (0xA100)". + # + # OK, so if you sent the cudaMalloc allocation, can you just wrap that up as + # one storage itself? No, because this cudaMalloc allocation might contain + # storages of mixed types: float, bytes, double... If you make the entire + # allocation a single storage of a type A, we'll hit an error when constructing + # a tensor of type B on the storage. + # + # cudaIpcMemHandle is an identifier to access the sender cudaMalloc allocation on the + # receiver side. However, cudaIpcMemHandles from each device in a given process may + # only be opened by one context per device per other process. + # If we open and close a memory handle multiples times in a process, CUDA is allowed + # to give it a different address; similarly, once we close the memory, we're not + # allowed to access it(and the storage/tensor built on top of it), even if it is + # still live in the original process. As we cannot make a cudaMalloc allocation + # to a single storage in one go, this requires us to cache the device pointer for + # each cudaIpcMemHandle on C++ side to reconstruct types of storages, while keep + # the old ones alives. + # See [https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html] + # + # This is fine, because all we need to do is to save our position in the allocation, + # and reconstruct storage and tensor from it. + # 0xA000 -> -------CUDA Allocation------ + # | | + # | | + # | | + # | | + # 0xA100 -> --------storage1 begin------ + # | | + # 0xA120 -> --------tensor1 begin ------ + # | | + # | | + # | | + # | | + # | | + # 0xA160 -> --------tensor1 end--------- + # | | + # | | + # | | + # 0xA200 -> --------storage1 end-------- + # | | + # 0xE000 -> --------CUDA allocation----- + # + # To send tensor1, the following info are required from sender to receiver for + # storage recontruction. + # 1. cudaIpcMemHandle of 0xA000(which can be mapped to a basePtr in receiver process). + # basePtr may not be exactly 0xA000 since it's a different process. + # 2. offset(0xA100) of storage1 in the CUDA allocation. + # 3. size of storage1(0x100). + # + # On receiver side: + # 1. Get the devPtr of the MemHandle to access the memory, reconstruct a storage + # of the same type using (basePtr, offset, size). + # 2. we can reconstruct the tensor on top of the reconstructed storage + # Tensor(size=0x040, offset=0x020, storage=Storage(data=basePtr+0xA100, size=0x0100)) + # + # This strategy has a few implications: + # + # 1. When we serialize a CUDA tensor for IPC, we cannot do it all in one + # go (non-compositionally), and this requires to have a global map + # memHandle -> devPtr for each process. + # + # 2. We MUST NOT let the new IPC tensor be resizable. Originally, a resize + # of the storage beyond 0x100 would merely have caused us to do a + # reallocation. You don't really want to do this, but if you did, + # all that would happen is that you would lose IPC sharing. But if + # you do this in the new world, we will happily let you write out of + # bounds of your "allocation", clobbering unrelated data in the cached + # allocator block. BAD! + # + # By the way, in old versions of PyTorch, we supported this situation + # natively using a "storage view", which permitted multiple storages to be + # views on each other. But this was the *only* use of storage views, so we + # eliminated it so that we could just use tensor views to implement the same + # thing. + # + + # TODO: Handle distinguishing between subclass and non-subclass versions of NT better + # https://github.com/pytorch/pytorch/issues/110543 + from torch.nested._internal.nested_tensor import NestedTensor + + if tensor.is_nested and not isinstance(tensor, NestedTensor): + return reduce_nested_tensor(tensor) + + if tensor.layout in { + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_bsr, + torch.sparse_csc, + torch.sparse_bsc, + }: + return reduce_sparse_tensor(tensor) + + storage = tensor._typed_storage() + + if storage._untyped_storage.device.type == "cuda": + ( + device, + handle, + storage_size_bytes, + storage_offset_bytes, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ) = storage._share_cuda_() + tensor_offset = tensor.storage_offset() + shared_cache[handle] = StorageWeakRef(storage) + # _backward_hooks purposely omitted here, see + # Note [Don't serialize hooks] + return ( + rebuild_cuda_tensor, + ( + type(tensor), + tensor.size(), + tensor.stride(), + tensor_offset, # tensor offset in its storage + type(storage), + tensor.dtype, + device, + handle, # identifier which CUDA allocation is the storage in. + storage_size_bytes, # size(in bytes) of the storage + storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation + tensor.requires_grad, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ), + ) + + # _backward_hooks purposely omitted here, see Note [Don't serialize hooks] + metadata = ( + tensor.storage_offset(), + tensor.size(), + tensor.stride(), + tensor.requires_grad, + ) + return (rebuild_tensor, (type(tensor), storage, metadata)) + + +def rebuild_nested_tensor( + rebuild_buffer_func, + rebuild_buffer_args, + rebuild_sizes_func, + rebuild_sizes_args, + rebuild_strides_func, + rebuild_strides_args, + rebuild_offsets_func, + rebuild_offsets_args, +): + buffer = rebuild_buffer_func(*rebuild_buffer_args) + sizes = rebuild_sizes_func(*rebuild_sizes_args) + strides = rebuild_strides_func(*rebuild_strides_args) + offsets = rebuild_offsets_func(*rebuild_offsets_args) + return torch._nested_view_from_buffer_copy(buffer, sizes, strides, offsets) + + +def reduce_nested_tensor(nt): + rebuild_buffer_func, rebuild_buffer_args = reduce_tensor(nt.values()) + rebuild_sizes_func, rebuild_sizes_args = reduce_tensor(nt._nested_tensor_size()) + rebuild_strides_func, rebuild_strides_args = reduce_tensor( + nt._nested_tensor_strides() + ) + rebuild_offsets_func, rebuild_offsets_args = reduce_tensor( + nt._nested_tensor_storage_offsets() + ) + + return ( + rebuild_nested_tensor, + ( + rebuild_buffer_func, + rebuild_buffer_args, + rebuild_sizes_func, + rebuild_sizes_args, + rebuild_strides_func, + rebuild_strides_args, + rebuild_offsets_func, + rebuild_offsets_args, + ), + ) + + +def rebuild_sparse_coo_tensor( + rebuild_indices_func, + rebuild_indices_args, + rebuild_values_func, + rebuild_values_args, + shape, + is_coalesced, +): + indices = rebuild_indices_func(*rebuild_indices_args) + values = rebuild_values_func(*rebuild_values_args) + return torch.sparse_coo_tensor(indices, values, shape, is_coalesced=is_coalesced) + + +def rebuild_sparse_compressed_tensor( + rebuild_compressed_indices_func, + rebuild_compressed_indices_args, + rebuild_plain_indices_func, + rebuild_plain_indices_args, + rebuild_values_func, + rebuild_values_args, + shape, + layout, +): + compressed_indices = rebuild_compressed_indices_func( + *rebuild_compressed_indices_args + ) + plain_indices = rebuild_plain_indices_func(*rebuild_plain_indices_args) + values = rebuild_values_func(*rebuild_values_args) + return torch.sparse_compressed_tensor( + compressed_indices, plain_indices, values, shape, layout=layout + ) + + +def reduce_sparse_tensor(sparse): + if sparse.layout is torch.sparse_coo: + rebuild_indices_func, rebuild_indices_args = reduce_tensor(sparse._indices()) + rebuild_values_func, rebuild_values_args = reduce_tensor(sparse._values()) + return ( + rebuild_sparse_coo_tensor, + ( + rebuild_indices_func, + rebuild_indices_args, + rebuild_values_func, + rebuild_values_args, + sparse.shape, + sparse.is_coalesced(), + ), + ) + else: + if sparse.layout in {torch.sparse_csr, torch.sparse_bsr}: + compressed_indices = sparse.crow_indices() + plain_indices = sparse.col_indices() + elif sparse.layout in {torch.sparse_csc, torch.sparse_bsc}: + compressed_indices = sparse.ccol_indices() + plain_indices = sparse.row_indices() + else: + raise NotImplementedError(sparse.layout) + ( + rebuild_compressed_indices_func, + rebuild_compressed_indices_args, + ) = reduce_tensor(compressed_indices) + rebuild_plain_indices_func, rebuild_plain_indices_args = reduce_tensor( + plain_indices + ) + rebuild_values_func, rebuild_values_args = reduce_tensor(sparse.values()) + return ( + rebuild_sparse_compressed_tensor, + ( + rebuild_compressed_indices_func, + rebuild_compressed_indices_args, + rebuild_plain_indices_func, + rebuild_plain_indices_args, + rebuild_values_func, + rebuild_values_args, + sparse.shape, + sparse.layout, + ), + ) + + +def fd_id(fd): + # Returns a tuple which uniquely identifies a file descriptor. In Mac OS, + # this doesn't work with shared memory handles, which is why we don't + # support the "file_descriptor" sharing method on that platform. + stat = os.fstat(fd) + return (stat.st_ino, stat.st_dev) + + +def storage_from_cache(cls, key): + storage_ref = shared_cache.get(key) + if storage_ref is None: + return None + return torch.UntypedStorage._new_with_weak_ptr(storage_ref.cdata) + + +def rebuild_storage_fd(cls, df, size): + fd = df.detach() + try: + storage = storage_from_cache(cls, fd_id(fd)) + if storage is not None: + return storage + storage = cls._new_shared_fd_cpu(fd, size) + shared_cache[fd_id(fd)] = StorageWeakRef(storage) + return storage + finally: + os.close(fd) + + +def rebuild_storage_filename(cls, manager, handle, size, dtype=None): + storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache( + cls, handle + ) + if storage is not None: + return storage._shared_decref() + if dtype is None: + storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size) + else: + byte_size = size * torch._utils._element_size(dtype) + untyped_storage: torch.UntypedStorage = ( + torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size) + ) + storage = torch.TypedStorage( + wrap_storage=untyped_storage, dtype=dtype, _internal=True + ) + shared_cache[handle] = StorageWeakRef(storage) + return storage._shared_decref() + + +def rebuild_storage_empty(cls): + return cls() + + +def rebuild_typed_storage(storage, dtype): + return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True) + + +# Use for torch.storage.TypedStorage +def reduce_typed_storage(storage): + return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype)) + + +def rebuild_typed_storage_child(storage, storage_type): + return storage_type(wrap_storage=storage, _internal=True) + + +# Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage +def reduce_typed_storage_child(storage): + return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage))) + + +def reduce_storage(storage): + from . import get_sharing_strategy + + if storage.is_cuda: + raise RuntimeError( + "Cannot pickle CUDA storage; try pickling a CUDA tensor instead" + ) + elif get_sharing_strategy() == "file_system": + metadata = storage._share_filename_cpu_() + cache_key = metadata[1] + rebuild = rebuild_storage_filename + if isinstance(storage, torch.TypedStorage): + metadata += (storage.dtype,) + storage._shared_incref() + elif storage.size() == 0: + # This is special cased because Empty tensors + # (with size 0) cannot be mmapped. + return (rebuild_storage_empty, (type(storage),)) + else: + fd, size = storage._share_fd_cpu_() + df = multiprocessing.reduction.DupFd(fd) + cache_key = fd_id(fd) + metadata = (df, size) + rebuild = rebuild_storage_fd # type: ignore[assignment] + + shared_cache[cache_key] = StorageWeakRef(storage) + return (rebuild, (type(storage),) + metadata) + + +def init_reductions(): + ForkingPickler.register(torch.cuda.Event, reduce_event) + + for t in torch._storage_classes: + if t.__name__ == "UntypedStorage": + ForkingPickler.register(t, reduce_storage) + else: + ForkingPickler.register(t, reduce_typed_storage_child) + + ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage) + + for t in torch._tensor_classes: + ForkingPickler.register(t, reduce_tensor) + + # TODO: Maybe this should be in tensor_classes? :) + ForkingPickler.register(torch.Tensor, reduce_tensor) + ForkingPickler.register(torch.nn.parameter.Parameter, reduce_tensor)